diff --git a/.github/workflows/nix_build.yaml b/.github/workflows/nix_build.yaml index 71ad59d0..e0076af6 100644 --- a/.github/workflows/nix_build.yaml +++ b/.github/workflows/nix_build.yaml @@ -21,7 +21,7 @@ jobs: nix_path: nixpkgs=channel:nixos-unstable - uses: cachix/cachix-action@v14 with: - name: text-generation-inference + name: huggingface # If you chose signing key for write access authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}' env: diff --git a/.github/workflows/nix_cache.yaml b/.github/workflows/nix_cache.yaml index 967a5982..7c73e584 100644 --- a/.github/workflows/nix_cache.yaml +++ b/.github/workflows/nix_cache.yaml @@ -20,7 +20,7 @@ jobs: nix_path: nixpkgs=channel:nixos-unstable - uses: cachix/cachix-action@v14 with: - name: text-generation-inference + name: huggingface # If you chose signing key for write access authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}" env: diff --git a/.github/workflows/nix_tests.yaml b/.github/workflows/nix_tests.yaml index d9b91048..4f68ff60 100644 --- a/.github/workflows/nix_tests.yaml +++ b/.github/workflows/nix_tests.yaml @@ -25,7 +25,7 @@ jobs: nix_path: nixpkgs=channel:nixos-unstable - uses: cachix/cachix-action@v14 with: - name: text-generation-inference + name: huggingface # If you chose signing key for write access authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}' env: diff --git a/Cargo.lock b/Cargo.lock index c757f885..c4b2572f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4650,7 +4650,7 @@ dependencies = [ [[package]] name = "text-generation-backends-trtllm" -version = "3.3.0-dev0" +version = "3.3.2-dev0" dependencies = [ "async-trait", "clap 4.5.32", @@ -4671,7 +4671,7 @@ dependencies = [ [[package]] name = "text-generation-benchmark" -version = "3.3.0-dev0" +version = "3.3.2-dev0" dependencies = [ "average", "clap 4.5.32", @@ -4691,7 +4691,7 @@ dependencies = [ [[package]] name = "text-generation-client" -version = "3.3.0-dev0" +version = "3.3.2-dev0" dependencies = [ "async-trait", "base64 0.22.1", @@ -4709,7 +4709,7 @@ dependencies = [ [[package]] name = "text-generation-launcher" -version = "3.3.0-dev0" +version = "3.3.2-dev0" dependencies = [ "clap 4.5.32", "ctrlc", @@ -4730,7 +4730,7 @@ dependencies = [ [[package]] name = "text-generation-router" -version = "3.3.0-dev0" +version = "3.3.2-dev0" dependencies = [ "anyhow", "async-stream", @@ -4782,7 +4782,7 @@ dependencies = [ [[package]] name = "text-generation-router-llamacpp" -version = "3.3.0-dev0" +version = "3.3.2-dev0" dependencies = [ "async-trait", "bindgen 0.71.1", @@ -4800,7 +4800,7 @@ dependencies = [ [[package]] name = "text-generation-router-v2" -version = "3.3.0-dev0" +version = "3.3.2-dev0" dependencies = [ "async-stream", "async-trait", @@ -4849,7 +4849,7 @@ dependencies = [ [[package]] name = "text-generation-router-v3" -version = "3.3.0-dev0" +version = "3.3.2-dev0" dependencies = [ "async-stream", "async-trait", diff --git a/Cargo.toml b/Cargo.toml index df40d8d5..06dc251b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ default-members = [ resolver = "2" [workspace.package] -version = "3.3.0-dev0" +version = "3.3.2-dev0" edition = "2021" authors = ["Olivier Dehaene"] homepage = "https://github.com/huggingface/text-generation-inference" diff --git a/Dockerfile b/Dockerfile index 03840b97..869596d0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -48,7 +48,7 @@ FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install WORKDIR /usr/src/ # NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099 -ARG PYTORCH_VERSION=2.6 +ARG PYTORCH_VERSION=2.7 ARG PYTHON_VERSION=3.11 # Keep in sync with `server/pyproject.toml @@ -121,13 +121,6 @@ COPY server/Makefile-awq Makefile # Build specific version of transformers RUN . .venv/bin/activate && make build-awq -# Build Lorax Punica kernels -FROM kernel-builder AS lorax-punica-builder -WORKDIR /usr/src -COPY server/Makefile-lorax-punica Makefile -# Build specific version of transformers -RUN . .venv/bin/activate && TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica - # Build Transformers CUDA kernels FROM kernel-builder AS custom-kernels-builder WORKDIR /usr/src @@ -210,8 +203,6 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages # Copy build artifacts from awq kernels builder COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages -# Copy build artifacts from lorax punica kernels builder -COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages # Copy build artifacts from mamba builder COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages diff --git a/Dockerfile.neuron b/Dockerfile.neuron index d22ca222..6228dbb7 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -5,7 +5,7 @@ RUN mkdir -p /tgi # Fetch the optimum-neuron sources directly to avoid relying on pypi deployments FROM alpine AS optimum-neuron RUN mkdir -p /optimum-neuron -ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.1.0.tar.gz /optimum-neuron/sources.tar.gz +ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.2.0.tar.gz /optimum-neuron/sources.tar.gz RUN tar -C /optimum-neuron -xf /optimum-neuron/sources.tar.gz --strip-components=1 # Build cargo components (adapted from TGI original Dockerfile) @@ -108,10 +108,10 @@ RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEU # Install neuronx packages RUN apt-get update -y \ && apt-get install -y --no-install-recommends \ - aws-neuronx-dkms=2.19.64.0 \ - aws-neuronx-collectives=2.23.135.0-3e70920f2 \ - aws-neuronx-runtime-lib=2.23.112.0-9b5179492 \ - aws-neuronx-tools=2.20.204.0 \ + aws-neuronx-dkms=2.20.28.0 \ + aws-neuronx-collectives=2.24.59.0-838c7fc8b \ + aws-neuronx-runtime-lib=2.24.53.0-f239092cc \ + aws-neuronx-tools=2.22.61.0 \ libxml2 \ && rm -rf /var/lib/apt/lists/* \ && apt-get clean @@ -125,11 +125,10 @@ RUN pip3 install \ --index-url https://download.pytorch.org/whl/cpu RUN pip3 install \ - neuronx-cc==2.16.372.0 \ - torch-neuronx==2.5.1.2.4.0 \ - transformers-neuronx==0.13.322 \ - neuronx-distributed==0.10.1 \ - libneuronxla==2.1.681.0 \ + neuronx-cc==2.17.194.0 \ + torch-neuronx==2.5.1.2.6.0 \ + neuronx-distributed==0.11.0 \ + libneuronxla==2.2.1630.0 \ --extra-index-url=https://pip.repos.neuron.amazonaws.com # Install HuggingFace packages @@ -160,7 +159,7 @@ RUN pip install dist/text_generation_server*.tar.gz # Final image FROM neuron -COPY backends/neuron/tgi_env.py /tgi_env.py +COPY backends/neuron/tgi_entry_point.py /tgi_entry_point.py COPY backends/neuron/tgi-entrypoint.sh /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh diff --git a/Dockerfile.nix b/Dockerfile.nix index f1e7e0f5..90390de6 100644 --- a/Dockerfile.nix +++ b/Dockerfile.nix @@ -6,7 +6,7 @@ FROM nixos/nix:2.18.8 AS builder RUN echo "experimental-features = nix-command flakes" >> /etc/nix/nix.conf RUN nix profile install nixpkgs#cachix -RUN cachix use text-generation-inference +RUN cachix use huggingface WORKDIR /root ADD . . RUN nix build . diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi index 06073fe4..02885405 100644 --- a/Dockerfile_gaudi +++ b/Dockerfile_gaudi @@ -1,5 +1,5 @@ # Those arguments are required to build the image -ARG HABANA_VERSION=1.20.0 +ARG HABANA_VERSION=1.21.0 ARG PYTORCH_VERSION=2.6.0 # Rust builder @@ -57,9 +57,12 @@ ARG PYTORCH_VERSION FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytorch-installer-${PYTORCH_VERSION}:latest AS base -ENV ATTENTION=default +ENV ATTENTION=paged ENV PREFIX_CACHING=0 ENV PREFILL_CHUNKING=0 +ENV PT_HPU_LAZY_MODE=1 +ENV PT_HPU_WEIGHT_SHARING=0 +ENV VLLM_EXPONENTIAL_BUCKETING=true # Text Generation Inference base env ENV HF_HOME=/data \ @@ -95,7 +98,9 @@ RUN cd server && \ pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \ BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \ pip install . --no-cache-dir -RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git +RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix +RUN pip install compressed-tensors==0.9.1 + # Install benchmarker COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark # Install router diff --git a/README.md b/README.md index 0d8fedbd..5586e0c7 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ model=HuggingFaceH4/zephyr-7b-beta volume=$PWD/data docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model + ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model ``` And then you can make requests like @@ -121,7 +121,7 @@ curl localhost:8080/v1/chat/completions \ **Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. -**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0-rocm --model-id $model` instead of the command above. +**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2-rocm --model-id $model` instead of the command above. To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli): ``` @@ -152,7 +152,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading token= docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model + ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model ``` ### A note on Shared Memory (shm) @@ -256,7 +256,7 @@ Another option is to install `text-generation-inference` locally using [Nix](htt we only support Nix on x86_64 Linux with CUDA GPUs. When using Nix, all dependencies can be pulled from a binary cache, removing the need to build them locally. -First follow the instructions to [install Cachix and enable the TGI cache](https://app.cachix.org/cache/text-generation-inference). +First follow the instructions to [install Cachix and enable the Hugging Face cache](https://app.cachix.org/cache/huggingface). Setting up the cache is important, otherwise Nix will build many of the dependencies locally, which can take hours. diff --git a/backends/gaudi/Makefile b/backends/gaudi/Makefile index c153a5ff..e135f16e 100644 --- a/backends/gaudi/Makefile +++ b/backends/gaudi/Makefile @@ -2,7 +2,7 @@ mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST))) mkfile_dir := $(dir $(mkfile_path)) root_dir := ${mkfile_dir}/../.. -HABANA_VERSION := 1.20.0 +HABANA_VERSION := 1.21.0 PYTORCH_VERSION := 2.6.0 .PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install @@ -50,6 +50,7 @@ local-dev-install: install-dependencies # In order to run the integration tests, you need to first build the image (make -C backends/gaudi image) run-integration-tests: + pip install -U pip uv uv pip install -r ${root_dir}/backends/gaudi/server/integration-tests/requirements.txt DOCKER_VOLUME=${root_dir}/data \ HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \ @@ -57,6 +58,7 @@ run-integration-tests: # This is used to capture the expected outputs for the integration tests offering an easy way to add more models to the integration tests capture-expected-outputs-for-integration-tests: + pip install -U pip uv DOCKER_VOLUME=${root_dir}/data \ HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \ uv run pytest --durations=0 -sv ${root_dir}/backends/gaudi/server/integration-tests/capture_expected_outputs.py diff --git a/backends/gaudi/server/integration-tests/test_model.py b/backends/gaudi/server/integration-tests/test_model.py index cb2bf6a9..40b27164 100644 --- a/backends/gaudi/server/integration-tests/test_model.py +++ b/backends/gaudi/server/integration-tests/test_model.py @@ -9,8 +9,8 @@ TEST_CONFIGS = { "meta-llama/Llama-3.1-8B-Instruct-shared": { "model_id": "meta-llama/Llama-3.1-8B-Instruct", "input": "What is Deep Learning?", - "expected_greedy_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use", - "expected_batch_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use", + "expected_greedy_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of", + "expected_batch_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of", "args": [ "--sharded", "true", @@ -165,20 +165,6 @@ TEST_CONFIGS = { "4", ], }, - "facebook/opt-125m": { - "model_id": "facebook/opt-125m", - "input": "What is Deep Learning?", - "expected_greedy_output": "\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout", - "expected_batch_output": "\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout", - "args": [ - "--max-input-tokens", - "512", - "--max-total-tokens", - "1024", - "--max-batch-size", - "4", - ], - }, "EleutherAI/gpt-j-6b": { "model_id": "EleutherAI/gpt-j-6b", "input": "What is Deep Learning?", diff --git a/backends/gaudi/server/poetry.lock b/backends/gaudi/server/poetry.lock index b9b2e138..c6cace66 100644 --- a/backends/gaudi/server/poetry.lock +++ b/backends/gaudi/server/poetry.lock @@ -1058,199 +1058,6 @@ files = [ {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, ] -[[package]] -name = "nvidia-cublas-cu12" -version = "12.4.5.8" -description = "CUBLAS native runtime libraries" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3"}, - {file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b"}, - {file = "nvidia_cublas_cu12-12.4.5.8-py3-none-win_amd64.whl", hash = "sha256:5a796786da89203a0657eda402bcdcec6180254a8ac22d72213abc42069522dc"}, -] - -[[package]] -name = "nvidia-cuda-cupti-cu12" -version = "12.4.127" -description = "CUDA profiling tools runtime libs." -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a"}, - {file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb"}, - {file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:5688d203301ab051449a2b1cb6690fbe90d2b372f411521c86018b950f3d7922"}, -] - -[[package]] -name = "nvidia-cuda-nvrtc-cu12" -version = "12.4.127" -description = "NVRTC native runtime libraries" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0eedf14185e04b76aa05b1fea04133e59f465b6f960c0cbf4e37c3cb6b0ea198"}, - {file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338"}, - {file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:a961b2f1d5f17b14867c619ceb99ef6fcec12e46612711bcec78eb05068a60ec"}, -] - -[[package]] -name = "nvidia-cuda-runtime-cu12" -version = "12.4.127" -description = "CUDA Runtime native Libraries" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3"}, - {file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5"}, - {file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:09c2e35f48359752dfa822c09918211844a3d93c100a715d79b59591130c5e1e"}, -] - -[[package]] -name = "nvidia-cudnn-cu12" -version = "9.1.0.70" -description = "cuDNN runtime libraries" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f"}, - {file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a"}, -] - -[package.dependencies] -nvidia-cublas-cu12 = "*" - -[[package]] -name = "nvidia-cufft-cu12" -version = "11.2.1.3" -description = "CUFFT native runtime libraries" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399"}, - {file = "nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9"}, - {file = "nvidia_cufft_cu12-11.2.1.3-py3-none-win_amd64.whl", hash = "sha256:d802f4954291101186078ccbe22fc285a902136f974d369540fd4a5333d1440b"}, -] - -[package.dependencies] -nvidia-nvjitlink-cu12 = "*" - -[[package]] -name = "nvidia-curand-cu12" -version = "10.3.5.147" -description = "CURAND native runtime libraries" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1f173f09e3e3c76ab084aba0de819c49e56614feae5c12f69883f4ae9bb5fad9"}, - {file = "nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b"}, - {file = "nvidia_curand_cu12-10.3.5.147-py3-none-win_amd64.whl", hash = "sha256:f307cc191f96efe9e8f05a87096abc20d08845a841889ef78cb06924437f6771"}, -] - -[[package]] -name = "nvidia-cusolver-cu12" -version = "11.6.1.9" -description = "CUDA solver native runtime libraries" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e"}, - {file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260"}, - {file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-win_amd64.whl", hash = "sha256:e77314c9d7b694fcebc84f58989f3aa4fb4cb442f12ca1a9bde50f5e8f6d1b9c"}, -] - -[package.dependencies] -nvidia-cublas-cu12 = "*" -nvidia-cusparse-cu12 = "*" -nvidia-nvjitlink-cu12 = "*" - -[[package]] -name = "nvidia-cusparse-cu12" -version = "12.3.1.170" -description = "CUSPARSE native runtime libraries" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3"}, - {file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1"}, - {file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-win_amd64.whl", hash = "sha256:9bc90fb087bc7b4c15641521f31c0371e9a612fc2ba12c338d3ae032e6b6797f"}, -] - -[package.dependencies] -nvidia-nvjitlink-cu12 = "*" - -[[package]] -name = "nvidia-cusparselt-cu12" -version = "0.6.2" -description = "NVIDIA cuSPARSELt" -optional = false -python-versions = "*" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:067a7f6d03ea0d4841c85f0c6f1991c5dda98211f6302cb83a4ab234ee95bef8"}, - {file = "nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:df2c24502fd76ebafe7457dbc4716b2fec071aabaed4fb7691a201cde03704d9"}, - {file = "nvidia_cusparselt_cu12-0.6.2-py3-none-win_amd64.whl", hash = "sha256:0057c91d230703924c0422feabe4ce768841f9b4b44d28586b6f6d2eb86fbe70"}, -] - -[[package]] -name = "nvidia-nccl-cu12" -version = "2.21.5" -description = "NVIDIA Collective Communication Library (NCCL) Runtime" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0"}, -] - -[[package]] -name = "nvidia-nvjitlink-cu12" -version = "12.4.127" -description = "Nvidia JIT LTO Library" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"}, - {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"}, - {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"}, -] - -[[package]] -name = "nvidia-nvtx-cu12" -version = "12.4.127" -description = "NVIDIA Tools Extension" -optional = false -python-versions = ">=3" -groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" -files = [ - {file = "nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7959ad635db13edf4fc65c06a6e9f9e55fc2f92596db928d169c0bb031e88ef3"}, - {file = "nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a"}, - {file = "nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485"}, -] - [[package]] name = "opentelemetry-api" version = "1.32.0" @@ -2650,63 +2457,6 @@ files = [ {file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"}, ] -[[package]] -name = "torch" -version = "2.6.0" -description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" -optional = false -python-versions = ">=3.9.0" -groups = ["main"] -files = [ - {file = "torch-2.6.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:6860df13d9911ac158f4c44031609700e1eba07916fff62e21e6ffa0a9e01961"}, - {file = "torch-2.6.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:c4f103a49830ce4c7561ef4434cc7926e5a5fe4e5eb100c19ab36ea1e2b634ab"}, - {file = "torch-2.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:56eeaf2ecac90da5d9e35f7f35eb286da82673ec3c582e310a8d1631a1c02341"}, - {file = "torch-2.6.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:09e06f9949e1a0518c5b09fe95295bc9661f219d9ecb6f9893e5123e10696628"}, - {file = "torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:7979834102cd5b7a43cc64e87f2f3b14bd0e1458f06e9f88ffa386d07c7446e1"}, - {file = "torch-2.6.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:ccbd0320411fe1a3b3fec7b4d3185aa7d0c52adac94480ab024b5c8f74a0bf1d"}, - {file = "torch-2.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:46763dcb051180ce1ed23d1891d9b1598e07d051ce4c9d14307029809c4d64f7"}, - {file = "torch-2.6.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:94fc63b3b4bedd327af588696559f68c264440e2503cc9e6954019473d74ae21"}, - {file = "torch-2.6.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:2bb8987f3bb1ef2675897034402373ddfc8f5ef0e156e2d8cfc47cacafdda4a9"}, - {file = "torch-2.6.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:b789069020c5588c70d5c2158ac0aa23fd24a028f34a8b4fcb8fcb4d7efcf5fb"}, - {file = "torch-2.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:7e1448426d0ba3620408218b50aa6ada88aeae34f7a239ba5431f6c8774b1239"}, - {file = "torch-2.6.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:9a610afe216a85a8b9bc9f8365ed561535c93e804c2a317ef7fabcc5deda0989"}, - {file = "torch-2.6.0-cp313-cp313-manylinux1_x86_64.whl", hash = "sha256:4874a73507a300a5d089ceaff616a569e7bb7c613c56f37f63ec3ffac65259cf"}, - {file = "torch-2.6.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:a0d5e1b9874c1a6c25556840ab8920569a7a4137afa8a63a32cee0bc7d89bd4b"}, - {file = "torch-2.6.0-cp313-cp313-win_amd64.whl", hash = "sha256:510c73251bee9ba02ae1cb6c9d4ee0907b3ce6020e62784e2d7598e0cfa4d6cc"}, - {file = "torch-2.6.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:ff96f4038f8af9f7ec4231710ed4549da1bdebad95923953a25045dcf6fd87e2"}, - {file = "torch-2.6.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:9ea955317cfcd3852b1402b62af258ce735c2edeee42ca9419b6bc889e5ae053"}, - {file = "torch-2.6.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:bb2c6c3e65049f081940f5ab15c9136c7de40d3f01192541c920a07c7c585b7e"}, - {file = "torch-2.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:683410f97984103148e31b38a8631acf31c3034c020c0f4d26171e7626d8317a"}, - {file = "torch-2.6.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:265f70de5fd45b864d924b64be1797f86e76c8e48a02c2a3a6fc7ec247d2226c"}, -] - -[package.dependencies] -filelock = "*" -fsspec = "*" -jinja2 = "*" -networkx = "*" -nvidia-cublas-cu12 = {version = "12.4.5.8", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-cupti-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-nvrtc-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-runtime-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cudnn-cu12 = {version = "9.1.0.70", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cufft-cu12 = {version = "11.2.1.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-curand-cu12 = {version = "10.3.5.147", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cusolver-cu12 = {version = "11.6.1.9", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cusparse-cu12 = {version = "12.3.1.170", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cusparselt-cu12 = {version = "0.6.2", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nccl-cu12 = {version = "2.21.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nvjitlink-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nvtx-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -setuptools = {version = "*", markers = "python_version >= \"3.12\""} -sympy = {version = "1.13.1", markers = "python_version >= \"3.9\""} -triton = {version = "3.2.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -typing-extensions = ">=4.10.0" - -[package.extras] -opt-einsum = ["opt-einsum (>=3.3)"] -optree = ["optree (>=0.13.0)"] - [[package]] name = "tqdm" version = "4.67.1" diff --git a/backends/gaudi/server/pyproject.toml b/backends/gaudi/server/pyproject.toml index 3f2676cb..fa2c2697 100644 --- a/backends/gaudi/server/pyproject.toml +++ b/backends/gaudi/server/pyproject.toml @@ -22,10 +22,9 @@ opentelemetry-instrumentation-grpc = "^0.53b0" hf-transfer = "^0.1.9" sentencepiece = "^0.2.0" peft = "^0.15" -optimum-habana = "1.17" -transformers = "^4.49" +transformers = "^4.52.4" numpy = "^1.26" -accelerate = "^0.33" +accelerate = "^1.7.0" outlines= { version = "^0.0.36", optional = true } prometheus-client = "^0.21.1" py-cpuinfo = "^9.0.0" diff --git a/backends/gaudi/server/requirements.txt b/backends/gaudi/server/requirements.txt index 1a5d767f..e6c9abf2 100644 --- a/backends/gaudi/server/requirements.txt +++ b/backends/gaudi/server/requirements.txt @@ -1,4 +1,4 @@ -accelerate==0.33.0 ; python_version >= "3.9" and python_version < "3.13" +accelerate==1.7.0 ; python_version >= "3.9" and python_version < "3.13" annotated-types==0.7.0 ; python_version >= "3.9" and python_version < "3.13" attrs==25.3.0 ; python_version >= "3.9" and python_version < "3.13" certifi==2025.1.31 ; python_version >= "3.9" and python_version < "3.13" @@ -36,19 +36,6 @@ nest-asyncio==1.6.0 ; python_version >= "3.9" and python_version < "3.13" networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13" numba==0.60.0 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" -nvidia-cublas-cu12==12.4.5.8 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" -nvidia-cuda-cupti-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" -nvidia-cuda-nvrtc-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" -nvidia-cuda-runtime-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" -nvidia-cudnn-cu12==9.1.0.70 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" -nvidia-cufft-cu12==11.2.1.3 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" -nvidia-curand-cu12==10.3.5.147 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" -nvidia-cusolver-cu12==11.6.1.9 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" -nvidia-cusparse-cu12==12.3.1.170 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" -nvidia-cusparselt-cu12==0.6.2 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" -nvidia-nccl-cu12==2.21.5 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" -nvidia-nvjitlink-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" -nvidia-nvtx-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" opentelemetry-api==1.32.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-common==1.32.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-grpc==1.32.0 ; python_version >= "3.9" and python_version < "3.13" @@ -59,7 +46,6 @@ opentelemetry-instrumentation==0.53b0 ; python_version >= "3.9" and python_versi opentelemetry-proto==1.32.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-sdk==1.32.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.53b0 ; python_version >= "3.9" and python_version < "3.13" -optimum-habana==1.17.0 ; python_version >= "3.9" and python_version < "3.13" optimum==1.24.0 ; python_version >= "3.9" and python_version < "3.13" outlines==0.0.36 ; python_version >= "3.9" and python_version < "3.13" packaging==24.2 ; python_version >= "3.9" and python_version < "3.13" @@ -88,9 +74,8 @@ shellingham==1.5.4 ; python_version >= "3.9" and python_version < "3.13" sympy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" threadpoolctl==3.6.0 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.21.1 ; python_version >= "3.9" and python_version < "3.13" -torch==2.6.0 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.67.1 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.49.0 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.52.4 ; python_version >= "3.9" and python_version < "3.13" triton==3.2.0 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64" typer==0.15.2 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.13.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py index 53837ef7..dc31ab2f 100644 --- a/backends/gaudi/server/text_generation_server/cli.py +++ b/backends/gaudi/server/text_generation_server/cli.py @@ -1,6 +1,4 @@ import os -import psutil -import signal import sys import typer @@ -19,6 +17,7 @@ class Quantization(str, Enum): gptq = "gptq" awq = "awq" fp8 = "fp8" + compressed_tensors = "compressed-tensors" class Dtype(str, Enum): @@ -26,6 +25,11 @@ class Dtype(str, Enum): bloat16 = "bfloat16" +class KVCacheDtype(str, Enum): + fp8_e4m3fn = "fp8_e4m3fn" + fp8_e5m2 = "fp8_e5m2" + + @app.command() def serve( model_id: str, @@ -34,6 +38,7 @@ def serve( quantize: Optional[Quantization] = None, speculate: Optional[int] = None, dtype: Optional[Dtype] = None, + kv_cache_dtype: Optional[KVCacheDtype] = None, trust_remote_code: bool = False, uds_path: Path = "/tmp/text-generation-server", logger_level: str = "INFO", @@ -93,7 +98,8 @@ def serve( # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value dtype = "bfloat16" if dtype is None else dtype.value - logger.info(f"quantize={quantize}") + kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value + logger.info(f"quantize={quantize} kv_cache_dtype={kv_cache_dtype}") if dtype is not None and quantize not in { None, "bitsandbytes", @@ -102,83 +108,24 @@ def serve( "gptq", "awq", "fp8", + "compressed-tensors", }: raise RuntimeError( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." ) - - logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype)) - - if sharded and os.getenv("ATTENTION", "default") not in {"paged"}: - tgi_file = Path(__file__).resolve().parent / "tgi_service.py" - num_shard = int(os.getenv("WORLD_SIZE", "1")) - logger.info("CLI SHARDED = {}".format(num_shard)) - import subprocess - - cmd = ( - f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}" - ) - cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}" - cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}" - cmd += f" --quantize {quantize} --max_input_tokens {max_input_tokens}" - if speculate is not None: - cmd += f"--speculate {speculate}" - logger.info("CLI server start deepspeed ={} ".format(cmd)) - sys.stdout.flush() - sys.stderr.flush() - with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc: - do_terminate = False - current_handler = signal.getsignal(signal.SIGTERM) - - def terminate_handler(sig, frame): - nonlocal do_terminate - do_terminate = True - if callable(current_handler): - current_handler(sig, frame) - - signal.signal(signal.SIGTERM, terminate_handler) - - finished = False - while not finished: - try: - if do_terminate: - parent = psutil.Process(proc.pid) - all_procs = parent.children(recursive=True) + [parent] - for p in all_procs: - try: - p.terminate() - except psutil.NoSuchProcess: - pass - _, alive = psutil.wait_procs(all_procs, timeout=30) - for p in alive: - p.kill() - - do_terminate = False - - proc.wait(timeout=3) - except subprocess.TimeoutExpired: - pass - else: - finished = True - - sys.stdout.flush() - sys.stderr.flush() - if proc.returncode != 0: - logger.error(f"{cmd} exited with status = {proc.returncode}") - return proc.returncode - else: - server.serve( - model_id, - lora_adapters, - revision, - sharded, - quantize, - speculate, - dtype, - trust_remote_code, - uds_path, - max_input_tokens, - ) + server.serve( + model_id, + lora_adapters, + revision, + sharded, + quantize, + speculate, + dtype, + kv_cache_dtype, + trust_remote_code, + uds_path, + max_input_tokens, + ) @app.command() diff --git a/backends/gaudi/server/text_generation_server/habana_quantization_env.py b/backends/gaudi/server/text_generation_server/habana_quantization_env.py deleted file mode 100644 index b03b7e26..00000000 --- a/backends/gaudi/server/text_generation_server/habana_quantization_env.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - -import os -import habana_frameworks.torch as htorch - -quant_config = os.getenv("QUANT_CONFIG", "") -is_quantization_enabled = quant_config != "" - -if is_quantization_enabled: - os.environ.setdefault("ENABLE_EXPERIMENTAL_FLAGS", "true") - os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true") - os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false") - os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false") - os.environ.setdefault("UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av") - os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE") - - -def patch_scoped_linear_all_reduce(model): - from deepspeed.module_inject.layers import LinearAllreduce - from optimum.habana.transformers.models.modeling_all_models import ( - ScopedLinearAllReduce, - ) - - for name, module in model.named_children(): - if type(module) is LinearAllreduce: - SL = ScopedLinearAllReduce(mod=module) - setattr(model, name, SL) - patch_scoped_linear_all_reduce(module) - - -def setup_quantization(model): - if is_quantization_enabled: - htorch.core.quantization._mark_params_as_const(model) - htorch.core.quantization._check_params_as_const(model) - htorch.core.hpu_initialize(model) - return model - - -def prepare_model_for_quantization(model): - if is_quantization_enabled: - if model.config.model_type in [ - "llama", - "falcon", - "qwen2", - "starcoder2", - "gemma", - ]: - patch_scoped_linear_all_reduce(model) - from neural_compressor.torch.quantization import FP8Config, convert - - config = FP8Config.from_json_file(quant_config) - model = convert(model, config) - return model diff --git a/backends/gaudi/server/text_generation_server/layers/__init__.py b/backends/gaudi/server/text_generation_server/layers/__init__.py index 0000ca91..fd146728 100644 --- a/backends/gaudi/server/text_generation_server/layers/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/__init__.py @@ -12,6 +12,7 @@ from text_generation_server.layers.speculative import SpeculativeHead # Just to add the `load` methods. from text_generation_server.layers.layernorm import load_layer_norm from text_generation_server.layers.conv import load_conv2d +from text_generation_server.layers.fp8 import Fp8Linear from text_generation_server.layers.lora import ( LoraLinear, @@ -27,6 +28,7 @@ __all__ = [ "TensorParallelEmbedding", "SpeculativeHead", "LoraLinear", + "Fp8Linear", "TensorParallelMultiAdapterLinear", "TensorParallelAdapterRowLinear", "load_layer_norm", diff --git a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py index 89a43d65..aa639832 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py @@ -10,18 +10,23 @@ from .hpu import ( SUPPORTS_WINDOWING, attention, paged_attention, + paged_attention_mla, + set_block_mapping, ) # KVCache needs `reshape_and_cache`, so ensure that it is defined already. -from .kv_cache import KVCache, get_kv_scales +from .kv_cache import KVCache, get_kv_scales, KVCompressCache __all__ = [ "attention", "get_kv_scales", "paged_attention", + "paged_attention_mla", + "set_block_mapping", "SUPPORTS_WINDOWING", "KVCache", + "KVCompressCache", "Seqlen", "HPUPagedAttentionMetadata", "trim_seqlen_metadata", diff --git a/backends/gaudi/server/text_generation_server/layers/attention/common.py b/backends/gaudi/server/text_generation_server/layers/attention/common.py index 9bd738fc..5e03cd44 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/common.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/common.py @@ -90,6 +90,8 @@ class Seqlen: def _async_h2d_tensor_copy(source, device="hpu"): if source is None: return None + if source.device.type == "hpu": + return source assert source.device.type == "cpu", "Source tensor is not present in host memory!" target = torch.empty(source.shape, dtype=source.dtype, device=device) target.copy_(source, non_blocking=True) diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py index 1d73dcb3..f12005d2 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py @@ -7,15 +7,67 @@ from vllm_hpu_extension.utils import Matmul from habana_frameworks.torch.hpex.kernels import FusedSDPA from vllm_hpu_extension.utils import ModuleFusedSDPA import os +from text_generation_server.models.globals import BLOCK_SIZE +import math SUPPORTS_WINDOWING = False -def fetch_from_cache(cache, blocks): - if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true": - return cache[: blocks.size(0)] - else: - return cache.index_select(0, blocks) +class FP8Matmul(torch.nn.Module): + + def __init__(self, scale_other): + super().__init__() + self.scale_input = torch.tensor(1.0, dtype=torch.bfloat16, device="hpu") + self.scale_other = scale_other + + def quant_input(self, x, scale): + return torch.ops.hpu.cast_to_fp8_v2( + x, scale, False, False, torch.float8_e4m3fn + )[0] + + def matmul_fp8( + self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None + ): + return torch.ops.hpu.fp8_gemm_v2( + A=x, + trans_A=False, + B=other, + trans_B=False, + D=None, + out_dtype=out_dtype, + A_scale_inv=scale_input_inv, + B_scale_inv=scale_other_inv, + bias=None, + accumulate=False, + ) + + def forward(self, input, other): + qinput = self.quant_input(input, self.scale_input) + qother = self.quant_input(other, self.scale_other) + output = self.matmul_fp8( + qinput, + qother, + out_dtype=torch.bfloat16, + scale_input_inv=1.0 / self.scale_input, + scale_other_inv=1.0 / self.scale_other, + ) + return output + + +class FetchFromCache(torch.nn.Module): + + def __init__(self, scale_inv): + super().__init__() + self.scale_inv = scale_inv + + def forward(self, cache, blocks): + if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true": + out = cache[: blocks.size(0)] + else: + out = cache.index_select(0, blocks) + if out.dtype == torch.float8_e4m3fn: + out = torch.ops.hpu.cast_from_fp8(out, self.scale_inv, torch.bfloat16) + return out def attention( @@ -55,6 +107,21 @@ def attention( return attn_output +def set_block_mapping(hpu_attention_meta: HPUPagedAttentionMetadata, batch_size): + block_mapping = torch.nn.functional.one_hot( + hpu_attention_meta.block_groups, num_classes=batch_size + ) + dtype = hpu_attention_meta.block_usage.dtype + device = hpu_attention_meta.block_usage.device + mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0) + mask = mask >= hpu_attention_meta.block_usage.unsqueeze(-1) + attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) + hpu_attention_meta = hpu_attention_meta._replace( + attn_bias=attn_bias, block_mapping=block_mapping.to(dtype) + ) + return hpu_attention_meta + + def paged_attention( query: torch.Tensor, kv_cache: KVCache, @@ -67,6 +134,7 @@ def paged_attention( hpu_attention_meta: HPUPagedAttentionMetadata, ): batch_size, head_num, head_size = query.shape + fp8_kv = kv_cache.dtype == torch.float8_e4m3fn output = ops.flat_pa( query=query.view(batch_size, 1, head_num * head_size), key_cache=kv_cache.key, @@ -75,20 +143,59 @@ def paged_attention( block_mapping=hpu_attention_meta.block_mapping, block_bias=hpu_attention_meta.attn_bias, block_groups=hpu_attention_meta.block_groups, + block_size=BLOCK_SIZE, scale=softmax_scale, - matmul_qk_op=Matmul(), - matmul_av_op=Matmul(), + matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(), + matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(), batch2block_matmul_op=Matmul(), block2batch_matmul_op=Matmul(), - keys_fetch_func=fetch_from_cache, - values_fetch_func=fetch_from_cache, + keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu), + values_fetch_func=FetchFromCache(1.0 / kv_scales.value_scale_cpu), ) # Reshape the output tensor. return output.view(batch_size, head_num, head_size) +def paged_attention_mla( + query: torch.Tensor, + kv_cache: KVCache, + kv_head_mapping: torch.Tensor, + softmax_scale: float, + seqlen: Seqlen, + *, + kv_scales: KVScales, + softcap: Optional[float] = None, + hpu_attention_meta: HPUPagedAttentionMetadata, + kv_lora_rank: int = 0, +): + batch_size, head_num, head_size = query.shape + fp8_kv = kv_cache.dtype == torch.float8_e4m3fn + output = ops.flat_pa_mla( + query=query, + key_cache=kv_cache.key, + value_cache=None, + block_list=hpu_attention_meta.block_list, + block_mapping=hpu_attention_meta.block_mapping, + block_bias=hpu_attention_meta.attn_bias, + block_groups=hpu_attention_meta.block_groups, + block_size=BLOCK_SIZE, + scale=softmax_scale, + matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(), + matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(), + batch2block_matmul_op=Matmul(), + block2batch_matmul_op=Matmul(), + keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu), + values_fetch_func=None, + kv_lora_rank=kv_lora_rank, + ) + # Reshape the output tensor. + return output.view(batch_size, head_num, -1) + + __all__ = [ "SUPPORTS_WINDOWING", "attention", "paged_attention", + "paged_attention_mla", + "set_block_mapping", ] diff --git a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py index d238cdb9..723c1ec0 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py @@ -5,7 +5,6 @@ import torch from text_generation_server.models.globals import BLOCK_SIZE from text_generation_server.utils.weights import Weights -from vllm_hpu_extension import cache_ops @dataclass @@ -50,15 +49,17 @@ class KVCache: ): """Construct the key-value cache for a layer.""" ## TODO FP8 kv cache support + if dtype is torch.float8_e5m2: + raise ValueError("torch.float8_e5m2 is not supported in hpu. ") self.kv_cache = ( torch.zeros( - (num_blocks, BLOCK_SIZE, num_heads, head_size), + (num_blocks * BLOCK_SIZE, num_heads, head_size), dtype=dtype, device=device, ), torch.zeros( - (num_blocks, BLOCK_SIZE, num_heads, head_size), + (num_blocks * BLOCK_SIZE, num_heads, head_size), dtype=dtype, device=device, ), @@ -101,24 +102,89 @@ class KVCache: key_cache, value_cache, slots, - kv_scales.key_scale_cpu, - kv_scales.value_scale_cpu, + kv_scales.key_scale, + kv_scales.value_scale, ) +class KVCompressCache(KVCache): + """ + Key-value cache for attention layers. + """ + + kv_cache: torch.Tensor + + def __init__( + self, + *, + num_blocks: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + ): + """Construct the key-value cache for a layer.""" + ## TODO FP8 kv cache support + if dtype is torch.float8_e5m2: + raise ValueError("torch.float8_e5m2 is not supported in hpu. ") + + self.kv_cache = torch.zeros( + (num_blocks * BLOCK_SIZE, 1, head_size), + dtype=dtype, + device=device, + ) + + @property + def dtype(self): + """Get the data type of the cache.""" + return self.kv_cache.dtype + + @property + def key(self): + """Get the key cache.""" + + return self.kv_cache + + @property + def value(self): + """Get the value cache.""" + + return self.kv_cache + + def store( + self, + *, + key: torch.Tensor, + value: torch.Tensor, + slots: torch.Tensor, + kv_scales: KVScales, + ): + """Store the key and value at the given slots.""" + ## TODO FP8 kv cache support + if self.kv_cache.dtype == torch.float8_e4m3fn: + key = torch.ops.hpu.cast_to_fp8_v2( + key, kv_scales.key_scale, False, False, torch.float8_e4m3fn + )[0] + self.kv_cache.index_copy_(0, slots, key) + + def paged_reshape_and_cache( key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, slots: torch.Tensor, - k_scale: float = 1.0, - v_scale: float = 1.0, + k_scale: torch.Tensor, + v_scale: torch.Tensor, ): - block_idx = slots // BLOCK_SIZE - block_offset = slots % BLOCK_SIZE - cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset) - cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset) + if key_cache.dtype == torch.float8_e4m3fn: + key = torch.ops.hpu.cast_to_fp8_v2( + key, k_scale, False, False, torch.float8_e4m3fn + )[0] + value = torch.ops.hpu.cast_to_fp8_v2( + value, v_scale, False, False, torch.float8_e4m3fn + )[0] + key_cache.index_copy_(0, slots, key) + value_cache.index_copy_(0, slots, value) def get_kv_scales(weights: Weights, prefix: str) -> KVScales: diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py new file mode 100644 index 00000000..507af706 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py @@ -0,0 +1,3 @@ +from .loader import CompressedTensorsLoader + +__all__ = ["CompressedTensorsLoader"] diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py new file mode 100644 index 00000000..0dccf34a --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py @@ -0,0 +1,169 @@ +from typing import Any, Dict, List, Union + +from compressed_tensors import QuantizationConfig, QuantizationStatus +from compressed_tensors.config import CompressionFormat +from compressed_tensors.quantization import ( + QuantizationScheme, + QuantizationType, + find_name_or_class_matches, +) +from loguru import logger +from pydantic import ValidationError +from torch import nn + +from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader +from text_generation_server.utils.log import log_once +from text_generation_server.utils.weights import ( + DefaultWeightsLoader, + UnquantizedWeight, + Weights, + WeightsLoader, +) + +# compressed-tensors can match modules as quantization targets. However, +# they need to be objects rather than classes or class names. Since we +# need to match `Linear` targets, make an instance that can be re-used. +_EMPTY_LINEAR: nn.Module = nn.Linear(0, 0) + + +class CompressedTensorsLoader(WeightsLoader): + """Loader for checkpoints stored in the compressed-tensors format.""" + + def __init__(self, config: Dict[str, Any]): + quantization_config_raw = config.get("quantization_config") + if quantization_config_raw is None: + # `compression_config` was renamed to `quantization_config`; support + # retained for backward compatibility. + quantization_config_raw = config.get("compression_config") + if quantization_config_raw is None: + raise ValueError( + "Checkpoint does not have compressed-tensors configuration" + ) + + try: + quantization_config = QuantizationConfig.model_validate( + quantization_config_raw + ) + except ValidationError as e: + raise ValueError("Cannot parse compressed-tensors configuration") from e + + if quantization_config.quantization_status not in ( + QuantizationStatus.COMPRESSED, + QuantizationStatus.FROZEN, + ): + raise ValueError( + f"Model quantization was not finished, status was: {quantization_config.quantization_status}" + ) + + self.ignore = ( + quantization_config.ignore if quantization_config.ignore is not None else [] + ) + self.loaders = self._get_target_loaders(quantization_config) + + for target, loader in self.loaders.items(): + log_once( + logger.info, + f"Using {loader} for compressed-tensors target '{target}'", + ) + + def get_weights(self, weights: Weights, prefix: str): + loader = self._lookup_loader(prefix) + return loader.get_weights(weights, prefix) + + def get_weights_col_packed( + self, + weights: "Weights", + prefix: str, + block_sizes: Union[int, List[int]], + ): + loader = self._lookup_loader(prefix) + return loader.get_weights_col_packed(weights, prefix, block_sizes) + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + loader = self._lookup_loader(prefixes[0]) + return loader.get_multi_weights_col(weights, prefixes, dim) + + def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int): + loader = self._lookup_loader(prefixes[0]) + return loader.get_multi_weights(weights, prefixes, dim) + + def get_weights_row(self, weights: Weights, prefix: str): + loader = self._lookup_loader(prefix) + return loader.get_weights_row(weights, prefix) + + def _get_target_loaders( + self, quantization_config: QuantizationConfig + ) -> Dict[str, WeightsLoader]: + """ + A compressed-tensors checkpoint can use different quantizations + for different targets. This method returns a dictionary with a + loader per target. + """ + + loaders: Dict[str, WeightsLoader] = {} + + format = quantization_config.format + + for group_name, group in quantization_config.config_groups.items(): + # The group configuration can be a string, but does that ever + # happen in a serialized quantization config? + assert isinstance(group, QuantizationScheme) + + loader = self._create_loader_for_group(format, group_name, group) + + # A quantized parameter group can have multiple targets, add the + # loader for all the targets. + for target in group.targets: + if target in loaders: + raise ValueError( + f"Target '{target} has multiple configured loaders'" + ) + loaders[target] = loader + + return loaders + + def _create_loader_for_group( + self, format: str, group_name: str, group: QuantizationScheme + ) -> WeightsLoader: + """ + Find and create a loader for the group with the given quantization + scheme. + """ + # NOTE: we ignore group.output_activations because we don't support + # output quantization yet. + + input_activations = group.input_activations + weights = group.weights + if ( + format + in { + CompressionFormat.float_quantized.value, + CompressionFormat.naive_quantized.value, + } + and weights is not None + and weights.type == QuantizationType.FLOAT + and weights.num_bits == 8 + ): + # FP W8A8 or W8A16. + return W8ANFpLoader(input_activations=input_activations, weights=weights) + else: + raise ValueError( + f"Group '{group_name}' has unsupported compressed-tensors configurtion" + ) + + def _lookup_loader(self, prefix: str) -> WeightsLoader: + """ + Look up the loader to use for a given parameter name (prefix). + """ + + if len(find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.ignore)) > 0: + return DefaultWeightsLoader(UnquantizedWeight) + + # We currently only handle linear layers, so unconditionally pass + # a `Linear` instance. + targets = find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.loaders.keys()) + if len(targets) == 0: + raise ValueError( + f"Cannot find compressed-tensors target for prefix: {prefix}" + ) + return self.loaders[targets[0]] diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py new file mode 100644 index 00000000..6eb00387 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py @@ -0,0 +1,253 @@ +from typing import List, Optional, Union + +import torch +from compressed_tensors.quantization import QuantizationArgs, QuantizationType + +from text_generation_server.layers.fp8 import ( + Fp8Weight, + _load_scalar_or_matrix_scale, + requantize_with_max_scale, +) +from text_generation_server.utils.weights import Weights, WeightsLoader + + +class W8ANFpLoader(WeightsLoader): + """ + Loader for W8A8/W8A16 FP compressed-tensors parameters. + """ + + def __init__( + self, + *, + input_activations: Optional[QuantizationArgs], + weights: QuantizationArgs, + ): + assert weights.type == QuantizationType.FLOAT and weights.num_bits == 8 + + # We ignore the `strategy` option which sets the scales to be + # per-tensor, per-channel or per-token. What scales are supported + # is dependent on the kernels used (e.g. cutlass can do tokenwise, + # Torch cannot, and FP8-Marlin does not quantize inputs at all). + # So, instead we try to use the best-possible configuration. + + self.load_weight_scale = not weights.dynamic + self.load_input_scale = ( + input_activations is not None and not input_activations.dynamic + ) + self.force_w8a16 = ( + input_activations is not None and input_activations.num_bits == 16 + ) + + def __str__(self) -> str: + def scale_to_str(scale): + return "static" if scale else "dynamic" + + quantization_type = f"W8A{16 if self.force_w8a16 else 8}" + + return f"{self.__class__.__name__} ({quantization_type}, weight: {scale_to_str(self.load_weight_scale)}, input: {scale_to_str(self.load_input_scale)})" + + def get_weights(self, weights: "Weights", prefix: str): + w = weights.get_tensor(f"{prefix}.weight") + + weight_scale = None + if self.load_weight_scale: + weight_scale = ( + weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + .reshape(-1) + .expand(w.shape[0]) + ) + logical_widths = [w.shape[0]] + w, weight_scale = requantize_with_max_scale( + w, + weight_scale.unsqueeze(-1).to(weights.device), + logical_widths, + weights.dtype, + ) + + input_scale = None + if self.load_input_scale: + input_scale = weights.get_tensor( + f"{prefix}.input_scale", to_dtype=False + ).reshape(-1) + + return Fp8Weight( + weight=w, + weight_scale=weight_scale, + input_scale=input_scale, + dtype=weights.dtype, + force_w8a16=self.force_w8a16, + ) + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + w = weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes + ) + + weight_scale = None + if self.load_weight_scale: + weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + if weight_scale.numel() > 1: + weight_scale = weights.get_packed_sharded( + f"{prefix}.weight_scale", + dim=0, + block_sizes=block_sizes, + to_dtype=False, + ) + weight_scale = weight_scale.reshape(-1).expand(w.shape[0]) + logical_widths = [w.shape[0]] + w, weight_scale = requantize_with_max_scale( + w, + weight_scale.unsqueeze(-1).to(weights.device), + logical_widths, + weights.dtype, + ) + + input_scale = None + if self.load_input_scale: + input_scale = weights.get_tensor(f"{prefix}.input_scale", to_dtype=False) + if input_scale.numel() > 1: + input_scale = weights.get_packed_sharded( + f"{prefix}.input_scale", + dim=0, + block_sizes=block_sizes, + to_dtype=False, + ) + input_scale = input_scale.reshape(-1).max() + + return Fp8Weight( + weight=w, + weight_scale=weight_scale, + input_scale=input_scale, + dtype=weights.dtype, + force_w8a16=self.force_w8a16, + ) + + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): + # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet + w = [ + weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes + ] + shapes = [x.shape for x in w] + + # Concat then send to the device + w = torch.cat(w, dim=dim).to(weights.device) + + weight_scale = None + if self.load_weight_scale: + weight_scale = [ + _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) + for p, shape in zip(prefixes, shapes) + ] + weight_scale = torch.cat(weight_scale, dim=0).reshape(-1) + logical_widths = [x[0] for x in shapes] + w, weight_scale = requantize_with_max_scale( + w, + weight_scale.unsqueeze(-1).to(weights.device), + logical_widths, + weights.dtype, + ) + + input_scale = None + if self.load_input_scale: + input_scale = [ + _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape) + for p, shape in zip(prefixes, shapes) + if weights.has_tensor(f"{p}.input_scale") + ] + assert len(input_scale) == 0 or len(input_scale) == len(prefixes) + input_scale = ( + torch.cat(input_scale, dim=0).reshape(-1).max() + if len(input_scale) != 0 + else None + ) + + return Fp8Weight( + weight=w, + weight_scale=weight_scale, + input_scale=input_scale, + dtype=weights.dtype, + force_w8a16=self.force_w8a16, + ) + + def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int): + # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet + w = [weights.get_tensor(f"{p}.weight", to_device=False) for p in prefixes] + shapes = [x.shape for x in w] + + # Concat then send to the device + w = torch.cat(w, dim=dim).to(weights.device) + + weight_scale = None + + if self.load_weight_scale: + weight_scale = [ + weights.get_tensor(f"{p}.weight_scale", to_dtype=False) + .reshape(-1) + .expand(shape[0]) + for p, shape in zip(prefixes, shapes) + ] + weight_scale = torch.cat(weight_scale, dim=0).reshape(-1) + logical_widths = [x[0] for x in shapes] + w, weight_scale = requantize_with_max_scale( + w, + weight_scale.unsqueeze(-1).to(weights.device), + logical_widths, + weights.dtype, + ) + + input_scale = None + if self.load_input_scale: + input_scale = [ + weights.get_tensor(f"{p}.input_scale", to_dtype=False) + .reshape(-1) + .expand(shape[0]) + for p, shape in zip(prefixes, shapes) + if weights.has_tensor(f"{p}.input_scale") + ] + assert len(input_scale) == 0 or len(input_scale) == len(prefixes) + input_scale = ( + torch.cat(input_scale, dim=0).reshape(-1).max() + if len(input_scale) != 0 + else None + ) + + return Fp8Weight( + weight=w, + weight_scale=weight_scale, + input_scale=input_scale, + dtype=weights.dtype, + force_w8a16=self.force_w8a16, + ) + + def get_weights_row(self, weights: "Weights", prefix: str): + w = weights.get_sharded(f"{prefix}.weight", dim=1) + weight_scale = None + if self.load_weight_scale: + weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + weight_scale = weight_scale.reshape(-1).expand(w.shape[0]) + logical_widths = [w.shape[0]] + w, weight_scale = requantize_with_max_scale( + w, + weight_scale.unsqueeze(-1).to(weights.device), + logical_widths, + weights.dtype, + ) + + input_scale = None + if self.load_input_scale: + input_scale = weights.get_tensor( + f"{prefix}.input_scale", to_dtype=False + ).reshape(-1) + + return Fp8Weight( + weight=w, + weight_scale=weight_scale, + input_scale=input_scale, + dtype=weights.dtype, + force_w8a16=self.force_w8a16, + ) diff --git a/backends/gaudi/server/text_generation_server/layers/fp8.py b/backends/gaudi/server/text_generation_server/layers/fp8.py index 0dc5cdaf..8de335ac 100644 --- a/backends/gaudi/server/text_generation_server/layers/fp8.py +++ b/backends/gaudi/server/text_generation_server/layers/fp8.py @@ -12,11 +12,151 @@ from text_generation_server.utils.weights import ( from vllm_hpu_extension.ops import scaled_fp8_quant from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2 -import habana_frameworks.torch.utils.experimental as htexp -w8a8_block_fp8_matmul = None -per_token_group_quant_fp8 = None quant_dtype: torch.dtype = torch.float8_e4m3fn +FP8_MAX = torch.finfo(torch.float8_e4m3fn).max +if is_hpu_gaudi2(): + FP8_MAX = torch.finfo(torch.float8_e4m3fnuz).max + + +def pad_weight(weight, block_size): + """Pads a matrix to make its dimensions multiples of block_size.""" + M, N = weight.shape[-2:] + block_size_m, block_size_n = block_size + pad_M = (block_size_m - M % block_size_m) % block_size_m + pad_N = (block_size_n - N % block_size_n) % block_size_n + + if pad_M == 0 and pad_N == 0: + return weight, M, N # No padding needed + padded_weight = torch.nn.functional.pad( + weight, (0, pad_N, 0, pad_M), mode="constant", value=0 + ) + return padded_weight, M, N # Return original dimensions for unpadding + + +def unpad_weight(weight, original_M, original_N, keep_first_dim=False): + """Removes padding from the matrix to restore its original shape.""" + if (weight.shape[-2] == original_M) and (weight.shape[-1] == original_N): + return weight + if keep_first_dim: + return weight[:, :original_M, :original_N] + else: + return weight[:original_M, :original_N] + + +def pad_block_fp8_weight_naive(weight, weight_scale, block_size): + + assert len(block_size) == 2 + + block_size_m, block_size_n = block_size + weight_scale_m, weight_scale_n = weight_scale.shape[-2:] + + weight, orig_M, orig_N = pad_weight(weight, block_size) + M, N = weight.shape[-2:] + + assert weight_scale_m == M // block_size_m + assert weight_scale_n == N // block_size_n + + return weight, orig_M, orig_N + + +def dynamic_quant(data, single_scale=False): + if single_scale: + scale = ((torch.abs(data)).max() + 1e-8) / FP8_MAX + else: + scale = ((torch.abs(data)).max(dim=-1).values + 1e-8) / FP8_MAX + scale = scale.unsqueeze(-1) + data_fp8 = torch.ops.hpu.cast_to_fp8_v2( + data, 1.0 / scale, False, False, torch.float8_e4m3fn + )[0] + return data_fp8, scale.float() + + +def dequant_block_fp8_weight_naive( + weight, + weight_scale, + block_size, + dtype=torch.bfloat16, + original_M=None, + original_N=None, + do_unpad=False, +): + if weight_scale is None: + return weight + assert len(block_size) == 2 + + weight_shape_len = len(weight.shape) + + block_size_m, block_size_n = block_size + + # mul scale + if weight_shape_len == 2: + weight_scale_m, weight_scale_n = weight_scale.shape + weight_scale = weight_scale.view(weight_scale_m, 1, weight_scale_n, 1) + weight = weight.view(weight_scale_m, block_size_m, weight_scale_n, block_size_n) + if is_hpu_gaudi2(): + fake_weight = weight.cpu().to(dtype).to(weight.device) + dequant_weight = fake_weight * weight_scale.to(dtype) + else: + dequant_weight = weight.to(dtype) * weight_scale.to(dtype) + dequant_weight = dequant_weight.view( + weight_scale_m * block_size_m, weight_scale_n * block_size_n + ) + keep_first_dim = False + elif weight_shape_len == 3: + fd, weight_scale_m, weight_scale_n = weight_scale.shape + weight_scale = weight_scale.view(fd, weight_scale_m, 1, weight_scale_n, 1) + weight = weight.view( + fd, weight_scale_m, block_size_m, weight_scale_n, block_size_n + ) + if is_hpu_gaudi2(): + fake_weight = weight.cpu().to(dtype).to(weight.device) + dequant_weight = fake_weight * weight_scale.to(dtype) + else: + dequant_weight = weight.to(dtype) * weight_scale.to(dtype) + dequant_weight = dequant_weight.view( + fd, weight_scale_m * block_size_m, weight_scale_n * block_size_n + ) + keep_first_dim = True + else: + raise ValueError("Only support original weight shape is either 2 or 3") + + if do_unpad: + dequant_weight = unpad_weight( + dequant_weight, original_M, original_N, keep_first_dim=keep_first_dim + ) + + return dequant_weight + + +def apply_block_fp8_linear_hpu_dynamic( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[0]] + + x_fp8, x_scale = dynamic_quant(input_2d) + + output = torch.ops.hpu.fp8_gemm_v2( + x_fp8, + False, + weight, + True, + None, + torch.bfloat16, + x_scale, + weight_scale, + None, + False, + ) + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]: @@ -42,7 +182,7 @@ def per_tensor_dequantize( ) -> torch.Tensor: device = tensor.device dtype = torch.bfloat16 - if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: + if is_hpu_gaudi2(): # dequant on cpu to avoid nan on gaudi2 tensor = tensor.to("cpu") @@ -67,7 +207,7 @@ def requantize_with_max_scale( for idx, logical_width in enumerate(logical_widths): end = start + logical_width weight_dq = per_tensor_dequantize( - weight[start:end, :], weight_scale[idx], dtype + weight[start:end, :], weight_scale[start:end, :], dtype ) weight[start:end, :], max_w_scale_normalized = fp8_quantize( weight_dq, max_w_scale @@ -130,6 +270,11 @@ class HybridFP8UnquantLoader(WeightsLoader): ) # FP8 branch scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + scale = scale.reshape(-1).expand(w.shape[0]) + logical_widths = [w.shape[0]] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype + ) input_scale = None if weights.has_tensor(f"{prefix}.input_scale"): @@ -138,10 +283,6 @@ class HybridFP8UnquantLoader(WeightsLoader): .reshape(-1) .max() ) - logical_widths = [w.shape[0]] - w, scale = requantize_with_max_scale( - w, scale.unsqueeze(0), logical_widths, weights.dtype - ) return Fp8Weight( weight=w, @@ -176,6 +317,11 @@ class HybridFP8UnquantLoader(WeightsLoader): block_sizes=block_sizes, to_dtype=False, ) + scale = scale.reshape(-1).expand(w.shape[0]) + logical_widths = [w.shape[0]] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype + ) input_scale = None if weights.has_tensor(f"{prefix}.input_scale"): @@ -190,10 +336,6 @@ class HybridFP8UnquantLoader(WeightsLoader): to_dtype=False, ) input_scale = input_scale.reshape(-1).max() - logical_widths = [w.shape[0]] - w, scale = requantize_with_max_scale( - w, scale.unsqueeze(0), logical_widths, weights.dtype - ) return Fp8Weight( weight=w, @@ -240,6 +382,11 @@ class HybridFP8UnquantLoader(WeightsLoader): ] scale = torch.cat(scale, dim=0).reshape(-1) + logical_widths = [x[0] for x in shapes] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype + ) + input_scale = [ _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape) for p, shape in zip(prefixes, shapes) @@ -252,9 +399,66 @@ class HybridFP8UnquantLoader(WeightsLoader): else None ) + return Fp8Weight( + weight=w, + weight_scale=scale, + input_scale=input_scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int): + # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet + w = [weights.get_tensor(f"{p}.weight", to_device=False) for p in prefixes] + shapes = [x.shape for x in w] + + # Concat then send to the device + w = torch.cat(w, dim=dim).to(weights.device) + + # FP8 branch + if w.dtype == torch.float8_e4m3fn: + if self.weight_block_size is not None: + scale = [ + weights.get_tensor(f"{p}.weight_scale_inv", to_device=False) + for p in prefixes + ] + scale = torch.cat(scale, dim=dim) + scale = scale.to(weights.device) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + weight_block_size=self.weight_block_size, + ) + + scale = [ + weights.get_tensor(f"{p}.weight_scale", to_dtype=False) + .reshape(-1) + .expand(shape[0]) + for p, shape in zip(prefixes, shapes) + ] + scale = torch.cat(scale, dim=0).reshape(-1) + logical_widths = [x[0] for x in shapes] w, scale = requantize_with_max_scale( - w, scale.to(weights.device), logical_widths, weights.dtype + w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype + ) + + input_scale = [ + weights.get_tensor(f"{p}.input_scale", to_dtype=False).reshape(-1) + for p in prefixes + if weights.has_tensor(f"{p}.input_scale") + ] + assert len(input_scale) == 0 or len(input_scale) == len(prefixes) + input_scale = ( + torch.cat(input_scale, dim=0).reshape(-1).max() + if len(input_scale) != 0 + else None ) return Fp8Weight( @@ -285,7 +489,15 @@ class HybridFP8UnquantLoader(WeightsLoader): weight_block_size=self.weight_block_size, ) - scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + scale = ( + weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + .reshape(-1) + .expand(w.shape[0]) + ) + logical_widths = [w.shape[0]] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype + ) input_scale = None if weights.has_tensor(f"{prefix}.input_scale"): @@ -294,10 +506,7 @@ class HybridFP8UnquantLoader(WeightsLoader): .reshape(-1) .max() ) - logical_widths = [w.shape[0]] - w, scale = requantize_with_max_scale( - w, scale.unsqueeze(0), logical_widths, weights.dtype - ) + return Fp8Weight( weight=w, weight_scale=scale, @@ -389,6 +598,22 @@ class Fp8Linear(torch.nn.Module): scale_upper_bound = kwargs.get("scale_upper_bound", None) weight_block_size = kwargs.get("weight_block_size", None) + if weight_block_size is not None: + weight, orig_M, orig_N = pad_block_fp8_weight_naive( + weight, scale, weight_block_size + ) + weight, scale = dynamic_quant( + dequant_block_fp8_weight_naive( + weight, + scale, + weight_block_size, + original_M=orig_M, + original_N=orig_N, + do_unpad=True, + ) + ) + scale = scale.squeeze(-1) + return cls( qweight=weight, scale=scale, @@ -399,60 +624,32 @@ class Fp8Linear(torch.nn.Module): weight_block_size=weight_block_size, ) - @classmethod - def get_shared_device_identity(cls, device): - # Input scaling factors are no longer optional in _scaled_mm starting - # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale - if device not in cls._device_identity_cache: - cls._device_identity_cache[device] = torch.ones(1, device=device) - return cls._device_identity_cache[device] - def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.weight_block_size is not None: - # https://arxiv.org/pdf/2412.19437 - # At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and - # scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we - # group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output - # channels). - qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1]) - output = w8a8_block_fp8_matmul( - qinput, - self.qweight, - scale, - self.scale, - self.weight_block_size, - output_dtype=input.dtype, + if self.weight_block_size is not None or self.input_scale is None: + return apply_block_fp8_linear_hpu_dynamic( + input, self.qweight, self.scale, self.input_scale, self.bias ) - if self.bias is not None: - output = output + self.bias - return output.to(dtype=input.dtype) - - qinput, scale = fp8_quantize( - input, - self.input_scale, - scale_upper_bound=self.scale_upper_bound, - scalar=True, - ) - - output = torch._scaled_mm( - qinput, - self.qweight.t(), - out_dtype=self.dtype, - scale_a=scale, - scale_b=self.scale, + x_fp8 = torch.ops.hpu.cast_to_fp8_v2( + input, 1.0 / self.input_scale, False, False, torch.float8_e4m3fn + )[0] + return torch.ops.hpu.fp8_gemm_v2( + A=x_fp8, + trans_A=False, + B=self.qweight, + trans_B=True, + D=None, + out_dtype=input.dtype, + A_scale_inv=self.input_scale, + B_scale_inv=self.scale, bias=self.bias, + accumulate=False, ) - if isinstance(output, tuple) and len(output) == 2: - output = output[0] - - return output - def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size): scale = weights.get_tensor(prefix, to_dtype=False) if scale.numel() > 1: scale = weights.get_sharded(prefix, dim=0, to_dtype=False) - return scale.reshape(-1) + return scale.reshape(-1).expand(shape[0]) diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py index 90b8f692..96b120b2 100644 --- a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py @@ -4,7 +4,12 @@ from typing import List, Optional, Union import torch from loguru import logger from text_generation_server.utils.log import log_once -from text_generation_server.utils.weights import Weight, Weights, WeightsLoader +from text_generation_server.utils.weights import ( + Weight, + Weights, + WeightsLoader, + DefaultWeightsLoader, +) from .hpu import QuantLinear @@ -72,6 +77,7 @@ class GPTQWeightsLoader(WeightsLoader): quant_method: str, quantize: str, sym: bool, + modules_to_not_convert: List[str], ): self.bits = bits self.desc_act = desc_act @@ -79,6 +85,12 @@ class GPTQWeightsLoader(WeightsLoader): self.quant_method = quant_method self.quantize = quantize self.sym = sym + self.modules_to_not_convert = modules_to_not_convert + + def is_layer_skipped_quantization( + self, prefix: str, modules_to_not_convert: List[str] + ): + return any(module_name in prefix for module_name in modules_to_not_convert) def get_weights(self, weights: Weights, prefix: str): self._get_gptq_params(weights) @@ -91,6 +103,9 @@ class GPTQWeightsLoader(WeightsLoader): log_once(logger.warning, "Disabling exllama because desc_act=True") use_exllama = False + if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert): + return DefaultWeightsLoader.get_weights(weights, prefix) + try: qweight = weights.get_tensor(f"{prefix}.qweight") except RuntimeError: @@ -145,6 +160,10 @@ class GPTQWeightsLoader(WeightsLoader): prefix: str, block_sizes: Union[int, List[int]], ): + if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert): + return DefaultWeightsLoader.get_weights_col_packed( + weights, prefix, block_sizes + ) try: qweight = weights.get_packed_sharded( f"{prefix}.qweight", dim=1, block_sizes=block_sizes @@ -196,6 +215,8 @@ class GPTQWeightsLoader(WeightsLoader): ) def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert): + return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim) try: qweight = torch.cat( [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 @@ -255,6 +276,63 @@ class GPTQWeightsLoader(WeightsLoader): use_exllama=use_exllama, ) + def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int): + if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert): + return DefaultWeightsLoader.get_multi_weights(weights, prefixes, dim) + try: + qweight = torch.cat( + [weights.get_tensor(f"{p}.qweight") for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized" + ) + + scales = torch.cat([weights.get_tensor(f"{p}.scales") for p in prefixes], dim=1) + + self._get_gptq_params(weights) + + qzeros = torch.cat([weights.get_tensor(f"{p}.qzeros") for p in prefixes], dim=1) + + use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act + + if self.quantize == "gptq" and self.quant_method == "gptq": + w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + elif self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + ).to(dtype=torch.int32) + else: + g_idx = None + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_awq_kernel=self.quantize == "awq", + use_exllama=use_exllama, + ) + def get_weights_row(self, weights: Weights, prefix: str): self._get_gptq_params(weights) @@ -263,6 +341,9 @@ class GPTQWeightsLoader(WeightsLoader): if self.bits != 4: use_exllama = False + if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert): + return DefaultWeightsLoader.get_weights_row(weights, prefix) + if self.desc_act: log_once(logger.warning, "Disabling exllama because desc_act=True") use_exllama = False diff --git a/backends/gaudi/server/text_generation_server/layers/layernorm.py b/backends/gaudi/server/text_generation_server/layers/layernorm.py index 84878791..4bbb6c1f 100644 --- a/backends/gaudi/server/text_generation_server/layers/layernorm.py +++ b/backends/gaudi/server/text_generation_server/layers/layernorm.py @@ -53,15 +53,10 @@ class FastRMSNorm(nn.Module): return cls(weight, eps) def forward(self, hidden_states, residual=None): - from vllm_hpu_extension.kernels import rms_norm - - orig_shape = hidden_states.shape if residual is not None: - residual += hidden_states.view(residual.shape) - else: - residual = hidden_states - # Note: HPUFusedRMSNorm requires 3D tensors as inputs - if len(orig_shape) == 2: - residual = residual.unsqueeze(0) - x = rms_norm().apply(residual, self.weight, self.variance_epsilon) - return x.view(orig_shape), residual.view(orig_shape) + hidden_states += residual + residual = hidden_states + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(self.weight.dtype), residual diff --git a/backends/gaudi/server/text_generation_server/layers/moe/fp8.py b/backends/gaudi/server/text_generation_server/layers/moe/fp8.py index 071b2abe..5365f24f 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/fp8.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/fp8.py @@ -2,6 +2,7 @@ from typing import Optional import torch import torch.nn as nn +import os from text_generation_server.utils.weights import Weights from text_generation_server.layers.fp8 import ( @@ -9,12 +10,11 @@ from text_generation_server.layers.fp8 import ( fp8_quantize, quant_dtype, normalize_e4m3fn_to_native_float8, + dynamic_quant, + dequant_block_fp8_weight_naive, ) - -try: - from .unquantized import fused_moe -except Exception: - fused_moe = None +from text_generation_server.layers.moe.fused_moe import select_experts +import habana_frameworks.torch as htorch class FP8SparseMoELayer(nn.Module): @@ -47,6 +47,16 @@ class FP8SparseMoELayer(nn.Module): self.weight_block_size = weights.weights_loader.weight_block_size self.scoring_func = scoring_func self.e_score_correction_bias = e_score_correction_bias + self.world_size = weights.process_group.size() + self.rank = weights.process_group.rank() + self.ep_rank = self.rank + self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true" + + if self.use_ep: + n_experts = (n_experts + self.world_size - 1) // self.world_size + self.ep_offset = self.ep_rank * n_experts + else: + self.ep_offset = 0 ( self.gate_up_proj, @@ -58,6 +68,8 @@ class FP8SparseMoELayer(nn.Module): gate_proj_name=gate_proj_name, up_proj_name=up_proj_name, weights=weights, + use_ep=self.use_ep, + ep_offset=self.ep_offset, ) self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = ( @@ -66,29 +78,89 @@ class FP8SparseMoELayer(nn.Module): n_experts=n_experts, name=down_proj_name, weights=weights, + use_ep=self.use_ep, + ep_offset=self.ep_offset, ) ) + if self.weight_block_size is not None: + self.gate_up_proj, self.gate_up_proj_weight_scale = dynamic_quant( + dequant_block_fp8_weight_naive( + self.gate_up_proj, + self.gate_up_proj_weight_scale, + self.weight_block_size, + ) + ) + self.down_proj, self.down_proj_weight_scale = dynamic_quant( + dequant_block_fp8_weight_naive( + self.down_proj, self.down_proj_weight_scale, self.weight_block_size + ) + ) + self.gate_up_proj_weight_scale, self.down_proj_weight_scale = ( + self.gate_up_proj_weight_scale.squeeze(-1), + self.down_proj_weight_scale.squeeze(-1), + ) def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: - return fused_moe( - x, - w1=self.gate_up_proj, - w2=self.down_proj, - gating_output=gating_output, - topk=self.topk, - renormalize=self.renormalize, - inplace=True, + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=gating_output, use_grouped_topk=self.n_expert_group is not None, - num_expert_group=self.n_expert_group, + top_k=self.topk, + renormalize=self.renormalize, topk_group=self.topk_group, + num_expert_group=self.n_expert_group, scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias, - use_fp8_w8a8=True, - w1_scale=self.gate_up_proj_weight_scale, - w2_scale=self.down_proj_weight_scale, - a1_scale=self.gate_up_proj_input_scale, - a2_scale=self.down_proj_input_scale, ) + total_num_experts = gating_output.size(-1) + x_fp8, x_scale = dynamic_quant(x, single_scale=True) + + if self.use_ep: + moe_n_slice = 1 + n_expert_slice = ( + total_num_experts + self.world_size - 1 + ) // self.world_size + else: + moe_n_slice = 1 + n_expert_slice = (total_num_experts + moe_n_slice - 1) // moe_n_slice + for i in range(moe_n_slice): + min_expert = i * n_expert_slice + max_expert = min((i + 1) * n_expert_slice, total_num_experts) + w13_list_slice = [ + self.gate_up_proj[j, ...] for j in range(min_expert, max_expert) + ] + w2_list_slice = [ + self.down_proj[j, ...] for j in range(min_expert, max_expert) + ] + w13_weight_scale = [ + self.gate_up_proj_weight_scale[j, ...] + for j in range(min_expert, max_expert) + ] + w2_weight_scale = [ + self.down_proj_weight_scale[j, ...] + for j in range(min_expert, max_expert) + ] + + current_hidden_states = torch.ops.hpu.mixture_of_experts( + hidden_states=x_fp8, + expert_routing_table=topk_ids.to(torch.int64), + router_weights=topk_weights.to(x.dtype), + w12=w13_list_slice, + w3=w2_list_slice, + d_scale_hidden_states=x_scale, + d_scale_w12=w13_weight_scale, + d_scale_w3=w2_weight_scale, + permuted_weights=True, + activation="silu", + experts_min=min_expert + self.ep_offset, + experts_max=max_expert + self.ep_offset - 1, + ) + htorch.core.mark_step() + if i == 0: + final_hidden_states = current_hidden_states + else: + final_hidden_states.add_(current_hidden_states) + return final_hidden_states def _load_expert_weights( @@ -98,13 +170,14 @@ def _load_expert_weights( n_experts: int, name: str, weights: Weights, + ep_offset: int = 0, ) -> torch.Tensor: all_weight = None all_weight_scales = None max_input_scale = None for i in range(n_experts): - weight = get_weight_fn(prefix, i, name, weights) + weight = get_weight_fn(prefix, i + ep_offset, name, weights) assert isinstance(weight, Fp8Weight) @@ -147,14 +220,26 @@ def _load_expert_multi_weights_col( gate_proj_name: str, up_proj_name: str, weights: Weights, + use_ep: bool = False, + ep_offset: int = 0, ) -> torch.Tensor: - def get_weight_fn(prefix, i, name, weights): + def get_weight_fn_sharded(prefix, i, name, weights): return weights.get_multi_weights_col( [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0 ) + def get_weight_fn(prefix, i, name, weights): + return weights.get_multi_weights( + [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0 + ) + return _load_expert_weights( - get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights + get_weight_fn if use_ep else get_weight_fn_sharded, + prefix=prefix, + n_experts=n_experts, + name=None, + weights=weights, + ep_offset=ep_offset if use_ep else 0, ) @@ -164,10 +249,20 @@ def _load_expert_weights_row( n_experts: int, name: str, weights: Weights, + use_ep: bool = False, + ep_offset: int = 0, ) -> torch.Tensor: - def get_weight_fn(prefix, i, name, weights): + def get_weight_fn_sharded(prefix, i, name, weights): return weights.get_weights_row(f"{prefix}.{i}.{name}") + def get_weight_fn(prefix, i, name, weights): + return weights.get_weights(f"{prefix}.{i}.{name}") + return _load_expert_weights( - get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights + get_weight_fn if use_ep else get_weight_fn_sharded, + prefix=prefix, + n_experts=n_experts, + name=name, + weights=weights, + ep_offset=ep_offset if use_ep else 0, ) diff --git a/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py b/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py index e26ff877..1987f0ed 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Tuple, Optional import torch @@ -25,12 +25,36 @@ def grouped_topk( renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - scores = torch.softmax(gating_output, dim=-1) + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + + gating_output = gating_output.float() + if e_score_correction_bias is not None: + e_score_correction_bias = e_score_correction_bias.float() + + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + num_token = scores.shape[0] - group_scores = ( - scores.view(num_token, num_expert_group, -1).max(dim=-1).values - ) # [n, n_group] + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1) + ) + else: + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ 1 ] # [n, top_k_group] @@ -41,13 +65,19 @@ def grouped_topk( .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) .reshape(num_token, -1) ) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) def fused_topk( @@ -63,3 +93,39 @@ def fused_topk( if renormalize: topk_weights /= topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids + + +def select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, +): + + # DeekSeekv2 uses grouped_top_k + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + topk_weights, topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + else: + topk_weights, topk_ids = fused_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + ) + return topk_weights, topk_ids diff --git a/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py index ec158398..58709ec3 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py @@ -4,7 +4,9 @@ import torch import torch.nn as nn from text_generation_server.utils.weights import UnquantizedWeight, Weights -from vllm_hpu_extension.ops import DynamicFusedMOE +from vllm_hpu_extension.ops import VllmMixtureOfExpertsOp +import habana_frameworks.torch as htorch +import torch.nn.functional as F class UnquantizedSparseMoELayer(nn.Module): @@ -53,13 +55,29 @@ class UnquantizedSparseMoELayer(nn.Module): weights=weights, ) - self.hpu_fused_moe = DynamicFusedMOE(n_experts) + self.MoeOp = VllmMixtureOfExpertsOp(n_experts, 0, n_experts - 1) for i in range(n_experts): - self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i]) - self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.down_proj[i]) + self.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i]) + self.MoeOp.w2_list[i].set_weight(self.down_proj[i]) def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: - return self.hpu_fused_moe(x, gating_output, self.topk) + htorch.core.mark_step() + routing_weights = F.softmax(gating_output, dim=1, dtype=torch.float32) + routing_weights, selected_experts = torch.topk( + routing_weights, self.topk, dim=-1 + ) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(x.dtype) + + final_hidden_states = self.MoeOp( + hidden_states=x, + expert_routing_table=selected_experts, + router_weights=routing_weights, + permuted_weights=True, + activation="silu", + ) + + return final_hidden_states.view(-1, x.shape[1]) def _load_expert_multi_weights_col( diff --git a/backends/gaudi/server/text_generation_server/layers/rotary.py b/backends/gaudi/server/text_generation_server/layers/rotary.py index 6a83d6a5..d381d4c6 100644 --- a/backends/gaudi/server/text_generation_server/layers/rotary.py +++ b/backends/gaudi/server/text_generation_server/layers/rotary.py @@ -36,9 +36,7 @@ class PositionRotaryEmbedding(nn.Module): self._sin_k_cached = None self.scaling_factor = scaling_factor self.dynamic_args = None - self._update_cos_sin_cache( - torch.float32, inv_freq.device, max_position_embeddings - ) + self.max_position_embeddings = max_position_embeddings def forward( self, @@ -270,7 +268,9 @@ class PositionRotaryEmbedding(nn.Module): self._sin_cached = torch.sin(freqs).to(dtype) def get_cos_sin(self, position_ids: torch.Tensor): - + self._update_cos_sin_cache( + torch.float32, position_ids.device, seqlen=self.max_position_embeddings + ) cos = torch.index_select(self._cos_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids) @@ -298,9 +298,6 @@ class SuRotaryEmbedding(PositionRotaryEmbedding): self._cos_k_cached = None self._sin_k_cached = None self.dynamic_args = None - self._update_cos_sin_cache( - torch.float32, short_inv_freq.device, max_position_embeddings - ) def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, @@ -354,9 +351,6 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding): self._cos_k_cached = None self._sin_k_cached = None self.dynamic_args = None - self._update_cos_sin_cache( - torch.float32, short_inv_freq.device, max_position_embeddings - ) def _update_cos_sin_cache(self, dtype, device, seqlen): if ( @@ -470,9 +464,6 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): mscale_all_dim: float, ): inv_freq = _create_inv_freq(dim, base, device) - super().__init__( - inv_freq, scaling_factor, max_position_embeddings * self.scaling_factor - ) self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base @@ -487,6 +478,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): / get_mscale(self.scaling_factor, mscale_all_dim) * self.attn_factor ) # Get n-d magnitude scaling corrected for interpolation + super().__init__(inv_freq, scaling_factor, max_position_embeddings) def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, @@ -600,6 +592,9 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): position_ids: torch.Tensor, ): slen = position_ids.shape[0] + self._update_cos_sin_cache( + torch.float32, position_ids.device, seqlen=self.max_position_embeddings + ) cos = self._cos_cached[position_ids].gather(1, self._sections[:slen]) sin = self._sin_cached[position_ids].gather(1, self._sections[:slen]) diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 778b14a1..3a91e94c 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -5,7 +5,6 @@ import os from loguru import logger from transformers.configuration_utils import PretrainedConfig -from transformers.models.auto import modeling_auto from huggingface_hub import hf_hub_download, HfApi from typing import Optional from pathlib import Path @@ -16,9 +15,6 @@ import enum from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model -from text_generation_server.models.causal_lm import CausalLM -from text_generation_server.models.bloom import BLOOM -from text_generation_server.models.starcoder import StarCoder from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import ( PhiMoEConfig, ) @@ -32,7 +28,6 @@ from text_generation_server.utils.adapter import ( from text_generation_server.adapters.lora import LoraWeights from text_generation_server.utils.log import log_master -from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi __all__ = [ "Model", @@ -40,13 +35,10 @@ __all__ = [ "Seq2SeqLM", "get_model_with_lora_adapters", ] -from text_generation_server.models.globals import ATTENTION -FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." +VLM_BATCH_TYPES = set() -FLASH_ATTENTION = False -if ATTENTION == "paged": - FLASH_ATTENTION = True +FLASH_ATTENTION = True try: from text_generation_server.models.flash_causal_lm import FlashCausalLM @@ -63,6 +55,9 @@ try: from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaForCausalLM, ) + from text_generation_server.models.custom_modeling.flash_llama4_modeling import ( + Llama4ForConditionalGeneration, + ) from text_generation_server.models.custom_modeling.flash_cohere_modeling import ( FlashCohereForCausalLM, ) @@ -83,9 +78,6 @@ try: from text_generation_server.models.custom_modeling.flash_neox_modeling import ( FlashGPTNeoXForCausalLM, ) - from text_generation_server.models.pali_gemma import ( - PaliGemmaBatch, - ) from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( PaliGemmaForConditionalGeneration, ) @@ -109,6 +101,12 @@ try: from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( Qwen2ForCausalLM, ) + from text_generation_server.models.custom_modeling.flash_qwen3_modeling import ( + Qwen3ForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_qwen3_moe_modeling import ( + Qwen3MoeForCausalLM, + ) from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, ) @@ -140,10 +138,23 @@ except ImportError as e: log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}") SUPPORTS_WINDOWING = False FLASH_ATTENTION = False + VLM_BATCH_TYPES = set() if FLASH_ATTENTION: __all__.append(FlashCausalLM) + from text_generation_server.models.flash_vlm_causal_lm import ( + FlashVlmCausalLMBatch, + ) + + VLM_BATCH_TYPES = { + FlashVlmCausalLMBatch, + FlashMllamaCausalLMBatch, + } + + +__all__.append(VLM_BATCH_TYPES) + class ModelType(enum.Enum): DEEPSEEK_V2 = { @@ -179,6 +190,11 @@ class ModelType(enum.Enum): "name": "Llama", "url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f", } + LLAMA4 = { + "type": "llama4", + "name": "Llama4", + "url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f", + } PHI3 = { "type": "phi3", "name": "Phi 3", @@ -274,6 +290,16 @@ class ModelType(enum.Enum): "name": "Qwen 2.5 VL", "url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e", } + QWEN3 = { + "type": "qwen3", + "name": "Qwen 3", + "url": "https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f", + } + QWEN3_MOE = { + "type": "qwen3_moe", + "name": "Qwen 3 Moe", + "url": "https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f", + } GALACTICA = { "type": "galactica", "name": "Galactica", @@ -324,6 +350,7 @@ def get_model( quantize: Optional[str], speculate: Optional[int], dtype: Optional[torch.dtype], + kv_cache_dtype: Optional[str], trust_remote_code: bool, max_input_tokens: int, ) -> Model: @@ -449,7 +476,12 @@ def get_model( model_type = config_dict["model_type"] - kv_cache_dtype = dtype + if kv_cache_dtype == "fp8_e4m3fn": + kv_cache_dtype = torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + kv_cache_dtype = torch.float8_e5m2 + else: + kv_cache_dtype = dtype if FLASH_ATTENTION: if model_type == DEEPSEEK_V2: @@ -589,6 +621,20 @@ def get_model( trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) + elif model_type == LLAMA4: + print(f"Llama4 model detected: {model_id}") + return FlashVlmCausalLM( + model_id=model_id, + model_class=Llama4ForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + support_chunking=False, + ) elif model_type == BAICHUAN: return FlashCausalLM( model_id=model_id, @@ -737,6 +783,8 @@ def get_model( kv_cache_dtype=kv_cache_dtype, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, + # TODO: Fix bug in rust image_text_replacement implementation + support_chunking=False, ) elif model_type == QWEN2_5_VL: return FlashVlmCausalLM( @@ -752,6 +800,32 @@ def get_model( lora_adapter_ids=lora_adapter_ids, config_class=Qwen2_5_VLConfig, processor_class=Qwen2_5_VLProcessor, + # TODO: Fix bug in rust image_text_replacement implementation + support_chunking=False, + ) + elif model_type == QWEN3: + return FlashCausalLM( + model_id=model_id, + model_class=Qwen3ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + elif model_type == QWEN3_MOE: + return FlashCausalLM( + model_id=model_id, + model_class=Qwen3MoeForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif model_type == MLLAMA: return FlashMllamaCausalLM( @@ -765,6 +839,7 @@ def get_model( default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, + support_chunking=False, ) elif model_type == IDEFICS2: return FlashVlmCausalLM( @@ -809,7 +884,6 @@ def get_model( default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, - batch_class=PaliGemmaBatch, ) elif model_type == LLAVA_NEXT: return FlashVlmCausalLM( @@ -823,60 +897,6 @@ def get_model( trust_remote_code=trust_remote_code, ) - from text_generation_server.models.vlm_causal_lm import VlmCausalLM - from text_generation_server.models.custom_modeling.mllama import ( - MllamaForConditionalGeneration, - ) - from text_generation_server.models.custom_modeling.llava_next import ( - LlavaNextForConditionalGeneration, - ) - - adapt_transformers_to_gaudi() - if SDP_ON_BF16 == 1: - torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) - if model_type == "gpt_bigcode": - return StarCoder(model_id=model_id, revision=revision, dtype=dtype) - if model_type == "bloom": - return BLOOM( - model_id=model_id, - revision=revision, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - if model_type == "llava_next": - return VlmCausalLM( - model_class=LlavaNextForConditionalGeneration, - model_id=model_id, - revision=revision, - quantize=None, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - if model_type == "mllama": - return VlmCausalLM( - model_class=MllamaForConditionalGeneration, - model_id=model_id, - revision=revision, - quantize=None, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - return CausalLM( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - raise ValueError(f"Unsupported model type {model_type}") @@ -890,6 +910,7 @@ def get_model_with_lora_adapters( quantize: Optional[str], speculate: Optional[int], dtype: Optional[torch.dtype], + kv_cache_dtype: Optional[str], trust_remote_code: bool, max_input_tokens: int, adapter_to_index: Dict[str, int], @@ -903,6 +924,7 @@ def get_model_with_lora_adapters( quantize, speculate, dtype, + kv_cache_dtype, trust_remote_code, max_input_tokens, ) diff --git a/backends/gaudi/server/text_generation_server/models/bloom.py b/backends/gaudi/server/text_generation_server/models/bloom.py deleted file mode 100644 index 6fe64374..00000000 --- a/backends/gaudi/server/text_generation_server/models/bloom.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - -import torch - -from typing import Optional, Type - -from transformers import PreTrainedTokenizerBase - -from text_generation_server.models import CausalLM -from text_generation_server.models.causal_lm import CausalLMBatch -from text_generation_server.pb import generate_pb2 - - -class BloomCausalLMBatch(CausalLMBatch): - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "CausalLMBatch": - batch = super().from_pb( - pb=pb, - tokenizer=tokenizer, - dtype=dtype, - device=device, - ) - batch.keys_head_dim_last = False - return batch - - -class BLOOM(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - super(BLOOM, self).__init__( - model_id=model_id, - revision=revision, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - @property - def batch_type(self) -> Type[CausalLMBatch]: - return BloomCausalLMBatch diff --git a/backends/gaudi/server/text_generation_server/models/causal_lm.py b/backends/gaudi/server/text_generation_server/models/causal_lm.py deleted file mode 100644 index 76c09264..00000000 --- a/backends/gaudi/server/text_generation_server/models/causal_lm.py +++ /dev/null @@ -1,1446 +0,0 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - -import bisect -from dataclasses import dataclass -from functools import wraps -import itertools -import json -import math -import os -import tempfile -import time -import copy -from typing import Dict, List, Optional, Tuple, Type - -import torch -import torch._dynamo -from loguru import logger -from opentelemetry import trace - -import text_generation_server.habana_quantization_env as hq_env -from text_generation_server.utils import weight_files -import habana_frameworks.torch as htorch -from optimum.habana.utils import HabanaProfile -from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES -from text_generation_server.utils.chunks import concat_text_chunks -from optimum.habana.checkpoint_utils import model_on_meta -from transformers import ( - AutoTokenizer, - AutoModelForCausalLM, - PreTrainedTokenizerBase, - AutoConfig, -) - -from text_generation_server.utils.tokens import batch_top_tokens -from text_generation_server.models import Model -from text_generation_server.models.types import ( - Batch, - Tokens, - Generation, - GeneratedText, -) -from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import ( - HeterogeneousNextTokenChooser, - StoppingCriteria, - is_tokenizer_transparent, - pad_next_token_chooser_parameters, -) -from optimum.habana.utils import get_hpu_memory_stats -from text_generation_server.utils.debug import dbg_trace -from text_generation_server.utils.speculate import get_speculate - -tracer = trace.get_tracer(__name__) -MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 2048)) -PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 256)) -CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] -LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1)) -BATCH_SIZE_EXPONENT_BASE = int(os.environ.get("BATCH_SIZE_EXPONENT_BASE", 2)) -SEQ_LEN_EXPONENT_BASE = int(os.environ.get("SEQ_LEN_EXPONENT_BASE", 2)) -MAX_BATCH_SIZE = ( - int(os.environ.get("MAX_BATCH_SIZE")) - if os.environ.get("MAX_BATCH_SIZE") is not None - else None -) - - -def torch_compile_for_eager(func): - if LAZY_MODE == 1: - return func - return torch.compile( - func, backend="hpu_backend", options={"keep_input_mutations": True} - ) - - -def round_up_seq(number, k, base): - exponent = math.ceil(math.log(number / k, base)) - return int(k * (base**exponent)) - - -def iterate_powers_of_base(max_value, start, base): - current = start - result = [] - assert ( - max_value >= start - ), f"max_value {max_value} must be greater than start {start}" - while current < max_value: - result.append(current) - current *= base - return result - - -def round_up_batch(number): - return BATCH_SIZE_EXPONENT_BASE ** ( - math.ceil(math.log(number, BATCH_SIZE_EXPONENT_BASE)) - ) - - -def to_tensor_indices(indices, device): - return torch.tensor(indices, dtype=torch.long, device=device) - - -def calculate_chunks(offset): - result = [] - while offset != 0: - sign = 1 if offset > 0 else -1 - best_chunk = min((abs(offset - sign * c), sign * c) for c in CHUNK_SIZES)[1] - result.append(best_chunk) - offset = offset - best_chunk - return result - - -def biggest_single_chunk(offset): - if offset != 0: - idx = bisect.bisect(CHUNK_SIZES, abs(offset)) - return int(math.copysign(CHUNK_SIZES[idx - 1], offset)) - else: - return 0 - - -@torch_compile_for_eager -def grouped_pad(tensor_groups, dims, values): - grouped_result = [] - for tensors, dim, value in zip(tensor_groups, dims, values): - padding = MAX_TOTAL_TOKENS - tensors[0].size(dim) if dim is not None else 0 - if padding > 0: - assert dim in [-1, -2], f"Only dims -1 and -2 are supported! {dim}" - pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding) - result = [ - torch.nn.functional.pad(t, pad_shape, value=value) for t in tensors - ] - else: - result = [t for t in tensors] - grouped_result.append(result) - htorch.core.mark_step() - return grouped_result - - -@torch_compile_for_eager -def roll(tensor, chunk, dim, merge_graphs): - if dim is None: - return tensor - tensor = torch.roll(tensor, chunk, dim) - if not merge_graphs: - htorch.core.mark_step() - return tensor - - -def grouped_roll(tensor_groups, chunk, dims, merge_graphs): - tensor_groups = [ - [roll(t, chunk, dim, merge_graphs) for t in tensors] - for tensors, dim in zip(tensor_groups, dims) - ] - if merge_graphs: - htorch.core.mark_step() - return tensor_groups - - -@torch_compile_for_eager -def grouped_shift(tensor_groups, dims, offset, merge_graphs): - chunks = calculate_chunks(offset) - for c in chunks: - tensor_groups = grouped_roll(tensor_groups, c, dims, merge_graphs) - return tensor_groups - - -def move(dst_tensors, dst_indices, src_tensors): - bs_dim = 0 - num_indices = dst_indices.size(0) - for i, (dst_t, src_t) in enumerate(zip(dst_tensors, src_tensors)): - if src_t.size(bs_dim) != num_indices: - src_t = torch.narrow(src_t, bs_dim, 0, num_indices) - dst_t.index_copy_(bs_dim, dst_indices, src_t) - htorch.core.mark_step() - - -def grouped_move(dst_tensor_groups, dst_indices, src_tensor_groups): - for dst_tensors, src_tensors in zip(dst_tensor_groups, src_tensor_groups): - move(dst_tensors, dst_indices, src_tensors) - - -@torch_compile_for_eager -def extend_tensor(tensor, padding, dim): - result = torch.cat([tensor, padding], dim=dim) - htorch.core.mark_step() - return result - - -@torch_compile_for_eager -def extend_batch(tensors, target_bs, dim): - diff = target_bs - tensors[0].size(dim) - # TODO: add support for shrinking bs - if diff <= 0: - return tensors - shape = list(tensors[0].shape) - shape[dim] = diff - padding = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype) - tensors = [extend_tensor(t, padding, dim) for t in tensors] - return tensors - - -def grouped_extend_batch(tensor_groups, target_bs, bs_dims): - tensor_groups = [ - extend_batch(tensors, target_bs, dim) - for tensors, dim in zip(tensor_groups, bs_dims) - ] - return tensor_groups - - -@torch_compile_for_eager -def merge(tensor_group): - tensor_group = [torch.stack(tensor_group)] - htorch.core.mark_step() - return tensor_group - - -@torch_compile_for_eager -def split(tensor_group, clone_data): - tensor_group = [t.squeeze(0) for t in torch.split(tensor_group[0], 1)] - if clone_data: - tensor_group = [t.clone() for t in tensor_group] - htorch.core.mark_step() - return tensor_group - - -def remove_kv_cache_from_output(module): - orig_fwd = module.forward - - @wraps(orig_fwd) - def forward(*args, **kwargs): - if kwargs["past_key_values"] is not None: - kwargs["return_dict"] = False - output = orig_fwd(*args, **kwargs) - first_value, second_value, *_ = output - if first_value.nelement() < 2: - return second_value - else: - return first_value - else: - kwargs["return_dict"] = True - return orig_fwd(*args, **kwargs) - - module.forward = forward - return module - - -@dataclass -class CausalLMRequest: - idx: int - data: generate_pb2.Request - input_length: int - prefix_offset: int - read_offset: int - stopping_criteria: StoppingCriteria - - all_input_ids: torch.Tensor - - @classmethod - def from_pb( - cls, idx: int, data: generate_pb2.Request, tokenizer: PreTrainedTokenizerBase - ): - return cls( - idx=idx, - data=data, - input_length=None, - prefix_offset=None, - read_offset=None, - stopping_criteria=StoppingCriteria.from_pb( - data.stopping_parameters, tokenizer - ), - all_input_ids=None, - ) - - def update_idx(self, new_idx): - prev = self.idx - self.idx = new_idx - return (new_idx, prev) - - -@dataclass -class CausalLMBatch(Batch): - batch_id: int - requests: List[CausalLMRequest] - - # Decoder values - input_ids: torch.Tensor - attention_mask: torch.Tensor - position_ids: torch.Tensor - past_key_values: Optional[List[Tuple]] - merged_kv_cache: bool - - # Lengths of all generations present in the batch - input_length: int - - # Generation helpers - next_token_chooser: HeterogeneousNextTokenChooser - top_n_tokens: List[int] - top_n_tokens_tensor: torch.Tensor - - input_length: int - - # Past metadata - logits = None - past = None - - keys_head_dim_last: bool = True - - def to_pb(self) -> generate_pb2.CachedBatch: - return generate_pb2.CachedBatch( - id=self.batch_id, - request_ids=[r.data.id for r in self.requests], - size=len(self), - max_tokens=self.max_tokens, - ) - - def detach_kv_cache(self): - past_keys = [past[0] for past in self.past_key_values] - past_values = [past[1] for past in self.past_key_values] - del self.past_key_values - return past_keys, past_values - - def attach_kv_cache(self, past_keys, past_values): - # TODO: Add support for models that don't store kv_cache in a list - self.past_key_values = list(zip(past_keys, past_values)) - - def merge_kv_cache_if_needed(self, target_bs, offset): - pad_needed = self.seq_length < MAX_TOTAL_TOKENS - shift_needed = offset != 0 - expand_needed = target_bs > self.batch_size - # Very simple heuristic to determine whether we should merge tensors - # this needs tuning for other models/scenarios - small_bs = len(self.past_key_values) > self.batch_size - if ( - not self.merged_kv_cache - and small_bs - and (pad_needed or shift_needed or expand_needed) - ): - past_keys, past_values = self.detach_kv_cache() - past_keys = merge(past_keys) - past_values = merge(past_values) - self.attach_kv_cache(past_keys, past_values) - self.merged_kv_cache = True - - def split_kv_cache_if_needed(self, clone_data): - if self.merged_kv_cache: - past_keys, past_values = self.detach_kv_cache() - past_keys = split(past_keys, clone_data) - past_values = split(past_values, clone_data) - self.attach_kv_cache(past_keys, past_values) - self.merged_kv_cache = False - - def get_tensor_groups(self): - past_keys, past_values = self.detach_kv_cache() - seq_dim = -1 - key_dim = -2 if self.keys_head_dim_last else -1 - value_dim = -2 - tensors = [ - [self.input_ids], - [self.attention_mask], - [self.position_ids], - past_keys, - past_values, - ] - # We don't need to align position_ids - seq_dims = [seq_dim, seq_dim, None, key_dim, value_dim] - bs_dims = [0, 0, 0] + ([1, 1] if self.merged_kv_cache else [0, 0]) - return tensors, seq_dims, bs_dims - - def set_tensor_groups(self, tensors): - self.input_ids = tensors.pop(0)[0] - self.attention_mask = tensors.pop(0)[0] - self.position_ids = tensors.pop(0)[0] - past_keys = tensors.pop(0) - past_values = tensors.pop(0) - self.attach_kv_cache(past_keys, past_values) - - def realign(self, target_bs, offset, pad_token_id): - tensors, seq_dims, _ = self.get_tensor_groups() - tensors = grouped_pad(tensors, seq_dims, [pad_token_id, 0, 0, 0, 0]) - tensors = grouped_shift(tensors, seq_dims, offset, self.merged_kv_cache) - self.set_tensor_groups(tensors) - - def expand_bs(self, target_bs): - tensors, _, bs_dims = self.get_tensor_groups() - tensors = grouped_extend_batch(tensors, target_bs, bs_dims) - self.set_tensor_groups(tensors) - - def used_indices(self): - return [req.idx for req in self.requests] - - def update_indices(self, new_indices): - for req, new_idx in zip(self.requests, new_indices): - req.idx = new_idx - return self.used_indices() - - def free_indices_generator(self): - used = set(req.idx for req in self.requests) - return (i for i in range(self.batch_size) if i not in used) - - def move_data(self, src_batches): - dst_tensors, _, dst_dims = self.get_tensor_groups() - free_indices_gen = self.free_indices_generator() - for src_b in src_batches: - dst_indices = to_tensor_indices( - src_b.update_indices(free_indices_gen), self.input_ids.device - ) - src_tensors, _, src_dims = src_b.get_tensor_groups() - grouped_move(dst_tensors, dst_indices, src_tensors) - self.set_tensor_groups(dst_tensors) - - @classmethod - def recombine( - cls, batches: List["CausalLMBatch"], pad_token_id: int - ) -> "CausalLMBatch": - if not all(b.past_key_values is not None for b in batches): - raise ValueError("KV cache not allocated! Cannot recombine before prefill!") - - total_requests = sum(len(b) for b in batches) - new_bs = total_requests - new_bs = round_up_batch(total_requests) - - batch_id = batches[0].batch_id - device = batches[0].input_ids.device - - input_lengths = [b.input_length for b in batches] - max_input_length = max(input_lengths) - offsets = [max_input_length - b.input_length for b in batches] - - cur_padding = [b.right_padding for b in batches] - # For prefill there is a space allocated only for first token - # Need to add padding to the max total tokens before first decode - - moves_needed = [ - total_requests - len(b) if b.batch_size == new_bs else total_requests - for b in batches - ] - dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0] - reshape = batches[dst_batch_idx].batch_size < new_bs - - # TODO: Add support for changing max seq len, i.e. due to output length bucketing - # FIXME: max_seq_len for non optimized code - if len(batches) > 1: - scenario = "CONCAT" - elif reshape: - scenario = "RESHAPE" - elif cur_padding[dst_batch_idx] <= 0: - scenario = "SHIFT" - offsets = [ - biggest_single_chunk(b.max_input_length - max_input_length) - for b in batches - ] - max_input_length = max_input_length + offsets[dst_batch_idx] - else: - # Nothing to do - return batches[0] - - dbg_trace( - scenario, - f"bs:{[b.batch_size for b in batches]}->{new_bs}" - f" reqs:{[len(b) for b in batches]}" - f" offsets:{offsets}" - f" input_lengths:{input_lengths}" - f" cur_padding:{cur_padding}" - f" dst_batch:{dst_batch_idx}", - ) - - grouped_requests = [[req for req in batch.requests] for batch in batches] - flat_requests = list(itertools.chain(*grouped_requests)) - - for i in range(len(batches)): - target_bs = new_bs if i == dst_batch_idx else batches[i].batch_size - batches[i].merge_kv_cache_if_needed(target_bs, offsets[i]) - batches[i].realign(target_bs, offsets[i], pad_token_id) - batches[i].split_kv_cache_if_needed(i == dst_batch_idx) - batches[dst_batch_idx].expand_bs(new_bs) - batches[dst_batch_idx].move_data( - [batches[i] for i in range(len(batches)) if i != dst_batch_idx] - ) - - top_n_tokens = [r.data.top_n_tokens for r in flat_requests] - top_n_tokens.extend([-1] * (new_bs - total_requests)) - top_n_tokens_tensor = torch.tensor( - top_n_tokens, device=device, dtype=torch.int64 - ) - - parameters = [r.data.parameters for r in flat_requests] - # append the dummy parameters for dummy requests - batch_size = batches[dst_batch_idx].batch_size - parameters = pad_next_token_chooser_parameters(parameters, batch_size) - - # update past grammar states - fsm_grammar_states = [0] * batch_size - for batch in batches: - for i, req in enumerate(batch.requests): - fsm_grammar_states[req.idx] = ( - batch.next_token_chooser.fsm_grammar_states[i] - ) - - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - parameters, - batches[dst_batch_idx].next_token_chooser.dtype, - batches[dst_batch_idx].next_token_chooser.device, - batches[dst_batch_idx].next_token_chooser.tokenizer, - fsm_grammar_states, - quantization_enabled=hq_env.is_quantization_enabled, - ) - - input_ids = batches[dst_batch_idx].input_ids - attention_mask = batches[dst_batch_idx].attention_mask - position_ids = batches[dst_batch_idx].position_ids - past_key_values = batches[dst_batch_idx].past_key_values - input_length = max_input_length - - htorch.core.mark_step() - - return cls( - batch_id=batch_id, - requests=flat_requests, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - merged_kv_cache=False, - next_token_chooser=next_token_chooser, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - input_length=input_length, - ) - - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "CausalLMBatch": - dbg_trace("FROM_PB", f"num_reqs:{len(pb.requests)}") - requests = [ - CausalLMRequest.from_pb(idx, req, tokenizer) - for idx, req in enumerate(pb.requests) - ] - inputs = [] - top_n_tokens = [] - - # Parse batch - max_truncation = 0 - for i, r in enumerate(pb.requests): - inputs.append(concat_text_chunks(r.input_chunks.chunks)) - top_n_tokens.append(r.top_n_tokens) - max_truncation = max(max_truncation, r.truncate) - - max_input_length = max_truncation - if max_input_length < PAD_SEQUENCE_TO_MULTIPLE_OF: - max_input_length = PAD_SEQUENCE_TO_MULTIPLE_OF - max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests) - - # TODO: by tokenizing all inputs at once we loose information on actual input lengths - # this means that we cannot shift inputs to the left after a long input sequence - # was filtered out - new_bs = round_up_batch(len(requests)) - missing_inputs = new_bs - len(inputs) - dummy_inputs = ["?"] * missing_inputs - parameters = [r.parameters for r in pb.requests] - # append the dummy parameters for dummy request - parameters = pad_next_token_chooser_parameters(parameters, new_bs) - - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - pb=parameters, - dtype=dtype, - device=device, - tokenizer=tokenizer, - quantization_enabled=hq_env.is_quantization_enabled, - ) - - tokenized_inputs = tokenizer( - inputs + dummy_inputs, - return_tensors="pt", - padding="longest", - return_token_type_ids=False, - truncation=True, - max_length=max_truncation, - ) - - input_len = tokenized_inputs["input_ids"].shape[1] - # Round up sequence length - bucket_size = max_input_length - left_padding = max_input_length - input_len - if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0: - assert ( - PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length - ), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" - rounded_seq_len = round_up_seq( - input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF, SEQ_LEN_EXPONENT_BASE - ) - if rounded_seq_len <= max_input_length: - bucket_size = rounded_seq_len - 1 - else: - bucket_size = max_input_length - 1 - left_padding = bucket_size - input_len - - input_ids = tokenized_inputs["input_ids"] - attention_mask = tokenized_inputs["attention_mask"] - - # Allocate space for first token - input_ids = torch.nn.functional.pad( - input_ids, (left_padding, 1), value=tokenizer.pad_token_id - ) - attention_mask = torch.nn.functional.pad( - attention_mask, (left_padding, 1), value=0 - ) - all_input_ids = torch.nn.functional.pad( - input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id - ).T.split(1, dim=1) - input_len = bucket_size - for r in requests: - r.input_length = input_len - r.prefix_offset = input_len - 5 - r.read_offset = input_len - r.all_input_ids = all_input_ids[r.idx] - - input_ids = input_ids.to(device) - attention_mask = attention_mask.to(device) - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - - old_bs = len(requests) - top_n_tokens.extend([-1] * (new_bs - old_bs)) - top_n_tokens_tensor = torch.tensor( - top_n_tokens, device=device, dtype=torch.int64 - ) - htorch.core.mark_step() - return cls( - batch_id=pb.id, - requests=requests, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=None, - merged_kv_cache=False, - next_token_chooser=next_token_chooser, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - input_length=input_len, - ) - - @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: - dbg_trace("FILTER", f"num_reqs:{len(self.requests)} -> {len(request_ids)}") - request_ids = set(request_ids) - self.requests = [req for req in self.requests if req.data.id in request_ids] - return self - - @classmethod - @tracer.start_as_current_span("concatenate") - def concatenate( - cls, batches: List["CausalLMBatch"], pad_token_id: int = 0 - ) -> "CausalLMBatch": - return cls.recombine(batches, pad_token_id) - - def __len__(self): - return len(self.requests) - - @property - def max_input_length(self): - return max(req.input_length for req in self.requests) - - @property - def batch_size(self): - return self.attention_mask.size(0) - - @property - def seq_length(self): - return self.attention_mask.size(1) - - @property - def right_padding(self): - return self.seq_length - self.input_length - - # Maximum number of tokens this batch will grow to - @property - def max_tokens(self): - max_total_tokens = self.attention_mask.size(1) - return len(self.requests) * max_total_tokens - - -class CausalLM(Model): - def __init__( - self, - model_id: str, - model_class: Optional[Type[torch.nn.Module]] = None, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - default_dtype=torch.float16, - trust_remote_code: bool = False, - tokenizer_class=AutoTokenizer, - config_class=AutoConfig, - batch_class=CausalLMBatch, - ): - if speculator: - raise RuntimeError("Speculator decoding is not enabled for AutoModel") - - self.prev_bs = 0 - self.quantize = quantize - - # Create tokenizer - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - # Create model - world_size = int(os.getenv("WORLD_SIZE", "1")) - rank = int(os.getenv("RANK", "0")) - dtype = torch.bfloat16 if dtype is None else dtype - device = torch.device("hpu") - - if hq_env.is_quantization_enabled: - htorch.core.hpu_set_env() - - # Get weight files - weight_files(model_id, revision=revision, extension=".safetensors") - - if world_size > 1: - os.environ.setdefault( - "DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1" - ) - model = self.get_deepspeed_model(model_id, dtype, revision) - model = hq_env.prepare_model_for_quantization(model) - else: - # Check support for rope scaling - model_kwargs = {} - config = AutoConfig.from_pretrained(model_id) - if hasattr(config, "rope_scaling"): - model_kwargs["rope_scaling"] = self.get_rope_scaling() - - model = AutoModelForCausalLM.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - trust_remote_code=trust_remote_code, - **model_kwargs, - ) - model = hq_env.prepare_model_for_quantization(model) - model = model.eval().to(device) - - self.enable_hpu_graph = ( - os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 - ) - self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "true").lower() == "true" - - if model.config.model_type not in [ - "gpt_bigcode" - ]: # gpt_bigcode/starcoderbase-3b skips remove_kv_cache_from_output() - model = remove_kv_cache_from_output(model) - - if self.enable_hpu_graph: - from habana_frameworks.torch.hpu import wrap_in_hpu_graph - - model = wrap_in_hpu_graph(model, disable_tensor_cache=True) - else: - if LAZY_MODE == 0: - # It is said that "keep_input_mutations" is safe for inference to be done - dbg_trace("TORCH COMPILE", "Torch compiling of model") - model.model = torch.compile( - model.model, - backend="hpu_backend", - options={"keep_input_mutations": True}, - ) - - model = hq_env.setup_quantization(model) - - if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: - raise ValueError(f"Model type {model.config.model_type} is not supported!") - - if tokenizer.pad_token_id is None: - if model.config.pad_token_id is not None: - tokenizer.pad_token_id = model.config.pad_token_id - elif model.config.eos_token_id is not None: - if isinstance(model.config.eos_token_id, int): - tokenizer.pad_token_id = model.config.eos_token_id - elif isinstance(model.config.eos_token_id, list): - tokenizer.pad_token_id = model.config.eos_token_id[0] - else: - raise ValueError( - f"{type(model.config.eos_token_id)} type of eos_token_id in the model's config is not supported for tokenizer.pad_token_id" - ) - elif tokenizer.eos_token_id is not None: - tokenizer.pad_token_id = tokenizer.eos_token_id - else: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - - self.kwargs = { - "use_cache": True, - "return_dict": True, - } - - if model.config.model_type in [ - "llama", - "mistral", - "starcoder2", - "qwen2", - "falcon", - "gpt_bigcode", - ]: - if model.config.model_type not in ["falcon", "gpt_bigcode"]: - self.kwargs["attn_softmax_bf16"] = True - - if model.config.model_type not in ["gpt_bigcode"]: - self.kwargs["trim_logits"] = True - - if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "true": - self.kwargs["use_flash_attention"] = True - if os.getenv("FLASH_ATTENTION_RECOMPUTE", "true").lower() == "true": - self.kwargs["flash_attention_recompute"] = True - - self.speculate = get_speculate() - - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - ) - - # Create profiler - ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(",")] - record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true" - output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile") - self.profiling_warmup_steps = ( - int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0 - ) - self.profiling_steps = ( - int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0 - ) - self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0")) - if self.profiling_steps > 0: - self.hb_profiler = HabanaProfile( - wait=self.profiling_wait_steps, - warmup=self.profiling_warmup_steps, - active=self.profiling_steps, - output_dir=output_dir, - record_shapes=record_shapes, - ) - self.hb_profiler.start() - else: - self.hb_profiler = None - self.step = 0 - - def get_deepspeed_model( - self, model_id: str, dtype: torch.dtype, revision: Optional[str] = None - ) -> torch.nn.Module: - import deepspeed - from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu - - world_size, rank, local_rank = initialize_distributed_hpu() - model_kwargs = {"revision": revision} - - # Initialize process(es) for DeepSpeed - deepspeed.init_distributed(dist_backend="hccl") - logger.info( - "DeepSpeed is enabled. world_size {} rank {} local_rank {}".format( - world_size, rank, local_rank - ) - ) - config = AutoConfig.from_pretrained(model_id, **model_kwargs) - load_to_meta = model_on_meta(config) - - # Check support for rope scaling - if hasattr(config, "rope_scaling"): - config.rope_scaling = self.get_rope_scaling() - model_kwargs["rope_scaling"] = self.get_rope_scaling() - - if load_to_meta: - # Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load - with deepspeed.OnDevice(dtype=dtype, device="meta"): - model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype) - else: - # TODO: revisit placement on CPU when auto-injection is possible - with deepspeed.OnDevice(dtype=dtype, device="cpu"): - model = AutoModelForCausalLM.from_pretrained( - model_id, torch_dtype=dtype, **model_kwargs - ) - model = model.eval() - - # Initialize the model - ds_inference_kwargs = {"dtype": dtype} - ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} - ds_inference_kwargs["enable_cuda_graph"] = False - - if load_to_meta: - # model loaded to meta is managed differently - checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") - checkpoint_files = [ - str(f) - for f in weight_files( - model_id, revision=revision, extension=".safetensors" - ) - ] - data = {"type": "ds_model", "checkpoints": checkpoint_files, "version": 1.0} - json.dump(data, checkpoints_json) - checkpoints_json.flush() - - ds_inference_kwargs["checkpoint"] = checkpoints_json.name - model = deepspeed.init_inference(model, **ds_inference_kwargs) - - return model.module - - def get_rope_scaling(self) -> Optional[Dict]: - rope_scaling = os.getenv("ROPE_SCALING", None) - if rope_scaling is None: - return None - - rope_factor = float(os.getenv("ROPE_FACTOR", 1.0)) - return {"type": rope_scaling, "factor": float(rope_factor)} - - @property - def batch_type(self) -> Type[CausalLMBatch]: - return CausalLMBatch - - def decode(self, generated_ids: List[int]) -> str: - return self.tokenizer.decode( - generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) - - def decode_token( - self, - all_input_ids: List[int], - prefix_offset: int = 0, - read_offset: int = 0, - ) -> Tuple[str, int, int]: - if is_tokenizer_transparent(self.tokenizer): - new_text = self.tokenizer.decode( - all_input_ids[read_offset:], skip_special_tokens=False - ) - return new_text, read_offset, len(all_input_ids) - else: - return super().decode_token(all_input_ids, prefix_offset, read_offset) - - def forward( - self, - input_ids, - attention_mask, - position_ids, - token_idx, - past_key_values: Optional[List[Tuple]] = None, - bypass_hpu_graph: Optional[bool] = None, - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - # Model Forward - kwargs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "past_key_values": past_key_values, - "token_idx": token_idx, - } - - # Optimum Habana got "lazy_mode" key-val only supported for llama type of models - if self.model.config.model_type == "llama": - kwargs["lazy_mode"] = LAZY_MODE == 1 - - if self.has_position_ids: - kwargs["position_ids"] = position_ids - - if bypass_hpu_graph is not None: - kwargs["bypass_hpu_graphs"] = bypass_hpu_graph - - kwargs.update(self.kwargs) - - if past_key_values is not None and self.model.config.model_type not in [ - "gpt_bigcode" - ]: - return self.model.forward(**kwargs) - else: - outputs = self.model.forward(**kwargs) - return outputs.logits, outputs.past_key_values - - @tracer.start_as_current_span("generate_token") - def generate_token( - self, batches: List[CausalLMBatch] - ) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]: - start = time.time_ns() - # Results - generations: List[Generation] = [] - prev_batches = [] - requests_to_generate = [] - # In order to pipeline any actions on CPU we perform the operation in 3 main stages: - # Stage 1. Collect next token ids of any previously started generations - for batch_id, batch in enumerate(batches): - if batch.logits is not None: - logits = batch.logits - past = batch.past - prefill = batch.past_key_values is None - if prefill: - # no right padding for prefill - token_idx_scalar = batch.attention_mask.shape[-1] - 1 - token_idx = torch.tensor(token_idx_scalar).to(self.device) - else: - token_idx_scalar = ( - batch.attention_mask.shape[-1] - batch.right_padding - ) - token_idx = torch.tensor(token_idx_scalar).to(self.device) - - # Select next token - input_length = batch.input_length - if logits.shape[-2] > 1: - next_token_ids, next_token_logprobs, logprobs, _, _ = ( - batch.next_token_chooser( - batch.input_ids, - logits[:, input_length - 1 : input_length, :].squeeze(-2), - self.speculate, - ) - ) - else: - next_token_ids, next_token_logprobs, logprobs, _, _ = ( - batch.next_token_chooser( - batch.input_ids, logits.squeeze(-2), self.speculate - ) - ) - # Speculation is not active for causal - accepted_ids = torch.ones_like(batch.input_ids)[:, 0] - batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, - batch.top_n_tokens_tensor, - logprobs, - accepted_ids, - ) - - prev_batches.append( - { - "next_token_ids": next_token_ids, - "next_token_logprobs": next_token_logprobs, - } - ) - - for req_idx, req in enumerate(batch.requests): - requests_to_generate.append( - { - "req": req, - "prev_req_idx": req.idx, - "batch_id": batch_id, - "seed": batch.next_token_chooser.seeds[req_idx], - "do_sample": batch.next_token_chooser.do_sample[req_idx], - "top_n_tokens": batch.top_n_tokens[req_idx], - "top_token_ids": batch_top_token_ids[req_idx], - "top_token_logprobs": batch_top_token_logprobs[req_idx], - "grammar_state": batch.next_token_chooser.fsm_grammar_states[ - req.idx - ], - } - ) - - htorch.core.mark_step() - - # Add new token into input_ids - batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1)) - - # Update attention_mask as we added a new token to input_ids - batch.attention_mask.index_fill_(1, token_idx, 1) - - # Adjust lengths - batch.input_length += 1 - - # Update position_ids - if prefill: - batch.position_ids = ( - torch.index_select(batch.position_ids, 1, token_idx - 1) + 1 - ) - else: - batch.position_ids += 1 - # Update past key values - if prefill or self.model.config.model_type in ["gpt_bigcode"]: - batch.past_key_values = past - - htorch.core.mark_step() - - # Stage 2. Prepare new batch for speculative scheduling - if len(batches) > 1: - batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id) - else: - batch = batches[0] - - prefill = batch.past_key_values is None - - # Check if we need to do any bookkeeping first - if not prefill: - batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id) - - scenario = "PREFILL" if prefill else "GENERATE" - if ( - self.enable_hpu_graph - and self.limit_hpu_graph - and round_up_batch(batch.batch_size) != self.prev_bs - ): - self.model.clear_cache() - self.prev_bs = round_up_batch(batch.batch_size) - dbg_trace( - scenario, - f"bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}", - ) - assert batch.right_padding > 0, "No more room for next token!" - - # Execute batch - if prefill: - # no right padding for prefill - token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device) - batch.logits, batch.past = self.forward( - batch.input_ids, - batch.attention_mask, - batch.position_ids, - token_idx, - batch.past_key_values, - bypass_hpu_graph=( - prefill and self.limit_hpu_graph if self.enable_hpu_graph else None - ), - ) - elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]): - # Don't schedule next forward if max_new_tokens for all requests equals 1 - # - we've already generated the first and only needed token in the prefill phase - pass - else: - token_idx = torch.tensor( - batch.attention_mask.shape[-1] - batch.right_padding - ).to(self.device) - input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1) - logits = self.forward( - input_ids, - batch.attention_mask, - batch.position_ids, - token_idx, - batch.past_key_values, - bypass_hpu_graph=( - prefill and self.limit_hpu_graph if self.enable_hpu_graph else None - ), - ) - if self.model.config.model_type in ["gpt_bigcode"]: - batch.logits, batch.past = logits - else: - batch.logits = logits - - htorch.core.mark_step() - - start_decode = time.time_ns() - - # Stage 3. Finish and return previous generations - stopped = len(requests_to_generate) > 0 - for prev_batch in prev_batches: - prev_batch["next_token_logprobs"] = prev_batch[ - "next_token_logprobs" - ].tolist() - prev_batch["next_token_ids_cpu"] = prev_batch["next_token_ids"].cpu() - htorch.core.mark_step() - - for req_data in requests_to_generate: - req = req_data["req"] - i = req_data["prev_req_idx"] - prev_batch_id = req_data["batch_id"] - assert len(prev_batches) > prev_batch_id - next_token_ids_cpu = prev_batches[prev_batch_id]["next_token_ids_cpu"] - next_token_logprobs = prev_batches[prev_batch_id]["next_token_logprobs"] - - request = req.data - input_length = req.input_length - prefix_offset = req.prefix_offset - read_offset = req.read_offset - do_sample = req_data["do_sample"] - seed = req_data["seed"] - stopping_criteria = req.stopping_criteria - all_input_ids = req.all_input_ids - next_token_id = next_token_ids_cpu[i] - next_token_logprob = next_token_logprobs[i] - top_n_tokens = req_data["top_n_tokens"] - top_token_ids = req_data["top_token_ids"] - top_token_logprobs = req_data["top_token_logprobs"] - grammar_state = req_data["grammar_state"] - - # Append next token to all tokens - all_input_ids[input_length] = next_token_id - new_input_length = input_length + 1 - - # Generated token - if ( - is_tokenizer_transparent(self.tokenizer) - and len(stopping_criteria.stop_sequence_criterias) == 0 - ): - next_token_text = "" - else: - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids[0:new_input_length, 0], prefix_offset, read_offset - ) - - # Evaluate stopping criteria - stop, reason = stopping_criteria( - next_token_id, - next_token_text, - ) - - if not stop: - stopped = False - - # Shard generations - # All generations will be appended in the rust sharded client - if i % self.world_size == self.rank: - if stop: - # Decode generated tokens - if is_tokenizer_transparent(self.tokenizer): - output_text = None - else: - output_text = self.decode( - all_input_ids[ - new_input_length - - stopping_criteria.current_tokens : new_input_length, - 0, - ] - ) - generated_text = GeneratedText( - output_text, - stopping_criteria.current_tokens, - reason, - seed if do_sample else None, - ) - else: - generated_text = None - - # Prefill - if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: - # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + next_token_logprobs - prefill_token_ids = all_input_ids[0 : new_input_length - 1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = Tokens( - prefill_token_ids, - prefill_logprobs, - prefill_texts, - is_special=[], - ) - else: - prefill_tokens = None - - if top_n_tokens > 0: - all_top_tokens = [] - for top_token_ids, top_token_logprobs in zip( - top_token_ids, top_token_logprobs - ): - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - special_toptokens = [ - token_id in self.all_special_ids - for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, - ) - all_top_tokens.append(top_tokens) - top_tokens = all_top_tokens - else: - top_tokens = None - - generation = Generation( - request.id, - prefill_tokens, - Tokens( - [next_token_id], - [next_token_logprob], - [next_token_text], - [next_token_id in self.all_special_ids], - ), - generated_text, - top_tokens, - ) - - generations.append(generation) - - batch.next_token_chooser = ( - batch.next_token_chooser.advance_grammar_single_with_past_state( - req.idx, next_token_id, grammar_state - ) - ) - - req.all_input_ids = all_input_ids - req.input_length = new_input_length - req.prefix_offset = prefix_offset - req.read_offset = read_offset - - htorch.core.mark_step() - self.step = self.step + 1 - if self.hb_profiler is not None: - if ( - self.step - > self.profiling_wait_steps - + self.profiling_warmup_steps - + self.profiling_steps - ): - self.hb_profiler.stop() - else: - self.hb_profiler.step() - - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, batch if not stopped else None, (forward_ns, decode_ns) - - def generate_warmup_batch(self, request, seq_len, batch_size): - batch = copy.deepcopy(request.batch) - for req in batch.requests: - req.truncate = seq_len - - for i in range(len(batch.requests) - batch_size): - batch.requests.pop() - - return self.batch_type.from_pb(batch, self.tokenizer, self.dtype, self.device) - - def warmup( - self, request: generate_pb2.WarmupRequest - ) -> Tuple[Optional[int], Optional[int], Optional[int]]: - assert ( - MAX_BATCH_SIZE is not None - ), "MAX_BATCH_SIZE is not set, it should be set in the launcher" - MAX_BATCH_TOTAL_TOKENS = MAX_BATCH_SIZE * request.max_total_tokens - logger.info(f"MAX_BATCH_SIZE: {MAX_BATCH_SIZE}") - logger.info(f"MAX_BATCH_TOTAL_TOKENS: {MAX_BATCH_TOTAL_TOKENS}") - MAX_TOTAL_TOKENS = request.max_total_tokens - - batch = self.batch_type.from_pb( - request.batch, self.tokenizer, self.dtype, self.device - ) - max_prefill_batch_size = batch.input_ids.shape[0] - try: - # max prefill batch size warmup - _, prefill_batch, _ = self.generate_token([batch]) - except Exception: - raise RuntimeError( - f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " - f"You need to decrease `--max-batch-prefill-tokens`" - ) - - del prefill_batch - - # Warmup prefill batch_size - max_input_tokens = request.max_input_tokens - max_exp = math.ceil(math.log(max_prefill_batch_size, BATCH_SIZE_EXPONENT_BASE)) - prefill_batch_size_list = [ - BATCH_SIZE_EXPONENT_BASE**exp - for exp in range( - 0, - max_exp + 1, - ) - ] - prefill_seqlen_list = iterate_powers_of_base( - max_input_tokens, PAD_SEQUENCE_TO_MULTIPLE_OF, SEQ_LEN_EXPONENT_BASE - ) - prefill_seqlen_list.append(max_input_tokens) - prefill_batch_size_list.sort(reverse=True) - prefill_seqlen_list.sort(reverse=True) - logger.info( - f"Prefill batch size list:{prefill_batch_size_list}\n" - f"Prefill sequence length list:{prefill_seqlen_list}\n" - ) - try: - for batch_size in prefill_batch_size_list: - for seq_len in prefill_seqlen_list: - logger.info( - f"Prefill warmup for `batch_size={batch_size}` and `sequence_length={seq_len}`, this may take a while..." - ) - batch = self.generate_warmup_batch(request, seq_len - 1, batch_size) - _, prefill_batch, _ = self.generate_token([batch]) - except Exception: - prefill_batch_size_list.sort() - prefill_seqlen_list.sort() - raise RuntimeError( - f"Not enough memory to run following prefill batch_size." - f"Prefill batch size list:{prefill_batch_size_list}" - f"Prefill sequence length list:{prefill_seqlen_list}" - f"You need to decrease `--max-batch-prefill-tokens`" - ) - prefill_seqlen_list.sort() - prefill_batch_size_list.sort() - mem_stats = get_hpu_memory_stats(self.device) - logger.info(f"Prefill warmup successful.\n" f"Memory stats: {mem_stats} ") - - max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS) - max_exp = math.ceil(math.log(max_decode_batch_size, BATCH_SIZE_EXPONENT_BASE)) - decode_batch_size_list = [ - BATCH_SIZE_EXPONENT_BASE**exp for exp in range(0, max_exp + 1) - ] - decode_batch_size_list.sort(reverse=True) - logger.info(f"Decode batch size list:{decode_batch_size_list}\n") - - try: - for batch_size in decode_batch_size_list: - logger.info( - f"Decode warmup for `batch_size={batch_size}`, this may take a while..." - ) - batches = [] - iters = math.floor(batch_size / max_prefill_batch_size) - for i in range(iters): - batch = self.generate_warmup_batch( - request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size - ) - _, prefill_batch, _ = self.generate_token([batch]) - batches.append(prefill_batch) - - if batch_size % max_prefill_batch_size != 0: - batch = self.generate_warmup_batch( - request, - PAD_SEQUENCE_TO_MULTIPLE_OF - 1, - batch_size % max_prefill_batch_size, - ) - _, prefill_batch, _ = self.generate_token([batch]) - batches.append(prefill_batch) - - _, decode_batch, _ = self.generate_token(batches) - _, decode_batch, _ = self.generate_token([decode_batch]) - del decode_batch - batches.clear() - - except Exception: - raise RuntimeError( - f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})." - f"You need to decrease `--max-batch-total-tokens`" - ) - - decode_batch_size_list.sort() - max_supported_total_tokens = MAX_TOTAL_TOKENS * decode_batch_size_list[-1] - mem_stats = get_hpu_memory_stats(self.device) - logger.info(f"Decode warmup successful.\n" f"Memory stats: {mem_stats} ") - - max_input_tokens = max_input_tokens - max_total_tokens = MAX_TOTAL_TOKENS - - return max_supported_total_tokens, max_input_tokens, max_total_tokens diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 3bcc689d..7a32a85c 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -28,6 +28,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -51,6 +52,8 @@ from habana_frameworks.torch.hpex.kernels import ( apply_rotary_pos_emb, ) +import habana_frameworks.torch as htorch + class CohereRotary(PositionRotaryEmbedding): def forward( @@ -413,6 +416,10 @@ class FlashCohereModel(torch.nn.Module): seqlen: torch.Tensor, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward @@ -420,7 +427,9 @@ class FlashCohereModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None - + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -433,6 +442,8 @@ class FlashCohereModel(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 15c243c9..42af7798 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -26,6 +26,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -44,6 +45,7 @@ from text_generation_server.layers.layernorm import ( FastLayerNorm, ) from vllm_hpu_extension.ops import DynamicFusedMOE +import habana_frameworks.torch as htorch class DbrxAttentionConfig(PretrainedConfig): @@ -677,13 +679,19 @@ class DbrxModel(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids) - residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -696,6 +704,8 @@ class DbrxModel(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 9d61c694..8e9002a2 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -33,6 +33,7 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention, + set_block_mapping, HPUPagedAttentionMetadata, ) from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales @@ -40,6 +41,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.utils.weights import Weights +import habana_frameworks.torch as htorch class DeepseekV2Config(PretrainedConfig): @@ -568,6 +570,10 @@ class DeepseekV2Model(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward @@ -575,6 +581,9 @@ class DeepseekV2Model(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -587,6 +596,8 @@ class DeepseekV2Model(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py index 1a7ce5cf..8e058093 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py @@ -28,11 +28,13 @@ from text_generation_server.layers import ( TensorParallelEmbedding, TensorParallelRowLinear, get_linear, + Fp8Linear, ) from text_generation_server.layers.attention import ( Seqlen, attention, - paged_attention, + paged_attention_mla, + set_block_mapping, HPUPagedAttentionMetadata, ) from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales @@ -40,6 +42,19 @@ from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.utils.weights import Weights +import habana_frameworks.torch as htorch + + +def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor: + if isinstance(layer, Fp8Linear): + eye = torch.eye( + layer.qweight.shape[-1], dtype=torch.bfloat16, device=layer.qweight.device + ) + dequant_weights = layer(eye) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight class DeepseekV3Config(PretrainedConfig): @@ -249,6 +264,44 @@ class DeepseekV3Attention(torch.nn.Module): 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device ).repeat_interleave(self.num_groups) + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj.linear).T + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.value_head_size, + ) + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.value_head_size], dim=-1 + ) + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) + + def _q_proj_and_k_up_proj(self, x): + q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj + q_nope, q_pe = ( + q_proj(x) + .view(-1, self.num_heads, self.head_size) + .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + ) + + # Convert from (B, N, P) to (N, B, P) + q_nope = q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + ql_nope = torch.bmm(q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + return ql_nope.transpose(0, 1), q_pe + + def _v_up_proj_and_o_proj(self, x): + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # Convert from (N, B, V) to (B, N * V) + x = x.transpose(0, 1).reshape(-1, self.num_heads * self.value_head_size) + return self.o_proj(x) + def forward( self, hidden_states: torch.Tensor, @@ -261,14 +314,9 @@ class DeepseekV3Attention(torch.nn.Module): hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): if self.q_lora_rank is None: - query = self.q_proj(hidden_states) + hidden_states_or_q_c = hidden_states else: - query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0]) - query = query.view(-1, self.num_heads, self.head_size) - - _, query_pe = torch.split( - query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 - ) + hidden_states_or_q_c = self.q_a_layernorm(self.q_a_proj(hidden_states))[0] compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, key_pe = torch.split( @@ -276,13 +324,18 @@ class DeepseekV3Attention(torch.nn.Module): ) key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim) - kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view( - -1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size - ) + kv_c_normed = self.kv_a_layernorm(compressed_kv.contiguous())[0] - key_nope, value = torch.split( - kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1 - ) + # Prefill + if cu_seqlen_prefill is not None: + q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj + query = q_proj(hidden_states_or_q_c) + query = query.view(-1, self.num_heads, self.head_size) + query_nope, query_pe = torch.split( + query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + else: + query_nope, query_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c) batch_size, heads, head_dim = query_pe.shape query_pe = ( @@ -297,33 +350,47 @@ class DeepseekV3Attention(torch.nn.Module): .reshape(batch_size, heads, head_dim) ) self.rotary_emb(query_pe, key_pe, cos, sin) + latent_vec_k = torch.concat( + (kv_c_normed, key_pe.view(-1, self.qk_rope_head_dim)), dim=-1 + ) + latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank) - query[..., self.qk_nope_head_dim :] = query_pe - key = torch.empty_like(query) - key[..., : self.qk_nope_head_dim] = key_nope - key[..., self.qk_nope_head_dim :] = key_pe - - # We need to pad the heads because Flash Attention does not support - # qk and v with different head sizes. - query = torch.nn.functional.pad( - query, (0, self.head_pad_size - self.head_size), value=0 - ) - key = torch.nn.functional.pad( - key, (0, self.head_pad_size - self.head_size), value=0 - ) - value = torch.nn.functional.pad( - value, (0, self.head_pad_size - self.value_head_size), value=0 - ) + latent_vec_k = latent_vec_k.unflatten(0, (slots.size(0), -1)) kv_cache.store( - key=key, - value=value, + key=latent_vec_k, + value=None, slots=slots, kv_scales=self.kv_scales, ) - # Prefill if cu_seqlen_prefill is not None: + kv = self.kv_b_proj(kv_c_normed).view( + -1, + self.num_key_value_heads, + self.qk_nope_head_dim + self.value_head_size, + ) + + key_nope, value = torch.split( + kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1 + ) + query[..., self.qk_nope_head_dim :] = query_pe + key = torch.empty_like(query) + key[..., : self.qk_nope_head_dim] = key_nope + key[..., self.qk_nope_head_dim :] = key_pe + + # We need to pad the heads because Flash Attention does not support + # qk and v with different head sizes. + query = torch.nn.functional.pad( + query, (0, self.head_pad_size - self.head_size), value=0 + ) + key = torch.nn.functional.pad( + key, (0, self.head_pad_size - self.head_size), value=0 + ) + value = torch.nn.functional.pad( + value, (0, self.head_pad_size - self.value_head_size), value=0 + ) + # flash attention attn_output = attention( query=query, @@ -334,9 +401,15 @@ class DeepseekV3Attention(torch.nn.Module): seqlen=seqlen, softmax_scale=self.softmax_scale, ) - # Decode + attn_output = attn_output[..., : self.value_head_size] + + return self.o_proj( + attn_output.reshape(-1, self.num_heads * self.value_head_size) + ) else: - attn_output = paged_attention( + # Decode + query = torch.cat([query_nope, query_pe], dim=-1) + attn_output = paged_attention_mla( query, kv_cache, self.kv_head_mapping, @@ -344,14 +417,10 @@ class DeepseekV3Attention(torch.nn.Module): seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, + kv_lora_rank=self.kv_lora_rank, ) - - # Remove padding. - attn_output = attn_output[..., : self.value_head_size] - - return self.o_proj( - attn_output.reshape(-1, self.num_heads * self.value_head_size) - ) + attn_output = self._v_up_proj_and_o_proj(attn_output) + return attn_output class DeepseekV3MLP(nn.Module): @@ -577,6 +646,10 @@ class DeepseekV3Model(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward @@ -584,6 +657,9 @@ class DeepseekV3Model(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -596,6 +672,8 @@ class DeepseekV3Model(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 79f21b0f..a1a20999 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -28,6 +28,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -46,6 +47,7 @@ from text_generation_server.layers.layernorm import ( FastRMSNorm, ) from text_generation_server.utils.weights import UnquantizedWeight +import habana_frameworks.torch as htorch class Gemma2Config(PretrainedConfig): @@ -465,6 +467,10 @@ class FlashGemma2Model(torch.nn.Module): adapter_data: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds # Get rotary cos and sin for this forward @@ -472,6 +478,10 @@ class FlashGemma2Model(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() + for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -485,6 +495,8 @@ class FlashGemma2Model(torch.nn.Module): adapter_data, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 609f03ac..7a2ec22e 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -28,6 +28,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -44,6 +45,7 @@ from text_generation_server.layers.layernorm import ( FastRMSNorm, ) from text_generation_server.utils.weights import UnquantizedWeight +import habana_frameworks.torch as htorch class GemmaConfig(PretrainedConfig): @@ -387,6 +389,10 @@ class FlashGemmaModel(torch.nn.Module): adapter_data: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds # Get rotary cos and sin for this forward @@ -394,6 +400,9 @@ class FlashGemmaModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -406,6 +415,8 @@ class FlashGemmaModel(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 10024a6d..a6b53656 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -27,6 +27,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -38,6 +39,7 @@ from text_generation_server.layers import ( get_linear, ) from text_generation_server.layers.attention.kv_cache import get_kv_scales +import habana_frameworks.torch as htorch def load_qkv(config, prefix: str, weights, head_size, num_heads): @@ -382,9 +384,17 @@ class FlashGPT2Model(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() + for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -395,6 +405,8 @@ class FlashGPT2Model(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states = self.norm(hidden_states) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index 41eeab78..679380a1 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -28,6 +28,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -48,6 +49,7 @@ from habana_frameworks.torch.hpex.kernels import ( RotaryPosEmbeddingMode, apply_rotary_pos_emb, ) +import habana_frameworks.torch as htorch def load_attention(config, prefix: str, weights): @@ -323,6 +325,10 @@ class FlashGPTJModel(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.wte(input_ids) # Get rotary cos and sin for this forward @@ -330,6 +336,9 @@ class FlashGPTJModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -342,6 +351,8 @@ class FlashGPTJModel(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.ln_f(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py new file mode 100644 index 00000000..3b30f2e0 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -0,0 +1,1439 @@ +# coding=utf-8 +# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import torch +import math +import torch.utils.checkpoint +from torch import nn +import torch.nn.functional as F + +import habana_frameworks.torch as htorch +from transformers.cache_utils import Cache +from transformers.activations import ACT2FN +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_outputs import ( + BaseModelOutput, +) + +from transformers.modeling_attn_mask_utils import AttentionMaskConverter + +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + SpeculativeHead, + FastLinear, +) +from text_generation_server.layers.layernorm import FastRMSNorm +from text_generation_server.layers.attention import ( + KVCache, + paged_attention, + set_block_mapping, + Seqlen, + HPUPagedAttentionMetadata, +) +from text_generation_server.models.custom_modeling.flash_llama_modeling import ( + FlashLlamaAttention, +) + + +def reshape_for_broadcast(freqs: torch.Tensor, target): + ndim = len(target) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(target)] + return freqs.view(*shape) + + +def apply_rotary_emb( + query: torch.Tensor, + key: torch.Tensor, + freqs_ci: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + query_shape = query.shape + key_shape = key.shape + cos_emb, sin_emb = freqs_ci.split(1, dim=-1) + + if len(query.shape) == 3: + query = query.unsqueeze(0) + key = key.unsqueeze(0) + + query_reshaped = query.float().reshape(*query.shape[:-1], -1, 2) + key_reshaped = key.float().reshape(*key.shape[:-1], -1, 2) + q_shape = query_reshaped.shape[:-1] + cos_emb = reshape_for_broadcast(cos_emb, q_shape) + sin_emb = reshape_for_broadcast(sin_emb, q_shape) + x_q, y_q = query_reshaped.unbind(-1) + x_k, y_k = key_reshaped.unbind(-1) + + x_q_rot = x_q * cos_emb - y_q * sin_emb + y_q_rot = x_q * sin_emb + y_q * cos_emb + x_k_rot = x_k * cos_emb - y_k * sin_emb + y_k_rot = x_k * sin_emb + y_k * cos_emb + + query_out = torch.stack([x_q_rot, y_q_rot], dim=-1).flatten(-2) + key_out = torch.stack([x_k_rot, y_k_rot], dim=-1).flatten(-2) + query_out = query_out.view(*query_shape) + key_out = key_out.view(*key_shape) + return query_out.type_as(query), key_out.type_as(key) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Llama4TextExperts(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.process_group = weights.process_group + self.num_experts = config.num_local_experts + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + self.hidden_size = config.hidden_size + self.expert_dim = self.intermediate_size + self.gate_up_proj = nn.Parameter( + weights.get_packed_sharded(f"{prefix}.gate_up_proj", dim=-1, block_sizes=2), + requires_grad=False, + ) + self.down_proj = nn.Parameter( + weights.get_sharded(f"{prefix}.down_proj", dim=1), requires_grad=False + ) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + This should really not be run on a single machine, as we are reaching compute bound: + - the inputs are expected to be "sorted" per expert already. + - the weights are viewed with another dim, to match num_expert, 1, shape * num_tokens, shape + + Args: + hidden_states (torch.Tensor): (batch_size * token_num, hidden_size) + selected_experts (torch.Tensor): (batch_size * token_num, top_k) + routing_weights (torch.Tensor): (batch_size * token_num, top_k) + Returns: + torch.Tensor + """ + gate_up_proj = self.gate_up_proj.view(self.num_experts, -1, 2 * self.expert_dim) + down_proj = self.down_proj.view(self.num_experts, self.expert_dim, -1) + hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) + gate_up = torch.bmm(hidden_states, gate_up_proj) + gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors + next_states = torch.bmm((up * self.act_fn(gate)), down_proj) + next_states = next_states.view(-1, self.hidden_size) + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(next_states, group=self.process_group) + + return next_states + + +# Phi3MLP +class Llama4TextMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config=config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config=config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + gate_up_states = self.gate_up_proj(x) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1]) + + +class Llama4TextL2Norm(torch.nn.Module): + def __init__(self, eps: float = 1e-6): + super().__init__() + self.eps = eps + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + return self._norm(x.float()).type_as(x) + + def extra_repr(self): + return f"eps={self.eps}" + + +class Llama4TextMoe(nn.Module): + def __init__( + self, + prefix, + config, + weights, + ): + super().__init__() + self.top_k = config.num_experts_per_tok + self.hidden_dim = config.hidden_size + self.num_experts = config.num_local_experts + self.experts = Llama4TextExperts( + config=config, prefix=f"{prefix}.experts", weights=weights + ) + self.router = FastLinear.load( + config=config, prefix=f"{prefix}.router", weights=weights, bias=False + ) + self.shared_expert = Llama4TextMLP( + config=config, prefix=f"{prefix}.shared_expert", weights=weights + ) + self.process_group = weights.process_group + + def forward(self, hidden_states, adapter_data): + seq_len, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_dim) + tokens_per_expert = hidden_states.shape[0] + router_logits = self.router(hidden_states) + + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) + router_scores = ( + torch.full_like(router_logits, float("-inf")) + .scatter_(1, router_indices, router_top_value) + .transpose(0, 1) + ) + # We do this to make sure we have -inf for non topK tokens before going through the ! + # Here we are just creating a tensor to index each and every single one of the hidden states. Let s maybe register a buffer for this! + router_indices = ( + torch.arange(tokens_per_expert, device=hidden_states.device) + .view(1, -1) + .expand(router_scores.size(0), -1) + ) + router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) + + router_indices = router_indices.reshape(-1, 1).expand(-1, self.hidden_dim) + routed_in = torch.gather( + input=hidden_states, + dim=0, + index=router_indices, + ).to(hidden_states.device) + + # we gather inputs corresponding to each expert based on the router indices + routed_in = routed_in * router_scores.reshape(-1, 1) + routed_out = self.experts(routed_in) + out = self.shared_expert(hidden_states) + + # now that we finished expert computation -> we scatter add because we gathered previously + # we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound + # this scales a lot better if you do EP! + out.scatter_add_( + dim=0, index=router_indices, src=routed_out.view(-1, self.hidden_dim) + ) + return out + + +class Llama4TextRotaryEmbedding(nn.Module): + def __init__(self, config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + self.rope_type = "llama3" if config.rope_scaling is not None else "default" + + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def forward(self, x, position_ids): + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + position_ids_expanded = position_ids[:, None, :].float() + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type != "mps" + else "cpu" + ) + inv_freq_expanded = inv_freq_expanded.to(device_type) + position_ids_expanded = position_ids_expanded.to(device_type) + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) + freqs_cis = ( + torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1) + * self.attention_scaling + ) + return freqs_cis.to(dtype=x.dtype, device=x.device) + + +class Llama4TextAttention(FlashLlamaAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, prefix, config, weights, layer_idx): + super().__init__(layer_idx, prefix, config, weights) + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.attn_scale = config.attn_scale + self.floor_scale = config.floor_scale + self.attn_temperature_tuning = config.attn_temperature_tuning + self.attention_dropout = config.attention_dropout + self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers + + if self.config.use_qk_norm and self.use_rope: + self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + freqs_ci, + cu_seqlen_prefill, + kv_cache: KVCache, + slots, + seqlen, + adapter_data, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bs = seqlen.input_lengths.shape[0] + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + qkv = self.query_key_value(hidden_states, adapter_data) + query_states, key_states, value_states = qkv.split( + [ + self.head_dim * self.num_heads, + self.head_dim * self.num_key_value_heads, + self.head_dim * self.num_key_value_heads, + ], + dim=-1, + ) + + query_states = query_states.view(hidden_shape) + key_states = key_states.view(hidden_shape) + value_states = value_states.view(hidden_shape) + + if self.use_rope: # the 16E model skips rope for long context on certain layers + query_states, key_states = apply_rotary_emb( + query_states, key_states, freqs_ci + ) + + if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm + query_states = self.qk_norm(query_states) + key_states = self.qk_norm(key_states) + + kv_cache.store( + key=key_states, + value=value_states, + slots=slots, + kv_scales=self.kv_scales, + ) + + # Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers + if self.attn_temperature_tuning and not self.use_rope: + attn_scales = ( + torch.log( + torch.floor((position_ids.float() + 1.0) / self.floor_scale) + 1.0 + ) + * self.attn_scale + + 1.0 + ) + attn_scales = attn_scales.view(*input_shape, 1, 1) + query_states = (query_states * attn_scales).to(query_states.dtype) + + # Prefill + if cu_seqlen_prefill is not None: + # sdpa + query = query_states.view(bs, -1, self.num_heads, self.head_dim).transpose( + 1, 2 + ) + key = key_states.view( + bs, -1, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value = value_states.view( + bs, -1, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + key = repeat_kv(key, self.num_key_value_groups) + value = repeat_kv(value, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None and causal_mask.ndim == 4: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + is_causal = query.shape[2] > 1 and causal_mask is None + # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions + # Reference: https://github.com/pytorch/pytorch/issues/112577. + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=0, + scale=self.scaling, + is_causal=is_causal, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + # Decode + else: + attn_output = paged_attention( + query_states, + kv_cache, + self.kv_head_mapping, + self.softmax_scale, + seqlen, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output, adapter_data) + return attn_output + + +class Llama4TextDecoderLayer(nn.Module): + def __init__(self, prefix, config, weights, layer_idx): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Llama4TextAttention( + f"{prefix}.self_attn", config, weights, layer_idx + ) + self.use_chunked_attention = int((layer_idx + 1) % 4 != 0) # <=> use rope + self.is_moe_layer = layer_idx in config.moe_layers + if self.is_moe_layer: # the 128E model interleaves dense / sparse + self.feed_forward = Llama4TextMoe(f"{prefix}.feed_forward", config, weights) + else: + self.feed_forward = Llama4TextMLP(f"{prefix}.feed_forward", config, weights) + + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states, + freqs_ci, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + adapter_data, + attention_mask: Optional[torch.Tensor] = None, + chunk_causal_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + residual = hidden_states + hidden_states, _ = self.input_layernorm(hidden_states) + + # use local attention mask for ROPE layers + if self.use_chunked_attention and chunk_causal_mask is not None: + attention_mask = chunk_causal_mask + + attention_states = self.self_attn( + hidden_states, + freqs_ci, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + adapter_data, + attention_mask=attention_mask, + position_ids=position_ids, + hpu_attention_meta=hpu_attention_meta, + ) + + hidden_states = residual + attention_states + + # Fully Connected + residual = hidden_states + + hidden_states, _ = self.post_attention_layernorm(hidden_states) + hidden_states = self.feed_forward(hidden_states, adapter_data) + hidden_states = residual + hidden_states.view(residual.shape) + return hidden_states + + +class Llama4TextModel(nn.Module): + + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens", weights=weights + ) + self.layers = nn.ModuleList( + [ + Llama4TextDecoderLayer( + prefix=f"{prefix}.layers.{layer_idx}", + config=config, + weights=weights, + layer_idx=layer_idx, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + # self.norm = Llama4TextRMSNorm(prefix=f"{prefix}.norm", config=config, weights=weights) + self.norm = FastRMSNorm.load( + prefix=f"{prefix}.norm", + weights=weights, + eps=config.rms_norm_eps, + ) + + self.rotary_emb = Llama4TextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + adapter_data, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) + + hidden_states = inputs_embeds + bs = seqlen.input_lengths.shape[0] + seq_len = inputs_embeds.shape[0] / bs + cache_position = torch.arange(0, seq_len, device=inputs_embeds.device) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask, chunk_causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds.view(bs, int(seq_len), -1), + cache_position, + None, + output_attentions=False, + use_cache=False, + ) + + freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1)) + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() + + for i, layer in enumerate(self.layers): + hidden_states = layer( + hidden_states, + freqs_ci, + cu_seqlen_prefill, + kv_cache[i], + slots, + seqlen, + adapter_data, + attention_mask=causal_mask, + chunk_causal_mask=chunk_causal_mask, + position_ids=position_ids, + hpu_attention_meta=hpu_attention_meta, + ) + if lazy_mode: + htorch.core.mark_step() + + hidden_states, _ = self.norm(hidden_states) + + return hidden_states + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + chunked_attention_mask=None, + use_cache=True, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return ( + attention_mask, + attention_mask, + ) # flash does not support chunked attn TODO support flash + return None, None + + if self.config._attn_implementation not in ["sdpa", "flex_attention", "eager"]: + return None, None + + sequence_length = input_tensor.shape[1] + attention_chunk_size = self.config.attention_chunk_size + + first_cache_position = cache_position[0] + + if past_key_values is not None: + full_cache_length = past_key_values.get_max_cache_shape() or sequence_length + else: + full_cache_length = ( + attention_mask.shape[-1] + if attention_mask is not None + else sequence_length + ) + + cond1 = first_cache_position >= attention_chunk_size + cond2 = (first_cache_position < attention_chunk_size) & ( + first_cache_position + sequence_length > attention_chunk_size + ) + key_length = ( + torch.where( + cond1, + attention_chunk_size + sequence_length - 1, + torch.where( + cond2, first_cache_position + sequence_length, attention_chunk_size + ), + ) + if use_cache + else full_cache_length + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + dtype, device = input_tensor.dtype, input_tensor.device + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=max(full_cache_length, attention_chunk_size), + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + device=device, + ) + if full_cache_length > self.config.attention_chunk_size: + start_idx = max(first_cache_position - attention_chunk_size + 1, 0) + end_idx = start_idx + key_length + chunked_attention_mask = self.create_chunked_attention_mask( + self.config.attention_chunk_size, + start=start_idx, # same offset as with flex + end=end_idx, + device=device, + ) + + local_attention_mask = attention_mask[ + :, start_idx:end_idx + ] # offset here as well + # It may be smaller than attention_chunk_size -> pad it + requires_padding = local_attention_mask.shape[-1] < attention_chunk_size + if requires_padding: + local_attention_mask = nn.functional.pad( + local_attention_mask, + (0, attention_chunk_size - local_attention_mask.shape[-1]), + ) + # Depending on the padding, take the query tokens from the end or the cache_position + if not requires_padding: + chunked_attention_mask = chunked_attention_mask[ + None, None, -sequence_length:, : + ] + else: + chunked_attention_mask = chunked_attention_mask[ + None, None, cache_position, : + ] + + chunked_attention_mask = chunked_attention_mask.expand( + input_tensor.shape[0], -1, -1, -1 + ) + chunked_attention_mask = ( + chunked_attention_mask * local_attention_mask[:, None, None, :] + ) + if self.config._attn_implementation == "eager": + min_dtype = torch.finfo(dtype).min + chunked_attention_mask = torch.where( + chunked_attention_mask == 0, min_dtype, 0.0 + ).to(dtype) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and attention_mask.ndim == 4 + and not output_attentions # Only unmask for 4d masks + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended( + causal_mask, min_dtype + ) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and chunked_attention_mask is not None + ): + chunked_attention_mask = chunked_attention_mask.bool() + causal_mask = causal_mask.bool() + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=first_cache_position, + is_training=self.training, + ): + causal_mask = None + return causal_mask, chunked_attention_mask + + def create_chunked_attention_mask( + self, attention_chunk_size: int, start: int, end: int, device: torch.device + ) -> torch.Tensor: + """ + Generate the following: + + 'What' : 0 ■ ⬚ ⬚ ⬚ ⬚ ⬚ | + '▁is' : 1 ■ ■ ⬚ ⬚ ⬚ ⬚ | + '▁ch' : 2 ■ ■ ■ ⬚ ⬚ ⬚ | + 'unked' : 3 ⬚ ⬚ ⬚ ■ ⬚ ⬚ | + '▁attention': 4 ⬚ ⬚ ⬚ ■ ■ ⬚ | + '?' : 5 ⬚ ⬚ ⬚ ■ ■ ■ | + + If the chunk size is 3. + This can just be applied over the already created attention mask + """ + arange_vector = torch.arange(start, end, device=device) + block_pos = torch.abs( + arange_vector.unsqueeze(0) // attention_chunk_size + - arange_vector.unsqueeze(1) // attention_chunk_size + ) + token_pos = arange_vector.unsqueeze(0) - arange_vector.unsqueeze(1) + mask = (block_pos == 0) & (token_pos <= 0) + return mask.to(device) + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange( + target_length, device=device + ) > cache_position.to(device).reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = ( + causal_mask.clone() + ) # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[ + :, None, None, : + ].to(device) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[ + :, :, :, :mask_length + ].masked_fill(padding_mask, min_dtype) + + return causal_mask + + +class Llama4ForCausalLM(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.model = Llama4TextModel( + prefix=f"{prefix}.model", config=config, weights=weights + ) + self.vocab_size = config.vocab_size + self.lm_head = SpeculativeHead.load( + config, + f"{prefix}.lm_head", + weights, + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + adapter_data: Optional[torch.Tensor] = None, + lm_head_indices: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + hidden_states = self.model( + inputs_embeds, + position_ids, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + adapter_data=adapter_data, + hpu_attention_meta=hpu_attention_meta, + attention_mask=attention_mask, + ) + + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits + + +class Llama4VisionMLP2(torch.nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.fc1 = TensorParallelColumnLinear.load( + config=config, prefix=f"{prefix}.fc1", weights=weights, bias=False + ) + self.fc2 = TensorParallelRowLinear.load( + config=config, prefix=f"{prefix}.fc2", weights=weights, bias=False + ) + self.activation_fn = nn.GELU() # ACT2FN[config.hidden_act] + self.dropout = config.projector_dropout + + def forward(self, hidden_states): + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + return self.activation_fn( + hidden_states + ) # TODO: check if we need to apply activation again + + +class Llama4MultiModalProjector(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.linear_1 = FastLinear.load( + config=config, prefix=f"{prefix}.linear_1", weights=weights, bias=False + ) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + return hidden_states + + +def pixel_shuffle(input_tensor, shuffle_ratio): + # input_tensor: [batch_size, num_patches, channels] + batch_size, num_patches, channels = input_tensor.shape + patch_size = int(math.sqrt(num_patches)) + + input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1) + batch_size, height, width, channels = input_tensor.size() + reshaped_tensor = input_tensor.view( + batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio) + ) + reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() + reshaped_tensor = reshaped_tensor.view( + batch_size, + int(height * shuffle_ratio), + int(width * shuffle_ratio), + int(channels / (shuffle_ratio**2)), + ) + reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() + + output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1]) + return output_tensor + + +class Llama4VisionPixelShuffleMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.pixel_shuffle_ratio = config.pixel_shuffle_ratio + self.inner_dim = int( + config.projector_input_dim // (self.pixel_shuffle_ratio**2) + ) + self.output_dim = config.projector_output_dim + self.mlp = Llama4VisionMLP2( + prefix=f"{prefix}.mlp", config=config, weights=weights + ) + + def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor: + encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio) + return self.mlp(encoded_patches) + + +# TODO there is a different RoPE for vision encoder, defined as below +def vision_reshape_for_broadcast(freqs_ci: torch.Tensor, query: torch.Tensor): + ndim = query.ndim + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(query.shape)] + return freqs_ci.view(*shape) + + +class Llama4VisionAttention(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads // weights.process_group.size() + self.progress_group = weights.process_group + + self.head_dim = config.hidden_size // config.num_attention_heads + self.num_key_value_groups = 1 + self.attention_dropout = config.attention_dropout + self.qkv_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=True, + ) + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + freqs_ci: torch.Tensor, # Now takes (cos_theta, sin_theta) instead of complex + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + qkv = self.qkv_proj(hidden_states) + + query_states, key_states, value_states = qkv.split( + [ + self.head_dim * self.num_heads, + self.head_dim * self.num_heads, + self.head_dim * self.num_heads, + ], + dim=2, + ) + query_states = query_states.view(hidden_shape) + key_states = key_states.view(hidden_shape) + value_states = value_states.view(hidden_shape) + + query_states, key_states = apply_rotary_emb( + query_states, key_states, freqs_ci=freqs_ci + ) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=False, + dropout_p=0, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output + + +class Llama4VisionMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.activation_fn = nn.GELU() # ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Llama4VisionEncoderLayer(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Llama4VisionAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.mlp = Llama4VisionMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights + ) + + self.input_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=1e-05 + ) + self.post_attention_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=1e-05 + ) + + def forward( + self, + hidden_state: torch.Tensor, + freqs_ci: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + # Self Attention + residual = hidden_state + + hidden_state = self.input_layernorm(hidden_state) + + hidden_state = self.self_attn( + hidden_state, + freqs_ci=freqs_ci, + attention_mask=attention_mask, + ) + + hidden_state = residual + hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + hidden_state = residual + hidden_state + outputs = (hidden_state,) + return outputs + + +class Llama4VisionEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`Llama4VisionEncoderLayer`]. + + Args: + config: Llama4VisionConfig + """ + + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [ + Llama4VisionEncoderLayer( + prefix=f"{prefix}.layers.{layer_id}", config=config, weights=weights + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.gradient_checkpointing = False + self.config = config + + def forward( + self, + hidden_states: torch.Tensor, + freqs_ci: torch.Tensor, # TODO move this to an attribute instead of keeping it around + attention_mask: Optional[torch.Tensor] = None, + ) -> Union[Tuple, BaseModelOutput]: + + for encoder_layer in self.layers: + layer_outputs = encoder_layer( + hidden_state=hidden_states, + attention_mask=attention_mask, + freqs_ci=freqs_ci, + ) + + hidden_states = layer_outputs[0] + + return hidden_states + + +class Llama4UnfoldConvolution(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + kernel_size = config.patch_size + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size) + self.linear = FastLinear.load( + config=config, prefix=f"{prefix}.linear", weights=weights, bias=False + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.unfold(hidden_states) + hidden_states = hidden_states.permute(0, 2, 1) + hidden_states = self.linear(hidden_states) + return hidden_states + + +class Llama4VisionRotaryEmbedding(nn.Module): + def __init__(self, config, weights): + super().__init__() + # Calculate image grid indices + idx = config.image_size // config.patch_size + img_idx = torch.arange( + idx**2, dtype=torch.int32, device=weights.device + ).reshape(idx**2, 1) + img_idx = torch.cat([img_idx, img_idx[:1]], dim=0) + + img_idx[-1, -1] = -2 # ID_CLS_TOKEN + # Calculate x and y coordinates + frequencies_x = img_idx % idx # x coordinates + frequencies_y = torch.div(img_idx, idx, rounding_mode="floor") # y coordinates + # Calculate frequency components + freq_dim = config.hidden_size // config.num_attention_heads // 2 + rope_freq = 1.0 / ( + config.rope_theta + ** ( + torch.arange(0, freq_dim, 2, device=weights.device)[ + : (freq_dim // 2) + ].float() + / freq_dim + ) + ) + + # Compute frequencies for x and y directions + freqs_x = (frequencies_x + 1)[..., None] * rope_freq[None, None, :] + freqs_x = freqs_x.repeat_interleave(2, dim=-1) + freqs_y = (frequencies_y + 1)[..., None] * rope_freq[None, None, :] + freqs_y = freqs_y.repeat_interleave(2, dim=-1) + + # Combine frequencies and mask special tokens + freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2] + freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0) + + freq_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1) + self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2 + + def forward(self, hidden_states): + """ + Returns the rotary embedding components (cosθ, sinθ) for the given hidden states + """ + return self.freqs_ci.to(dtype=hidden_states.dtype, device=hidden_states.device) + + +class Llama4VisionModel(nn.Module): + + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.image_size = config.image_size + self.patch_size = config.patch_size + self.hidden_size = config.hidden_size + self.num_channels = config.num_channels + + self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 + self.scale = config.hidden_size**-0.5 + + self.patch_embedding = Llama4UnfoldConvolution( + prefix=f"{prefix}.patch_embedding", config=config, weights=weights + ) + + self.class_embedding = nn.Parameter( + weights.get_tensor(f"{prefix}.class_embedding"), requires_grad=False + ) + + self.positional_embedding_vlm = nn.Parameter( + weights.get_tensor(f"{prefix}.positional_embedding_vlm"), + requires_grad=False, + ) + + self.rotary_embedding = Llama4VisionRotaryEmbedding(config, weights) + + # layer norms + self.layernorm_pre = nn.LayerNorm.load( + prefix=f"{prefix}.layernorm_pre", weights=weights, eps=config.norm_eps + ) + self.layernorm_post = nn.LayerNorm.load( + prefix=f"{prefix}.layernorm_post", weights=weights, eps=config.norm_eps + ) + + # encoders + self.model = Llama4VisionEncoder( + prefix=f"{prefix}.model", config=config, weights=weights + ) + self.vision_adapter = Llama4VisionPixelShuffleMLP( + prefix=f"{prefix}.vision_adapter", config=config, weights=weights + ) + + def forward( + self, + pixel_values: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + ): + # num_concurrent_media and num_chunks are both currently 1 + batch_size_times_num_tiles, num_channels, height, width = pixel_values.shape + num_concurrent_media = 1 + num_chunks = 1 + hidden_state = self.patch_embedding(pixel_values) + _, num_patches, hidden_dim = hidden_state.shape + + # Add cls token + hidden_state = hidden_state.reshape( + batch_size_times_num_tiles * num_concurrent_media * num_chunks, + num_patches, + hidden_dim, + ) + class_embedding = self.class_embedding.expand( + hidden_state.shape[0], 1, hidden_state.shape[-1] + ) + hidden_state = torch.cat([hidden_state, class_embedding], dim=1) + num_patches += 1 + + # Position embeddings + hidden_state = hidden_state.reshape( + batch_size_times_num_tiles * num_concurrent_media, + num_chunks, + num_patches, + hidden_dim, + ) + positional_embedding = self.positional_embedding_vlm.to( + dtype=hidden_state.dtype, device=hidden_state.device + ) + hidden_state = hidden_state + positional_embedding + hidden_state = self.layernorm_pre(hidden_state) + hidden_state = hidden_state.view(batch_size_times_num_tiles, -1, hidden_dim) + freqs_ci = self.rotary_embedding(pixel_values) + + hidden_state = self.model( + hidden_state, + attention_mask=None, + freqs_ci=freqs_ci, + ) + + hidden_state = self.layernorm_post(hidden_state) + + hidden_state = hidden_state[:, :-1, :] + + # now, we use Llama4VisionPixelShuffle + mlp to project embeddings + hidden_state = self.vision_adapter(hidden_state) + return hidden_state + + +class Llama4ForConditionalGeneration(nn.Module): + + def __init__(self, prefix: str, config, weights): + super().__init__() + self.config = config + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + config.text_config.quantize = config.quantize + config.text_config.speculator = config.speculator + config.text_config._attn_implementation = None + + self.vision_model = Llama4VisionModel( + prefix="vision_model", config=config.vision_config, weights=weights + ) + + self.multi_modal_projector = Llama4MultiModalProjector( + prefix="multi_modal_projector", config=config, weights=weights + ) + + self.text_model = Llama4ForCausalLM( + prefix="language_model", config=config.text_config, weights=weights + ) + self.vocab_size = config.text_config.vocab_size + self.pad_token_id = ( + self.config.pad_token_id if self.config.pad_token_id is not None else -1 + ) + self.config = config + self.dtype = weights.dtype + self.device = weights.device + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + vision_feature_layer: Union[int, List[int]], + vision_feature_select_strategy: str, + **kwargs, + ): + """ + Obtains image last hidden states from the vision tower and apply al projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + vision_feature_layer (`Union[int, List[int]]`): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + vision_feature_select_strategy (`str`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + if vision_feature_select_strategy not in ["default", "full"]: + raise ValueError( + f"Unexpected select feature strategy: {self.vision_feature_select_strategy}" + ) + kwargs = {k: v for k, v in kwargs.items() if v is not None} + hidden_state = self.vision_model(pixel_values) + return hidden_state + + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=self.config.vision_config.vision_feature_layer, + vision_feature_select_strategy=self.config.vision_config.vision_feature_select_strategy, + image_sizes=image_sizes, + ) + vision_flat = image_features.view(-1, image_features.size(-1)) + image_features = self.multi_modal_projector(vision_flat) + return image_features + + def get_inputs_embeds( + self, + input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + pixel_values: torch.FloatTensor = None, + image_sizes: Optional[torch.LongTensor] = None, + ): + inputs_embeds = self.text_model.model.embed_tokens(input_ids) + + if vision_embeds is not None: + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + original_inputs_embeds_shape = inputs_embeds.shape + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze( + -1 + ) + final_mask = special_image_mask.to(inputs_embeds.device) + inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) + + final_mask_1d = final_mask[..., 0].reshape(-1) + num_tokens_to_fill = final_mask_1d.sum() + + if num_tokens_to_fill != vision_embeds.size(0): + raise ValueError( + f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, " + f"but multi_modal_projector returned {vision_embeds.size(0)}" + ) + + expanded_mask = final_mask_1d.unsqueeze(-1).expand( + -1, inputs_embeds.size(-1) + ) + inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, vision_embeds) + inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape) + return inputs_embeds + + def forward( + self, + inputs_embeds: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + cu_seqlen_prefill: Optional[torch.Tensor] = None, + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = None, + slots: torch.Tensor = None, + seqlen: Seqlen = None, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, + lm_head_indices: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + **lm_kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + logits, speculative_logits = self.text_model( + inputs_embeds, + position_ids, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + hpu_attention_meta, + adapter_data, + lm_head_indices, + attention_mask, + ) + + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 81af5560..70fcc824 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -26,7 +26,7 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN - +import habana_frameworks.torch as htorch from text_generation_server.layers.attention import ( KVCache, get_kv_scales, @@ -35,6 +35,7 @@ from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoE from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -143,12 +144,14 @@ class FlashLlamaAttention(torch.nn.Module): config.num_key_value_heads = getattr( config, "num_key_value_heads", config.num_attention_heads ) - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=self.head_size, - base=config.rope_theta, - device=weights.device, - ) + + if config.model_type != "llama4_text": + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) # `config.attention_multiplier` is used in Granite self.softmax_scale = getattr( @@ -547,6 +550,11 @@ class FlashLlamaModel(torch.nn.Module): hpu_attention_meta: Optional[HPUPagedAttentionMetadata], cross_attention_states=None, ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) + hidden_states = inputs_embeds # Get rotary cos and sin for this forward @@ -554,6 +562,9 @@ class FlashLlamaModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -568,6 +579,8 @@ class FlashLlamaModel(torch.nn.Module): cross_attention_states, hpu_attention_meta=hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py index 88548042..c4d4f728 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py @@ -163,9 +163,114 @@ class FlashLlavaNextForConditionalGeneration(nn.Module): ) return inputs_embeds - def forward( + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + # num_special_image_tokens = (input_ids == self.config.image_token_index).sum() + # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" + # 1. Extract the input embeddings + + # 2. Merge text and images + num_images, num_patches, channels, height, width = pixel_values.shape + pixel_values = pixel_values.view( + num_images * num_patches, channels, height, width + ) + image_features = self.vision_tower(pixel_values) + + # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer] + # Already done within the clip model + selected_image_feature = image_features.last_hidden_state + + if self.config.vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif self.config.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise RuntimeError( + f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid." + ) + + image_features = self.multi_modal_projector(selected_image_feature) + + # split up image_features for each of the individual images + # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) + # if we assume each image has 5 image features (base image + 4 patches) + split_sizes = [num_patches] * num_images + image_features = torch.split(image_features, split_sizes, dim=0) + + # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" + height = width = ( + self.config.vision_config.image_size // self.config.vision_config.patch_size + ) + + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + + if height * width != base_image_feature.shape[0]: + raise ValueError( + "The number of patches is not consistent with the image size." + ) + + # Dimensions are intentionally swapped to be bug-compatible with + # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat( + ( + image_feature, + self.image_newline[:, None, None].expand( + *image_feature.shape[:-1], 1 + ), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: + image_feature = image_feature[0] + image_feature = torch.cat( + (image_feature, self.image_newline[None]), dim=0 + ) + new_image_features.append(image_feature) + image_features = torch.stack(new_image_features, dim=0) + return image_features.view(-1, image_features.shape[-1]) + + def get_inputs_embeds( self, input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + pixel_values: torch.FloatTensor = None, + image_sizes: Optional[torch.LongTensor] = None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + + if vision_embeds is not None: + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, vision_embeds + ) + return inputs_embeds + + def forward( + self, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -173,101 +278,9 @@ class FlashLlavaNextForConditionalGeneration(nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, - pixel_values: torch.FloatTensor = None, - # Unused for this model - pixel_attention_mask=None, - image_sizes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, ): - inputs_embeds = self.text_model.embed_tokens(input_ids) - if pixel_values is not None and len(pixel_values) > 0: - # num_special_image_tokens = (input_ids == self.config.image_token_index).sum() - # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" - # 1. Extract the input embeddings - - # 2. Merge text and images - num_images, num_patches, channels, height, width = pixel_values.shape - pixel_values = pixel_values.view( - num_images * num_patches, channels, height, width - ) - image_features = self.vision_tower(pixel_values) - - # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer] - # Already done within the clip model - selected_image_feature = image_features.last_hidden_state - - if self.config.vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif self.config.vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - else: - raise RuntimeError( - f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid." - ) - - image_features = self.multi_modal_projector(selected_image_feature) - - # split up image_features for each of the individual images - # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) - # if we assume each image has 5 image features (base image + 4 patches) - split_sizes = [num_patches] * num_images - image_features = torch.split(image_features, split_sizes, dim=0) - - # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" - height = width = ( - self.config.vision_config.image_size - // self.config.vision_config.patch_size - ) - - new_image_features = [] - for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: - base_image_feature = image_feature[0] - image_feature = image_feature[1:] - - if height * width != base_image_feature.shape[0]: - raise ValueError( - "The number of patches is not consistent with the image size." - ) - - # Dimensions are intentionally swapped to be bug-compatible with - # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 - num_patch_width, num_patch_height = get_anyres_image_grid_shape( - image_sizes[image_idx], - self.config.image_grid_pinpoints, - self.config.vision_config.image_size, - ) - image_feature = image_feature.view( - num_patch_height, num_patch_width, height, width, -1 - ) - image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() - image_feature = image_feature.flatten(1, 2).flatten(2, 3) - image_feature = unpad_image(image_feature, image_sizes[image_idx]) - image_feature = torch.cat( - ( - image_feature, - self.image_newline[:, None, None].expand( - *image_feature.shape[:-1], 1 - ), - ), - dim=-1, - ) - image_feature = image_feature.flatten(1, 2).transpose(0, 1) - image_feature = torch.cat( - (base_image_feature, image_feature), dim=0 - ) - else: - image_feature = image_feature[0] - image_feature = torch.cat( - (image_feature, self.image_newline[None]), dim=0 - ) - new_image_features.append(image_feature) - image_features = torch.stack(new_image_features, dim=0) - - inputs_embeds = self._merge_input_ids_with_image_features( - input_ids, inputs_embeds, image_features - ) hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index d23d4f67..008df32d 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -30,6 +30,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -45,6 +46,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +import habana_frameworks.torch as htorch class MistralConfig(PretrainedConfig): @@ -109,7 +111,8 @@ class MistralAttention(torch.nn.Module): ) self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size - if hasattr(config, "head_dim"): + + if getattr(config, "head_dim", None) is not None: self.head_size = config.head_dim else: self.head_size = self.hidden_size // self.num_heads @@ -395,12 +398,19 @@ class MistralModel(torch.nn.Module): hpu_attention_meta: Optional[HPUPagedAttentionMetadata], adapter_data: Optional[torch.Tensor] = None, ): + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds # Get rotary cos and sin for this forward # Avoid to index in each layer cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -414,6 +424,8 @@ class MistralModel(torch.nn.Module): adapter_data, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) return hidden_states diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 1ef6be48..4993b444 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -37,6 +37,7 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention, + set_block_mapping, HPUPagedAttentionMetadata, ) from text_generation_server.layers.attention.kv_cache import get_kv_scales @@ -44,6 +45,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.utils.weights import UnquantizedWeight +import habana_frameworks.torch as htorch class MixtralConfig(PretrainedConfig): @@ -445,6 +447,10 @@ class MixtralModel(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward @@ -452,6 +458,9 @@ class MixtralModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -464,6 +473,8 @@ class MixtralModel(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) @@ -499,7 +510,6 @@ class FlashMixtralForCausalLM(torch.nn.Module): lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.model( input_ids, position_ids, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py index 421a0a65..fe6d137b 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py @@ -38,6 +38,7 @@ from text_generation_server.models.custom_modeling.flash_llama_modeling import ( ) from habana_frameworks.torch.hpex.kernels import FusedSDPA from vllm_hpu_extension.utils import ModuleFusedSDPA +import habana_frameworks.torch as htorch def _prepare_aspect_ratio_attention_mask( @@ -236,10 +237,19 @@ class MllamaVisionSdpaAttention(nn.Module): key = key.transpose(1, 2) value = value.transpose(1, 2) - attn_output = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask + fsdpa_op = ModuleFusedSDPA(FusedSDPA) + attn_output = fsdpa_op( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + scale=None, + softmax_mode="None", + recompute_mode=None, + valid_sequence_lengths=None, ) - attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch_size, q_seq_len, -1) @@ -320,6 +330,9 @@ class MllamaVisionEncoder(nn.Module): attention_mask: Optional[torch.Tensor] = None, ): encoder_states = [hidden_states] + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for encoder_layer in self.layers: layer_outputs = encoder_layer( hidden_states, @@ -328,6 +341,8 @@ class MllamaVisionEncoder(nn.Module): hidden_states = layer_outputs encoder_states.append(hidden_states) + if lazy_mode: + htorch.core.mark_step() return hidden_states, encoder_states @@ -699,8 +714,6 @@ class MllamaTextCrossAttention(nn.Module): # key_states = key_states.repeat(1, self.num_key_value_groups, 1) # value_states = value_states.repeat(1, self.num_key_value_groups, 1) - - causal = False # logger.info( # f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}" # ) @@ -715,7 +728,7 @@ class MllamaTextCrossAttention(nn.Module): value_states, attn_mask=None, dropout_p=0.0, - is_causal=causal, + is_causal=False, scale=None, softmax_mode="None", recompute_mode=None, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 33f63333..6e1050b6 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -29,6 +29,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -47,6 +48,7 @@ from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) from text_generation_server.utils.weights import UnquantizedWeight +import habana_frameworks.torch as htorch class GPTNeoXConfig(TransformersGPTNeoXConfig): @@ -353,6 +355,10 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_in(input_ids) # Get rotary cos and sin for this forward @@ -360,6 +366,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -372,6 +381,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.final_layer_norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 4d31d5dd..a13b9f09 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -62,10 +62,40 @@ class PaliGemmaForConditionalGeneration(nn.Module): self.pad_token_id = ( config.pad_token_id if config.pad_token_id is not None else -1 ) + self.dtype = weights.dtype + + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + pixel_values = pixel_values.to(dtype=self.dtype) + image_outputs = self.vision_tower(pixel_values) + last_hidden_state = self.post_vision_tower_layernorm( + image_outputs.last_hidden_state + ) + image_features = self.multi_modal_projector(last_hidden_state) + image_features = image_features.view(-1, image_features.shape[-1]) + return image_features + + def get_inputs_embeds( + self, + input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + + if vision_embeds is not None: + mask = input_ids == self.config.image_token_index + inputs_embeds[mask] = vision_embeds + + return inputs_embeds def forward( self, - input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -73,32 +103,13 @@ class PaliGemmaForConditionalGeneration(nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, - pixel_values: torch.FloatTensor = None, - # Unused here - pixel_attention_mask: Optional[torch.BoolTensor] = None, - image_sizes: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - inputs_embeds = self.text_model.embed_tokens(input_ids) # TODO This is odd but apparently pali gemma position ids start at 1. if cu_seqlen_prefill is not None: position_ids += 1 - if pixel_values is not None: - pixel_values = pixel_values.to(dtype=inputs_embeds.dtype) - image_outputs = self.vision_tower(pixel_values) - last_hidden_state = self.post_vision_tower_layernorm( - image_outputs.last_hidden_state - ) - image_features = self.multi_modal_projector(last_hidden_state) - - # mask where image or padding tokens - mask = input_ids == self.config.image_token_index - - # insert image features into input embeddings - inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) - hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 0c777912..78aaf0d5 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -9,6 +9,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -26,6 +27,7 @@ from text_generation_server.layers.layernorm import ( from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) +import habana_frameworks.torch as htorch class PhiConfig(PretrainedConfig): @@ -346,6 +348,10 @@ class FlashPhiModel(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward @@ -353,6 +359,9 @@ class FlashPhiModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -365,6 +374,8 @@ class FlashPhiModel(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py index bb585cc4..c28f3aee 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py @@ -18,7 +18,6 @@ from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging - logger = logging.get_logger(__name__) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index af4b404d..ac31e53b 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -8,6 +8,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -22,6 +23,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +import habana_frameworks.torch as htorch def load_attention(config, prefix, weights): @@ -287,6 +289,10 @@ class Qwen2Model(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( @@ -294,6 +300,9 @@ class Qwen2Model(torch.nn.Module): ) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states = layer( hidden_states, @@ -306,6 +315,8 @@ class Qwen2Model(torch.nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states) @@ -353,7 +364,6 @@ class Qwen2ForCausalLM(torch.nn.Module): lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py new file mode 100644 index 00000000..8bd00c13 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py @@ -0,0 +1,359 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, List + +import torch +from torch import nn +import habana_frameworks.torch as htorch +from text_generation_server.layers.attention import ( + paged_attention, + attention, + set_block_mapping, + Seqlen, + HPUPagedAttentionMetadata, +) +from text_generation_server.layers.attention.kv_cache import get_kv_scales +from text_generation_server.layers import ( + TensorParallelEmbedding, + TensorParallelRowLinear, + TensorParallelColumnLinear, + SpeculativeHead, +) + + +from text_generation_server.layers.layernorm import ( + FastRMSNorm, +) +from .flash_qwen2_modeling import Qwen2MLP +from text_generation_server.layers.rotary import PositionRotaryEmbedding + + +class Qwen3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config, prefix, weights, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.num_heads = config.num_attention_heads + self.attention_dropout = config.attention_dropout + self.softmax_scale = self.head_dim**-0.5 + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_dim, + base=config.rope_theta, + device=weights.device, + ) + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + self.query_key_value = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + + self.kv_scales = get_kv_scales(weights, f"{prefix}") + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + self.max_past = ( + config.sliding_window if config.sliding_window is not None else -1 + ) + + self.q_norm = FastRMSNorm.load( + prefix=f"{prefix}.q_norm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.k_norm = FastRMSNorm.load( + prefix=f"{prefix}.k_norm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.sliding_window = config.sliding_window + if not ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + self.sliding_window = None + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + hpu_attention_meta, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + qkv = self.query_key_value(hidden_states) + query_states, key_states, value_states = qkv.split( + [ + self.head_dim * self.num_heads, + self.head_dim * self.num_key_value_heads, + self.head_dim * self.num_key_value_heads, + ], + dim=1, + ) + + query_states, _ = self.q_norm(query_states.view(hidden_shape)) + key_states, _ = self.k_norm(key_states.view(hidden_shape)) + value_states = value_states.view(hidden_shape) + self.rotary_emb(query_states, key_states, cos, sin) + + kv_cache.store( + key=key_states, + value=value_states, + slots=slots, + kv_scales=self.kv_scales, + ) + + # Prefill + if cu_seqlen_prefill is not None: + # sdpa + attn_output = attention( + query=query_states, + key=key_states, + value=value_states, + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, + window_size_left=self.max_past, + ) + # Decode + else: + attn_output = paged_attention( + query_states, + kv_cache, + self.kv_head_mapping, + self.softmax_scale, + seqlen, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + return self.o_proj(attn_output) + + +class Qwen3DecoderLayer(nn.Module): + def __init__(self, config, prefix, weights, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Qwen3Attention( + config=config, + prefix=f"{prefix}.self_attn", + weights=weights, + layer_idx=layer_idx, + ) + self.mlp = Qwen2MLP(config=config, prefix=f"{prefix}.mlp", weights=weights) + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + hpu_attention_meta, + ) -> torch.Tensor: + residual = hidden_states + hidden_states, _ = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + hpu_attention_meta, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states, _ = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Qwen3Model(nn.Module): + def __init__(self, config, prefix: str, weights): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.layers = nn.ModuleList( + [ + Qwen3DecoderLayer( + config=config, + prefix=f"{prefix}.layers.{layer_idx}", + weights=weights, + layer_idx=layer_idx, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = FastRMSNorm.load( + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, + ) + + residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() + + for i, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + slots, + seqlen, + hpu_attention_meta, + ) + if lazy_mode: + htorch.core.mark_step() + + hidden_states, _ = self.norm(hidden_states) + + # add hidden states from the last decoder layer + return hidden_states + + +class Qwen3ForCausalLM(nn.Module): + + def __init__(self, prefix: str, config, weights): + super().__init__() + self.model = Qwen3Model(config=config, prefix="model", weights=weights) + self.vocab_size = config.vocab_size + if config.tie_word_embeddings: + suffix = "model.embed_tokens" + else: + suffix = "lm_head" + + self.lm_head = SpeculativeHead.load( + config, + prefix=f"{prefix}.{suffix}" if prefix else suffix, + weights=weights, + ) + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + inputs_embeds = self.embed_tokens(input_ids) + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + hidden_states = self.model( + inputs_embeds, + position_ids, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + hpu_attention_meta, + ) + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + + return logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py new file mode 100644 index 00000000..5e4bc7fa --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py @@ -0,0 +1,542 @@ +# coding=utf-8 +# Copyright 5 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn +import torch.nn.functional as F +from text_generation_server.layers.attention import ( + attention, + paged_attention, + Seqlen, + HPUPagedAttentionMetadata, +) +from text_generation_server.layers.attention.kv_cache import get_kv_scales +from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer +from text_generation_server.layers import ( + TensorParallelEmbedding, + TensorParallelColumnLinear, + TensorParallelRowLinear, + SpeculativeHead, + FastLinear, +) + +from text_generation_server.layers.layernorm import ( + FastRMSNorm, +) +from .flash_qwen2_modeling import Qwen2MLP +from .flash_qwen3_modeling import Qwen3Attention +from transformers.activations import ACT2FN +from text_generation_server.layers.rotary import PositionRotaryEmbedding + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen3MoeAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config, prefix, weights, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = FastLinear.load( + config, f"{prefix}.q_proj", weights, bias=config.attention_bias + ) + + self.k_proj = FastLinear.load( + config, f"{prefix}.k_proj", weights, bias=config.attention_bias + ) + self.v_proj = FastLinear.load( + config, f"{prefix}.v_proj", weights, bias=config.attention_bias + ) + self.o_proj = FastLinear.load( + config, f"{prefix}.o_proj", weights, bias=config.attention_bias + ) + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_dim, + base=config.rope_theta, + device=weights.device, + ) + + self.q_norm = FastRMSNorm.load( + prefix=f"{prefix}.q_norm", + weights=weights, + eps=config.rms_norm_eps, + ) + + self.k_norm = FastRMSNorm.load( + prefix=f"{prefix}.k_norm", + weights=weights, + eps=config.rms_norm_eps, + ) + + self.max_past = ( + config.sliding_window if config.sliding_window is not None else -1 + ) + + self.kv_scales = get_kv_scales(weights, f"{prefix}") + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_key_value_groups) + + self.sliding_window = config.sliding_window + if not ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + self.sliding_window = None + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + hpu_attention_meta, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states, _ = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)) + key_states, _ = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)) + value_states = self.v_proj(hidden_states).view(hidden_shape) + + self.rotary_emb(query_states, key_states, cos, sin) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + kv_cache.store( + key=key_states, + value=value_states, + slots=slots, + kv_scales=self.kv_scales, + ) + + # Prefill + if cu_seqlen_prefill is not None: + # sdpa + attn_output = attention( + query=query_states, + key=key_states, + value=value_states, + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.scaling, + window_size_left=self.max_past, + ) + # Decode + else: + attn_output = paged_attention( + query_states, + kv_cache, + self.kv_head_mapping, + self.scaling, + seqlen, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output + + +class Qwen3MoE(nn.Module): + def __init__(self, prefix, config, moe_layer_cls: Type[MoELayer], weights): + super().__init__() + + # gating + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + + self.moe = moe_layer_cls( + n_expert_group=None, + n_experts=config.num_experts, + prefix=f"{prefix}.experts", + renormalize=True, + topk=config.num_experts_per_tok, + topk_group=None, + weights=weights, + ) + # gate_proj_name="w1", + # up_proj_name="w3", + # down_proj_name="w2", + + assert isinstance(self.moe, MoELayer) + + self.process_group = weights.process_group + + def forward(self, x: torch.Tensor) -> torch.Tensor: + router_logits = self.gate(x) + out = self.moe(x, gating_output=router_logits) + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out.view(*x.shape) + + +class Qwen3MoeMLP(nn.Module): + def __init__(self, prefix, config, weights, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = ( + intermediate_size + if intermediate_size is not None + else config.intermediate_size + ) + # Fuse gate and up proj + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + gate_up_states = self.gate_up_proj(x) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + + +class Qwen3MoeSparseMoeBlock(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + + # gating + # self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + self.experts = nn.ModuleList( + [ + Qwen3MoeMLP( + prefix=f"{prefix}.experts.{i}", + config=config, + weights=weights, + intermediate_size=config.moe_intermediate_size, + ) + for i in range(self.num_experts) + ] + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + input_shape = hidden_states.shape + _, hidden_dim = hidden_states.shape + # hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=hidden_states.dtype) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) + if self.norm_topk_prob: # only diff with mixtral sparse moe block! + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (input_shape), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=self.num_experts + ).permute(2, 1, 0) + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = ( + expert_layer(current_state) * routing_weights[top_x, idx, None] + ) + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) + final_hidden_states = final_hidden_states.reshape(input_shape) + return final_hidden_states + + +class Qwen3MoeDecoderLayer(nn.Module): + def __init__(self, config, prefix, weights, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.num_key_value_heads // weights.process_group.size() > 0: + self.self_attn = Qwen3Attention( + config, + prefix=f"{prefix}.self_attn", + weights=weights, + layer_idx=layer_idx, + ) + else: + self.self_attn = Qwen3MoeAttention( + config, + prefix=f"{prefix}.self_attn", + weights=weights, + layer_idx=layer_idx, + ) + + moe_layer_cls = ( + SparseMoELayer if SparseMoELayer.is_supported(weights) else DenseMoELayer + ) + + if (layer_idx not in config.mlp_only_layers) and ( + config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 + ): + self.mlp = Qwen3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights) + # self.mlp = Qwen3MoeSparseMoeBlock(f"{prefix}.mlp", config, weights) + + else: + self.mlp = Qwen2MLP(config=config, prefix=f"{prefix}.mlp", weights=weights) + + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + hpu_attention_meta, + ) -> torch.Tensor: + if residual is None: + residual = hidden_states + + hidden_states, _ = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + hpu_attention_meta, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states, _ = self.post_attention_layernorm(hidden_states) + + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + return hidden_states + + +class Qwen3MoeModel(nn.Module): + def __init__(self, config, prefix: str, weights): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.layers = nn.ModuleList( + [ + Qwen3MoeDecoderLayer( + config=config, + prefix=f"{prefix}.layers.{layer_idx}", + weights=weights, + layer_idx=layer_idx, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = FastRMSNorm.load( + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + ) -> torch.Tensor: + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, + ) + + residual = None + for i, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + slots, + seqlen, + hpu_attention_meta, + ) + + hidden_states, _ = self.norm(hidden_states) + + # add hidden states from the last decoder layer + return hidden_states + + +class Qwen3MoeForCausalLM(nn.Module): + + def __init__(self, prefix: str, config, weights): + super().__init__() + self.model = Qwen3MoeModel(config=config, prefix="model", weights=weights) + self.vocab_size = config.vocab_size + if config.tie_word_embeddings: + suffix = "model.embed_tokens" + else: + suffix = "lm_head" + + self.lm_head = SpeculativeHead.load( + config, + prefix=f"{prefix}.{suffix}" if prefix else suffix, + weights=weights, + ) + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + inputs_embeds = self.embed_tokens(input_ids) + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + hidden_states = self.model( + inputs_embeds, + position_ids, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + hpu_attention_meta, + ) + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + + return logits diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 141e13a6..06616f85 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -18,9 +18,11 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.attention import ( attention, paged_attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) +import habana_frameworks.torch as htorch def load_row(config, prefix: str, weights, bias: bool): @@ -627,6 +629,10 @@ class FlashRWModel(FlashRWPreTrainedModel): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.word_embeddings(input_ids) # Get rotary cos and sin for this forward @@ -634,6 +640,9 @@ class FlashRWModel(FlashRWPreTrainedModel): cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.h): hidden_states, residual = layer( hidden_states, @@ -646,6 +655,8 @@ class FlashRWModel(FlashRWPreTrainedModel): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.ln_f(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index b68f4784..b6a0d32a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -8,6 +8,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -23,6 +24,7 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.layernorm import ( FastLayerNorm, ) +import habana_frameworks.torch as htorch def load_multi_mqa( @@ -436,12 +438,19 @@ class FlashSantacoderModel(nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.wte(input_ids) + self.wpe(position_ids) if self.process_group.size() > 1: torch.distributed.all_reduce(hidden_states, group=self.process_group) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -452,6 +461,8 @@ class FlashSantacoderModel(nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.ln_f(hidden_states, residual) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 76f6f473..1a749595 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -29,6 +29,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -50,6 +51,7 @@ from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) from text_generation_server.utils.weights import UnquantizedWeight +import habana_frameworks.torch as htorch class Starcoder2Config(PretrainedConfig): @@ -510,6 +512,10 @@ class Starcoder2Model(torch.nn.Module): adapter_data, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward @@ -517,6 +523,9 @@ class Starcoder2Model(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -530,6 +539,8 @@ class Starcoder2Model(torch.nn.Module): adapter_data, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states, residual) @@ -578,7 +589,6 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.model( input_ids, position_ids, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py index 02806ac9..41a45373 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py @@ -734,9 +734,107 @@ class Idefics2ForConditionalGeneration(nn.Module): inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) return inputs_embeds - def forward( + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + assert pixel_values is not None + batch_size, num_images, num_channels, height, width = pixel_values.shape + all_states = [] + all_pixel_values = pixel_values + all_pixel_mask = pixel_attention_mask + for i in range(batch_size): + pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility + pixel_values = pixel_values[i : i + 1] + pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum( + dim=(-1, -2, -3) + ) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=( + pixel_values.size(0), + pixel_values.size(2), + pixel_values.size(3), + ), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask/pP p + pixel_attention_mask = all_pixel_mask[i : i + 1] + pixel_attention_mask = pixel_attention_mask.view( + 1 * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[ + real_images_inds + ].contiguous() + + patch_size = self.config.vision_config.patch_size + """ + patches_subgrid = pixel_attention_mask.unfold( + dimension=1, size=patch_size, step=patch_size + ) + patches_subgrid = patches_subgrid.unfold( + dimension=2, size=patch_size, step=patch_size + ) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + """ + # hpu does none support unfold + conv_kernel = torch.ones( + [1, 1, patch_size, patch_size], + dtype=pixel_values.dtype, + device=pixel_values.device, + ) + patches_subgrid = torch.nn.functional.conv2d( + pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), + conv_kernel, + stride=patch_size, + ).squeeze(1) + patch_attention_mask = torch.gt(patches_subgrid, 0) + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + # Modality projection & resampling + image_hidden_states = self.connector( + image_hidden_states, + attention_mask=patch_attention_mask.view(pixel_values.size(0), -1), + ) + all_states.append(image_hidden_states) + image_hidden_states = torch.stack(all_states, dim=0) + return image_hidden_states.view(-1, image_hidden_states.shape[-1]) + + def get_inputs_embeds( self, input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + + if vision_embeds is not None: + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, vision_embeds + ) + return inputs_embeds + + def forward( + self, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -744,98 +842,9 @@ class Idefics2ForConditionalGeneration(nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, - pixel_values: torch.FloatTensor = None, - pixel_attention_mask: Optional[torch.BoolTensor] = None, - # Unused here - image_sizes: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, ): - inputs_embeds = self.text_model.embed_tokens(input_ids) - if pixel_values is not None: - batch_size, num_images, num_channels, height, width = pixel_values.shape - all_states = [] - all_pixel_values = pixel_values - all_pixel_mask = pixel_attention_mask - for i in range(batch_size): - pixel_values = all_pixel_values.to( - dtype=self.dtype - ) # fp16 compatibility - pixel_values = pixel_values[i : i + 1] - pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) - - # Remove padding images - padding images are full 0. - nb_values_per_image = pixel_values.shape[1:].numel() - real_images_inds = (pixel_values == 0.0).sum( - dim=(-1, -2, -3) - ) != nb_values_per_image - pixel_values = pixel_values[real_images_inds].contiguous() - - # Handle the vision attention mask - if pixel_attention_mask is None: - pixel_attention_mask = torch.ones( - size=( - pixel_values.size(0), - pixel_values.size(2), - pixel_values.size(3), - ), - dtype=torch.bool, - device=pixel_values.device, - ) - else: - # Remove padding images from the mask/pP p - pixel_attention_mask = all_pixel_mask[i : i + 1] - pixel_attention_mask = pixel_attention_mask.view( - 1 * num_images, *pixel_attention_mask.shape[2:] - ) - pixel_attention_mask = pixel_attention_mask[ - real_images_inds - ].contiguous() - - patch_size = self.config.vision_config.patch_size - """ - patches_subgrid = pixel_attention_mask.unfold( - dimension=1, size=patch_size, step=patch_size - ) - patches_subgrid = patches_subgrid.unfold( - dimension=2, size=patch_size, step=patch_size - ) - patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() - """ - # hpu does none support unfold - conv_kernel = torch.ones( - [1, 1, patch_size, patch_size], - dtype=pixel_values.dtype, - device=pixel_values.device, - ) - patches_subgrid = torch.nn.functional.conv2d( - pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), - conv_kernel, - stride=patch_size, - ).squeeze(1) - patch_attention_mask = torch.eq( - patches_subgrid, (patch_size * patch_size) - ) - - # Get sequence from the vision encoder - image_hidden_states = self.vision_model( - pixel_values=pixel_values, - patch_attention_mask=patch_attention_mask, - ) - - # Modality projection & resampling - image_hidden_states = self.connector( - image_hidden_states, - attention_mask=patch_attention_mask.view(pixel_values.size(0), -1), - ) - all_states.append(image_hidden_states) - image_hidden_states = torch.stack(all_states, dim=0) - # When we generate, we don't want to replace the potential image_token_id that we generated by images - # that simply don't exist - inputs_embeds = self._merge_input_ids_with_image_features( - input_ids, inputs_embeds, image_hidden_states - ) - hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py index 964526fc..6dd44c11 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py @@ -477,9 +477,107 @@ class Idefics3ForConditionalGeneration(nn.Module): inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) return inputs_embeds - def forward( + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + batch_size, num_images, num_channels, height, width = pixel_values.shape + all_states = [] + all_pixel_values = pixel_values + all_pixel_mask = pixel_attention_mask + for i in range(batch_size): + pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility + pixel_values = pixel_values[i : i + 1] + pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum( + dim=(-1, -2, -3) + ) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=( + pixel_values.size(0), + pixel_values.size(2), + pixel_values.size(3), + ), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask/pP p + pixel_attention_mask = all_pixel_mask[i : i + 1] + pixel_attention_mask = pixel_attention_mask.view( + 1 * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[ + real_images_inds + ].contiguous() + + patch_size = self.config.vision_config.patch_size + + """ + patches_subgrid = pixel_attention_mask.unfold( + dimension=1, size=patch_size, step=patch_size + ) + patches_subgrid = patches_subgrid.unfold( + dimension=2, size=patch_size, step=patch_size + ) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + """ + # hpu does none support unfold + conv_kernel = torch.ones( + [1, 1, patch_size, patch_size], + dtype=pixel_values.dtype, + device=pixel_values.device, + ) + patches_subgrid = torch.nn.functional.conv2d( + pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), + conv_kernel, + stride=patch_size, + ).squeeze(1) + patch_attention_mask = torch.gt(patches_subgrid, 0) + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + # Modality projection & resampling + image_hidden_states = self.connector( + image_hidden_states, + ) + + all_states.append(image_hidden_states) + image_hidden_states = torch.stack(all_states, dim=0) + + return image_hidden_states.view(-1, image_hidden_states.shape[-1]) + + def get_inputs_embeds( self, input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + + if vision_embeds is not None: + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, vision_embeds + ) + return inputs_embeds + + def forward( + self, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -487,99 +585,10 @@ class Idefics3ForConditionalGeneration(nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, - pixel_values: torch.FloatTensor = None, - pixel_attention_mask: Optional[torch.BoolTensor] = None, - # Unused here - image_sizes: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, image_indices=None, ): - inputs_embeds = self.text_model.embed_tokens(input_ids) - if pixel_values is not None: - batch_size, num_images, num_channels, height, width = pixel_values.shape - all_states = [] - all_pixel_values = pixel_values - all_pixel_mask = pixel_attention_mask - for i in range(batch_size): - pixel_values = all_pixel_values.to( - dtype=self.dtype - ) # fp16 compatibility - pixel_values = pixel_values[i : i + 1] - pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) - - # Remove padding images - padding images are full 0. - nb_values_per_image = pixel_values.shape[1:].numel() - real_images_inds = (pixel_values == 0.0).sum( - dim=(-1, -2, -3) - ) != nb_values_per_image - pixel_values = pixel_values[real_images_inds].contiguous() - # Handle the vision attention mask - if pixel_attention_mask is None: - pixel_attention_mask = torch.ones( - size=( - pixel_values.size(0), - pixel_values.size(2), - pixel_values.size(3), - ), - dtype=torch.bool, - device=pixel_values.device, - ) - else: - # Remove padding images from the mask/pP p - pixel_attention_mask = all_pixel_mask[i : i + 1] - pixel_attention_mask = pixel_attention_mask.view( - 1 * num_images, *pixel_attention_mask.shape[2:] - ) - pixel_attention_mask = pixel_attention_mask[ - real_images_inds - ].contiguous() - - patch_size = self.config.vision_config.patch_size - """ - patches_subgrid = pixel_attention_mask.unfold( - dimension=1, size=patch_size, step=patch_size - ) - patches_subgrid = patches_subgrid.unfold( - dimension=2, size=patch_size, step=patch_size - ) - patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() - """ - # hpu does none support unfold - conv_kernel = torch.ones( - [1, 1, patch_size, patch_size], - dtype=pixel_values.dtype, - device=pixel_values.device, - ) - patches_subgrid = torch.nn.functional.conv2d( - pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), - conv_kernel, - stride=patch_size, - ).squeeze(1) - patch_attention_mask = torch.eq( - patches_subgrid, (patch_size * patch_size) - ) - - # Get sequence from the vision encoder - image_hidden_states = self.vision_model( - pixel_values=pixel_values, - patch_attention_mask=patch_attention_mask, - ) - - # Modality projection & resampling - image_hidden_states = self.connector( - image_hidden_states, - ) - - all_states.append(image_hidden_states) - image_hidden_states = torch.stack(all_states, dim=0) - - inputs_embeds = self._merge_input_ids_with_image_features( - input_ids, inputs_embeds, image_hidden_states - ) - hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py deleted file mode 100644 index 00ecdf95..00000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py +++ /dev/null @@ -1,467 +0,0 @@ -# coding=utf-8 -# Copyright 2024 the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch Llava-NeXT model.""" - -from typing import List, Optional, Union - -import torch -import torch.utils.checkpoint -import numpy as np - -from loguru import logger -from transformers.models.llava_next.modeling_llava_next import ( - unpad_image, -) -from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration -from transformers.image_processing_utils import select_best_resolution - - -def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): - """ - Calculate the shape of the image patch grid after the preprocessing for images of any resolution. - - Args: - image_size (`tuple`): - The size of the input image in the format (width, height). - grid_pinpoints (`List`): - A list containing possible resolutions. Each item in the list should be a tuple or list - of the form `(height, width)`. - patch_size (`int`): - The size of each image patch. - - Returns: - tuple: The shape of the image patch grid in the format (width, height). - """ - if not isinstance(grid_pinpoints, list): - raise ValueError("grid_pinpoints should be a list of tuples or lists") - - height, width = select_best_resolution(image_size, grid_pinpoints) - return height // patch_size, width // patch_size - - -# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79 -def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): - """ - Calculate the number of patches after the preprocessing for images of any resolution. - - Args: - image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`): - The size of the input image in the format (height, width). ? - grid_pinpoints (`List`): - A list containing possible resolutions. Each item in the list should be a tuple or list - of the form `(height, width)`. - patch_size (`int`): - The size of each image patch. - - Returns: - int: the number of patches - """ - if not isinstance(grid_pinpoints, list): - raise TypeError("grid_pinpoints should be a list of tuples or lists") - - # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate - if not isinstance(image_size, (list, tuple)): - if not isinstance(image_size, (torch.Tensor, np.ndarray)): - raise TypeError( - f"image_size invalid type {type(image_size)} with value {image_size}" - ) - image_size = image_size.tolist() - - best_resolution = select_best_resolution(image_size, grid_pinpoints) - height, width = best_resolution - num_patches = 0 - # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1 - for i in range(0, height, patch_size): - for j in range(0, width, patch_size): - num_patches += 1 - # add the base patch - num_patches += 1 - return num_patches - - -class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): - - def forward( - self, - input_ids: torch.LongTensor = None, - pixel_values: torch.FloatTensor = None, - image_sizes: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - vision_feature_layer: Optional[int] = None, - vision_feature_select_strategy: Optional[str] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - token_idx: Optional[torch.Tensor] = None, - use_flash_attention: Optional[bool] = True, - flash_attention_recompute: Optional[bool] = True, - ): - - if token_idx is not None: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - outputs = self.language_model( - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - token_idx=token_idx, - use_flash_attention=use_flash_attention, - flash_attention_recompute=flash_attention_recompute, - ) - - logits = outputs[0] - - if not return_dict: - output = (logits,) + outputs[1:] - return output - - return outputs - - # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411 - def pack_image_features( - self, - image_features, - image_sizes, - vision_feature_select_strategy, - image_newline=None, - ): - """ - Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. - - Args: - image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`) - List of image feature tensor, each contains all the visual feature of all patches. - image_sizes (`torch.Tensor` of shape `(num_images, 2)`) - Actual image size of each images (H, W). - vision_feature_select_strategy (`str`) - The feature selection strategy used to select the vision feature from the vision backbone. - image_newline (`torch.Tensor` of shape `(embed_dim)`) - New line embedding vector. - Returns: - image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`) - feature_lens (`List[int]`) - token length of each image in image_features - """ - new_image_features = [] - feature_lens = [] - for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: - base_image_feature = image_feature[0] - image_feature = image_feature[1:] - height = width = ( - self.config.vision_config.image_size - // self.config.vision_config.patch_size - ) - - num_patch_height, num_patch_width = get_anyres_image_grid_shape( - image_sizes[image_idx], - self.config.image_grid_pinpoints, - self.config.vision_config.image_size, - ) - - if ( - np.prod(image_feature.shape) - % (num_patch_height * num_patch_width * height * width) - != 0 - and vision_feature_select_strategy == "default" - ): - logger.warning_once( - "Image feature shape does not line up with the provided patch size. " - "You may be using the `default` vision_feature_select_strategy with a" - " visual encoder that does not have CLS." - ) - - image_feature = image_feature.view( - num_patch_height, num_patch_width, height, width, -1 - ) - image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() - image_feature = image_feature.flatten(1, 2).flatten(2, 3) - image_feature = unpad_image(image_feature, image_sizes[image_idx]) - if image_newline is not None: - image_feature = torch.cat( - ( - image_feature, - image_newline[:, None, None] - .expand(*image_feature.shape[:-1], 1) - .to(image_feature.device, image_feature.dtype), - ), - dim=-1, - ) - image_feature = image_feature.flatten(1, 2).transpose(0, 1) - image_feature = torch.cat((base_image_feature, image_feature), dim=0) - else: - image_feature = image_feature[0] - if image_newline is not None: - image_feature = torch.cat( - (image_feature, image_newline[None].to(image_feature)), dim=0 - ) - new_image_features.append(image_feature) - feature_lens.append(image_feature.size(0)) - image_features = torch.cat(new_image_features, dim=0) - feature_lens = torch.tensor( - feature_lens, dtype=torch.long, device=image_features.device - ) - return image_features, feature_lens - - # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479 - def get_image_features( - self, - pixel_values: torch.FloatTensor, - image_sizes: torch.Tensor, - vision_feature_layer: Union[int, List[int]], - vision_feature_select_strategy: str, - ): - """ - Obtains image last hidden states from the vision tower and apply multimodal projection. - - Args: - pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`) - The tensors corresponding to the input images. - image_sizes (`torch.Tensor` of shape `(num_images, 2)`) - Actual image size of each images (H, W). - vision_feature_layer (`Union[int, List[int]]`): - The index of the layer to select the vision feature. If multiple indices are provided, - the vision feature of the corresponding indices will be concatenated to form the - vision features. - vision_feature_select_strategy (`str`): - The feature selection strategy used to select the vision feature from the vision backbone. - Can be one of `"default"` or `"full"` - Returns: - image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches - and are of shape `(num_patches, image_length, embed_dim)`). - """ - # ! infer image_num_patches from image_sizes - image_num_patches = [ - image_size_to_num_patches( - image_size=imsize, - grid_pinpoints=self.config.image_grid_pinpoints, - patch_size=self.config.vision_config.image_size, - ) - for imsize in image_sizes - ] - if pixel_values.dim() == 5: - # stacked if input is (batch_size, num_patches, num_channels, height, width) - _pixel_values_list = [ - pix_val[:num_patch] - for pix_val, num_patch in zip(pixel_values, image_num_patches) - ] - pixel_values = torch.cat(_pixel_values_list, dim=0) - elif pixel_values.dim() != 4: - # otherwise has to be stacked from list of (num_patches, num_channels, height, width) - raise ValueError( - f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions" - ) - - image_features = self.vision_tower(pixel_values, output_hidden_states=True) - # If we have one vision feature layer, return the corresponding hidden states, - # otherwise, select the hidden states of each feature layer and concatenate them - if isinstance(vision_feature_layer, int): - selected_image_feature = image_features.hidden_states[vision_feature_layer] - else: - hs_pool = [ - image_features.hidden_states[layer_idx] - for layer_idx in vision_feature_layer - ] - selected_image_feature = torch.cat(hs_pool, dim=-1) - - if vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - - image_features = self.multi_modal_projector(selected_image_feature) - image_features = torch.split(image_features, image_num_patches, dim=0) - return image_features - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - inputs_embeds=None, - pixel_values=None, - image_sizes=None, - attention_mask=None, - **kwargs, - ): - """ - Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635 - The only differences are: - - add new args token_idx - - add the process of merging images into inputs_embeds - """ - token_idx = kwargs.get("token_idx", None) - if token_idx is None: - return super().prepare_inputs_for_generation( - input_ids=input_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - pixel_values=pixel_values, - image_sizes=image_sizes, - attention_mask=attention_mask, - **kwargs, - ) - else: - use_flash_attention = kwargs.get("use_flash_attention", True) - flash_attention_recompute = kwargs.get("flash_attention_recompute", True) - - position_ids = kwargs.get("position_ids", None) - labels = kwargs.get("labels", None) - if ( - past_key_values is None - and pixel_values is not None - and input_ids.shape[1] != 1 - ): - vision_feature_select_strategy = kwargs.get( - "vision_feature_select_strategy", None - ) - vision_feature_layer = kwargs.get("vision_feature_layer", None) - vision_feature_select_strategy = ( - vision_feature_select_strategy - if vision_feature_select_strategy is not None - else self.config.vision_feature_select_strategy - ) - vision_feature_layer = ( - vision_feature_layer - if vision_feature_layer is not None - else self.config.vision_feature_layer - ) - - # 1. Extract the input embeddings - inputs_embeds = self.get_input_embeddings()(input_ids) - # 2. Merge text and images - image_features = self.get_image_features( - pixel_values, - image_sizes, - vision_feature_layer=vision_feature_layer, - vision_feature_select_strategy=vision_feature_select_strategy, - ) - - # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" - image_features, feature_lens = self.pack_image_features( - image_features, - image_sizes, - vision_feature_select_strategy=vision_feature_select_strategy, - image_newline=self.image_newline, - ) - - special_image_mask = ( - input_ids == self.config.image_token_index - ).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to( - inputs_embeds.device - ) - if inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_index).sum() - n_image_features = image_features.shape[0] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - - image_features = image_features.to( - inputs_embeds.device, inputs_embeds.dtype - ) - inputs_embeds = inputs_embeds.masked_scatter( - special_image_mask, image_features - ) - - # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of - # generation with cache - elif past_key_values is not None: - seq_len = input_ids.shape[1] - pad_len = seq_len - token_idx.item() - input_ids = torch.index_select(input_ids, 1, token_idx - 1) - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where( - first_layer_past_key_value.float().sum(-2) == 0 - ) - # Get the target length - past_length = first_layer_past_key_value.shape[-1] - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - - attention_mask = extended_attention_mask - attention_mask[:, -pad_len:] = 0 - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - if token_idx is not None: - position_ids = ( - torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - ) - else: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "token_idx": token_idx, - "labels": labels, - "use_flash_attention": use_flash_attention, - "flash_attention_recompute": flash_attention_recompute, - } - ) - - return model_inputs diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py deleted file mode 100644 index 6ba0ffff..00000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py +++ /dev/null @@ -1,292 +0,0 @@ -# coding=utf-8 -# Copyright 2024 the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Mllama model.""" - -from typing import Optional, Tuple, List, Union - -import torch -import torch.utils.checkpoint - -from optimum.habana.transformers.models import GaudiMllamaForConditionalGeneration -from optimum.habana.transformers.models.mllama.modeling_mllama import ( - _prepare_cross_attention_mask, -) -from transformers.modeling_outputs import CausalLMOutputWithPast - - -class MllamaForConditionalGeneration(GaudiMllamaForConditionalGeneration): - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - aspect_ratio_mask: Optional[torch.Tensor] = None, - aspect_ratio_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, - token_idx: Optional[torch.Tensor] = None, - use_flash_attention: Optional[bool] = True, - flash_attention_recompute: Optional[bool] = True, - **kwargs, - ) -> Union[Tuple, CausalLMOutputWithPast]: - """ - Copied from MllamaForConditionalGeneration::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2077 - The only differences are: - - add token_idx input - - add use_flash_attention and flash_attention_recompute - """ - full_text_row_masked_out_mask = kwargs.get( - "full_text_row_masked_out_mask", None - ) - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" - ) - - outputs = self.language_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - cross_attention_states=cross_attention_states, - cross_attention_mask=cross_attention_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - past_key_values=past_key_values, - use_cache=use_cache, - inputs_embeds=inputs_embeds, - labels=labels, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=return_dict, - cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, - token_idx=token_idx, - use_flash_attention=use_flash_attention, - flash_attention_recompute=flash_attention_recompute, - ) - - logits = outputs[0] - if not return_dict: - output = (logits,) + outputs[1:] - return output - - return outputs - - def prepare_inputs_for_generation( - self, - input_ids=None, - inputs_embeds=None, - attention_mask=None, - position_ids=None, - pixel_values=None, - aspect_ratio_ids=None, - aspect_ratio_mask=None, - cross_attention_mask=None, - past_key_values=None, - use_cache=False, - cache_position=None, - num_logits_to_keep=None, - **kwargs, - ): - """ - Copied from MllamaForConditionalGeneration::prepare_inputs_for_generation: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2208 - The only differences are: - - add token_idx handling - - add bucket_internal handling - - add use_flash_attention and flash_attention_recompute - """ - - token_idx = kwargs.get("token_idx", None) - if token_idx is None: - return super().prepare_inputs_for_generation( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - position_ids=position_ids, - pixel_values=pixel_values, - aspect_ratio_ids=aspect_ratio_ids, - aspect_ratio_mask=aspect_ratio_mask, - cross_attention_mask=cross_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) - else: - use_flash_attention = kwargs.get("use_flash_attention", True) - flash_attention_recompute = kwargs.get("flash_attention_recompute", True) - position_ids = kwargs.get("position_ids", None) - output_attentions = kwargs.get("output_attentions", None) - output_hidden_states = kwargs.get("output_hidden_states", None) - return_dict = kwargs.get("return_dict", None) - labels = kwargs.get("labels", None) - cross_attention_states = kwargs.get("cross_attention_states", None) - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - bucket_internal = kwargs.get("bucket_internal", None) - - if past_key_values is not None: - if token_idx is not None: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) - elif inputs_embeds is not None: # Exception 1 - input_ids = input_ids[:, -cache_position.shape[0] :] - elif ( - input_ids.shape[1] != cache_position.shape[0] - ): # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - elif bucket_internal and token_idx is not None: - # for the 1st token we can slice the inputs till token idx for the fwd pass. - input_ids = input_ids[:, :token_idx] - attention_mask = attention_mask[:, :token_idx] - if cross_attention_mask is not None: - cross_attention_mask = cross_attention_mask[:, :token_idx, ...] - - # TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - if token_idx is not None: - position_ids = torch.index_select( - position_ids, 1, token_idx - 1 - ) - else: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. - position_ids = position_ids.clone( - memory_format=torch.contiguous_format - ) - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if pixel_values is not None and cross_attention_states is not None: - raise ValueError( - "`pixel_values` and `cross_attention_states` cannot be provided simultaneously" - ) - - if pixel_values is not None: - if aspect_ratio_ids is None: - raise ValueError( - "`aspect_ratio_ids` must be provided if `pixel_values` is provided" - ) - # get vision tokens from vision model - vision_outputs = self.vision_model( - pixel_values=pixel_values, - aspect_ratio_ids=aspect_ratio_ids, - aspect_ratio_mask=aspect_ratio_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=return_dict, - use_flash_attention=use_flash_attention, - ) - cross_attention_states = vision_outputs[0] - cross_attention_states = self.multi_modal_projector( - cross_attention_states - ).reshape(-1, cross_attention_states.shape[-2], self.hidden_size) - - if cross_attention_mask is not None: - cross_attention_mask, full_text_row_masked_out_mask = ( - _prepare_cross_attention_mask( - cross_attention_mask, - num_vision_tokens=self.vision_model.num_patches, - dtype=self.dtype, - token_idx=token_idx, - ) - ) - else: - full_text_row_masked_out_mask = None - - if cross_attention_mask is not None: - if cache_position is not None: - cross_attention_mask = cross_attention_mask[:, :, cache_position] - full_text_row_masked_out_mask = full_text_row_masked_out_mask[ - :, :, cache_position - ] - elif past_key_values is not None: - if token_idx is not None: - cross_attention_mask = torch.index_select( - cross_attention_mask, -2, token_idx - 1 - ) - full_text_row_masked_out_mask = torch.index_select( - full_text_row_masked_out_mask, -2, token_idx - 1 - ) - else: - cross_attention_mask = cross_attention_mask[:, :, -1:] - full_text_row_masked_out_mask = full_text_row_masked_out_mask[ - :, :, -1: - ] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} - else: - # The clone here is for the same reason as for `position_ids`. - model_inputs = { - "input_ids": input_ids.clone(memory_format=torch.contiguous_format), - "inputs_embeds": None, - } - - if num_logits_to_keep is not None: - model_inputs["num_logits_to_keep"] = num_logits_to_keep - - # keep cache_position implementation as None for HPU - cache_position = None - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "token_idx": token_idx, - "labels": labels, - "return_dict": kwargs.get("return_dict"), - "full_text_row_masked_out_mask": full_text_row_masked_out_mask, - "use_flash_attention": use_flash_attention, - "cross_attention_mask": cross_attention_mask, - "cross_attention_states": cross_attention_states, - "output_attentions": output_attentions, - "flash_attention_recompute": flash_attention_recompute, - } - ) - - return model_inputs diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py index 441b0016..ac1578e9 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -45,11 +45,17 @@ from text_generation_server.layers.attention import ( from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( Qwen2Model, ) +from habana_frameworks.torch.hpex.kernels import ( + RotaryPosEmbeddingMode, + apply_rotary_pos_emb, +) +import habana_frameworks.torch as htorch # Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py from typing import Union from transformers.feature_extraction_utils import BatchFeature -from transformers.image_utils import ImageInput, VideoInput +from transformers.image_utils import ImageInput +from transformers.video_utils import VideoInput from transformers.processing_utils import ( ProcessingKwargs, ProcessorMixin, @@ -375,28 +381,6 @@ class Qwen2_5_VLConfig(PretrainedConfig): super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb_vision( - tensor: torch.Tensor, freqs: torch.Tensor -) -> torch.Tensor: - orig_dtype = tensor.dtype - tensor = tensor.float() - cos = freqs.cos() - sin = freqs.sin() - cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - output = (tensor * cos) + (rotate_half(tensor) * sin) - output = output.to(orig_dtype) - return output - - class Qwen2_5VLAttention(nn.Module): def __init__(self, *, prefix, config, weights): super().__init__() @@ -426,7 +410,8 @@ class Qwen2_5VLAttention(nn.Module): self, hidden_state: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, max_seqlen: int, ) -> torch.Tensor: # apply the qkv linear layer to the hidden state @@ -444,29 +429,37 @@ class Qwen2_5VLAttention(nn.Module): query = query.view(*_shape) key = key.view(*_shape) value = value.view(*_shape) - # apply rotary positional embeddings - query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze( - 0 - ) - key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0) + rope_mode = RotaryPosEmbeddingMode.BLOCKWISE + rotary_dim = cos.shape[-1] + query_rot = query[..., :rotary_dim] + query_pass = query[..., rotary_dim:] + query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode) + query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query.shape)) - # calc maximum sequence length for any batch - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - causal = False + key_rot = key[..., :rotary_dim] + key_pass = key[..., rotary_dim:] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) + key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape)) # execute sdpa - query = query.unsqueeze(0).transpose(1, 2) - key = key.unsqueeze(0).transpose(1, 2) - value = value.unsqueeze(0).transpose(1, 2) + causal = False + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) fsdpa_op = ModuleFusedSDPA(FusedSDPA) + attention_mask = torch.zeros( + [1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool + ) + for i in range(1, len(cu_seqlens)): + attention_mask[ + :, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i] + ] = True attn_output = fsdpa_op( query, key, value, - attn_mask=None, + attn_mask=attention_mask, dropout_p=0.0, is_causal=causal, scale=None, @@ -474,7 +467,7 @@ class Qwen2_5VLAttention(nn.Module): recompute_mode=None, valid_sequence_lengths=None, ) - attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous() + attn_output = attn_output.transpose(0, 1) # reshape output to original dimensions attn_output = attn_output.reshape(hidden_state.shape[0], -1) @@ -533,11 +526,9 @@ class Qwen2_5VLVisionBlock(nn.Module): weights=weights, ) - def forward( - self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen - ) -> torch.Tensor: + def forward(self, hidden_states, cu_seqlens, cos, sin, max_seqlen) -> torch.Tensor: norm1_out, _ = self.norm1(hidden_states) - attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen) + attn_out = self.attn(norm1_out, cu_seqlens, cos, sin, max_seqlen) hidden_states = hidden_states + attn_out norm2_out, _ = self.norm2(hidden_states) mlp_out = self.mlp(norm2_out) @@ -608,7 +599,7 @@ class Qwen2_5VisionModel(nn.Module): config=config, weights=weights, ) - # import ipdb; ipdb.set_trace() + self.temporal_patch_size = config.temporal_patch_size self.spatial_patch_size = config.spatial_patch_size self.in_channels = config.in_channels @@ -736,6 +727,10 @@ class Qwen2_5VisionModel(nn.Module): ) rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device) + cos = rotary_pos_emb.cos() + sin = rotary_pos_emb.sin() + cos = torch.cat((cos, cos), dim=-1).unsqueeze(1) + sin = torch.cat((sin, sin), dim=-1).unsqueeze(1) cu_window_seqlens = torch.tensor( cu_window_seqlens, @@ -754,6 +749,9 @@ class Qwen2_5VisionModel(nn.Module): max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]) # iterately apply the blocks to the hidden states + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for layer_num, block in enumerate(self.blocks): # NOTE: qwen2_5_vl.py has a concept of full attention blocks # that are applied at specific layers. @@ -762,9 +760,9 @@ class Qwen2_5VisionModel(nn.Module): else: cu_seqlens_now = cu_window_seqlens - hidden_states = block( - hidden_states, cu_seqlens_now, rotary_pos_emb, max_seqlen - ) + hidden_states = block(hidden_states, cu_seqlens_now, cos, sin, max_seqlen) + if lazy_mode: + htorch.core.mark_step() # apply the final patch merger to the hidden states hidden_states = self.merger(hidden_states) @@ -886,9 +884,6 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): full_llm_pos_ids_list = [ item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist ] - # import ipdb - - # ipdb.set_trace() max_s = full_llm_pos_ids_list[-1].max() + 1 final_text_len = input_ids_len - vision_ends[-1] if final_text_len > 0: @@ -900,9 +895,33 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): ) return position_ids - def forward( + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0) + return image_embeds + + def get_inputs_embeds( self, input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + ): + inputs_embeds = self.embed_tokens(input_ids) + + # apply the visual model to the pixel values if they are provided + if vision_embeds is not None: + mask = torch.where(input_ids == self.image_token_id) + inputs_embeds[mask] = vision_embeds + + return inputs_embeds + + def forward( + self, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -910,26 +929,10 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor], - pixel_values: torch.FloatTensor = None, - image_grid_thw: Optional[torch.LongTensor] = None, - # Unused in this model - video_grid_thw: Optional[torch.LongTensor] = None, - pixel_attention_mask=None, - image_sizes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, image_indices=None, ): - inputs_embeds = self.embed_tokens(input_ids) - - # apply the visual model to the pixel values if they are provided - if pixel_values is not None and len(pixel_values) > 0: - if pixel_values is not None: - image_embeds = self.visual( - pixel_values, grid_thw=image_grid_thw - ).squeeze(0) - mask = torch.where(input_ids == self.image_token_id) - inputs_embeds[mask] = image_embeds hidden_states = self.text_model( inputs_embeds=inputs_embeds, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 47ae2ac9..96acef31 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -44,28 +44,11 @@ from text_generation_server.layers.attention import ( from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( Qwen2Model, ) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb_vision( - tensor: torch.Tensor, freqs: torch.Tensor -) -> torch.Tensor: - orig_dtype = tensor.dtype - tensor = tensor.float() - cos = freqs.cos() - sin = freqs.sin() - cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - output = (tensor * cos) + (rotate_half(tensor) * sin) - output = output.to(orig_dtype) - return output +from habana_frameworks.torch.hpex.kernels import ( + RotaryPosEmbeddingMode, + apply_rotary_pos_emb, +) +import habana_frameworks.torch as htorch class Qwen2VLAttention(nn.Module): @@ -96,7 +79,8 @@ class Qwen2VLAttention(nn.Module): self, hidden_state: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, max_seqlen: int, ) -> torch.Tensor: # apply the qkv linear layer to the hidden state @@ -116,27 +100,36 @@ class Qwen2VLAttention(nn.Module): value = value.view(*_shape) # apply rotary positional embeddings - query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze( - 0 - ) - key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0) + rope_mode = RotaryPosEmbeddingMode.BLOCKWISE + rotary_dim = cos.shape[-1] + query_rot = query[..., :rotary_dim] + query_pass = query[..., rotary_dim:] + query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode) + query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query.shape)) - # calc maximum sequence length for any batch - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - causal = False + key_rot = key[..., :rotary_dim] + key_pass = key[..., rotary_dim:] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) + key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape)) # execute sdpa - query = query.unsqueeze(0).transpose(1, 2) - key = key.unsqueeze(0).transpose(1, 2) - value = value.unsqueeze(0).transpose(1, 2) + causal = False + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) fsdpa_op = ModuleFusedSDPA(FusedSDPA) + attention_mask = torch.zeros( + [1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool + ) + for i in range(1, len(cu_seqlens)): + attention_mask[ + :, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i] + ] = True attn_output = fsdpa_op( query, key, value, - attn_mask=None, + attn_mask=attention_mask, dropout_p=0.0, is_causal=causal, scale=None, @@ -144,7 +137,7 @@ class Qwen2VLAttention(nn.Module): recompute_mode=None, valid_sequence_lengths=None, ) - attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous() + attn_output = attn_output.transpose(0, 1) # reshape output to original dimensions attn_output = attn_output.reshape(hidden_state.shape[0], -1) attn_output = self.proj(attn_output) @@ -193,11 +186,9 @@ class Qwen2VLVisionBlock(nn.Module): weights=weights, ) - def forward( - self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen - ) -> torch.Tensor: + def forward(self, hidden_states, cu_seqlens, cos, sin, max_seqlen) -> torch.Tensor: norm1_out, residual = self.norm1(hidden_states) - attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen) + attn_out = self.attn(norm1_out, cu_seqlens, cos, sin, max_seqlen) hidden_states = attn_out + residual norm2_out, residual = self.norm2(hidden_states) hidden_states = hidden_states + self.mlp(norm2_out) @@ -330,6 +321,11 @@ class Qwen2VisionModel(nn.Module): rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype) + cos = rotary_pos_emb.cos() + sin = rotary_pos_emb.sin() + cos = torch.cat((cos, cos), dim=-1).unsqueeze(1) + sin = torch.cat((sin, sin), dim=-1).unsqueeze(1) + # create a cu_seqlens tensor to be used in the attention mask cu_seqlens = torch.repeat_interleave( grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] @@ -337,8 +333,13 @@ class Qwen2VisionModel(nn.Module): cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]) # iterately apply the blocks to the hidden states + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for block in self.blocks: - hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen) + hidden_states = block(hidden_states, cu_seqlens, cos, sin, max_seqlen) + if lazy_mode: + htorch.core.mark_step() # apply the final patch merger to the hidden states hidden_states = self.merger(hidden_states) @@ -474,9 +475,33 @@ class Qwen2VLForConditionalGeneration(nn.Module): ) return position_ids - def forward( + def get_vision_embeds( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0) + return image_embeds + + def get_inputs_embeds( self, input_ids: torch.Tensor, + vision_embeds: torch.Tensor = None, + ): + inputs_embeds = self.embed_tokens(input_ids) + + # apply the visual model to the pixel values if they are provided + if vision_embeds is not None: + mask = torch.where(input_ids == self.image_token_id) + inputs_embeds[mask] = vision_embeds + + return inputs_embeds + + def forward( + self, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -484,26 +509,10 @@ class Qwen2VLForConditionalGeneration(nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor], - pixel_values: torch.FloatTensor = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - pixel_attention_mask=None, - image_sizes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, image_indices=None, ): - inputs_embeds = self.embed_tokens(input_ids) - - # apply the visual model to the pixel values if they are provided - if pixel_values is not None and len(pixel_values) > 0: - if pixel_values is not None: - image_embeds = self.visual( - pixel_values, grid_thw=image_grid_thw - ).squeeze(0) - mask = torch.where(input_ids == self.image_token_id) - inputs_embeds[mask] = image_embeds - hidden_states = self.text_model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index ad585172..9883a73f 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -53,6 +53,7 @@ from text_generation_server.models.globals import ( ) from text_generation_server.layers.attention import ( KVCache, + KVCompressCache, Seqlen, HPUPagedAttentionMetadata, trim_attn_metadata, @@ -68,11 +69,14 @@ from text_generation_server.utils.import_utils import ( synchronize, get_free_memory, ) - +from text_generation_server.utils.prefill_chunking import ( + get_max_prefill_tokens, +) import vllm_hpu_extension.environment as environment import habana_frameworks.torch as htorch import itertools -from vllm_hpu_extension.bucketing import HPUBucketingContext +from vllm_hpu_extension.bucketing.common import get_bucketing_context +from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes tracer = trace.get_tracer(__name__) @@ -149,19 +153,14 @@ def prepare_for_decode( block_list_device = _async_h2d_tensor_copy(block_list) block_groups_device = _async_h2d_tensor_copy(block_groups) block_usage_device = _async_h2d_tensor_copy(block_usage) - block_mapping = torch.nn.functional.one_hot( - block_groups_device, num_classes=batch_size - ) - mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0) - mask = mask >= block_usage.unsqueeze(-1) - attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) + return trim_attn_metadata( HPUPagedAttentionMetadata( block_list=block_list_device, block_groups=block_groups_device, block_usage=block_usage_device, - block_mapping=block_mapping.to(dtype), - attn_bias=attn_bias, + block_mapping=None, + attn_bias=None, ) ) @@ -424,7 +423,7 @@ class FlashCausalLMBatch(Batch): for i, input_ids in enumerate(all_input_ids): all_input_ids_tensor[i, : len(input_ids)] = input_ids - # Create tensors on device + # put on cpu temporarily, move to hpu in prepare_for_prefill all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64) top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64) @@ -628,21 +627,25 @@ class FlashCausalLMBatch(Batch): # Index into tensors input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] - adapter_indices = self.adapter_meta.adapter_indices[indices] input_lengths_tensor = self.input_lengths_tensor[indices] cache_lengths_tensor = self.cache_lengths_tensor[indices] # Move to GPU now that we have the whole tensor slot_indices = slot_indices.to(device) - - adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32) - adapter_meta = AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ) + if self.adapter_meta is not None: + adapter_indices = self.adapter_meta.adapter_indices[indices] + adapter_segments, adapter_segment_indices = find_segments( + adapter_indices + ) + adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32) + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) + else: + adapter_meta = None htorch.core.mark_step() return type(self)( batch_id=self.batch_id, @@ -691,7 +694,9 @@ class FlashCausalLMBatch(Batch): @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": + def concatenate( + cls, batches: List["FlashCausalLMBatch"], padded_total_bs: int = 0 + ) -> "FlashCausalLMBatch": # Batch attributes requests = [] requests_idx_mapping = {} @@ -704,6 +709,7 @@ class FlashCausalLMBatch(Batch): max_length = 0 max_input_length = 0 max_current_length = 0 + ADAPTER_TO_INDEX = get_adapter_to_index() for b in batches: total_batch_size += len(b) max_blocks = max(max_blocks, b.max_blocks) @@ -739,7 +745,10 @@ class FlashCausalLMBatch(Batch): adapter_meta = None adapter_segment_builder = None else: - input_ids = batches[0].input_ids.new_empty(total_batch_size) + if padded_total_bs == batches[0].input_ids.shape[0]: + input_ids = batches[0].input_ids + else: + input_ids = batches[0].input_ids.new_empty(total_batch_size) if ( batches[0].position_ids is not None and batches[0].position_ids.dim() == 2 @@ -757,14 +766,15 @@ class FlashCausalLMBatch(Batch): cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty( total_batch_size ) - total_indices_size = sum( - b.adapter_meta.adapter_indices.shape[0] for b in batches - ) - adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( - total_indices_size - ) - adapter_segment_builder = SegmentConcatBuilder() - adapter_set = set() + if ADAPTER_TO_INDEX: + total_indices_size = sum( + b.adapter_meta.adapter_indices.shape[0] for b in batches + ) + adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( + total_indices_size + ) + adapter_segment_builder = SegmentConcatBuilder() + adapter_set = set() prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty( total_batch_size @@ -772,9 +782,7 @@ class FlashCausalLMBatch(Batch): block_tables_tensor = batches[0].block_tables_tensor.new_zeros( (total_batch_size, max_blocks) ) - all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( - (total_batch_size, max_length) - ) + all_input_ids_tensor = batches[0].all_input_ids_tensor top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( total_batch_size, ) @@ -815,13 +823,14 @@ class FlashCausalLMBatch(Batch): start_index = cumulative_batch_size end_index = cumulative_batch_size + valid_bsize - index = torch.tensor( - list(range(start_index, end_index)), device=batch.input_ids.device - ) + index = torch.tensor(list(range(start_index, end_index)), device="cpu") top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor) - all_input_ids_tensor[ - start_index:end_index, : batch.all_input_ids_tensor.shape[1] - ] = batch.all_input_ids_tensor[:valid_bsize, :max_length] + if i > 0: + all_input_ids_tensor.index_copy_( + 0, + index.to(batch.all_input_ids_tensor.device), + batch.all_input_ids_tensor[:valid_bsize, :], + ) block_tables_tensor[ start_index:end_index, : batch.block_tables_tensor.shape[1] @@ -841,7 +850,10 @@ class FlashCausalLMBatch(Batch): ) if not prefilling: - input_ids.index_copy_(0, index, batch.input_ids[:valid_bsize]) + if padded_total_bs != batches[0].input_ids.shape[0] or i > 0: + input_ids.index_copy_( + 0, index.to(input_ids.device), batch.input_ids[:valid_bsize] + ) position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize]) slot_indices.index_copy_( 0, index, batch.slot_indices + cumulative_slots @@ -852,20 +864,21 @@ class FlashCausalLMBatch(Batch): cache_lengths_tensor.index_copy_( 0, index, batch.cache_lengths_tensor[:valid_bsize] ) - adapter_start_index = cumulative_adapter_indices_size - adapter_end_index = ( - cumulative_adapter_indices_size - + batch.adapter_meta.adapter_indices.shape[0] - ) - adapter_indices[adapter_start_index:adapter_end_index] = ( - batch.adapter_meta.adapter_indices - ) - cumulative_adapter_indices_size = adapter_end_index - adapter_set.update(batch.adapter_meta.adapter_set) - adapter_segment_builder.concat( - batch.adapter_meta.adapter_segments, - batch.adapter_meta.segment_indices, - ) + if ADAPTER_TO_INDEX: + adapter_start_index = cumulative_adapter_indices_size + adapter_end_index = ( + cumulative_adapter_indices_size + + batch.adapter_meta.adapter_indices.shape[0] + ) + adapter_indices[adapter_start_index:adapter_end_index] = ( + batch.adapter_meta.adapter_indices + ) + cumulative_adapter_indices_size = adapter_end_index + adapter_set.update(batch.adapter_meta.adapter_set) + adapter_segment_builder.concat( + batch.adapter_meta.adapter_segments, + batch.adapter_meta.segment_indices, + ) else: if isinstance(batch.input_ids, torch.Tensor): batch.input_ids = batch.input_ids.view(-1, 1).tolist() @@ -908,7 +921,7 @@ class FlashCausalLMBatch(Batch): else: speculative_ids = None - if adapter_segment_builder is not None: + if ADAPTER_TO_INDEX and adapter_segment_builder is not None: adapter_segments, adapter_segment_indices = adapter_segment_builder.build() adapter_meta = AdapterBatchMetadata( adapter_indices=adapter_indices, @@ -955,7 +968,7 @@ class FlashCausalLMBatch(Batch): num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, - adapter_meta=adapter_meta, + adapter_meta=adapter_meta if ADAPTER_TO_INDEX else None, hpu_attn_meta=None, next_token_logits=None, speculative_logits=None, @@ -974,7 +987,6 @@ class FlashCausalLMBatch(Batch): else: padded_bs = self.input_ids.shape[0] slots = self.slots[self.slot_indices] - extra_pad = padded_bs - self.input_ids.shape[0] self.hpu_attn_meta = prepare_for_decode( dtype, @@ -985,17 +997,29 @@ class FlashCausalLMBatch(Batch): padded_bs, bucketing_ctx, ) - self.input_ids = F.pad(self.input_ids, (0, extra_pad), value=0) - self.position_ids = F.pad(self.position_ids, (0, extra_pad), value=1) + self.input_ids = F.pad( + self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=0 + ) + + if self.position_ids.dim() == 2: + # Qwen VL case + self.position_ids = F.pad( + self.position_ids, + (0, 0, 0, padded_bs - self.position_ids.shape[0]), + value=1, + ) + else: + self.position_ids = F.pad( + self.position_ids, (0, padded_bs - self.position_ids.shape[0]), value=1 + ) self.input_lengths_tensor = F.pad( - self.input_lengths_tensor, (0, extra_pad), value=0 + self.input_lengths_tensor, + (0, padded_bs - self.input_lengths_tensor.shape[0]), + value=0, ) self.cache_lengths_tensor = F.pad( - self.cache_lengths_tensor, (0, extra_pad), value=0 - ) - self.all_input_ids_tensor = F.pad( - self.all_input_ids_tensor, - (0, 0, 0, extra_pad), + self.cache_lengths_tensor, + (0, padded_bs - self.cache_lengths_tensor.shape[0]), value=0, ) next_token_chooser_parameters = [] @@ -1015,7 +1039,9 @@ class FlashCausalLMBatch(Batch): fsm_grammar_states, ) - def prepare_for_prefill(self, max_padded_input_len, max_padded_bs): + def prepare_for_prefill( + self, max_padded_input_len, max_padded_bs, max_total_tokens + ): # Prepare values if we need to continue prefilling # Speculation must be ignored while we prefill even with chunking # it simplifies everything @@ -1031,6 +1057,7 @@ class FlashCausalLMBatch(Batch): # need extra pad to match warmup seq extra_pad = max_padded_input_len - self.max_input_length extra_pad_bs = max_padded_bs - len(self) + device = "hpu" if isinstance(self.input_ids, list) and len(self) > 1: input_ids_padded_length = [] input_ids = [] @@ -1041,15 +1068,26 @@ class FlashCausalLMBatch(Batch): input_ids.append(input_id) input_ids_padded_length.append(padded) input_ids = np.concatenate(input_ids, dtype=np.int64) - self.input_ids = torch.tensor(input_ids, dtype=torch.int64) + self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) elif isinstance(self.input_ids, list): input_ids = self.input_ids[0] input_ids_padded_length.append(extra_pad) input_ids = [0] * extra_pad + input_ids - self.input_ids = torch.tensor(input_ids, dtype=torch.int64) + self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) else: - self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0) - input_ids_padded_length.extend([extra_pad] * len(self)) + input_ids = self.input_ids.new_zeros(max_padded_input_len * len(self)) + src_pos = 0 + for i in range(len(self)): + end_pos = (i + 1) * max_padded_input_len + start_pos = end_pos - self.input_lengths[i] + input_ids[start_pos:end_pos] = self.input_ids[ + src_pos : src_pos + self.input_lengths[i] + ] + input_ids_padded_length.append( + max_padded_input_len - self.input_lengths[i] + ) + src_pos += self.input_lengths[i] + self.input_ids = input_ids self.input_ids = F.pad( self.input_ids, (0, extra_pad_bs * max_padded_input_len), value=0 @@ -1239,7 +1277,9 @@ class FlashCausalLMBatch(Batch): self.slot_indices = slot_indices self.prefill_cu_outlens = prefill_cu_outlens - self.prefill_cache_indices = torch.zeros_like(self.input_ids, dtype=torch.bool) + self.prefill_cache_indices = torch.zeros_like( + self.input_ids, dtype=torch.bool, device="cpu" + ) self.prefill_cache_indices[prefill_cache_indices] = True if all_prefill_logprobs: @@ -1272,12 +1312,17 @@ class FlashCausalLMBatch(Batch): self.prefill_next_token_indices = ( self.prefill_next_token_indices + input_ids_padded_length_tensor ) - - self.all_input_ids_tensor = F.pad( - self.all_input_ids_tensor, - (0, 0, 0, extra_pad_bs), - value=0, + all_input_ids_tensor = torch.zeros( + (max_padded_bs, max(max_total_tokens, self.all_input_ids_tensor.shape[-1])), + dtype=torch.int64, + device="hpu", ) + for i in range(len(self)): + all_input_ids_tensor[i, : self.all_input_ids_tensor.shape[-1]] = ( + self.all_input_ids_tensor[i] + ) + self.all_input_ids_tensor = all_input_ids_tensor + next_token_chooser_parameters = [] next_token_chooser_parameters.extend([r.parameters for r in self.requests]) pad_next_token_chooser_parameters(next_token_chooser_parameters, max_padded_bs) @@ -1295,21 +1340,24 @@ class FlashCausalLMBatch(Batch): fsm_grammar_states, ) - if adapter_set: - adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64) - adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - else: - adapter_indices = torch.zeros_like(self.input_ids) - adapter_segments = [0, len(adapter_indices)] - adapter_segment_indices = [len(adapter_indices) - 1] + if ADAPTER_TO_INDEX: + if adapter_set: + adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64) + adapter_segments, adapter_segment_indices = find_segments( + adapter_indices + ) + else: + adapter_indices = torch.zeros_like(self.input_ids) + adapter_segments = [0, len(adapter_indices)] + adapter_segment_indices = [len(adapter_indices) - 1] - adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32) - self.adapter_meta = AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ) + adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32) + self.adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) def __len__(self): return len(self.requests) @@ -1352,6 +1400,8 @@ class FlashCausalLM(Model): ): self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() + if world_size > 1: + self.process_group_cpu = torch.distributed.new_group(backend="gloo") device = torch.device("hpu") dtype = torch.bfloat16 if dtype is None else dtype @@ -1419,7 +1469,7 @@ class FlashCausalLM(Model): raise ValueError("Cannot get the number of key/value heads") self.num_kv_heads = ( num_kv_heads // self.process_group.size() - if num_kv_heads > 1 + if num_kv_heads // self.process_group.size() > 0 else num_kv_heads ) assert self.num_kv_heads > 0 @@ -1427,7 +1477,7 @@ class FlashCausalLM(Model): if head_size is None: # Some models use GQA and different sizes for o_proj # and q_proj, that allows for that. - if hasattr(config, "head_dim"): + if getattr(config, "head_dim", None) is not None: self.head_size = config.head_dim else: self.head_size = config.hidden_size // config.num_attention_heads @@ -1438,15 +1488,20 @@ class FlashCausalLM(Model): self.kv_cache = [] self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype self.bucketing_ctx = None + self.max_total_tokens = None + self.max_input_tokens = None + htorch.core.hpu_set_env() if htorch.utils.internal.is_lazy(): htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True) environment.set_model_config(self.config) self.use_contiguous_pa = ( os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true" ) - self.limit_hpu_graphs = ( - os.environ.get("LIMIT_HPU_GRAPHS", "false").lower() == "true" + self.limit_hpu_graph = ( + os.environ.get("LIMIT_HPU_GRAPH", "false").lower() == "true" ) + self.skip_warmup = os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true" + self.max_seq_len_to_capture = 8192 super().__init__( model_id=model_id, model=model, @@ -1478,16 +1533,27 @@ class FlashCausalLM(Model): ): self.kv_cache = [] empty_cache() - self.kv_cache = [ - KVCache( - num_blocks=num_blocks, - num_heads=num_heads, - head_size=head_size, - dtype=dtype, - device=device, - ) - for _ in range(num_layers) - ] + if self.config.model_type == "deepseek_v3": + self.kv_cache = [ + KVCompressCache( + num_blocks=num_blocks, + head_size=self.config.kv_lora_rank + self.config.qk_rope_head_dim, + dtype=dtype, + device=device, + ) + for _ in range(num_layers) + ] + else: + self.kv_cache = [ + KVCache( + num_blocks=num_blocks, + num_heads=num_heads, + head_size=head_size, + dtype=dtype, + device=device, + ) + for _ in range(num_layers) + ] def warmup( self, @@ -1495,16 +1561,48 @@ class FlashCausalLM(Model): max_input_tokens: Optional[int], max_total_tokens: Optional[int], ): + if os.environ.get("MAX_BATCH_SIZE") is None: + raise RuntimeError( + "MAX_BATCH_SIZE is not set, it should be set in the launcher " + "using `--max-batch-size xxx`" + ) # The warmup batch is the biggest batch we could ever receive self.kv_cache = [] empty_cache() - + self.graphed_buckets = set() # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Calculate the number of blocks that can be allocated with the free memory dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size() - cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size - total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size + if self.config.model_type == "deepseek_v3": + cache_block_size = BLOCK_SIZE * ( + self.config.kv_lora_rank + self.config.qk_rope_head_dim + ) + else: + cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size + cache_block_size = cache_block_size * 2 + total_cache_size = self.num_layers * cache_block_size * dtype_size + free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM) + self.mem_reserved = int(free_memory * (1 - MEMORY_FRACTION)) + graph_reserved_mem = ( + float(os.environ.get("TGI_GRAPH_RESERVED_MEM", "0.1")) + if htorch.utils.internal.is_lazy() + else 0 + ) + mem_used_from_graph = int( + (free_memory - self.mem_reserved) * graph_reserved_mem + ) + log_master( + logger.info, + f"Free memory on device {self.device}: {format_bytes(free_memory)} used_for_graph: {format_bytes(mem_used_from_graph)} ratio {graph_reserved_mem} reserved_for_runtime: {format_bytes(self.mem_reserved)}", + ) + if max_total_tokens is None: + max_total_tokens = sum(batch.input_lengths) + if max_input_tokens is None: + max_input_tokens = max_total_tokens - 1 + + self.max_total_tokens = max_total_tokens + self.max_input_tokens = max_input_tokens try: self.init_kv_cache( batch.num_blocks, @@ -1519,15 +1617,6 @@ class FlashCausalLM(Model): num_tokens = batch.to_pb().current_tokens synchronize(self.device) - free_memory = get_free_memory( - self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM - ) - real_free_memory = get_free_memory(self.device, MEMORY_FRACTION) - log_master( - logger.debug, - f"Free memory {free_memory / 1e9:.2f}GB , (real: {real_free_memory / 1e9:.2f}GB", - ) - _, _batch, _ = self.generate_token([batch]) except Exception: raise RuntimeError( @@ -1536,8 +1625,9 @@ class FlashCausalLM(Model): ) synchronize(self.device) - free_memory = get_free_memory(self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM) - kv_memory = free_memory + free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM) + + kv_memory = free_memory - self.mem_reserved - mem_used_from_graph num_blocks = ( # Leave 5% for some wiggle room int(kv_memory // total_cache_size) @@ -1546,15 +1636,9 @@ class FlashCausalLM(Model): ) log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}") - if max_total_tokens is None: - max_total_tokens = sum(batch.input_lengths) - - if max_input_tokens is None: - max_input_tokens = max_total_tokens - 1 self.kv_cache = [] empty_cache() - self.init_kv_cache( num_blocks, self.num_layers, @@ -1563,56 +1647,177 @@ class FlashCausalLM(Model): self.kv_cache_dtype, self.device, ) - - max_num_seqs = int(os.getenv("MAX_BATCH_SIZE", 128)) - if os.getenv("VLLM_PROMPT_SEQ_BUCKET_MAX") is None: - os.environ["VLLM_PROMPT_SEQ_BUCKET_MAX"] = str(max_input_tokens) - if os.getenv("VLLM_DECODE_BLOCK_BUCKET_MAX") is None: - max_total_blocks = ( - math.ceil(max_total_tokens / BLOCK_SIZE) * max_num_seqs + 1 - ) - os.environ["VLLM_DECODE_BLOCK_BUCKET_MAX"] = str(max_total_blocks) + self.max_batch_prefill_tokens = get_max_prefill_tokens() + max_num_seqs = int(os.getenv("MAX_BATCH_SIZE")) + HPUBucketingContext = get_bucketing_context() + # need to warmup one more step since block is allocated from 1 + block_step = os.getenv("VLLM_DECODE_BLOCK_BUCKET_STEP", BLOCK_SIZE) + max_total_tokens_aligned = math.ceil( + max_total_tokens / BLOCK_SIZE + ) * BLOCK_SIZE + math.ceil(block_step * BLOCK_SIZE / max_num_seqs) + model_max_length = self.tokenizer.model_max_length + max_position_embeddings = getattr( + self.config, "max_position_embeddings", model_max_length + ) self.bucketing_ctx = HPUBucketingContext( max_num_seqs, - os.getenv("PREFILL_MAX_BS", 64), # self.max_num_prefill_seqs, #TODO + max_num_seqs, # self.max_num_prefill_seqs, #TODO BLOCK_SIZE, - num_blocks * BLOCK_SIZE, + max_num_seqs * max_total_tokens_aligned, False, + min(model_max_length, max_position_embeddings), + max_input_tokens, + max_total_tokens_aligned, ) - self.bucketing_ctx.num_hpu_blocks = num_blocks - if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true": - logger.info("skip warmup hpu graph, not recommmended") + max_blocks = max( + BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE + ) + self.bucketing_ctx.num_hpu_blocks = min(max_blocks, num_blocks) + synchronize(self.device) + if self.skip_warmup: + self.bucketing_ctx.generate_prompt_buckets() + self.bucketing_ctx.generate_decode_buckets( + self.bucketing_ctx.num_hpu_blocks + ) + log_master( + logger.info, "skip warmup hpu graph, not recommmended, may cause OOM" + ) del _batch, batch return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens - self.warmup_hpu_graph(batch) del _batch, batch return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens + def log_warmup(self, prefilling, i, max_i, batch_size, seq_len): + free_mem = format_bytes(HabanaMemoryProfiler.current_free_device_memory()) + phase = "Prompt" if prefilling else "Decode" + dim = "seq_len" if prefilling else "num_blocks" + graphed_bucket = (batch_size, seq_len, prefilling) + bypass = graphed_bucket not in self.graphed_buckets + msg = ( + f"[Warmup][{phase}][{i+1}/{max_i}] " + f"batch_size:{batch_size} " + f"{dim}:{seq_len} " + f"bypass:{bypass} " + f"free_mem:{free_mem}" + ) + log_master(logger.info, msg) + + def use_graphs(self, prefill, seq_len, batch_size): + if self.limit_hpu_graph and prefill: + return False + + if self.skip_warmup: + return True + + return (batch_size, seq_len, prefill) in self.graphed_buckets + + def align_workers(self, value, op): + if self.world_size <= 1: + return value + value_t = torch.tensor(value, device="cpu") + torch.distributed.all_reduce(value_t, op=op, group=self.process_group_cpu) + return value_t.item() + def warmup_hpu_graph(self, batch): + prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3")) + free_mem = HabanaMemoryProfiler.current_free_device_memory() + graph_free_mem = free_mem - self.mem_reserved + graph_free_mem = self.align_workers( + graph_free_mem, torch.distributed.ReduceOp.MIN + ) + prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem + decode_available_memory = graph_free_mem - prompt_available_memory + msg = ( + f"Using {format_bytes(graph_free_mem)}" + f"/{format_bytes(free_mem)} " + "of free device memory for HPUGraphs, " + f"{format_bytes(prompt_available_memory)} for prompt and " + f"{format_bytes(decode_available_memory)} for decode " + f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})" + ) + log_master(logger.info, msg) + start_time = time.time() + warmup_shape_count = 0 warmup_times = 3 self.bucketing_ctx.generate_prompt_buckets() - for i, (batch_size, seq_len) in enumerate( - reversed(self.bucketing_ctx.prompt_buckets) - ): - log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") - for index in range(warmup_times): - self.warmup_prefill(seq_len, batch_size, batch) + + def ordering_function_min_tokens(b): + return (b[0] * b[1], b[1], b[0]) + + buckets = list( + sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens) + ) + total_batch_seq = 0.001 + total_mem = 0 + available_mem = prompt_available_memory + for i, (batch_size, seq_len) in enumerate(buckets): + if batch_size * seq_len > self.max_batch_prefill_tokens: + continue + # Graph memory usage is proportional to seq dimension in a batch + batch_seq = batch_size * seq_len + mem_estimate = batch_seq / total_batch_seq * total_mem + graphed_bucket = (batch_size, seq_len, True) + if not ( + mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture + ): + if graphed_bucket not in self.graphed_buckets: + self.graphed_buckets.add(graphed_bucket) + warmup_shape_count += 1 + self.log_warmup(True, i, len(buckets), batch_size, seq_len) + with HabanaMemoryProfiler() as mem_prof: + for index in range(warmup_times): + self.warmup_prefill(seq_len, batch_size, batch) + synchronize(self.device) + used_mem = self.align_workers( + mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX + ) + if graphed_bucket in self.graphed_buckets: + available_mem -= used_mem + total_mem += used_mem + total_batch_seq += batch_seq + + def ordering_function_max_bs(b): + return (-b[0], b[1]) self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) - for i, (batch_size, block_num) in enumerate( - reversed(self.bucketing_ctx.decode_buckets) - ): + buckets = list( + sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs) + ) + free_mem = HabanaMemoryProfiler.current_free_device_memory() + total_batch_seq = 0.001 + total_mem = 0 + available_mem = free_mem - self.mem_reserved + for i, (batch_size, block_num) in enumerate(buckets): if batch_size > block_num: continue - log_master( - logger.info, f"warmup decode bs {batch_size} block_num {block_num}" + # Graph memory usage is proportional to seq dimension in a batch + batch_seq = batch_size + mem_estimate = batch_seq / total_batch_seq * total_mem + graphed_bucket = (batch_size, block_num, False) + if not mem_estimate >= available_mem: + if graphed_bucket not in self.graphed_buckets: + self.graphed_buckets.add(graphed_bucket) + warmup_shape_count += 1 + self.log_warmup(False, i, len(buckets), batch_size, block_num) + with HabanaMemoryProfiler() as mem_prof: + for index in range(warmup_times): + self.warmup_decode(batch_size, block_num, batch) + synchronize(self.device) + used_mem = self.align_workers( + mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX ) - for index in range(warmup_times): - self.warmup_decode(batch_size, block_num, batch) - synchronize(self.device) + if graphed_bucket in self.graphed_buckets: + available_mem -= used_mem + total_mem += used_mem + total_batch_seq += batch_seq + + log_master( + logger.info, + f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", + ) def warmup_prefill( self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch @@ -1643,7 +1848,9 @@ class FlashCausalLM(Model): lm_head_indices = input_lengths - 1 kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs + kwargs["bypass_hpu_graphs"] = not self.use_graphs( + True, prompt_len, batch_size + ) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( @@ -1696,7 +1903,9 @@ class FlashCausalLM(Model): slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype) kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = False + kwargs["bypass_hpu_graphs"] = not self.use_graphs( + False, hpu_attention_meta.block_list.shape[0], batch_size + ) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( input_ids=_async_h2d_tensor_copy(input_ids), @@ -1779,11 +1988,11 @@ class FlashCausalLM(Model): # This makes sure the max_s for the decode pass is correct. max_s = min(self.max_past(), max_s) if batch.prefill_cache_indices is not None: - slots_pad = torch.zeros_like(input_ids) + slots_pad = torch.zeros_like(input_ids, device=slots.device) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad else: - slots_pad = torch.zeros_like(input_ids) + slots_pad = torch.zeros_like(input_ids, device=slots.device) slots_pad[: slots.shape[0]] = slots slots = slots_pad seqlen = Seqlen( @@ -1792,12 +2001,18 @@ class FlashCausalLM(Model): kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = ( - batch.prefilling if self.limit_hpu_graphs else False + batch_size = input_lengths.shape[0] + prompt_len = ( + input_ids.shape[0] // batch_size + if batch.prefilling + else batch.hpu_attn_meta.block_list.shape[0] + ) + kwargs["bypass_hpu_graphs"] = not self.use_graphs( + batch.prefilling, prompt_len, batch_size ) logits, speculative_logits = self.model.forward( - input_ids=_async_h2d_tensor_copy(input_ids), + input_ids=input_ids, position_ids=_async_h2d_tensor_copy(position_ids), cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), kv_cache=kv_cache, @@ -1836,9 +2051,9 @@ class FlashCausalLM(Model): accepted_ids, speculative_ids, ) = batch.next_token_chooser( - _async_h2d_tensor_copy( - batch.all_input_ids_tensor[:, : batch.max_current_length] - ), + batch.all_input_ids_tensor[ + : batch.next_token_logits.shape[0], : batch.max_current_length + ], batch.next_token_logits, speculate, batch.speculative_ids, @@ -1852,15 +2067,29 @@ class FlashCausalLM(Model): accepted_ids, ) if batch.valid_indices is not None: - next_input_ids = next_input_ids.cpu() - next_token_logprobs = next_token_logprobs.cpu() - accepted_ids = accepted_ids.cpu() - batch.all_input_ids_tensor = batch.all_input_ids_tensor[ - batch.valid_indices - ] - next_input_ids = next_input_ids[batch.valid_indices] - next_token_logprobs = next_token_logprobs[batch.valid_indices] - accepted_ids = accepted_ids[batch.valid_indices] + # TODO speculative decoding handling missing + index = torch.arange( + 0, + len(batch.valid_indices), + device=batch.all_input_ids_tensor.device, + ) + batch.all_input_ids_tensor.index_copy_( + 0, index, batch.all_input_ids_tensor[batch.valid_indices] + ) + padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size( + len(batch.valid_indices) + ) + next_input_ids.index_copy_( + 0, index, next_input_ids[batch.valid_indices] + ) + next_input_ids = next_input_ids[:padded_total_bs] + + next_token_logprobs.index_copy_( + 0, index, next_token_logprobs[batch.valid_indices] + ) + accepted_ids.index_copy_( + 0, index, accepted_ids[batch.valid_indices] + ) if speculative_ids is not None: speculative_ids = speculative_ids[batch.valid_indices] batch.top_n_tokens_tensor = batch.top_n_tokens_tensor[ @@ -1894,16 +2123,16 @@ class FlashCausalLM(Model): batch.position_ids = batch.position_ids[indices] batch.slot_indices = batch.slot_indices[indices[: len(batch)]] - batch.adapter_meta.adapter_indices = ( - batch.adapter_meta.adapter_indices[indices] - ) + if batch.adapter_meta is not None: + batch.adapter_meta.adapter_indices = ( + batch.adapter_meta.adapter_indices[indices] + ) # For each member of the batch # Cumulative length - accepted_ids = accepted_ids.cpu() - cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) - torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) - next_input_ids = next_input_ids.cpu() + if batch.speculative_logits is not None: + cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) + torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) for i in range(len(batch)): batch.all_input_ids_tensor[ i, @@ -1912,33 +2141,47 @@ class FlashCausalLM(Model): + batch.input_lengths[i] + accepted_ids[i], ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] + batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] + accepted_ids = accepted_ids.cpu() + if batch.position_ids.dim() == 2: + # Qwen2_vl case: + batch.position_ids += accepted_ids.unsqueeze(-1) + else: + batch.position_ids += accepted_ids + batch.cache_lengths_tensor += ( + batch.input_lengths_tensor + accepted_ids - 1 + ) + batch.input_lengths_tensor = torch.ones_like( + batch.input_lengths_tensor + ) + batch.slot_indices += accepted_ids[: len(batch)] else: index = batch.cache_lengths_tensor + batch.input_lengths_tensor - index = index.to(batch.all_input_ids_tensor) + index = F.pad( + index, (0, next_input_ids.shape[0] - index.shape[0]), value=0 + ) + index = index.to(batch.all_input_ids_tensor.device) batch_idx = torch.arange( 0, - batch.all_input_ids_tensor.shape[0], + index.shape[0], dtype=torch.long, device=batch.all_input_ids_tensor.device, ) batch.all_input_ids_tensor.index_put_( (batch_idx, index.long()), next_input_ids ) - batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] + batch.input_ids = next_input_ids + batch.position_ids += 1 + batch.cache_lengths_tensor += batch.input_lengths_tensor + batch.input_lengths_tensor = torch.ones_like( + batch.input_lengths_tensor + ) + batch.slot_indices += 1 + batch.speculative_ids = speculative_ids - if batch.position_ids.dim() == 2: - # Qwen2_vl case: - batch.position_ids += accepted_ids.unsqueeze(-1) - else: - batch.position_ids += accepted_ids - batch.cache_lengths_tensor += ( - batch.input_lengths_tensor + accepted_ids - 1 - ) - batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor) - batch.slot_indices += accepted_ids[: len(batch)] # Does a HPU <-> CPU sync internally - if prefill: + if prefill and batch.adapter_meta is not None: # adjust segment lengths to account for all request lengths being 1 during decoding adapter_segments, _ = find_segments( batch.adapter_meta.adapter_indices @@ -2008,7 +2251,18 @@ class FlashCausalLM(Model): htorch.core.mark_step() # Stage 2. Prepare new batch for speculative scheduling if len(batches) > 1: - batch = self.batch_type.concatenate(batches) + if self.bucketing_ctx is not None: + total_batch_size = 0 + for b in batches: + total_batch_size += len(b) + padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size( + total_batch_size + ) + batch = self.batch_type.concatenate( + batches, padded_total_bs=padded_total_bs + ) + else: + batch = self.batch_type.concatenate(batches) else: batch = batches[0] prefill = batch.prefilling @@ -2019,40 +2273,48 @@ class FlashCausalLM(Model): batch.max_input_length ), self.bucketing_ctx.get_padded_prompt_batch_size(len(batch)), + self.max_total_tokens, ) else: - batch.prepare_for_prefill(batch.max_input_length, len(batch)) + batch.prepare_for_prefill( + batch.max_input_length, len(batch), self.max_total_tokens + ) else: batch.prepare_for_decode( self.dtype, self.use_contiguous_pa, self.bucketing_ctx ) + if hasattr(self, "set_inputs_embeds") and callable(self.set_inputs_embeds): + self.set_inputs_embeds(batch) prefill_logprobs = batch.prefill_next_token_indices is not None # Update adapter indices for speculative tokens (if present) adapter_meta = batch.adapter_meta - if batch.speculative_ids is not None: - B, speculative_length = batch.speculative_ids.shape - new_length = speculative_length + 1 - adapter_indices = ( - adapter_meta.adapter_indices.unsqueeze(-1) - .expand(B, new_length) - .reshape(-1) - ) - adapter_segments = adapter_meta.adapter_segments * new_length - adapter_meta = AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_meta.adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_meta.segment_indices, - ) + if adapter_meta is not None: + if batch.speculative_ids is not None: + B, speculative_length = batch.speculative_ids.shape + new_length = speculative_length + 1 + adapter_indices = ( + adapter_meta.adapter_indices.unsqueeze(-1) + .expand(B, new_length) + .reshape(-1) + ) + adapter_segments = adapter_meta.adapter_segments * new_length + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_meta.adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_meta.segment_indices, + ) - # Assign pointers to adapter weights - # TODO(travis): don't update this if indices haven't changed - adapter_data = AdapterBatchData.from_meta( - adapter_meta, - self.layer_to_adapter_weights, - prefill, - batch.prefill_head_indices, - ) + # Assign pointers to adapter weights + # TODO(travis): don't update this if indices haven't changed + adapter_data = AdapterBatchData.from_meta( + adapter_meta, + self.layer_to_adapter_weights, + prefill, + batch.prefill_head_indices, + ) + else: + adapter_data = None out, speculative_logits = self.forward(batch, adapter_data) diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index 1776b219..5bd2292e 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -1,7 +1,7 @@ import torch from PIL import Image from io import BytesIO - +from dataclasses import dataclass from opentelemetry import trace from typing import Iterable, Optional, Tuple, List, Type, Dict @@ -23,9 +23,11 @@ from text_generation_server.layers.attention import ( _async_h2d_tensor_copy, ) import habana_frameworks.torch as htorch +import time from text_generation_server.utils.import_utils import ( synchronize, ) +from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes tracer = trace.get_tracer(__name__) @@ -37,6 +39,33 @@ IDEFICS3_FAKE_IMAGE_TOKEN = "" IDEFICS3_GLOBAL_IMG_TOKEN = "" +def prompt_split_image_llama4(aspect_ratio, num_patches_per_chunk): + """ + Create a structured string representation of image tokens + + Args: + num_patches: Number of patches in the image + + Returns: + String with appropriate image tokens + """ + img_string = "<|image_start|>" + ratio_h, ratio_w = aspect_ratio + if ratio_h * ratio_w > 1: + for yy in range(ratio_h): + for xx in range(ratio_w): + img_string += "<|patch|>" * num_patches_per_chunk + if xx < ratio_w - 1: + img_string += "<|tile_x_separator|>" + + img_string += "<|tile_y_separator|>" + img_string += "<|image|>" + img_string += "<|patch|>" * num_patches_per_chunk + img_string += "<|image_end|>" + + return img_string + + # copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60 def _prompt_split_image( *, @@ -90,17 +119,17 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): return height // patch_size, width // patch_size -def image_text_replacement(processor, image_input, config, image_id: int) -> str: +def image_text_replacement(processor, image_input, config) -> str: if config.model_type == "idefics2": image_seq_len = 64 image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}" if processor.image_processor.do_image_splitting: image_str *= 5 - return image_str + return image_str, IDEFICS2_FAKE_TOKEN if config.model_type == "idefics3": # TODO: implement this in a more general way - n_rows = image_input["rows"][0][image_id] - n_cols = image_input["cols"][0][image_id] + n_rows = image_input["rows"][0][0] + n_cols = image_input["cols"][0][0] image_seq_len = int( ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2) @@ -113,35 +142,52 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str image_token=IDEFICS3_IMAGE_TOKEN, global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN, ) - return image_str + return image_str, IDEFICS3_FAKE_IMAGE_TOKEN elif config.model_type == "llava_next": - height, width = image_input["image_sizes"][image_id] + height, width = image_input["image_sizes"][0] num_features = get_number_of_features(height, width, config) log_master( logger.info, f"Found {num_features} features in image of resolution {height}x{width}", ) - return "" * num_features + return "" * num_features, "" elif config.model_type == "paligemma": - return "" * config.text_config.num_image_tokens + return "" * config.text_config.num_image_tokens, "" elif config.model_type == "qwen2_vl": - grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id] + grid_t, grid_h, grid_w = image_input["image_grid_thw"][0] num_pads = grid_t * grid_h * grid_w // 4 padding = "<|image_pad|>" * num_pads - return f"<|vision_start|>{padding}<|vision_end|>" + return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>" elif config.model_type == "qwen2_5_vl": - grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id] + grid_t, grid_h, grid_w = image_input["image_grid_thw"][0] num_pads = grid_t * grid_h * grid_w // 4 padding = "<|image_pad|>" * num_pads - return f"<|vision_start|>{padding}<|vision_end|>" + return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>" elif config.model_type == "gemma3": # TODO: get correct number of features via reviewing the Gemma3 architecture # and calculating the number of image tokens num_pads = 256 padding = "" * num_pads - return f"\n\n{padding}\n\n" + return f"\n\n{padding}\n\n", "" + elif config.model_type == "llama4": + patch_size = config.vision_config.patch_size + pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio + downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2))) + aspect_ratios = image_input["aspect_ratios"][0] + image_height, image_width = image_input["pixel_values"][0].shape[-2:] + + num_patches_per_chunk = int( + (image_height // patch_size) + * (image_width // patch_size) + // downsample_ratio + ) + tokens_for_this_image = prompt_split_image_llama4( + aspect_ratios, num_patches_per_chunk + ) + + return tokens_for_this_image, "<|image_start|>" else: raise RuntimeError(f"Unknown config {config.model_type} for multimodal") @@ -154,6 +200,27 @@ def image_text_replacement_fixup(config, text: str) -> str: return text +def preprocess_text(config, text: str) -> str: + if config.model_type == "paligemma": + return "" + text + "\n" + return text + + +def preprocess_image(config, img): + model_type = config.model_type + + if model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20: + img = img.resize((img.width * 2, img.height * 2)) + if model_type == "paligemma": + img = img.convert("RGB") + + if model_type not in {"llava_next", "gemma3", "llama4"}: + # TODO: check if this is needed + img = [img] + + return img + + def get_unpadded_features( original_height: int, original_width: int, @@ -208,103 +275,259 @@ def get_number_of_features(height: int, width: int, config) -> int: return unpadded_features + newline_features + base_features +def scatter_image_embeds( + embeds: torch.Tensor, is_embed: Optional[torch.Tensor] +) -> torch.Tensor: + if is_embed is None: + return embeds + + placeholders = embeds.new_full( + (is_embed.shape[0], embeds.shape[-1]), + fill_value=torch.nan, + ) + placeholders[is_embed.to(embeds.device)] = embeds + return placeholders + + +def gather_image_embeds( + embeds: torch.Tensor, is_embed: Optional[torch.Tensor] +) -> Optional[torch.Tensor]: + if is_embed is None: + return embeds + sel = embeds[is_embed.to(embeds.device)] + return sel if sel.numel() else None + + +@dataclass +class ImagePositions: + offset: int + length: int + id: int + num_placeholder_tokens: int + is_embed: Optional[torch.Tensor] = None + + class FlashVlmCausalLMBatch(FlashCausalLMBatch): + image_inputs: Optional[List[List[Dict[str, torch.Tensor]]]] + image_positions: Optional[List[List[ImagePositions]]] + encoder_cache: Optional[List[Dict[int, torch.Tensor]]] pixel_values: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]] image_sizes: Optional[List[Tuple[int, int]]] image_grid_thw: Optional[torch.Tensor] + cache_entries_to_free: List[Tuple[int, int]] + has_image_inputs: bool = False + inputs_embeds: Optional[torch.Tensor] = None @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches): - batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches) + def concatenate(cls, batches, padded_total_bs: int = 0): + batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches, padded_total_bs) + batch.image_inputs = [] + batch.image_positions = [] + batch.encoder_cache = [] + for b in batches: + if b.image_inputs is not None: + batch.image_inputs.extend(b.image_inputs) + else: + batch.image_inputs.append(None) + if b.image_positions is not None: + batch.image_positions.extend(b.image_positions) + else: + batch.image_positions.append(None) + if b.encoder_cache is not None: + batch.encoder_cache.extend(b.encoder_cache) + else: + batch.encoder_cache.append(None) + batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None batch.image_grid_thw = None + batch.inputs_embeds = None + # To be filled in prepare_for_prefill + batch.has_image_inputs = False + batch.cache_entries_to_free = [] return batch @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]): + if len(request_ids) == 0: + raise ValueError("Batch must have at least one request") + + image_inputs = [] + image_positions = [] + encoder_cache = [] + + for request_id in request_ids: + idx = self.requests_idx_mapping[request_id] + image_inputs.append(self.image_inputs[idx]) + image_positions.append(self.image_positions[idx]) + encoder_cache.append(self.encoder_cache[idx]) + batch = super().filter(request_ids) batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None batch.image_grid_thw = None + batch.inputs_embeds = None + batch.image_inputs = image_inputs + batch.image_positions = image_positions + batch.encoder_cache = encoder_cache + + # To be filled in prepare_for_prefill + batch.has_image_inputs = False + batch.cache_entries_to_free = [] return batch @classmethod def batch_tokenized_inputs( cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config ): - # Process images first. We need all of them so that the processor - # can make the image splits the same size. And we need the final - # sizes to insert correct number of image tokens. - images = [] + kwargs = {} + if ( + hasattr(processor, "image_processor_class") + and processor.image_processor_class == "Idefics3ImageProcessor" + ): + kwargs["return_row_col_info"] = True + + max_length = 0 + vocab = tokenizer.get_vocab() + + if not hasattr(config, "image_token_index"): + config.image_token_index = config.image_token_id + + batch_tokenized_inputs: List[List[int]] = [] + batch_image_inputs: List[Optional[List[dict]]] = [] + batch_image_positions: List[Optional[List[ImagePositions]]] = [] + for r in requests: + text_parts = [] + image_inputs = [] + image_texts = [] + + image_id = 0 + for chunk in r.input_chunks.chunks: chunk_type = chunk.WhichOneof("chunk") if chunk_type == "text": - pass + text = preprocess_text(config, chunk.text) + text_parts.append(text) elif chunk_type == "image": - image = Image.open(BytesIO(chunk.image.data)) - # qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the - # default warmup image is 20x20 - if config.model_type in {"qwen2_vl", "qwen2_5_vl"}: - if image.width <= 20: - w = image.width * 2 - h = image.height * 2 - image = image.resize((w, h)) + img = Image.open(BytesIO(chunk.image.data)) + img = preprocess_image(config, img) - if config.model_type == "llava_next": - images.append(image) - elif config.model_type == "gemma3": - images.append(image) - else: - images.append([image]) + image_input = processor.image_processor( + [img], return_tensors="pt", **kwargs + ) + image_inputs.append(image_input) + + img_text, img_start_token_str = image_text_replacement( + processor, image_input, config + ) + text_parts.append(img_text) + + image_texts.append([image_id, img_start_token_str, img_text]) + image_id += 1 else: raise RuntimeError(f"Invalid chunk type {chunk_type}") - if images: - kwargs = {} - if ( - hasattr(processor, "image_processor_class") - and processor.image_processor_class == "Idefics3ImageProcessor" - ): - kwargs["return_row_col_info"] = True - - image_inputs = processor.image_processor( - images, return_tensors="pt", **kwargs - ) - else: - image_inputs = None - - batch_tokenized_inputs = [] - max_length = 0 - image_id = 0 - for r in requests: - full_text = "" - for chunk in r.input_chunks.chunks: - chunk_type = chunk.WhichOneof("chunk") - if chunk_type == "text": - full_text += chunk.text - elif chunk_type == "image": - full_text += image_text_replacement( - processor, image_inputs, config, image_id - ) - image_id += 1 - - full_text = image_text_replacement_fixup(config, full_text) + full_text = image_text_replacement_fixup(config, "".join(text_parts)) input_ids = tokenizer( full_text, truncation=True, max_length=r.truncate, - add_special_tokens=r.add_special_tokens, + add_special_tokens=( + r.add_special_tokens if config.model_type != "paligemma" else False + ), )["input_ids"] max_length = max(max_length, len(input_ids)) - batch_tokenized_inputs.append(input_ids) - return batch_tokenized_inputs, image_inputs + if len(image_inputs) > 0: + img_start_token = vocab[image_texts[0][1]] + image_positions = cls.get_image_positions( + input_ids, image_texts, img_start_token, config, tokenizer + ) + else: + image_inputs = None + image_positions = None + + batch_tokenized_inputs.append(input_ids) + batch_image_inputs.append(image_inputs) + batch_image_positions.append(image_positions) + + return batch_tokenized_inputs, batch_image_inputs, batch_image_positions + + @classmethod + def get_image_positions( + cls, + input_ids: List[int], + image_texts: List[Tuple[int, str, str]], + img_start_token: int, + config, + tokenizer: PreTrainedTokenizerBase, + ) -> List[ImagePositions]: + image_positions = [] + num_images = len(image_texts) + + input_ids_t = torch.as_tensor(input_ids) + img_start_token_pos = torch.where(input_ids_t.eq(img_start_token))[0] + num_tokens = input_ids_t.numel() + + last_pos = 0 + for i in range(num_images): + image_id, img_start_token_str, img_text = image_texts[i] + img_text = image_text_replacement_fixup(config, img_text) + + if config.model_type == "gemma3": + img_text = img_text.replace("\n\n", "") + + tokens = tokenizer(img_text, add_special_tokens=False, return_tensors="pt")[ + "input_ids" + ][0] + length = tokens.numel() + + assert ( + length <= num_tokens + ), f"{length} > {num_tokens} Image is truncated, try increasing --max-batch-prefill-tokens" + + pos = torch.searchsorted(img_start_token_pos, last_pos, right=False) + index = img_start_token_pos[pos] + assert torch.equal( + input_ids_t[index : index + length], tokens + ), "Image tokens not found in input_ids" + + is_embed = tokens == config.image_token_index + num_placeholder_tokens = int(is_embed.sum()) + if num_placeholder_tokens == length: + is_embed = None + + pos = ImagePositions( + offset=index, + length=length, + id=image_id, + num_placeholder_tokens=num_placeholder_tokens, + is_embed=is_embed, + ) + + image_positions.append(pos) + last_pos = index + length + + if ( + config.model_type == "idefics2" + and i + 1 != num_images + and input_ids[last_pos] == config.image_token_index + ): + fake_token = last_pos - 1 + fake_token_index = torch.searchsorted( + img_start_token_pos, fake_token, right=False + ) + img_start_token_pos[fake_token_index] = last_pos + image_texts[i + 1][2] = image_texts[i + 1][2][ + len(img_start_token_str) : + ] + + return image_positions @classmethod def from_pb_processor( @@ -316,33 +539,164 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch): dtype: torch.dtype, device: torch.device, ) -> "FlashVlmCausalLMBatch": - batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( - pb.requests, tokenizer, processor, config + batch_tokenized_inputs, image_inputs, image_positions = ( + cls.batch_tokenized_inputs(pb.requests, tokenizer, processor, config) ) batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) - if image_inputs is not None: - batch.pixel_values = image_inputs["pixel_values"].to(device=device) - if "pixel_attention_mask" in image_inputs: - batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to( - device=device - ) - else: - batch.pixel_attention_mask = None - if "image_sizes" in image_inputs: - batch.image_sizes = image_inputs["image_sizes"].to(device=device) - else: - batch.image_sizes = None - if "image_grid_thw" in image_inputs: - batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device) - else: - batch.image_grid_thw = None - else: + batch.image_inputs = image_inputs + batch.image_positions = image_positions + batch.encoder_cache = [{} for _ in range(len(pb.requests))] + if len(image_inputs): batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None batch.image_grid_thw = None return batch + def prepare_for_prefill( + self, max_padded_input_len, max_padded_bs, max_total_tokens + ): + super().prepare_for_prefill( + max_padded_input_len, max_padded_bs, max_total_tokens + ) + + self.has_image_inputs = False + self.cache_entries_to_free = [] + + self.pixel_values = [] + + assert ( + len(self.cache_lengths) + == len(self.input_lengths) + == len(self.prefilling_mask) + ), "Mismatch in lengths of cache_lengths, input_lengths, and prefilling_mask" + + for i, ( + cache_length, + input_length, + request_prefilling, + ) in enumerate( + zip( + self.cache_lengths, + self.input_lengths, + self.prefilling_mask, + ) + ): + if not request_prefilling or self.image_positions[i] is None: + continue + + for image_position in self.image_positions[i]: + if image_position is None: + continue + start_pos = image_position.offset + length = image_position.length + + if start_pos >= cache_length + input_length: + # No encoder input required at this step + break + if start_pos + length <= cache_length: + # The encode input is already processed + continue + + self.has_image_inputs = True + + if image_position.id not in self.encoder_cache[i]: + image_inputs = self.image_inputs[i][image_position.id] + self.pixel_values.append((i, image_position.id, image_inputs)) + + # Remove the image from the image_inputs + self.image_inputs[i][image_position.id] = None + + if not self.has_image_inputs: + self.pixel_values = None + self.pixel_attention_mask = None + self.image_sizes = None + self.image_grid_thw = None + else: + image_grid_thw_list = [ + x[2]["image_grid_thw"] + for x in self.pixel_values + if "image_grid_thw" in x[2] + ] + if image_grid_thw_list: + self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0) + else: + self.image_grid_thw = None + + def update_encoder_cache(self, encoder_outputs, request_id, img_pos): + self.encoder_cache[request_id][img_pos.id] = scatter_image_embeds( + encoder_outputs, img_pos.is_embed + ) + + def gather_vision_embeds(self): + device = self.input_ids.device + chunks = [] + for ( + i, + cache_length, + input_length, + request_prefilling, + ) in zip( + range(len(self.requests)), + self.cache_lengths, + self.input_lengths, + self.prefilling_mask, + ): + if not request_prefilling or self.image_positions[i] is None: + continue + + for image_position in self.image_positions[i]: + if image_position is None: + continue + start_pos = image_position.offset + length = image_position.length + + if start_pos >= cache_length + input_length: + # No encoder input required at this step + break + if start_pos + length <= cache_length: + # The encode input is already processed + continue + + start_idx = max(cache_length - start_pos, 0) + end_idx = min(cache_length - start_pos + input_length, length) + + assert ( + image_position.id in self.encoder_cache[i] + ), f"image_id {image_position.id} not in encoder_cache {self.encoder_cache[i]}" + encoder_output = self.encoder_cache[i][image_position.id] + + is_embed = image_position.is_embed + if is_embed is not None: + is_embed = is_embed[start_idx:end_idx] + + from loguru import logger + + logger.info( + f"image_id {image_position.id} start_idx {start_idx} end_idx {end_idx}, length {length}" + ) + + embeds = gather_image_embeds( + encoder_output[start_idx:end_idx], + is_embed=is_embed, + ) + if embeds is not None: + chunks.append(embeds) + + if end_idx == length: + self.cache_entries_to_free.append((i, image_position.id)) + self.image_positions[i][image_position.id] = None + + if len(chunks) == 0: + return None + return torch.cat(chunks, dim=0).to(device) + + def free_encoder_cache(self): + for i, image_id in self.cache_entries_to_free: + self.encoder_cache[i].pop(image_id, None) + + self.cache_entries_to_free = [] + class FlashVlmCausalLM(FlashCausalLM): def __init__( @@ -354,6 +708,7 @@ class FlashVlmCausalLM(FlashCausalLM): batch_class=FlashVlmCausalLMBatch, revision, trust_remote_code: bool, + support_chunking: bool = False, **kwargs, ): if PREFIX_CACHING: @@ -371,8 +726,7 @@ class FlashVlmCausalLM(FlashCausalLM): model_id=model_id, revision=revision, trust_remote_code=trust_remote_code, - # FIXME: VLM do not work with context chunking yet - support_chunking=False, + support_chunking=support_chunking, **kwargs, ) @@ -423,9 +777,12 @@ class FlashVlmCausalLM(FlashCausalLM): bucketing_ctx=None, ) slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype) + inputs_embeds = self.get_inputs_embeds( + input_ids=input_ids.to(self.device), + ) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( - input_ids=_async_h2d_tensor_copy(input_ids), + inputs_embeds=inputs_embeds, position_ids=_async_h2d_tensor_copy(position_ids), cu_seqlen_prefill=None, kv_cache=self.kv_cache, @@ -433,27 +790,145 @@ class FlashVlmCausalLM(FlashCausalLM): seqlen=trim_seqlen_metadata(seqlen), hpu_attention_meta=hpu_attention_meta, lm_head_indices=None, - pixel_values=None, - pixel_attention_mask=None, - image_sizes=None, - image_grid_thw=None, + attention_mask=None, ) def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch): + free_mem = HabanaMemoryProfiler.current_free_device_memory() + graph_free_mem = free_mem - self.mem_reserved + graph_free_mem = self.align_workers( + graph_free_mem, torch.distributed.ReduceOp.MIN + ) + decode_available_memory = graph_free_mem + msg = ( + f"Using {format_bytes(graph_free_mem)}" + f"/{format_bytes(free_mem)} " + "of free device memory for HPUGraphs, " + f"{format_bytes(decode_available_memory)} for decode " + ) + log_master(logger.info, msg) + start_time = time.time() + warmup_shape_count = 0 warmup_times = 3 + # only warmup decode, for prefill, image pixal size may change, make the warmup useless + def ordering_function_max_bs(b): + return (-b[0], b[1]) + self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) - for i, (batch_size, block_num) in enumerate( - reversed(self.bucketing_ctx.decode_buckets) - ): + buckets = list( + sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs) + ) + total_batch_seq = 0.001 + total_mem = 0 + available_mem = decode_available_memory + for i, (batch_size, block_num) in enumerate(buckets): if batch_size > block_num: continue - log_master( - logger.info, f"warmup decode bs {batch_size} block_num {block_num}" + # Graph memory usage is proportional to seq dimension in a batch + batch_seq = batch_size + mem_estimate = batch_seq / total_batch_seq * total_mem + graphed_bucket = (batch_size, block_num, False) + if not mem_estimate >= available_mem: + if graphed_bucket not in self.graphed_buckets: + self.graphed_buckets.add(graphed_bucket) + warmup_shape_count += 1 + self.log_warmup(False, i, len(buckets), batch_size, block_num) + with HabanaMemoryProfiler() as mem_prof: + for index in range(warmup_times): + self.warmup_decode(batch_size, block_num, batch) + synchronize(self.device) + used_mem = self.align_workers( + mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX ) - for index in range(warmup_times): - self.warmup_decode(batch_size, block_num, batch) - synchronize(self.device) + if graphed_bucket in self.graphed_buckets: + + available_mem -= used_mem + total_mem += used_mem + total_batch_seq += batch_seq + + log_master( + logger.info, + f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", + ) + + def get_vision_embeds( + self, + pixel_values: torch.Tensor, + pixel_attention_mask: torch.Tensor, + image_sizes: torch.Tensor, + image_grid_thw: torch.Tensor, + ): + embeds = self.model.get_vision_embeds( + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_sizes=image_sizes, + image_grid_thw=image_grid_thw, + ) + return embeds + + def get_inputs_embeds( + self, + input_ids: torch.Tensor, + vision_embeds: Optional[torch.Tensor] = None, + ): + return self.model.get_inputs_embeds( + input_ids=input_ids, + vision_embeds=vision_embeds, + ) + + def encode_images(self, batch): + if batch.pixel_values is not None: + device = batch.input_ids.device + for request_id, image_id, image_input in batch.pixel_values: + pixel_values = image_input["pixel_values"].to(device) + + if "pixel_attention_mask" in image_input: + pixel_attention_mask = image_input["pixel_attention_mask"].to( + device + ) + else: + pixel_attention_mask = None + + if "image_sizes" in image_input: + image_sizes = image_input["image_sizes"].to(device) + else: + image_sizes = None + + if "image_grid_thw" in image_input: + image_grid_thw = image_input["image_grid_thw"] + else: + image_grid_thw = None + + encoder_outputs = self.get_vision_embeds( + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_sizes=image_sizes, + image_grid_thw=image_grid_thw, + ) + batch.update_encoder_cache( + encoder_outputs, + request_id, + batch.image_positions[request_id][image_id], + ) + + batch.pixel_values = None + batch.pixel_attention_mask = None + batch.image_sizes = None + + def set_inputs_embeds(self, batch): + if batch.has_image_inputs: + self.encode_images(batch) + vision_embeds = batch.gather_vision_embeds() + batch.has_image_inputs = False + else: + vision_embeds = None + + inputs_embeds = self.get_inputs_embeds( + batch.input_ids, vision_embeds=vision_embeds + ) + + batch.inputs_embeds = inputs_embeds def forward( self, @@ -502,6 +977,7 @@ class FlashVlmCausalLM(FlashCausalLM): position_ids = new_position_ids else: input_ids = batch.input_ids + inputs_embeds = batch.inputs_embeds position_ids = batch.position_ids cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = self.kv_cache @@ -514,10 +990,25 @@ class FlashVlmCausalLM(FlashCausalLM): if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}: if position_ids.dim() == 1 and batch.prefilling: position_ids = self.model.get_position_ids( - input_ids, batch.image_grid_thw + input_ids.cpu(), batch.image_grid_thw ) batch.position_ids = position_ids + attention_mask = None + attention_mask_forward = None + if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None: + attention_mask = self.model.get_attention_mask( + input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True + ) + min_dtype = torch.finfo(self.dtype).min + attention_mask_forward = torch.where(attention_mask, 0, min_dtype).to( + input_ids.device + ) + attention_mask = attention_mask.reshape(-1) + if self.model.config.model_type == "llama4": + attention_mask = (input_ids != 0).long() + attention_mask_forward = attention_mask.view(input_lengths.shape[0], -1) + if cu_seqlen_prefill is None and self.max_past() is not None: # In decode, not prefill, we're actually overwriting the KV-cache # in a circular buffer mode. @@ -526,14 +1017,21 @@ class FlashVlmCausalLM(FlashCausalLM): kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = batch.prefilling - + batch_size = input_lengths.shape[0] + seqlen = ( + input_ids.shape[0] // batch_size + if batch.prefilling + else batch.hpu_attn_meta.block_list.shape[0] + ) + kwargs["bypass_hpu_graphs"] = not self.use_graphs( + batch.prefilling, seqlen, batch_size + ) if batch.prefill_cache_indices is not None: - slots_pad = torch.zeros_like(input_ids) + slots_pad = torch.zeros_like(input_ids, device=slots.device) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad else: - slots_pad = torch.zeros_like(input_ids) + slots_pad = torch.zeros_like(input_ids, device=slots.device) slots_pad[: slots.shape[0]] = slots slots = slots_pad @@ -541,7 +1039,7 @@ class FlashVlmCausalLM(FlashCausalLM): input_lengths=_async_h2d_tensor_copy(input_lengths), ) logits, speculative_logits = self.model.forward( - input_ids=_async_h2d_tensor_copy(input_ids), + inputs_embeds=inputs_embeds, position_ids=_async_h2d_tensor_copy(position_ids), cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), kv_cache=kv_cache, @@ -549,18 +1047,9 @@ class FlashVlmCausalLM(FlashCausalLM): seqlen=trim_seqlen_metadata(seqlen), hpu_attention_meta=batch.hpu_attn_meta, lm_head_indices=_async_h2d_tensor_copy(lm_head_indices), - pixel_values=batch.pixel_values, - pixel_attention_mask=batch.pixel_attention_mask, - image_sizes=batch.image_sizes, - image_grid_thw=batch.image_grid_thw, + attention_mask=attention_mask_forward, **kwargs, ) - if batch.pixel_values is not None: - batch.pixel_values = None - if batch.pixel_attention_mask is not None: - batch.pixel_attention_mask = None - if batch.image_sizes is not None: - batch.image_sizes = None - if batch.image_grid_thw is not None: - batch.image_grid_thw = None + batch.image_grid_thw = None + batch.free_encoder_cache() return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/models/galactica.py b/backends/gaudi/server/text_generation_server/models/galactica.py deleted file mode 100644 index 7c4e462c..00000000 --- a/backends/gaudi/server/text_generation_server/models/galactica.py +++ /dev/null @@ -1,156 +0,0 @@ -import re -import torch -import torch.distributed - - -from transformers import ( - PreTrainedTokenizerBase, -) -from text_generation_server.models.causal_lm import CausalLMBatch -from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import ( - NextTokenChooser, - StoppingCriteria, -) -from text_generation_server.utils.chunks import concat_text_chunks - -# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py - -# we split individual characters inside special tokens like [START_DNA] -CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])") - -# token added to implement a custom sequence tokenization. This token is added at -# corpus cleaning step and removed in pretokenization. The digits are added to increase the chance -# that they do not occur in the corpus. The digits are escaped so that the token does not appear -# literally in the source code in case we ever include it in the training data. -SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E" - - -def _insert_split_marker(m: re.Match): - """ - Applies split marker based on a regex match of special tokens such as - [START_DNA]. - Parameters - ---------- - n : str - Input text to split - Returns - ---------- - str - the text with the split token added - """ - start_token, _, sequence, end_token = m.groups() - sequence = re.sub(r"(.)", rf"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL) - return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}" - - -def escape_custom_split_sequence(text): - """ - Applies custom splitting to the text for GALILEO's tokenization - Parameters - ---------- - text : str - Input text to split - Returns - ---------- - str - the text with the split token added - """ - return CUSTOM_SEQ_RE.sub(_insert_split_marker, text) - - -# END CREDIT - - -class GalacticaCausalLMBatch(CausalLMBatch): - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "GalacticaCausalLMBatch": - inputs = [] - next_token_choosers = [] - stopping_criterias = [] - prefix_offsets = [] - top_n_tokens = [] - read_offsets = [] - requests_idx_mapping = {} - - # Parse batch - max_truncation = 0 - padding_right_offset = 0 - max_decode_tokens = 0 - for i, r in enumerate(pb.requests): - requests_idx_mapping[r.id] = i - # Add escape_custom_split_sequence to the CausalLMBatch logic - inputs.append( - escape_custom_split_sequence(concat_text_chunks(r.input_chunks.chunks)) - ) - next_token_choosers.append( - NextTokenChooser.from_pb(r.parameters, device, tokenizer) - ) - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) - stopping_criterias.append(stopping_criteria) - top_n_tokens.append(r.top_n_tokens) - max_truncation = max(max_truncation, r.truncate) - max_decode_tokens += stopping_criteria.max_new_tokens - padding_right_offset = max( - padding_right_offset, stopping_criteria.max_new_tokens - ) - - tokenized_inputs = tokenizer( - inputs, - return_tensors="pt", - padding=True, - return_token_type_ids=False, - truncation=True, - max_length=max_truncation, - ).to(device) - for _ in pb.requests: - input_len = tokenized_inputs["input_ids"].shape[1] - prefix_offsets.append(0) - read_offsets.append(input_len) - - input_lengths = tokenized_inputs["attention_mask"].sum(1) - max_input_length = input_lengths.max() - - input_ids = tokenized_inputs["input_ids"] - # Allocate maximum attention_mask - attention_mask = input_ids.new_zeros( - (pb.size, max_input_length + padding_right_offset) - ) - # Copy tokenizer attention_mask into fully allocated attention_mask - attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] - - position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 - position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) - all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) - top_n_tokens_tensor = torch.tensor( - top_n_tokens, device=device, dtype=torch.int64 - ) - - max_tokens = len(inputs) * max_input_length + max_decode_tokens - - return cls( - batch_id=pb.id, - requests=pb.requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=None, - all_input_ids=list(all_input_ids), - input_lengths=input_lengths.tolist(), - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - max_input_length=max_input_length.item(), - padding_right_offset=padding_right_offset, - max_tokens=max_tokens, - ) diff --git a/backends/gaudi/server/text_generation_server/models/globals.py b/backends/gaudi/server/text_generation_server/models/globals.py index cd221e14..cdde67ca 100644 --- a/backends/gaudi/server/text_generation_server/models/globals.py +++ b/backends/gaudi/server/text_generation_server/models/globals.py @@ -4,14 +4,14 @@ from loguru import logger from text_generation_server.utils.log import log_master REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"} -ATTENTION = os.getenv("ATTENTION", "default") +ATTENTION = os.getenv("ATTENTION", "paged") # default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0" PREFIX_CACHING = os.getenv("PREFIX_CACHING", "0").lower() in { "1", "true", } log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") -_expected = {"paged", "default"} +_expected = {"paged"} assert ( ATTENTION in _expected ), f"Attention is not valid {ATTENTION}, expected {_expected}" diff --git a/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py b/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py deleted file mode 100644 index 98d7352a..00000000 --- a/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py +++ /dev/null @@ -1,882 +0,0 @@ -from io import BytesIO -from PIL import Image -import torch -import time - -from dataclasses import dataclass -from opentelemetry import trace -from transformers import ( - AutoConfig, - AutoProcessor, - AutoTokenizer, - PreTrainedTokenizerBase, - ProcessorMixin, -) -from typing import Optional, Tuple, List, Type, Dict - -from text_generation_server.models import Model -from text_generation_server.models.types import ( - Batch, - Tokens, - Generation, - GeneratedText, -) -from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling -import torch.distributed -from text_generation_server.models.custom_modeling.idefics_modeling import ( - IdeficsForVisionText2Text, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.quantization import get_loader - -tracer = trace.get_tracer(__name__) - - -@dataclass -class IdeficsCausalLMBatch(Batch): - batch_id: int - requests: List[generate_pb2.Request] - requests_idx_mapping: Dict[int, int] - - # Decoder values - input_ids: torch.Tensor - attention_mask: torch.Tensor - position_ids: torch.Tensor - pixel_values: Optional[torch.Tensor] - image_hidden_states: Optional[torch.Tensor] - image_attention_mask: Optional[torch.Tensor] - past_key_values: Optional[List[Tuple]] - - # All tokens - all_input_ids: List[torch.Tensor] - - # Lengths of all generations present in the batch - input_lengths: List[int] - prefix_offsets: List[int] - read_offsets: List[int] - - # Generation helpers - next_token_choosers: List[NextTokenChooser] - stopping_criterias: List[StoppingCriteria] - - # Metadata used for padding - max_input_length: int - padding_right_offset: int - - # Maximum number of tokens this batch will grow to - max_tokens: int - - # Past metadata - keys_head_dim_last: bool = True - - def to_pb(self) -> generate_pb2.CachedBatch: - return generate_pb2.CachedBatch( - id=self.batch_id, - request_ids=[r.id for r in self.requests], - size=len(self), - max_tokens=self.max_tokens, - ) - - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "IdeficsCausalLMBatch": - raise NotImplementedError - - @classmethod - def from_pb_processor( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - processor: ProcessorMixin, # Hack - config, - dtype: torch.dtype, - device: torch.device, - ) -> "IdeficsCausalLMBatch": - inputs = [] - next_token_choosers = [] - stopping_criterias = [] - prefix_offsets = [] - read_offsets = [] - requests_idx_mapping = {} - - # Parse batch - max_truncation = 0 - padding_right_offset = 0 - max_decode_tokens = 0 - for i, r in enumerate(pb.requests): - requests_idx_mapping[r.id] = i - inputs.append(r.input_chunks.chunks) - next_token_choosers.append( - NextTokenChooser.from_pb(r.parameters, device, tokenizer) - ) - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) - stopping_criterias.append(stopping_criteria) - max_truncation = max(max_truncation, r.truncate) - max_decode_tokens += stopping_criteria.max_new_tokens - padding_right_offset = max( - padding_right_offset, stopping_criteria.max_new_tokens - ) - - # TODO Check impact on idefics - prompts = [] - for inp in inputs: - # Each input is encoded into a list, where each element of this input list is either a string or a URL - prompt = [] - for chunk in inp: - chunk_type = chunk.WhichOneof("chunk") - if chunk_type == "text": - prompt.append(chunk.text) - elif chunk_type == "image": - image = Image.open(BytesIO(chunk.image.data)) - prompt.append(image) - else: - raise RuntimeError(f"Invalid chunk type {chunk_type}") - prompts.append(prompt) - - # The processor replaces the call to tokenizer, and - # a/ takes care of fetching images from the URL - # b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model - tokenized_inputs = processor( - prompts, - return_tensors="pt", - padding=True, - truncation=True, - max_length=max_truncation, - # TODO Check impact on idefics - # add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token - ).to(device) - for _ in pb.requests: - input_len = tokenized_inputs["input_ids"].shape[1] - prefix_offsets.append( - input_len - 5 - ) # To decode without potential fallbacks errors - read_offsets.append( - input_len - ) # To decode without potential fallbacks errors - - input_lengths = tokenized_inputs["attention_mask"].sum(1) - max_input_length = input_lengths.max() - - input_ids = tokenized_inputs["input_ids"] - pixel_values = tokenized_inputs.get("pixel_values", None) - image_hidden_states = None - # Allocate maximum attention_mask - attention_mask = input_ids.new_zeros( - (pb.size, max_input_length + padding_right_offset) - ) - # Copy tokenizer attention_mask into fully allocated attention_mask - attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] - # Do the same for image_attention_mask - if pixel_values is None: - image_attention_mask = None - else: - image_attention_mask = input_ids.new_zeros( - ( - pb.size, - max_input_length + padding_right_offset, - pixel_values.size(1), - ) - ) - image_attention_mask[:, :max_input_length, :] = tokenized_inputs[ - "image_attention_mask" - ] - - position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 - position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) - all_input_ids = tokenized_inputs["input_ids"].T.split( - 1, dim=1 - ) # It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list - - max_tokens = len(inputs) * (max_input_length + max_decode_tokens) - - return cls( - batch_id=pb.id, - requests=pb.requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - pixel_values=pixel_values, - image_hidden_states=image_hidden_states, - image_attention_mask=image_attention_mask, - past_key_values=None, - all_input_ids=list(all_input_ids), - input_lengths=input_lengths.tolist(), - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - max_input_length=max_input_length.item(), - padding_right_offset=padding_right_offset, - max_tokens=max_tokens, - ) - - @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]: - # It deletes requests from the batch. For instance when client lost connection - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") - if len(request_ids) == len(self): - return self - - keep_indices = [] - - # New values after filtering - requests_idx_mapping = {} - requests = [] - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - max_input_length = 0 - - next_token_choosers = [] - stopping_criterias = [] - - total_remaining_decode_tokens = 0 - new_padding_right_offset = 0 - - for i, request_id in enumerate(request_ids): - idx = self.requests_idx_mapping[request_id] - requests_idx_mapping[request_id] = i - keep_indices.append(idx) - - requests.append(self.requests[idx]) - prefix_offsets.append(self.prefix_offsets[idx]) - read_offsets.append(self.read_offsets[idx]) - all_input_ids.append(self.all_input_ids[idx]) - - request_input_length = self.input_lengths[idx] - input_lengths.append(request_input_length) - max_input_length = max(max_input_length, request_input_length) - - next_token_choosers.append(self.next_token_choosers[idx]) - stopping_criteria = self.stopping_criterias[idx] - stopping_criterias.append(stopping_criteria) - remaining_decode_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - ) - total_remaining_decode_tokens += remaining_decode_tokens - new_padding_right_offset = max( - new_padding_right_offset, remaining_decode_tokens - ) - - # Apply indices to input_ids, attention mask, past key values and other items that need to be cached - input_ids = self.input_ids[keep_indices] - position_ids = self.position_ids[keep_indices] - self.attention_mask = self.attention_mask[ - keep_indices, - -(self.padding_right_offset + max_input_length) : ( - self.attention_mask.shape[1] - self.padding_right_offset - ) - + new_padding_right_offset, - ] - # Do the same for pixel_values and image_attention_mask - pixel_values = self.pixel_values[keep_indices] - self.image_attention_mask = self.image_attention_mask[ - keep_indices, - -(self.padding_right_offset + max_input_length) : ( - self.image_attention_mask.shape[1] - self.padding_right_offset - ) - + new_padding_right_offset, - :, - ] - if self.image_hidden_states is None: - image_hidden_states = None - else: - image_hidden_states = self.image_hidden_states[keep_indices] - - # Ensure that past_key_values tensors can be updated in-place - if type(self.past_key_values[0]) is tuple: - self.past_key_values = [list(layer) for layer in self.past_key_values] - - # Update tensors in-place to allow incremental garbage collection - past_kv_length = max_input_length - 1 - for layer in self.past_key_values: - past_keys, past_values = layer - if len(past_keys.shape) == 3: - # Force past to be of dim [self_size, num_heads, ...] for easy indexing - past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:]) - past_values = past_values.view(len(self), -1, *past_values.shape[-2:]) - if self.keys_head_dim_last: - layer[0] = past_keys[keep_indices, :, -past_kv_length:, :] - else: - layer[0] = past_keys[keep_indices, :, :, -past_kv_length:] - del past_keys - layer[1] = past_values[keep_indices, :, -past_kv_length:, :] - del past_values - - max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens - - self.requests = requests - self.requests_idx_mapping = requests_idx_mapping - self.input_ids = input_ids - self.pixel_values = pixel_values - self.image_hidden_states = image_hidden_states - self.position_ids = position_ids - self.all_input_ids = all_input_ids - self.input_lengths = input_lengths - self.prefix_offsets = prefix_offsets - self.read_offsets = read_offsets - self.next_token_choosers = next_token_choosers - self.stopping_criterias = stopping_criterias - self.max_input_length = max_input_length - self.padding_right_offset = new_padding_right_offset - self.max_tokens = max_tokens - - return self - - @classmethod - @tracer.start_as_current_span("concatenate") - def concatenate( - cls, batches: List["IdeficsCausalLMBatch"] - ) -> "IdeficsCausalLMBatch": - # It adds new requests to the batch - # Used for padding - total_batch_size = 0 - max_input_length = 0 - max_num_images = 0 - padding_right_offset = 0 - for batch in batches: - total_batch_size += len(batch) - max_input_length = max(max_input_length, batch.max_input_length) - max_num_images = max(max_num_images, batch.pixel_values.size(1)) - padding_right_offset = max(padding_right_offset, batch.padding_right_offset) - - # Batch attributes - requests = [] - requests_idx_mapping = {} - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - next_token_choosers = [] - stopping_criterias = [] - max_tokens = 0 - - # Batch tensors - input_ids = None - attention_mask = None - position_ids = None - pixel_values = None - image_hidden_states = None - image_attention_mask = None - past_key_values = [] - - # Used for slicing correctly inside the tensors - # Equivalent to a cumsum on batch sizes - start_index = 0 - for i, batch in enumerate(batches): - requests.extend(batch.requests) - input_lengths.extend(batch.input_lengths) - prefix_offsets.extend(batch.prefix_offsets) - read_offsets.extend(batch.read_offsets) - all_input_ids.extend(batch.all_input_ids) - next_token_choosers.extend(batch.next_token_choosers) - stopping_criterias.extend(batch.stopping_criterias) - - if i == 0: - requests_idx_mapping = batch.requests_idx_mapping - else: - # We need to offset the mapping for each batch by the cumulative batch size - for k, v in batch.requests_idx_mapping.items(): - requests_idx_mapping[k] = v + start_index - - # Slicing end index for this batch - end_index = start_index + len(batch) - - # We only concatenate batches that did at least one step - if batch.past_key_values is None: - raise ValueError("only concatenate prefilled batches") - - # Create empty tensor - # input_ids is always of shape [batch_size, 1] - # We do not need to pad it - if input_ids is None: - input_ids = batch.input_ids.new_empty((total_batch_size, 1)) - # Copy to correct indices - input_ids[start_index:end_index] = batch.input_ids - - # Create padded tensor - if attention_mask is None: - attention_mask = batch.attention_mask.new_zeros( - (total_batch_size, max_input_length + padding_right_offset), - ) - - curr_batch_max_num_images = batch.pixel_values.size(1) - if pixel_values is None: - pixel_values = batch.pixel_values.new_zeros( - (total_batch_size, max_num_images, 3, 224, 224) - ) - pixel_values[start_index:end_index, :curr_batch_max_num_images] = ( - batch.pixel_values - ) - - if image_attention_mask is None: - image_attention_mask = batch.image_attention_mask.new_zeros( - ( - total_batch_size, - max_input_length + padding_right_offset, - max_num_images, - ) - ) - - # We need to slice the attention mask to remove padding from previous steps - # and to remove unused allocated space - left_offset = max_input_length - batch.max_input_length - batch_left_offset = ( - batch.attention_mask.shape[1] - - batch.max_input_length - - batch.padding_right_offset - ) - attention_mask[ - start_index:end_index, - left_offset:-padding_right_offset, - ] = batch.attention_mask[ - :, - batch_left_offset : -batch.padding_right_offset, - ] - image_attention_mask[ - start_index:end_index, - left_offset:-padding_right_offset, - :curr_batch_max_num_images, - ] = batch.image_attention_mask[ - :, batch_left_offset : -batch.padding_right_offset, : - ] - - # Create empty tensor - # position_ids is always of shape [batch_size, 1] - if position_ids is None: - position_ids = batch.position_ids.new_empty((total_batch_size, 1)) - position_ids[start_index:end_index] = batch.position_ids - - # Shenanigans to get dimensions because BLOOM outputs a past with a different shape - # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] - # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] - # And ensure that we can update tensors in-place - if isinstance(batch.past_key_values[0], tuple): - batch.past_key_values = [ - [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] - for layer in batch.past_key_values - ] - elif len(batch.past_key_values[0][0].shape) == 3: - for layer in batch.past_key_values: - for k, t in enumerate(layer): - layer[k] = t.view(len(batch), -1, *t.shape[-2:]) - - # Add eventual padding tokens that were added while concatenating - max_tokens += batch.max_tokens + ( - max_input_length - batch.max_input_length - ) * len(batch) - - start_index = end_index - - first_past_kvs = batches[0].past_key_values - _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape - - padded_past_values_shape = ( - total_batch_size, - num_heads, - max_input_length - 1, - head_dim, - ) - - if batches[0].keys_head_dim_last: - padded_past_keys_shape = padded_past_values_shape - else: - # seq_length is last for BLOOM - padded_past_keys_shape = ( - total_batch_size, - num_heads, - head_dim, - max_input_length - 1, - ) - - # Iterate over attention layers - # Concatenate past key values layer by layer to allow incremental garbage collection - for j in range(len(first_past_kvs)): - padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape) - start_index = 0 - for batch in batches: - past_keys = batch.past_key_values[j][0] - # Clear reference to the original tensor - batch.past_key_values[j][0] = None - - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the keys to remove the padding from previous batches - past_seq_len = batch.max_input_length - 1 - if batch.keys_head_dim_last: - padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = ( - past_keys[:, :, -past_seq_len:, :] - ) - else: - # BLOOM case - padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = ( - past_keys[:, :, :, -past_seq_len:] - ) - del past_keys - - start_index = end_index - - padded_past_values = first_past_kvs[j][1].new_zeros( - padded_past_values_shape - ) - start_index = 0 - for batch in batches: - past_values = batch.past_key_values[j][1] - # Clear reference to the original tensor - batch.past_key_values[j][1] = None - - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the past values to remove the padding from previous batches - past_seq_len = batch.max_input_length - 1 - padded_past_values[start_index:end_index, :, -past_seq_len:, :] = ( - past_values[:, :, -past_seq_len:, :] - ) - del past_values - - # Update values - start_index = end_index - - past_key_values.append([padded_past_keys, padded_past_values]) - - return cls( - batch_id=batches[0].batch_id, - requests=requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - pixel_values=pixel_values, - image_hidden_states=image_hidden_states, - image_attention_mask=image_attention_mask, - past_key_values=past_key_values, - all_input_ids=all_input_ids, - input_lengths=input_lengths, - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - max_input_length=max_input_length, - padding_right_offset=padding_right_offset, - keys_head_dim_last=batches[0].keys_head_dim_last, - max_tokens=max_tokens, - ) - - def __len__(self): - return len(self.requests) - - -class IdeficsCausalLM(Model): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.quantize = quantize - self.process_group, rank, world_size = initialize_torch_distributed() - device = torch.device("hpu") - dtype = torch.bfloat16 if dtype is None else dtype - self.device, self.dtype = device, dtype - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) - config.quantize = quantize - config.speculator = speculator - config.vision_config.quantize = quantize - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - self.processor = AutoProcessor.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - weights_loader = get_loader( - quantize=quantize, model_id=model_id, revision=revision - ) - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device=device, - dtype=dtype, - process_group=self.process_group, - weights_loader=weights_loader, - ) - - model = IdeficsForVisionText2Text(config, weights) - - self.config = config - - torch.distributed.barrier(group=self.process_group) - super().__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - @property - def batch_type(self) -> Type[IdeficsCausalLMBatch]: - return IdeficsCausalLMBatch - - def forward( - self, - input_ids, - attention_mask, - position_ids, - pixel_values, - image_hidden_states, - image_attention_mask, - past_key_values: Optional = None, - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - # Model Forward - kwargs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "pixel_values": pixel_values, - "image_hidden_states": image_hidden_states, - "image_attention_mask": image_attention_mask, - "past_key_values": past_key_values, - "use_cache": True, - "return_dict": True, - } - if self.has_position_ids: - kwargs["position_ids"] = position_ids - - outputs, speculative_logits = self.model.forward(**kwargs) - return ( - outputs.logits, - speculative_logits, - outputs.past_key_values, - outputs.image_hidden_states, - ) - - @tracer.start_as_current_span("generate_token") - def generate_token( - self, batch: IdeficsCausalLMBatch - ) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch], Tuple[int, int]]: - start = time.time_ns() - # slice the attention mask to the correct shape - attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] - if batch.image_attention_mask is None: - image_attention_mask = None - else: - if batch.input_ids.size(1) == 1: - # THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images), - # but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension - # this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated - # token need to attend to the encoder hidden states (i.e. the vision encoder) - # Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic - image_attention_mask = batch.image_attention_mask[ - :, -(batch.padding_right_offset + 1) - ].unsqueeze(1) - else: - image_attention_mask = batch.image_attention_mask[ - :, : -batch.padding_right_offset - ] - - logits, speculative_logits, past, image_hidden_states = self.forward( - input_ids=batch.input_ids, - attention_mask=attention_mask, - position_ids=batch.position_ids, - pixel_values=batch.pixel_values, - image_hidden_states=batch.image_hidden_states, - image_attention_mask=image_attention_mask, - past_key_values=batch.past_key_values, - ) - # Hardcoded remove image tokens - logits[:, 32000:32001] = torch.finfo(logits.dtype).min - - start_decode = time.time_ns() - - # Results - generations: List[Generation] = [] - stopped = True - - # Zipped iterator - iterator = zip( - batch.requests, - batch.input_lengths, - batch.prefix_offsets, - batch.read_offsets, - logits, - batch.next_token_choosers, - batch.stopping_criterias, - batch.all_input_ids, - ) - - # For each member of the batch - for i, ( - request, - input_length, - prefix_offset, - read_offset, - logits, - next_token_chooser, - stopping_criteria, - all_input_ids, - ) in enumerate(iterator): - # Select next token - next_token_id, logprobs = next_token_chooser( - all_input_ids.view(1, -1), logits[-1:, :] - ) - - # Append next token to all tokens - all_input_ids = torch.cat([all_input_ids, next_token_id]) - new_input_length = input_length + 1 - - # Generated token - next_token_logprob = logprobs[-1, next_token_id] - next_token_id_squeezed = next_token_id.squeeze() - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids[:, 0], prefix_offset, read_offset - ) - - # Evaluate stopping criteria - stop, reason = stopping_criteria( - next_token_id_squeezed, - next_token_text, - ) - - if not stop: - stopped = False - - # Shard generations - # All generations will be appended in the rust sharded client - if i % self.world_size == self.rank: - if stop: - # Decode generated tokens - output_text, _, _ = self.decode_token( - all_input_ids[:, 0], - prefix_offset=len(all_input_ids) - - stopping_criteria.current_tokens - - 1, - read_offset=len(all_input_ids) - - stopping_criteria.current_tokens, - skip_special_tokens=True, - ) - # Get seed - if isinstance(next_token_chooser.choice, Sampling): - seed = next_token_chooser.choice.seed - else: - seed = None - - generated_text = GeneratedText( - output_text, stopping_criteria.current_tokens, reason, seed - ) - else: - generated_text = None - - # Prefill - if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: - # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + torch.log_softmax( - logits, -1 - ).gather(1, all_input_ids[1:]).squeeze(1)[ - -new_input_length:-1 - ].tolist() - prefill_token_ids = all_input_ids[-new_input_length:-1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = Tokens( - prefill_token_ids, - prefill_logprobs, - prefill_texts, - is_special=[], - ) - else: - prefill_tokens = None - - top_tokens = None - - generation = Generation( - request.id, - prefill_tokens, - Tokens( - [next_token_id_squeezed], - [next_token_logprob], - [next_token_text], - [next_token_id_squeezed.item() in self.all_special_ids], - ), - generated_text, - top_tokens, - ) - - generations.append(generation) - - # Update values - batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( - next_token_id_squeezed.item() - ) - batch.input_ids[i, 0] = next_token_id - batch.all_input_ids[i] = all_input_ids - batch.input_lengths[i] = new_input_length - batch.prefix_offsets[i] = prefix_offset - batch.read_offsets[i] = read_offset - batch.max_input_length = max(batch.max_input_length, new_input_length) - - # We finished all generations in the batch; there is no next batch - if stopped: - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, None, (forward_ns, decode_ns) - - # Slice unused values from prefill - batch.input_ids = batch.input_ids[:, :1] - - # Update attention_mask as we added a new token to input_ids - batch.attention_mask[:, -batch.padding_right_offset] = 1 - batch.image_attention_mask[:, -batch.padding_right_offset, :] = ( - batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :] - ) - # Decrease right offset - batch.padding_right_offset -= 1 - - # Update position_ids - batch.position_ids = batch.position_ids[:, -1:] + 1 - - # Update past key values - batch.past_key_values = past - batch.image_hidden_states = image_hidden_states - - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, batch, (forward_ns, decode_ns) diff --git a/backends/gaudi/server/text_generation_server/models/mamba.py b/backends/gaudi/server/text_generation_server/models/mamba.py deleted file mode 100644 index f6dcde68..00000000 --- a/backends/gaudi/server/text_generation_server/models/mamba.py +++ /dev/null @@ -1,814 +0,0 @@ -import torch -import torch.distributed -from transformers import AutoTokenizer, PreTrainedTokenizerBase -from typing import Optional -from text_generation_server.models.custom_modeling.mamba_modeling import ( - MambaConfig, -) -from loguru import logger -from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.models.globals import CUDA_GRAPHS, MEM_POOL -import time -from text_generation_server.models.custom_modeling.mamba_modeling import ( - MambaModel, - InferenceParams, -) -from text_generation_server.models import Model -from typing import Any, List, Tuple, Type, Dict -from text_generation_server.models.types import ( - Batch, - Tokens, - Generation, - GeneratedText, -) -from text_generation_server.utils.chunks import concat_text_chunks -from text_generation_server.utils.quantization import get_loader -from text_generation_server.utils.tokens import batch_top_tokens, Sampling -from dataclasses import dataclass -from text_generation_server.utils import NextTokenChooser, StoppingCriteria - - -def new_inference_params( - n_blocks: int, - batch_size: int, - d_inner: int, - d_conv: int, - d_state: int, - seqlen_offset: int, - dtype: torch.dtype, - device: torch.device, -): - max_seqlen = 0 - conv_states = torch.zeros( - ( - n_blocks, - batch_size, - d_inner, - d_conv, - ), - device=device, - dtype=dtype, - ) - ssm_states = torch.zeros( - ( - n_blocks, - batch_size, - d_inner, - d_state, - ), - device=device, - dtype=dtype, - ) - inference_params = InferenceParams( - max_seqlen=max_seqlen, - max_batch_size=batch_size, - seqlen_offset=seqlen_offset, - conv_states=conv_states, - ssm_states=ssm_states, - ) - return inference_params - - -@dataclass -class MambaBatch(Batch): - batch_id: int - requests: List[generate_pb2.Request] - requests_idx_mapping: Dict[int, int] - - # Decoder values - input_ids: torch.Tensor - - # All tokens - all_input_ids: List[torch.Tensor] - - # Lengths of all generations present in the batch - input_lengths: List[int] - prefix_offsets: List[int] - read_offsets: List[int] - - # Generation helpers - next_token_choosers: List[NextTokenChooser] - stopping_criterias: List[StoppingCriteria] - top_n_tokens: List[int] - top_n_tokens_tensor: torch.Tensor - - # Metadata used for padding - max_input_length: int - padding_right_offset: int - - # Maximum number of tokens this batch will grow to - max_tokens: int - - # Past metadata - keys_head_dim_last: bool = True - - # Inference params - inference_params: Optional[Dict[str, Any]] = None - - def to_pb(self) -> generate_pb2.CachedBatch: - return generate_pb2.CachedBatch( - id=self.batch_id, - request_ids=[r.id for r in self.requests], - size=len(self), - max_tokens=self.max_tokens, - ) - - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "MambaBatch": - inputs = [] - next_token_choosers = [] - stopping_criterias = [] - top_n_tokens = [] - prefix_offsets = [] - read_offsets = [] - requests_idx_mapping = {} - - # Parse batch - max_truncation = 0 - padding_right_offset = 0 - max_decode_tokens = 0 - for i, r in enumerate(pb.requests): - requests_idx_mapping[r.id] = i - inputs.append(concat_text_chunks(r.input_chunks.chunks)) - next_token_choosers.append( - NextTokenChooser.from_pb(r.parameters, device, tokenizer) - ) - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) - stopping_criterias.append(stopping_criteria) - top_n_tokens.append(r.top_n_tokens) - max_truncation = max(max_truncation, r.truncate) - max_decode_tokens += stopping_criteria.max_new_tokens - padding_right_offset = max( - padding_right_offset, stopping_criteria.max_new_tokens - ) - - tokenized_inputs = tokenizer( - inputs, - return_tensors="pt", - padding=True, - return_token_type_ids=False, - truncation=True, - max_length=max_truncation, - ).to(device) - for _ in pb.requests: - input_len = tokenized_inputs["input_ids"].shape[1] - prefix_offsets.append(input_len - 5) - read_offsets.append(input_len) - - input_lengths = tokenized_inputs["attention_mask"].sum(1) - max_input_length = input_lengths.max() - input_ids = tokenized_inputs["input_ids"] - all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) - top_n_tokens_tensor = torch.tensor( - top_n_tokens, device=device, dtype=torch.int64 - ) - max_tokens = len(inputs) * (max_input_length + max_decode_tokens) - return cls( - batch_id=pb.id, - requests=pb.requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - # past_input_ids=None, - all_input_ids=list(all_input_ids), - input_lengths=input_lengths.tolist(), - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - max_input_length=max_input_length.item(), - padding_right_offset=padding_right_offset, - max_tokens=max_tokens, - ) - - def filter(self, request_ids: List[int]) -> Optional["MambaBatch"]: - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") - if len(request_ids) == len(self): - return self - - keep_indices = [] - - # New values after filtering - requests_idx_mapping = {} - requests = [] - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - max_input_length = 0 - - next_token_choosers = [] - stopping_criterias = [] - top_n_tokens = [] - - total_remaining_decode_tokens = 0 - new_padding_right_offset = 0 - - indices = [] - for i, request_id in enumerate(request_ids): - idx = self.requests_idx_mapping[request_id] - requests_idx_mapping[request_id] = i - keep_indices.append(idx) - - requests.append(self.requests[idx]) - prefix_offsets.append(self.prefix_offsets[idx]) - read_offsets.append(self.read_offsets[idx]) - all_input_ids.append(self.all_input_ids[idx]) - - request_input_length = self.input_lengths[idx] - input_lengths.append(request_input_length) - max_input_length = max(max_input_length, request_input_length) - indices.append(idx) - - next_token_choosers.append(self.next_token_choosers[idx]) - stopping_criteria = self.stopping_criterias[idx] - stopping_criterias.append(stopping_criteria) - top_n_tokens.append(self.top_n_tokens[idx]) - remaining_decode_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - ) - total_remaining_decode_tokens += remaining_decode_tokens - new_padding_right_offset = max( - new_padding_right_offset, remaining_decode_tokens - ) - - # Apply indices to input_ids, attention mask, past key values and other items that need to be cached - input_ids = self.input_ids[keep_indices] - - top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices] - max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens - - self.requests = requests - self.requests_idx_mapping = requests_idx_mapping - self.input_ids = input_ids - self.all_input_ids = all_input_ids - self.input_lengths = input_lengths - self.prefix_offsets = prefix_offsets - self.read_offsets = read_offsets - self.next_token_choosers = next_token_choosers - self.stopping_criterias = stopping_criterias - self.top_n_tokens = top_n_tokens - self.top_n_tokens_tensor = top_n_tokens_tensor - self.max_input_length = max_input_length - self.padding_right_offset = new_padding_right_offset - self.max_tokens = max_tokens - - # TODO - # Kept it simple by just updating the state, maybe updating the other CPU values is necessary. - self.inference_params.conv_states = self.inference_params.conv_states[ - :, indices - ] - self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices] - return self - - @classmethod - def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch": - # Used for padding - total_batch_size = 0 - max_input_length = 0 - padding_right_offset = 0 - for batch in batches: - total_batch_size += len(batch) - max_input_length = max(max_input_length, batch.max_input_length) - padding_right_offset = max(padding_right_offset, batch.padding_right_offset) - - # Batch attributes - requests = [] - requests_idx_mapping = {} - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - next_token_choosers = [] - stopping_criterias = [] - top_n_tokens = [] - max_tokens = 0 - seqlen_offset = 0 - - (n_blocks, _, d_inner, d_conv) = batches[0].inference_params.conv_states.shape - (_, _, _, d_state) = batches[0].inference_params.ssm_states.shape - dtype = batches[0].inference_params.conv_states.dtype - device = batches[0].inference_params.conv_states.device - inference_params = new_inference_params( - n_blocks=n_blocks, - batch_size=total_batch_size, - d_state=d_state, - d_conv=d_conv, - d_inner=d_inner, - seqlen_offset=seqlen_offset, - device=device, - dtype=dtype, - ) - - # Batch tensors - input_ids = None - top_n_tokens_tensor = None - - # Used for slicing correctly inside the tensors - # Equivalent to a cumsum on batch sizes - start_index = 0 - for i, batch in enumerate(batches): - requests.extend(batch.requests) - input_lengths.extend(batch.input_lengths) - prefix_offsets.extend(batch.prefix_offsets) - read_offsets.extend(batch.read_offsets) - all_input_ids.extend(batch.all_input_ids) - next_token_choosers.extend(batch.next_token_choosers) - stopping_criterias.extend(batch.stopping_criterias) - top_n_tokens.extend(batch.top_n_tokens) - - if i == 0: - requests_idx_mapping = batch.requests_idx_mapping - else: - # We need to offset the mapping for each batch by the cumulative batch size - for k, v in batch.requests_idx_mapping.items(): - requests_idx_mapping[k] = v + start_index - - # Slicing end index for this batch - end_index = start_index + len(batch) - - # Create empty tensor - # input_ids is always of shape [batch_size, 1] - # We do not need to pad it - if input_ids is None: - input_ids = batch.input_ids.new_empty((total_batch_size, 1)) - # Copy to correct indices - input_ids[start_index:end_index] = batch.input_ids - - if top_n_tokens_tensor is None: - top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( - total_batch_size, - ) - top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor - - # Add eventual padding tokens that were added while concatenating - max_tokens += batch.max_tokens + ( - max_input_length - batch.max_input_length - ) * len(batch) - - inference_params.max_seqlen = max( - inference_params.max_seqlen, batch.inference_params.max_seqlen - ) - assert batch.inference_params.seqlen_offset != 0, "Invalid seqlen offset" - inference_params.seqlen_offset = max( - inference_params.seqlen_offset, batch.inference_params.seqlen_offset - ) - - inference_params.conv_states[:, start_index:end_index] = ( - batch.inference_params.conv_states - ) - inference_params.ssm_states[:, start_index:end_index] = ( - batch.inference_params.ssm_states - ) - - start_index = end_index - - return cls( - batch_id=batches[0].batch_id, - requests=requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - all_input_ids=all_input_ids, - input_lengths=input_lengths, - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - max_input_length=max_input_length, - padding_right_offset=padding_right_offset, - keys_head_dim_last=batches[0].keys_head_dim_last, - max_tokens=max_tokens, - inference_params=inference_params, - ) - - def __len__(self): - return len(self.requests) - - -class Mamba(Model): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.quantize = quantize - self.process_group, _rank, world_size = initialize_torch_distributed() - if world_size > 1: - raise RuntimeError("Mamba does not support Tensor Parallelism (TP)") - self.cuda_graphs = {} - if torch.cuda.is_available(): - device = torch.device("cuda") - # Bf16 is important. In f16 accumulations in the matmul are causing - # differences while the server is under load. - # This is detectable by the integration load test - dtype = torch.bfloat16 if dtype is None else dtype - else: - if quantize: - raise ValueError("quantization is not available on CPU") - - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - "EleutherAI/gpt-neox-20b", - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - config = MambaConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - - tokenizer.bos_token_id = config.bos_token_id - tokenizer.eos_token_id = config.eos_token_id - tokenizer.pad_token = tokenizer.eos_token - - config.quantize = quantize - config.speculator = speculator - torch.distributed.barrier(group=self.process_group) - weights_loader = get_loader( - quantize=quantize, model_id=model_id, revision=revision - ) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device, - dtype, - process_group=self.process_group, - weights_loader=weights_loader, - ) - model = MambaModel(config, weights) - torch.distributed.barrier(group=self.process_group) - super(Mamba, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - ) - - @property - def batch_type(self) -> Type[MambaBatch]: - return MambaBatch - - def warmup(self, batch) -> Optional[int]: - # TODO: implement warmup for Mamba if needed - if CUDA_GRAPHS: - if self.speculate is None or self.speculate == 0: - try: - logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}") - # Warmup cuda graphs - for bs in CUDA_GRAPHS: - self.cuda_graph_warmup(bs) - except Exception: - logger.exception("Decode cuda graph warmup failed") - else: - logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") - - return None - - def cuda_graph_warmup(self, batch_size: int): - input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device) - n_blocks = len(self.model.blocks) - - d_state = self.model.config.d_state - d_conv = self.model.config.d_conv - # Inner takes the expand multiplication - d_inner = self.model.config.d_inner - - # Important seqlen_offset to go through the update mecanism with the state - seqlen_offset = 1 - inference_params = new_inference_params( - n_blocks=n_blocks, - batch_size=batch_size, - d_state=d_state, - d_conv=d_conv, - d_inner=d_inner, - seqlen_offset=seqlen_offset, - device=self.device, - dtype=self.dtype, - ) - - graph = torch.cuda.CUDAGraph() - - torch.cuda.synchronize() - # Run once outside to warmup - self.model.forward(input_ids=input_ids, inference_params=inference_params) - torch.cuda.synchronize() - - with torch.cuda.graph(graph, pool=MEM_POOL): - logits, speculative_logits = self.model.forward( - input_ids=input_ids, inference_params=inference_params - ) - torch.cuda.synchronize() - graph_dict = { - "input_ids": input_ids, - "inference_params": inference_params, - "graph": graph, - "logits": logits, - "speculative_logits": speculative_logits, - } - self.cuda_graphs[batch_size] = graph_dict - - def tunableop_warmup(self, batch_size: int, seqlen: int): - input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device) - n_blocks = len(self.model.blocks) - - d_state = self.model.config.d_state - d_conv = self.model.config.d_conv - # Inner takes the expand multiplication - d_inner = self.model.config.d_inner - - # Important seqlen_offset to go through the update mecanism with the state - seqlen_offset = 1 - inference_params = new_inference_params( - n_blocks=n_blocks, - batch_size=seqlen, - d_state=d_state, - d_conv=d_conv, - d_inner=d_inner, - seqlen_offset=seqlen_offset, - device=self.device, - dtype=self.dtype, - ) - - self.model.forward(input_ids=input_ids, inference_params=inference_params) - - def forward( - self, input_ids: torch.Tensor, inference_params: Any - ) -> Tuple[torch.Tensor, torch.Tensor]: - bs = input_ids.shape[0] - padded_bs = bs - if bs == 3: - padded_bs = 4 - elif 3 < bs <= 8: - padded_bs = 8 - elif bs > 8: - padded_bs = (bs + 7) // 8 * 8 - - # Try to find an associated cuda graph - cuda_graph = self.cuda_graphs.get(padded_bs, None) - is_prefill = inference_params is None or inference_params.seqlen_offset == 0 - - if is_prefill or cuda_graph is None: - return self.model( - input_ids, - inference_params=inference_params, - ) - - # Copy inputs to the static inputs of the cuda graph - # Static inputs are potentially padded - cuda_graph["input_ids"][:bs] = input_ids - cuda_graph["inference_params"].conv_states[ - :, :bs - ] = inference_params.conv_states - cuda_graph["inference_params"].ssm_states[:, :bs] = inference_params.ssm_states - - # Replay the graph - cuda_graph["graph"].replay() - - inference_params.conv_states.copy_( - cuda_graph["inference_params"].conv_states[:, :bs] - ) - inference_params.ssm_states.copy_( - cuda_graph["inference_params"].ssm_states[:, :bs] - ) - # Slice output to the correct shape - speculative_logits = ( - cuda_graph["speculative_logits"][:bs] - if cuda_graph["speculative_logits"] is not None - else None - ) - logits = cuda_graph["logits"][:bs] - return logits, speculative_logits - - def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]: - start = time.time_ns() - input_ids = ( - batch.input_ids - ) # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids - - batch_size, max_seqlen = input_ids.shape - # Inference params - - if batch.inference_params is None: - # 0 is important here - seqlen_offset = 0 - n_blocks = len(self.model.blocks) - d_state = self.model.config.d_state - d_conv = self.model.config.d_conv - d_inner = self.model.config.d_inner - inference_params = new_inference_params( - n_blocks=n_blocks, - batch_size=batch_size, - d_state=d_state, - d_conv=d_conv, - d_inner=d_inner, - seqlen_offset=seqlen_offset, - device=self.device, - dtype=self.dtype, - ) - batch.inference_params = inference_params - - # Forward pass - logits, speculative_logits = self.forward( - input_ids, inference_params=batch.inference_params - ) - - # batch.inference_params = new_inference_params - # Results - generations: List[Generation] = [] - stopped = True - - # Speculation is not active for causal - accepted_ids = torch.ones_like(batch.input_ids)[:, 0] - batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, - batch.top_n_tokens_tensor, - torch.log_softmax(logits[:, -1], -1), - accepted_ids, - ) - - start_decode = time.time_ns() - - # Zipped iterator - iterator = zip( - batch.requests, - batch.input_lengths, - batch.prefix_offsets, - batch.read_offsets, - logits, - batch.next_token_choosers, - batch.stopping_criterias, - batch.all_input_ids, - batch.top_n_tokens, - batch_top_token_ids, - batch_top_token_logprobs, - ) - - # For each member of the batch - for i, ( - request, - input_length, - prefix_offset, - read_offset, - logits, - next_token_chooser, - stopping_criteria, - all_input_ids, - top_n_tokens, - top_token_ids, - top_token_logprobs, - ) in enumerate(iterator): - # Select next token - next_token_id, logprobs = next_token_chooser( - all_input_ids.view(1, -1), logits[-1:, :] - ) - - # Append next token to all tokens - all_input_ids = torch.cat([all_input_ids, next_token_id]) - new_input_length = input_length + 1 - - # Generated token - next_token_logprob = logprobs[-1, next_token_id] - next_token_id_squeezed = next_token_id.squeeze() - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids[:, 0], prefix_offset, read_offset - ) - - # Evaluate stopping criteria - stop, reason = stopping_criteria( - next_token_id_squeezed, - next_token_text, - ) - - if not stop: - stopped = False - - # Shard generations - # All generations will be appended in the rust sharded client - if i % self.world_size == self.rank: - if stop: - # Decode generated tokens - output_text, _, _ = self.decode_token( - all_input_ids[:, 0], - prefix_offset=len(all_input_ids) - - stopping_criteria.current_tokens - - 1, - read_offset=len(all_input_ids) - - stopping_criteria.current_tokens, - skip_special_tokens=True, - ) - # Get seed - if isinstance(next_token_chooser.choice, Sampling): - seed = next_token_chooser.choice.seed - else: - seed = None - - generated_text = GeneratedText( - output_text, stopping_criteria.current_tokens, reason, seed - ) - else: - generated_text = None - - if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: - # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + torch.log_softmax( - logits, -1 - ).gather(1, all_input_ids[1:]).squeeze(1)[ - -new_input_length:-1 - ].tolist() - prefill_token_ids = all_input_ids[-new_input_length:-1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = Tokens( - prefill_token_ids, - prefill_logprobs, - prefill_texts, - is_special=[], - ) - else: - prefill_tokens = None - - if top_n_tokens > 0: - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - special_toptokens = [ - token_id in self.all_special_ids for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, - ) - else: - top_tokens = None - - generation = Generation( - request.id, - prefill_tokens, - Tokens( - [next_token_id_squeezed], - [next_token_logprob], - [next_token_text], - [next_token_id_squeezed.item() in self.all_special_ids], - ), - generated_text, - top_tokens, - ) - - generations.append(generation) - - # Update values - batch.next_token_choosers[i] = batch.next_token_choosers[ - i - ].advance_grammar(next_token_id_squeezed.item()) - batch.input_ids[i, 0] = next_token_id - batch.all_input_ids[i] = all_input_ids - batch.input_lengths[i] = new_input_length - batch.prefix_offsets[i] = prefix_offset - batch.read_offsets[i] = read_offset - batch.max_input_length = max(batch.max_input_length, new_input_length) - - # We finished all generations in the batch; there is no next batch - if stopped: - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, None, (forward_ns, decode_ns) - - # Slice unused values from prefill - batch.input_ids = batch.input_ids[:, :1] - - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, batch, (forward_ns, decode_ns) diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 5de9bca8..1be36d09 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -32,6 +32,9 @@ from text_generation_server.utils.import_utils import ( ) import torch.nn.functional as F from text_generation_server.utils.log import log_master +import time +import os +from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes tracer = trace.get_tracer(__name__) @@ -43,10 +46,17 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): aspect_ratio_mask: Optional[torch.Tensor] = None cross_attention_states: Optional[torch.Tensor] = None + def prepare_for_prefill( + self, max_padded_input_len, max_padded_bs, max_total_tokens + ): + super(FlashVlmCausalLMBatch, self).prepare_for_prefill( + max_padded_input_len, max_padded_bs, max_total_tokens + ) + @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches): - batch = super().concatenate(batches) + def concatenate(cls, batches, padded_total_bs: int = 0): + batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches, padded_total_bs) batch.pixel_values = None batch.pixel_attention_mask = None @@ -70,7 +80,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]): assert self.image_indices is not None - batch = super().filter(request_ids) + batch = super(FlashVlmCausalLMBatch, self).filter(request_ids) assert self.image_indices is not None indices = [] for i, request_id in enumerate(request_ids): @@ -96,6 +106,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): ] else: batch.cross_attention_states = None + batch.pixel_values = None return batch @classmethod @@ -187,7 +198,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): input_ids = np.concatenate(batch.input_ids, dtype=np.int64) else: input_ids = batch.input_ids[0] - batch.input_ids = torch.tensor(input_ids, dtype=torch.int64) + batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1) @@ -225,6 +236,10 @@ def generate_cross_attention_states( class FlashMllamaCausalLM(FlashVlmCausalLM): + def set_inputs_embeds(self, batch): + # Set the input embeddings to None, as we are using the input_ids for the model + batch.inputs_embeds = None + def warmup_decode( self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch ): @@ -267,6 +282,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): cross_attention_states, image_indices, input_lengths, 1, False ) slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype) + kwargs = {} + if htorch.utils.internal.is_lazy(): + kwargs["bypass_hpu_graphs"] = not self.use_graphs( + False, hpu_attention_meta.block_list.shape[0], batch_size + ) self.model.forward( input_ids=_async_h2d_tensor_copy(input_ids), position_ids=_async_h2d_tensor_copy(position_ids), @@ -280,6 +300,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): cross_attention_states=cross_attention_states, indices=_async_h2d_tensor_copy(indices), cross_attention_len=_async_h2d_tensor_copy(cross_attention_len), + **kwargs, ) def warmup_prefill( @@ -325,7 +346,9 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): ) kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs + kwargs["bypass_hpu_graphs"] = not self.use_graphs( + True, prompt_len, batch_size + ) self.model.forward( input_ids=_async_h2d_tensor_copy(input_ids), position_ids=_async_h2d_tensor_copy(position_ids), @@ -343,26 +366,103 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): ) def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch): + prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3")) + free_mem = HabanaMemoryProfiler.current_free_device_memory() + graph_free_mem = free_mem - self.mem_reserved + graph_free_mem = self.align_workers( + graph_free_mem, torch.distributed.ReduceOp.MIN + ) + prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem + decode_available_memory = graph_free_mem - prompt_available_memory + msg = ( + f"Using {format_bytes(graph_free_mem)}" + f"/{format_bytes(free_mem)} " + "of free device memory for HPUGraphs, " + f"{format_bytes(prompt_available_memory)} for prompt and " + f"{format_bytes(decode_available_memory)} for decode " + f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})" + ) + log_master(logger.info, msg) + start_time = time.time() + warmup_shape_count = 0 warmup_times = 3 self.bucketing_ctx.generate_prompt_buckets() - for i, (batch_size, seq_len) in enumerate( - reversed(self.bucketing_ctx.prompt_buckets) - ): - log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") - for index in range(warmup_times): - self.warmup_prefill(seq_len, batch_size, batch) + + def ordering_function_min_tokens(b): + return (b[0] * b[1], b[1], b[0]) + + buckets = list( + sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens) + ) + graph_free_mem + total_batch_seq = 0.001 + total_mem = 0 + available_mem = prompt_available_memory + for i, (batch_size, seq_len) in enumerate(buckets): + if batch_size * seq_len > self.max_batch_prefill_tokens: + continue + # Graph memory usage is proportional to seq dimension in a batch + batch_seq = batch_size * seq_len + mem_estimate = batch_seq / total_batch_seq * total_mem + graphed_bucket = (batch_size, seq_len, True) + if not ( + mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture + ): + if graphed_bucket not in self.graphed_buckets: + self.graphed_buckets.add(graphed_bucket) + warmup_shape_count += 1 + self.log_warmup(True, i, len(buckets), batch_size, seq_len) + with HabanaMemoryProfiler() as mem_prof: + for index in range(warmup_times): + self.warmup_prefill(seq_len, batch_size, batch) + synchronize(self.device) + used_mem = self.align_workers( + mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX + ) + if graphed_bucket in self.graphed_buckets: + available_mem -= used_mem + total_mem += used_mem + total_batch_seq += batch_seq + + def ordering_function_max_bs(b): + return (-b[0], b[1]) + self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) - for i, (batch_size, block_num) in enumerate( - reversed(self.bucketing_ctx.decode_buckets) - ): + buckets = list( + sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs) + ) + free_mem = HabanaMemoryProfiler.current_free_device_memory() + total_batch_seq = 0.001 + total_mem = 0 + available_mem = free_mem - self.mem_reserved + for i, (batch_size, block_num) in enumerate(buckets): if batch_size > block_num: continue - log_master( - logger.info, f"warmup decode bs {batch_size} block_num {block_num}" + # Graph memory usage is proportional to seq dimension in a batch + batch_seq = batch_size + mem_estimate = batch_seq / total_batch_seq * total_mem + graphed_bucket = (batch_size, block_num, False) + if not mem_estimate >= available_mem: + if graphed_bucket not in self.graphed_buckets: + self.graphed_buckets.add(graphed_bucket) + warmup_shape_count += 1 + self.log_warmup(False, i, len(buckets), batch_size, block_num) + with HabanaMemoryProfiler() as mem_prof: + for index in range(warmup_times): + self.warmup_decode(batch_size, block_num, batch) + synchronize(self.device) + used_mem = self.align_workers( + mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX ) - for index in range(warmup_times): - self.warmup_decode(batch_size, block_num, batch) - synchronize(self.device) + if graphed_bucket in self.graphed_buckets: + available_mem -= used_mem + total_mem += used_mem + total_batch_seq += batch_seq + + log_master( + logger.info, + f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", + ) def forward( self, @@ -438,15 +538,22 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = ( - batch.prefilling if self.limit_hpu_graphs else False + batch_size = input_lengths.shape[0] + seqlen = ( + input_ids.shape[0] // batch_size + if batch.prefilling + else batch.hpu_attn_meta.block_list.shape[0] ) + kwargs["bypass_hpu_graphs"] = not self.use_graphs( + batch.prefilling, seqlen, batch_size + ) + if batch.prefill_cache_indices is not None: - slots_pad = torch.zeros_like(input_ids) + slots_pad = torch.zeros_like(input_ids, device=slots.device) slots_pad[batch.prefill_cache_indices] = slots slots = slots_pad else: - slots_pad = torch.zeros_like(input_ids) + slots_pad = torch.zeros_like(input_ids, device=slots.device) slots_pad[: slots.shape[0]] = slots slots = slots_pad orig_bs = len(batch) @@ -475,7 +582,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): input_lengths=_async_h2d_tensor_copy(input_lengths), ) logits, speculative_logits = self.model.forward( - input_ids=_async_h2d_tensor_copy(input_ids), + input_ids=input_ids, position_ids=_async_h2d_tensor_copy(position_ids), cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill), kv_cache=kv_cache, diff --git a/backends/gaudi/server/text_generation_server/models/pali_gemma.py b/backends/gaudi/server/text_generation_server/models/pali_gemma.py deleted file mode 100644 index e91aaed9..00000000 --- a/backends/gaudi/server/text_generation_server/models/pali_gemma.py +++ /dev/null @@ -1,71 +0,0 @@ -from io import BytesIO -from PIL import Image -import torch -import torch.distributed -from opentelemetry import trace -from typing import Iterable -from text_generation_server.models.flash_vlm_causal_lm import ( - FlashVlmCausalLMBatch, - image_text_replacement, -) - -from text_generation_server.pb.generate_pb2 import Request - -tracer = trace.get_tracer(__name__) - - -class PaliGemmaBatch(FlashVlmCausalLMBatch): - @classmethod - def batch_tokenized_inputs( - cls, requests: Iterable[Request], tokenizer, processor, config - ): - batch_inputs = [] - image_inputs = [] - max_truncation = 0 - for r in requests: - full_text = "" - image_id = 0 - for chunk in r.input_chunks.chunks: - chunk_type = chunk.WhichOneof("chunk") - if chunk_type == "text": - full_text += "" + chunk.text + "\n" - elif chunk_type == "image": - image = Image.open(BytesIO(chunk.image.data)) - # TODO do_convert_RGB should be on by default ? - image = image.convert("RGB") - image_input = processor.image_processor(image, return_tensors="pt") - full_text += image_text_replacement( - processor, image_input, config, image_id - ) - image_inputs.append(image_input) - else: - raise RuntimeError(f"Invalid chunk type {chunk_type}") - - batch_inputs.append(full_text) - max_truncation = max(max_truncation, r.truncate) - - batch_tokenized_inputs = tokenizer( - batch_inputs, - truncation=True, - max_length=max_truncation, - add_special_tokens=False, - )["input_ids"] - if image_inputs: - image_input = image_inputs[0] - new_image_inputs = { - "pixel_values": torch.cat( - [img["pixel_values"] for img in image_inputs], dim=0 - ), - } - if "pixel_attention_mask" in image_input: - new_image_inputs["pixel_attention_mask"] = torch.cat( - [img["pixel_attention_mask"] for img in image_inputs], dim=0 - ) - if "image_sizes" in image_input: - new_image_inputs["image_sizes"] = torch.cat( - [img["image_sizes"] for img in image_inputs], dim=0 - ) - image_inputs = new_image_inputs - else: - image_inputs = None - return batch_tokenized_inputs, image_inputs diff --git a/backends/gaudi/server/text_generation_server/models/starcoder.py b/backends/gaudi/server/text_generation_server/models/starcoder.py deleted file mode 100644 index 6c6ca2cf..00000000 --- a/backends/gaudi/server/text_generation_server/models/starcoder.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch -from dataclasses import dataclass -from typing import List, Optional, Type - -from text_generation_server.models import CausalLM -from text_generation_server.models.causal_lm import CausalLMBatch - - -@dataclass -class StarCoderCausalLMBatch(CausalLMBatch): - past_key_values: Optional[List[torch.Tensor]] - - def detach_kv_cache(self): - past_keys = [] - past_values = [] - last_dim = int(self.past_key_values[0].size(dim=-1) / 2) - for key_value in self.past_key_values: - past_keys.append(key_value.split((last_dim, last_dim), dim=-1)[0]) - past_values.append(key_value.split((last_dim, last_dim), dim=-1)[1]) - del self.past_key_values - - return past_keys, past_values - - def attach_kv_cache(self, past_keys, past_values): - self.past_key_values = [ - torch.cat((key, value), dim=-1) - for key, value in zip(past_keys, past_values) - ] - - -class StarCoder(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - ): - - super(StarCoder, self).__init__( - model_id=model_id, - revision=revision, - dtype=dtype, - ) - - @property - def batch_type(self) -> Type[CausalLMBatch]: - return StarCoderCausalLMBatch diff --git a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py deleted file mode 100644 index 0e37609e..00000000 --- a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py +++ /dev/null @@ -1,1607 +0,0 @@ -import json -import re -import torch -import os -import time -import math -from PIL import Image -from io import BytesIO -from opentelemetry import trace -from loguru import logger -from typing import Iterable, Optional, Tuple, List, Type, Dict -import tempfile -import copy -from text_generation_server.models import Model -from transformers import PreTrainedTokenizerBase -from text_generation_server.utils import weight_files -from text_generation_server.utils.tokens import batch_top_tokens -from text_generation_server.pb import generate_pb2 -from text_generation_server.models.causal_lm import ( - CausalLMBatch, - CausalLMRequest, - remove_kv_cache_from_output, -) - -from transformers.models.llava_next.modeling_llava_next import ( - get_anyres_image_grid_shape, -) - -from transformers import AutoProcessor -import text_generation_server.habana_quantization_env as hq_env -from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi -from text_generation_server.utils import ( - HeterogeneousNextTokenChooser, - make_tokenizer_optional, - is_tokenizer_transparent, - pad_next_token_chooser_parameters, -) -import habana_frameworks.torch as htorch -from optimum.habana.utils import HabanaProfile -from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES -from optimum.habana.utils import get_hpu_memory_stats -from optimum.habana.checkpoint_utils import get_ds_injection_policy - -from transformers import ( - AutoTokenizer, - AutoConfig, -) -from optimum.habana.checkpoint_utils import model_on_meta - -from text_generation_server.utils.speculate import get_speculate -from text_generation_server.models.types import ( - Tokens, - Generation, - GeneratedText, -) -from text_generation_server.utils.debug import dbg_trace - -tracer = trace.get_tracer(__name__) - -IDEFICS2_FAKE_TOKEN = "" -IDEFICS2_IMAGE_TOKEN = "" - - -IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") -BASE_IMAGE_TOKENS = int(os.environ.get("BASE_IMAGE_TOKENS", 2048)) -MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 8192)) -PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128)) -CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] -LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1)) - - -PREFILL_WARMUP_BATCH_SIZE_LIST = [] -PREFILL_WARMUP_SEQLEN_LIST = [] -DECODE_WARMUP_BATCH_SIZE_LIST = [] -CROSS_ATTENTION_LAYERS = [] - - -def round_up(warmup_list: list, num): - i = 0 - for i in warmup_list: - if num <= i: - break - return i if i > 0 else num - - -def split(string) -> List[Dict[str, str]]: - parts = [] - cursor = 0 - for pattern in IMAGES.finditer(string): - start = pattern.start() - if start != cursor: - parts.append({"type": "text", "content": string[cursor:start]}) - - parts.append({"type": "image", "content": pattern.group(1)}) - cursor = pattern.end() - - if cursor != len(string): - parts.append({"type": "text", "content": string[cursor:]}) - - return parts - - -def image_text_replacement(config) -> str: - if config.model_type == "idefics2": - image_seq_len = 64 - image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}" - return image_str - elif config.model_type == "llava_next": - return "" - elif config.model_type == "paligemma": - return "" - elif config.model_type == "mllama": - return "<|image|>" - else: - raise RuntimeError(f"Unknown config {config.model_type} for multimodal") - - -def image_text_replacement_fixup(config, text: str) -> str: - if config.model_type == "idefics2": - return text.replace( - f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN - ) - return text - - -def get_unpadded_features( - original_height: int, - original_width: int, - npatches: int, - num_patch_height: int, - num_patch_width: int, -) -> Tuple[int, int]: - current_height = npatches * num_patch_height - current_width = npatches * num_patch_width - - aspect_ratio: float = original_width / original_height - current_aspect_ratio: float = current_width / current_height - - if aspect_ratio > current_aspect_ratio: - new_height = (original_height * current_width) // original_width - padding = (current_height - new_height) // 2 - current_height = current_height - (2 * padding) - else: - new_width = (original_width * current_height) // original_height - padding = (current_width - new_width) // 2 - current_width = current_width - (2 * padding) - - unpadded_features = current_height * current_width - newline_features = current_height - return (unpadded_features, newline_features) - - -def get_number_of_features(height: int, width: int, config) -> int: - # From config - # Hardcoded for CLIP for now - # image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]] - image_grid_pinpoints = config.image_grid_pinpoints - image_size = config.vision_config.image_size - patch_size = config.vision_config.patch_size - - assert image_size % patch_size == 0 - - npatches = image_size // patch_size - - # Dimensions are intentionally swapped to be bug-compatible with - # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 - num_patch_width, num_patch_height = get_anyres_image_grid_shape( - [height, width], - image_grid_pinpoints, - image_size, - ) - - unpadded_features, newline_features = get_unpadded_features( - height, width, npatches, num_patch_height, num_patch_width - ) - # The base patch covers the entire image - base_features = npatches**2 - return unpadded_features + newline_features + base_features - - -class VlmCausalLMBatch(CausalLMBatch): - pixel_values: Optional[List[torch.Tensor]] - pixel_attention_mask: Optional[List[torch.Tensor]] - image_sizes: Optional[List[Tuple[int, int]]] - aspect_ratio_ids: Optional[torch.Tensor] = None - aspect_ratio_mask: Optional[torch.Tensor] = None - cross_attention_mask: Optional[torch.Tensor] = None - prefilling: bool = True - token_idx: torch.Tensor = None - - def __init__( - self, - batch_id, - requests, - input_ids, - attention_mask, - position_ids, - past_key_values, - merged_kv_cache, - next_token_chooser, - top_n_tokens, - top_n_tokens_tensor, - input_length, - pixel_values: Optional[List[torch.Tensor]] = None, - pixel_attention_mask: Optional[List[torch.Tensor]] = None, - image_sizes: Optional[List[Tuple[int, int]]] = None, - aspect_ratio_ids: Optional[torch.Tensor] = None, - aspect_ratio_mask: Optional[torch.Tensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - prefilling: Optional[bool] = True, - ): - super().__init__( - batch_id=batch_id, - requests=requests, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - merged_kv_cache=merged_kv_cache, - next_token_chooser=next_token_chooser, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - input_length=input_length, - ) - - self.pixel_values = pixel_values - self.pixel_attention_mask = pixel_attention_mask - self.image_sizes = image_sizes - self.aspect_ratio_ids = aspect_ratio_ids - self.aspect_ratio_mask = aspect_ratio_mask - self.cross_attention_mask = cross_attention_mask - self.prefilling = prefilling - - @property - def token_idx(self): # noqa: F811 - if self.prefilling: - # no right padding for prefill - token_idx_scalar = self.attention_mask.shape[-1] - 1 - return torch.tensor(token_idx_scalar).to(self.attention_mask.device) - else: - token_idx_scalar = self.attention_mask.shape[-1] - self.right_padding - return torch.tensor(token_idx_scalar).to(self.attention_mask.device) - - def padding_process(self, pad_id: int): - # self.input_ids = torch.index_select(self.input_ids, 1, self.token_idx - 1) - right_padding = MAX_TOTAL_TOKENS - self.attention_mask.shape[1] - self.input_ids = torch.nn.functional.pad( - self.input_ids, (0, right_padding), value=pad_id - ) - self.attention_mask = torch.nn.functional.pad( - self.attention_mask, (0, right_padding), value=0 - ) - # if self.position_ids is not None: - # self.position_ids = torch.index_select(self.position_ids, 1, self.token_idx - 1) + 1 - if self.cross_attention_mask is not None: - self.cross_attention_mask = torch.nn.functional.pad( - self.cross_attention_mask, (0, 0, 0, 0, 0, right_padding), value=0 - ) - if self.past is not None: - past_key_values_list = list(self.past_key_values) - for layer_id in range(len(self.past)): - past_key_value_list = list(self.past_key_values[layer_id]) - if layer_id not in CROSS_ATTENTION_LAYERS: - past_key_value_list[0] = torch.nn.functional.pad( - self.past_key_values[layer_id][0], - (0, 0, 0, right_padding), - value=0, - ) - past_key_value_list[1] = torch.nn.functional.pad( - self.past_key_values[layer_id][1], - (0, 0, 0, right_padding), - value=0, - ) - past_key_values_list[layer_id] = tuple(past_key_value_list) - self.past_key_values = tuple(past_key_values_list) - - self.prefilling = False - self.input_length = self.input_length - - @classmethod - def from_tokenized( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - batch_tokenized_inputs, - dtype: torch.dtype, - device: torch.device, - is_warmup: bool = False, - ) -> "VlmCausalLMBatch": - - dbg_trace("FROM_PB", f"num_reqs:{len(pb.requests)}") - requests = [ - CausalLMRequest.from_pb(idx, req, tokenizer) - for idx, req in enumerate(pb.requests) - ] - - max_input_length = max(r.data.truncate for r in requests) - max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests) - # TODO: Add support for sparse batches - top_n_tokens = [r.top_n_tokens for r in pb.requests] - top_n_tokens_tensor = torch.tensor( - top_n_tokens, device=device, dtype=torch.int64 - ) - - # TODO: by tokenizing all inputs at once we loose information on actual input lengths - # this means that we cannot shift inputs to the left after a long input sequence - # was filtered out - new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests)) - parameters = [r.parameters for r in pb.requests] - # append the dummy parameters for dummy request - parameters = pad_next_token_chooser_parameters(parameters, new_bs) - - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - pb=parameters, - dtype=dtype, - device=device, - tokenizer=tokenizer, - quantization_enabled=hq_env.is_quantization_enabled, - ) - tokenized_inputs = batch_tokenized_inputs - input_len = tokenized_inputs["input_ids"].shape[1] - - bucket_size = max_input_length - left_padding = max_input_length - input_len - if is_warmup is False: - rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1) - bucket_size = rounded_seq_len - 1 - left_padding = bucket_size - input_len - - input_ids = tokenized_inputs["input_ids"] - attention_mask = tokenized_inputs["attention_mask"] - cross_attention_mask = tokenized_inputs.get("cross_attention_mask", None) - # Allocate space for first token - input_ids = torch.nn.functional.pad( - input_ids, (left_padding, 1), value=tokenizer.pad_token_id - ) - attention_mask = torch.nn.functional.pad( - attention_mask, (left_padding, 1), value=0 - ) - if cross_attention_mask is not None: - cross_attention_mask = torch.nn.functional.pad( - cross_attention_mask, (0, 0, 0, 0, left_padding, 1), value=0 - ) - all_input_ids = torch.nn.functional.pad( - input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id - ).T.split(1, dim=1) - - # New input length after left padding - input_len = bucket_size - for r in requests: - r.input_length = input_len - r.prefix_offset = input_len - 5 - r.read_offset = input_len - r.all_input_ids = all_input_ids[r.idx] - input_ids = input_ids.to(device) - attention_mask = attention_mask.to(device) - cross_attention_mask = ( - cross_attention_mask.to(device) - if cross_attention_mask is not None - else None - ) - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - htorch.core.mark_step() - - return cls( - batch_id=pb.id, - requests=requests, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=None, - merged_kv_cache=False, - next_token_chooser=next_token_chooser, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - input_length=input_len, - cross_attention_mask=cross_attention_mask, - ) - - @classmethod - def batch_tokenized_inputs( - cls, - requests: Iterable[generate_pb2.Request], - tokenizer, - processor, - config, - is_warmup, - ): - image_inputs = {} - texts = [] - images = [] - batch_tokenized_inputs = {} - - for i, r in enumerate(requests): - # Each input is encoded into a list, where each element of this input list is either a string or a URL - curr_text = "" - curr_image = None - for chunk in r.input_chunks.chunks: - chunk_type = chunk.WhichOneof("chunk") - if chunk_type == "text": - curr_text += chunk.text - elif chunk_type == "image": - image = Image.open(BytesIO(chunk.image.data)) - # TODO unsure about BOS - curr_image = image - else: - raise RuntimeError(f"Invalid chunk type {chunk_type}") - - if image_text_replacement(config) not in curr_text: - if "" in curr_text: - curr_text = curr_text.replace( - "", image_text_replacement(config) - ) - else: - curr_text = image_text_replacement(config) + curr_text - - texts.append(curr_text) - if curr_image is not None: - if config.model_type == "mllama": - images.append([curr_image]) - else: - images.append(curr_image) - - if is_warmup is True: - images += [images[0]] * (len(texts) - len(images)) - - missing_inputs = 0 - dummy_images = None - if is_warmup is False: - new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests)) - missing_inputs = new_bs - len(requests) - if missing_inputs > 0: - dummy_inputs = [] - if len(texts) > 0: - dummy_inputs = [texts[0]] * missing_inputs - dummy_images = [images[0]] * missing_inputs - texts += dummy_inputs - images += dummy_images - - processor_output = processor( - images, - texts, - truncation=True, - max_length=r.truncate, - add_special_tokens=r.add_special_tokens, - return_tensors="pt", - padding_side="left", - padding="longest", - ) - if "input_ids" in processor_output: - batch_tokenized_inputs.update({"input_ids": processor_output["input_ids"]}) - if "attention_mask" in processor_output: - batch_tokenized_inputs.update( - {"attention_mask": processor_output["attention_mask"]} - ) - if "cross_attention_mask" in processor_output: - batch_tokenized_inputs.update( - {"cross_attention_mask": processor_output["cross_attention_mask"]} - ) - if "pixel_values" in processor_output: - image_inputs.update({"pixel_values": processor_output["pixel_values"]}) - if "pixel_attention_mask" in processor_output: - image_inputs.update( - {"pixel_attention_mask": processor_output["pixel_attention_mask"]} - ) - if "aspect_ratio_ids" in processor_output: - image_inputs.update( - {"aspect_ratio_ids": processor_output["aspect_ratio_ids"]} - ) - if "aspect_ratio_mask" in processor_output: - image_inputs.update( - {"aspect_ratio_mask": processor_output["aspect_ratio_mask"]} - ) - if "image_sizes" in processor_output: - image_inputs.update({"image_sizes": processor_output["image_sizes"]}) - - return batch_tokenized_inputs, image_inputs - - @classmethod - def from_pb_processor( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - processor, - config, - dtype: torch.dtype, - device: torch.device, - is_warmup: bool = False, - ) -> "VlmCausalLMBatch": - batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( - pb.requests, tokenizer, processor, config, is_warmup - ) - batch = cls.from_tokenized( - pb, tokenizer, batch_tokenized_inputs, dtype, device, is_warmup=is_warmup - ) - if image_inputs is not None: - batch.pixel_values = image_inputs["pixel_values"].to(device=device) - if "pixel_attention_mask" in image_inputs: - batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to( - device=device - ) - else: - batch.pixel_attention_mask = None - if "image_sizes" in image_inputs: - batch.image_sizes = image_inputs["image_sizes"].to(device=device) - else: - batch.image_sizes = None - if "aspect_ratio_ids" in image_inputs: - batch.aspect_ratio_ids = image_inputs["aspect_ratio_ids"].to( - device=device - ) - else: - batch.aspect_ratio_ids = None - if "aspect_ratio_mask" in image_inputs: - batch.aspect_ratio_mask = image_inputs["aspect_ratio_mask"].to( - device=device - ) - else: - batch.aspect_ratio_mask = None - else: - batch.pixel_values = None - batch.pixel_attention_mask = None - batch.image_sizes = None - batch.aspect_ratio_ids = None - batch.aspect_ratio_mask = None - batch.cross_attention_mask = None - - return batch - - @classmethod - @tracer.start_as_current_span("concatenate") - def concatenate( - cls, - batches: List["CausalLMBatch"], - pad_token_id: int = 0, - is_warmup: bool = False, - ) -> "CausalLMBatch": - return cls.recombine(batches, pad_token_id, is_warmup) - - @classmethod - def recombine( - cls, - batches: List["VlmCausalLMBatch"], - pad_token_id: int, - is_warmup: bool = False, - ) -> "VlmCausalLMBatch": - if not all(b.past_key_values is not None for b in batches): - raise ValueError("KV cache not allocated! Cannot recombine before prefill!") - # Used for padding - - total_requests = sum(len(b) for b in batches) - new_bs = total_requests - if not is_warmup: - new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, total_requests) - - if len(batches) > 1: - scenario = "CONCAT" - elif batches[0].prefilling: - scenario = "SHIFT" - else: - return batches[0] - - dbg_trace( - scenario, - f"bs:{[b.batch_size for b in batches]}->{new_bs}" - f" reqs:{[len(b) for b in batches]}", - ) - - if scenario == "SHIFT": - batch = batches[0] - batch.padding_process(pad_token_id) - return batch - - total_batch_size = 0 - max_input_length = 0 - for i, batch in enumerate(batches): - total_batch_size += len(batch) - max_input_length = max(max_input_length, batch.input_length) - # Batch attributes - requests = [] - input_lengths = [] - top_n_tokens = [] - parameters = [] - fsm_grammar_states = [] - - # Batch tensors - input_ids = None - attention_mask = None - position_ids = None - past_key_values = [] - top_n_tokens_tensor = None - cross_attention_mask = None - # Used for slicing correctly inside the tensors - # Equivalent to a cumsum on batch sizes - start_index = 0 - for i, batch in enumerate(batches): - keep_indices = [] - for req in batch.requests: - keep_indices.append(req.idx) - - requests.extend(batch.requests) - parameters.extend([r.data.parameters for r in batch.requests]) - fsm_grammar_states.extend( - [batch.next_token_chooser.fsm_grammar_states[i] for i in keep_indices] - ) - input_lengths.extend([batch.input_length]) - top_n_tokens.extend([batch.top_n_tokens[i] for i in keep_indices]) - - # Slicing end index for this batch - end_index = start_index + len(batch) - - # We only concatenate batches that did at least one step - if batch.past_key_values is None: - raise ValueError("only concatenate prefilled batches") - - # Create empty tensor - # input_ids is always of shape [batch_size, 1] - # We do not need to pad it - if input_ids is None: - input_ids = batch.input_ids.new_empty((new_bs, MAX_TOTAL_TOKENS)) - # # Copy to correct indices - - left_offset = max_input_length - batch.input_length - right_padding = MAX_TOTAL_TOKENS - max_input_length - input_ids[start_index:end_index, left_offset:-right_padding] = ( - batch.input_ids[keep_indices, : batch.input_length] - ) - - # Create padded tensor - if top_n_tokens_tensor is None: - top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( - new_bs, - ) - top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor[ - keep_indices - ] - - if attention_mask is None: - attention_mask = batch.attention_mask.new_zeros( - (new_bs, MAX_TOTAL_TOKENS), - ) - - attention_mask[ - start_index:end_index, - left_offset:-right_padding, - ] = batch.attention_mask[ - keep_indices, - : batch.input_length, - ] - - if batch.cross_attention_mask is not None: - cross_attention_mask_shape = list(batch.cross_attention_mask.shape) - cross_attention_mask_shape[1] = MAX_TOTAL_TOKENS - cross_attention_mask_shape[0] = new_bs - cross_attention_mask_shape = torch.Size(cross_attention_mask_shape) - if cross_attention_mask is None: - cross_attention_mask = batch.cross_attention_mask.new_zeros( - cross_attention_mask_shape, - ) - cross_attention_mask[ - start_index:end_index, - left_offset:-right_padding, - ] = batch.cross_attention_mask[ - keep_indices, - : batch.input_length, - ] - - # Create empty tensor - # position_ids is always of shape [batch_size, 1] - if position_ids is None: - position_ids = batch.position_ids.new_empty((new_bs, 1)) - position_ids[start_index:end_index] = batch.position_ids[keep_indices, :] - - # Shenanigans to get dimensions because BLOOM outputs a past with a different shape - # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] - # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] - # And ensure that we can update tensors in-place - if isinstance(batch.past_key_values, tuple): - batch.past_key_values = [ - [t.view(batch.batch_size, -1, *t.shape[-2:]) for t in layer] - for layer in batch.past_key_values - ] - elif len(batch.past_key_values[0][0].shape) == 3: - for layer in batch.past_key_values: - for k, t in enumerate(layer): - layer[k] = t.view(batch.batch_size, -1, *t.shape[-2:]) - - start_index = end_index - - first_past_kvs = batches[0].past_key_values - _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape - past_key_values = [] - for layer_id in range(len(batches[0].past_key_values)): - if layer_id in CROSS_ATTENTION_LAYERS: - padded_past_keys_shape = list( - batches[0].past_key_values[layer_id][0].shape - ) - padded_past_keys_shape[0] = new_bs - padded_past_keys_shape = torch.Size(padded_past_keys_shape) - else: - padded_past_keys_shape = ( - new_bs, - num_heads, - MAX_TOTAL_TOKENS, - head_dim, - ) - - padded_past_keys = first_past_kvs[layer_id][0].new_zeros( - padded_past_keys_shape - ) - padded_past_values = first_past_kvs[layer_id][1].new_zeros( - padded_past_keys_shape - ) - start_index = 0 - for batch in batches: - keep_indices = [] - for req in batch.requests: - keep_indices.append(req.idx) - - left_offset = max_input_length - batch.input_length - right_padding = MAX_TOTAL_TOKENS - max_input_length - past_keys = batch.past_key_values[layer_id][0] - past_values = batch.past_key_values[layer_id][1] - # Clear reference to the original tensor - batch.past_key_values[layer_id] = None - - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the keys to remove the padding from previous batches - if layer_id in CROSS_ATTENTION_LAYERS: - padded_past_keys[start_index:end_index, :, :, :] = past_keys[ - keep_indices, :, :, : - ] - padded_past_values[start_index:end_index, :, :, :] = past_values[ - keep_indices, :, :, : - ] - - else: - padded_past_keys[ - start_index:end_index, :, left_offset:-right_padding, : - ] = past_keys[keep_indices, :, : batch.input_length, :] - padded_past_values[ - start_index:end_index, :, left_offset:-right_padding, : - ] = past_values[keep_indices, :, : batch.input_length, :] - - start_index = end_index - - past_key_values.append(tuple([padded_past_keys, padded_past_values])) - past_key_values = tuple(past_key_values) - - batch_id = batches[0].batch_id - top_n_tokens.extend([-1] * (new_bs - total_batch_size)) - fsm_grammar_states.extend([-1] * (new_bs - total_batch_size)) - - for idx, req in enumerate(requests): - req.idx = idx - - parameters = pad_next_token_chooser_parameters(parameters, new_bs) - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - parameters, - batches[0].next_token_chooser.dtype, - batches[0].next_token_chooser.device, - batches[0].next_token_chooser.tokenizer, - fsm_grammar_states, - quantization_enabled=hq_env.is_quantization_enabled, - ) - input_length = max_input_length - - htorch.core.mark_step() - - return cls( - batch_id=batch_id, - requests=requests, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - merged_kv_cache=False, - next_token_chooser=next_token_chooser, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - input_length=input_length, - pixel_values=None, - pixel_attention_mask=None, - image_sizes=None, - aspect_ratio_ids=None, - aspect_ratio_mask=None, - cross_attention_mask=cross_attention_mask, - prefilling=False, - ) - - -class VlmCausalLM(Model): - def __init__( - self, - model_class, - model_id: str, - *, - processor_class=AutoProcessor, - processor_kwargs=None, - batch_class=VlmCausalLMBatch, - revision, - quantize: Optional[str] = None, - dtype, - trust_remote_code: bool, - **kwargs, - ): - adapt_transformers_to_gaudi() - if processor_kwargs is None: - processor_kwargs = {} - self.processor = processor_class.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - **processor_kwargs, - ) - self.batch_class = batch_class - self.prev_bs = 0 - self.quantize = quantize - - # Create tokenizer - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - make_tokenizer_optional(tokenizer) - - # Create model - world_size = int(os.getenv("WORLD_SIZE", "1")) - rank = int(os.getenv("RANK", "0")) - dtype = torch.bfloat16 if dtype is None else dtype - device = torch.device("hpu") - - if hq_env.is_quantization_enabled: - htorch.core.hpu_set_env() - - # Get weight files - weight_files(model_id, revision=revision, extension=".safetensors") - - if world_size > 1: - os.environ.setdefault( - "DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1" - ) - model = self.get_deepspeed_model(model_class, model_id, dtype, revision) - model = hq_env.prepare_model_for_quantization(model) - else: - # Check support for rope scaling - model_kwargs = {} - config = AutoConfig.from_pretrained(model_id) - if hasattr(config, "rope_scaling"): - model_kwargs["rope_scaling"] = self.get_rope_scaling() - - model = model_class.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - trust_remote_code=trust_remote_code, - **model_kwargs, - ) - model = hq_env.prepare_model_for_quantization(model) - model = model.eval().to(device) - - self.enable_hpu_graph = ( - os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 - ) - self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "true").lower() == "true" - model = remove_kv_cache_from_output(model) - if self.enable_hpu_graph: - from habana_frameworks.torch.hpu import wrap_in_hpu_graph - - model = wrap_in_hpu_graph(model, disable_tensor_cache=True) - else: - if LAZY_MODE == 0: - # It is said that "keep_input_mutations" is safe for inference to be done - dbg_trace("TORCH COMPILE", "Torch compiling of model") - model.model = torch.compile( - model.model, - backend="hpu_backend", - options={"keep_input_mutations": True}, - ) - - model = hq_env.setup_quantization(model) - - if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: - raise ValueError(f"Model type {model.config.model_type} is not supported!") - - if tokenizer.pad_token_id is None: - if model.config.pad_token_id is not None: - tokenizer.pad_token_id = model.config.pad_token_id - elif model.config.eos_token_id is not None: - if isinstance(model.config.eos_token_id, int): - tokenizer.pad_token_id = model.config.eos_token_id - elif isinstance(model.config.eos_token_id, list): - tokenizer.pad_token_id = model.config.eos_token_id[0] - else: - raise ValueError( - f"{type(model.config.eos_token_id)} type of eos_token_id in the model's config is not supported for tokenizer.pad_token_id" - ) - elif tokenizer.eos_token_id is not None: - tokenizer.pad_token_id = tokenizer.eos_token_id - else: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - - self.kwargs = { - "use_cache": True, - "return_dict": True, - } - - if model.config.model_type in ["llava_next"]: - self.kwargs["attn_softmax_bf16"] = True - self.kwargs["trim_logits"] = True - - if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "true": - self.kwargs["use_flash_attention"] = True - if os.getenv("FLASH_ATTENTION_RECOMPUTE", "true").lower() == "true": - self.kwargs["flash_attention_recompute"] = True - - self.speculate = get_speculate() - if model.config.model_type == "mllama": - global CROSS_ATTENTION_LAYERS, BASE_IMAGE_TOKENS - CROSS_ATTENTION_LAYERS = model.config.text_config.cross_attention_layers - BASE_IMAGE_TOKENS = 0 - - super(VlmCausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - ) - - # Create profiler - ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(",")] - record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true" - output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile") - self.profiling_warmup_steps = ( - int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0 - ) - self.profiling_steps = ( - int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0 - ) - self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0")) - if self.profiling_steps > 0: - self.hb_profiler = HabanaProfile( - wait=self.profiling_wait_steps, - warmup=self.profiling_warmup_steps, - active=self.profiling_steps, - output_dir=output_dir, - record_shapes=record_shapes, - ) - self.hb_profiler.start() - else: - self.hb_profiler = None - self.step = 0 - - @property - def batch_type(self) -> Type[VlmCausalLMBatch]: - return self.batch_class - - def max_past(self) -> Optional[int]: - return getattr(self.model.text_model, "max_past", None) - - def get_deepspeed_model( - self, - model_class, - model_id: str, - dtype: torch.dtype, - revision: Optional[str] = None, - ) -> torch.nn.Module: - import deepspeed - from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu - - world_size, rank, local_rank = initialize_distributed_hpu() - model_kwargs = {"revision": revision} - - # Initialize process(es) for DeepSpeed - deepspeed.init_distributed(dist_backend="hccl") - logger.info( - "DeepSpeed is enabled. world_size {} rank {} local_rank {}".format( - world_size, rank, local_rank - ) - ) - config = AutoConfig.from_pretrained(model_id, **model_kwargs) - load_to_meta = model_on_meta(config) - - # Check support for rope scaling - if hasattr(config, "rope_scaling"): - config.rope_scaling = self.get_rope_scaling() - model_kwargs["rope_scaling"] = self.get_rope_scaling() - - if load_to_meta: - # Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load - with deepspeed.OnDevice(dtype=dtype, device="meta"): - model = model_class.from_config(config, torch_dtype=dtype) - else: - # TODO: revisit placement on CPU when auto-injection is possible - with deepspeed.OnDevice(dtype=dtype, device="cpu"): - model = model_class.from_pretrained( - model_id, torch_dtype=dtype, **model_kwargs - ) - model = model.eval() - - # Initialize the model - ds_inference_kwargs = {"dtype": dtype} - ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} - ds_inference_kwargs["enable_cuda_graph"] = False - ds_inference_kwargs["injection_policy"] = get_ds_injection_policy( - model.language_model.config - ) - - if load_to_meta: - # model loaded to meta is managed differently - checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") - checkpoint_files = [ - str(f) - for f in weight_files( - model_id, revision=revision, extension=".safetensors" - ) - ] - data = {"type": "ds_model", "checkpoints": checkpoint_files, "version": 1.0} - json.dump(data, checkpoints_json) - checkpoints_json.flush() - ds_inference_kwargs["checkpoint"] = checkpoints_json.name - model = deepspeed.init_inference(model, **ds_inference_kwargs) - - return model.module - - def get_rope_scaling(self) -> Optional[Dict]: - rope_scaling = os.getenv("ROPE_SCALING", None) - if rope_scaling is None: - return None - - rope_factor = float(os.getenv("ROPE_FACTOR", 1.0)) - return {"type": rope_scaling, "factor": float(rope_factor)} - - def decode(self, generated_ids: List[int]) -> str: - return self.tokenizer.decode( - generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) - - def decode_token( - self, - all_input_ids: List[int], - prefix_offset: int = 0, - read_offset: int = 0, - ) -> Tuple[str, int, int]: - if is_tokenizer_transparent(self.tokenizer): - new_text = self.tokenizer.decode( - all_input_ids[read_offset:], skip_special_tokens=False - ) - return new_text, read_offset, len(all_input_ids) - else: - return super().decode_token(all_input_ids, prefix_offset, read_offset) - - def forward( - self, - batch: VlmCausalLMBatch, - bypass_hpu_graph: Optional[bool] = None, - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - # Model Forward - kwargs = { - "input_ids": batch.input_ids, - "attention_mask": batch.attention_mask, - "past_key_values": batch.past_key_values, - "token_idx": batch.token_idx, - "pixel_values": batch.pixel_values, - } - - if self.model.config.model_type == "mllama": - kwargs["aspect_ratio_ids"] = batch.aspect_ratio_ids - kwargs["aspect_ratio_mask"] = batch.aspect_ratio_mask - kwargs["cross_attention_mask"] = batch.cross_attention_mask - else: - kwargs["image_sizes"] = batch.image_sizes - - hpu_kwargs = {} - # Optimum Habana got "lazy_mode" key-val only supported for llama type of models - if self.model.config.model_type == "llama": - hpu_kwargs["lazy_mode"] = LAZY_MODE == 1 - - if self.has_position_ids: - kwargs["position_ids"] = batch.position_ids - if bypass_hpu_graph is not None: - hpu_kwargs["bypass_hpu_graphs"] = bypass_hpu_graph - - kwargs.update(self.kwargs) - model_inputs = self.model.prepare_inputs_for_generation(**kwargs) - - if batch.past_key_values is not None: - return self.model.forward(**model_inputs, **hpu_kwargs) - else: - outputs = self.model.forward(**model_inputs, **hpu_kwargs) - return outputs.logits, outputs.past_key_values - - @tracer.start_as_current_span("generate_token") - def generate_token( - self, batches: list[VlmCausalLMBatch], is_warmup: bool = False - ) -> Tuple[List[Generation], Optional[VlmCausalLMBatch], Tuple[int, int]]: - - start = time.time_ns() - # Results - generations: List[Generation] = [] - prev_batches = [] - requests_to_generate = [] - # In order to pipeline any actions on CPU we perform the operation in 3 main stages: - # Stage 1. Collect next token ids of any previously started generations - for batch_id, batch in enumerate(batches): - if batch.logits is not None: - logits = batch.logits - past = batch.past - prefill = batch.past_key_values is None - if prefill: - # no right padding for prefill - token_idx_scalar = batch.attention_mask.shape[-1] - 1 - token_idx = torch.tensor(token_idx_scalar).to(self.device) - else: - token_idx_scalar = ( - batch.attention_mask.shape[-1] - batch.right_padding - ) - token_idx = torch.tensor(token_idx_scalar).to(self.device) - - # Select next token - input_length = batch.input_length - if logits.shape[-2] > 1: - next_token_ids, next_token_logprobs, logprobs, _, _ = ( - batch.next_token_chooser( - batch.input_ids, - logits[:, input_length - 1 : input_length, :].squeeze(-2), - self.speculate, - ) - ) - else: - next_token_ids, next_token_logprobs, logprobs, _, _ = ( - batch.next_token_chooser( - batch.input_ids, logits.squeeze(-2), self.speculate - ) - ) - # Speculation is not active for causal - accepted_ids = torch.ones_like(batch.input_ids)[:, 0] - batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, - batch.top_n_tokens_tensor, - logprobs, - accepted_ids, - ) - - prev_batches.append( - { - "next_token_ids": next_token_ids, - "next_token_logprobs": next_token_logprobs, - } - ) - - for req_idx, req in enumerate(batch.requests): - requests_to_generate.append( - { - "req": req, - "prev_req_idx": req.idx, - "batch_id": batch_id, - "seed": batch.next_token_chooser.seeds[req_idx], - "do_sample": batch.next_token_chooser.do_sample[req_idx], - "top_n_tokens": batch.top_n_tokens[req_idx], - "top_token_ids": batch_top_token_ids[req_idx], - "top_token_logprobs": batch_top_token_logprobs[req_idx], - "grammar_state": batch.next_token_chooser.fsm_grammar_states[ - req.idx - ], - } - ) - - htorch.core.mark_step() - - # Add new token into input_ids - batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1)) - - # Update attention_mask as we added a new token to input_ids - batch.attention_mask.index_fill_(1, token_idx, 1) - - # add cross-attn mask for new token - if batch.cross_attention_mask is not None: - cross_attention_mask_prev = batch.cross_attention_mask - if token_idx is not None: - mask = cross_attention_mask_prev[ - :, token_idx - 2 : token_idx - 1, ... - ] - cross_attention_mask_prev.index_copy_(1, token_idx - 1, mask) - batch.cross_attention_mask = cross_attention_mask_prev - - # Adjust lengths - batch.input_length += 1 - # Update position_ids - if prefill: - batch.position_ids = ( - torch.index_select(batch.position_ids, 1, token_idx - 1) + 1 - ) - else: - batch.position_ids += 1 - # Update past key values - if prefill: - batch.past_key_values = past - - htorch.core.mark_step() - - # Stage 2. Prepare new batch for speculative scheduling - if len(batches) > 1: - batch = self.batch_type.concatenate( - batches, self.tokenizer.pad_token_id, is_warmup - ) - else: - batch = batches[0] - - prefill = batch.past_key_values is None - - # Check if we need to do any bookkeeping first - if not prefill: - batch = self.batch_type.recombine( - [batch], self.tokenizer.pad_token_id, is_warmup - ) - - scenario = "PREFILL" if prefill else "GENERATE" - if ( - self.enable_hpu_graph - and self.limit_hpu_graph - and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) - != self.prev_bs - ): - self.model.clear_cache() - self.prev_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) - dbg_trace( - scenario, - f"bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}", - ) - # assert batch.right_padding > 0, 'No more room for next token!' - - # Execute batch - if prefill: - # no right padding for prefill - # token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device) - batch.logits, batch.past = self.forward( - batch, - bypass_hpu_graph=( - prefill and self.limit_hpu_graph if self.enable_hpu_graph else None - ), - ) - - elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]): - # Don't schedule next forward if max_new_tokens for all requests equals 1 - # - we've already generated the first and only needed token in the prefill phase - pass - else: - # token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) - batch.logits = self.forward( - batch, - bypass_hpu_graph=( - prefill and self.limit_hpu_graph if self.enable_hpu_graph else None - ), - ) - - if batch.pixel_values is not None: - batch.pixel_values = None - if batch.aspect_ratio_ids is not None: - batch.aspect_ratio_ids = None - if batch.aspect_ratio_mask is not None: - batch.aspect_ratio_mask = None - - htorch.core.mark_step() - - start_decode = time.time_ns() - - # Stage 3. Finish and return previous generations - stopped = len(requests_to_generate) > 0 - for prev_batch in prev_batches: - prev_batch["next_token_logprobs"] = prev_batch[ - "next_token_logprobs" - ].tolist() - prev_batch["next_token_ids_cpu"] = prev_batch["next_token_ids"].cpu() - htorch.core.mark_step() - - for req_data in requests_to_generate: - req = req_data["req"] - i = req_data["prev_req_idx"] - prev_batch_id = req_data["batch_id"] - assert len(prev_batches) > prev_batch_id - next_token_ids_cpu = prev_batches[prev_batch_id]["next_token_ids_cpu"] - next_token_logprobs = prev_batches[prev_batch_id]["next_token_logprobs"] - - request = req.data - input_length = req.input_length - prefix_offset = req.prefix_offset - read_offset = req.read_offset - do_sample = req_data["do_sample"] - seed = req_data["seed"] - stopping_criteria = req.stopping_criteria - all_input_ids = req.all_input_ids - next_token_id = next_token_ids_cpu[i] - next_token_logprob = next_token_logprobs[i] - top_n_tokens = req_data["top_n_tokens"] - top_token_ids = req_data["top_token_ids"] - top_token_logprobs = req_data["top_token_logprobs"] - grammar_state = req_data["grammar_state"] - - # Append next token to all tokens - all_input_ids[input_length] = next_token_id - new_input_length = input_length + 1 - - # Generated token - if ( - is_tokenizer_transparent(self.tokenizer) - and len(stopping_criteria.stop_sequence_criterias) == 0 - ): - next_token_text = "" - else: - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids[0:new_input_length, 0], prefix_offset, read_offset - ) - - # Evaluate stopping criteria - stop, reason = stopping_criteria( - next_token_id, - next_token_text, - ) - - if not stop: - stopped = False - - # Shard generations - # All generations will be appended in the rust sharded client - if i % self.world_size == self.rank: - if stop: - # Decode generated tokens - if is_tokenizer_transparent(self.tokenizer): - output_text = None - else: - output_text = self.decode( - all_input_ids[ - new_input_length - - stopping_criteria.current_tokens : new_input_length, - 0, - ] - ) - generated_text = GeneratedText( - output_text, - stopping_criteria.current_tokens, - reason, - seed if do_sample else None, - ) - else: - generated_text = None - - # Prefill - if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: - # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + next_token_logprobs - prefill_token_ids = all_input_ids[0 : new_input_length - 1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = Tokens( - prefill_token_ids, - prefill_logprobs, - prefill_texts, - is_special=[], - ) - else: - prefill_tokens = None - - if top_n_tokens > 0: - all_top_tokens = [] - for top_token_ids, top_token_logprobs in zip( - top_token_ids, top_token_logprobs - ): - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - special_toptokens = [ - token_id in self.all_special_ids - for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, - ) - all_top_tokens.append(top_tokens) - top_tokens = all_top_tokens - else: - top_tokens = None - - generation = Generation( - request.id, - prefill_tokens, - Tokens( - [next_token_id], - [next_token_logprob], - [next_token_text], - [next_token_id in self.all_special_ids], - ), - generated_text, - top_tokens, - ) - - generations.append(generation) - - batch.next_token_chooser = ( - batch.next_token_chooser.advance_grammar_single_with_past_state( - req.idx, next_token_id, grammar_state - ) - ) - - req.all_input_ids = all_input_ids - req.input_length = new_input_length - req.prefix_offset = prefix_offset - req.read_offset = read_offset - - htorch.core.mark_step() - self.step = self.step + 1 - if self.hb_profiler is not None: - if ( - self.step - > self.profiling_wait_steps - + self.profiling_warmup_steps - + self.profiling_steps - ): - self.hb_profiler.stop() - else: - self.hb_profiler.step() - - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, batch if not stopped else None, (forward_ns, decode_ns) - - def batch_from_pb(self, batch, is_warmup): - return self.batch_type.from_pb_processor( - batch, - self.tokenizer, - self.processor, - self.model.config, - self.dtype, - self.device, - is_warmup, - ) - - def generate_warmup_batch(self, request, seq_len, batch_size, is_warmup): - batch = copy.deepcopy(request.batch) - for req in batch.requests: - req.truncate = seq_len - - for i in range(len(batch.requests) - batch_size): - batch.requests.pop() - - return self.batch_from_pb(batch, is_warmup) - - def warmup( - self, request: generate_pb2.WarmupRequest - ) -> Tuple[Optional[int], Optional[int], Optional[int]]: - global MAX_TOTAL_TOKENS - MAX_TOTAL_TOKENS = request.max_total_tokens - batch = self.batch_from_pb(request.batch, is_warmup=True) - max_input_tokens = request.max_input_tokens - max_prefill_batch_size = batch.input_ids.shape[0] - max_batch_size_str = os.environ.get("MAX_BATCH_SIZE") - if max_batch_size_str is not None: - MAX_BATCH_SIZE = int(max_batch_size_str) - else: - raise ValueError("MAX_BATCH_SIZE is not set") - - try: - # max prefill batch size warmup - _, prefill_batch, _ = self.generate_token([batch], is_warmup=True) - except Exception: - raise RuntimeError( - f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " - f"You need to decrease `--max-batch-prefill-tokens`" - ) - - global BASE_IMAGE_TOKENS, PREFILL_WARMUP_BATCH_SIZE_LIST, PREFILL_WARMUP_SEQLEN_LIST, DECODE_WARMUP_BATCH_SIZE_LIST - PREFILL_WARMUP_BATCH_SIZE_LIST = [] - batch_size = 1 - while batch_size <= max_prefill_batch_size: - PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size) - batch_size = batch_size * 2 - if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size: - PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size) - - if self.model.config.model_type == "mllama": - seq_len = PAD_SEQUENCE_TO_MULTIPLE_OF - else: - seq_len = BASE_IMAGE_TOKENS - - PREFILL_WARMUP_SEQLEN_LIST = [] - i = 0 - while seq_len <= max_input_tokens: - PREFILL_WARMUP_SEQLEN_LIST.append(seq_len) - seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF * (2**i) - i += 1 - if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_tokens: - PREFILL_WARMUP_SEQLEN_LIST.append(max_input_tokens) - - # Prefill and decode warmup - DECODE_WARMUP_BATCH_SIZE_LIST = [] - prefill_batch = None - decode_batch = None - logger.info( - f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}" - f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}" - ) - try: - for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST: - for seq_len in PREFILL_WARMUP_SEQLEN_LIST: - logger.info( - f"Prefill warmup for `batch_size={batch_size}` and `sequence_length={seq_len}`, this may take a while..." - ) - batch = self.generate_warmup_batch( - request, seq_len, batch_size, is_warmup=True - ) - _, prefill_batch, _ = self.generate_token([batch], is_warmup=True) - assert prefill_batch is not None - _, decode_batch, _ = self.generate_token( - [prefill_batch], is_warmup=True - ) - - DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) - - except Exception: - raise RuntimeError( - "Not enough memory to handle following prefill and decode warmup." - "You need to decrease `--max-batch-prefill-tokens`" - ) - - mem_stats = get_hpu_memory_stats(self.device) - logger.info(f"Prefill warmup successful.\n" f"Memory stats: {mem_stats} ") - - max_decode_batch_size = MAX_BATCH_SIZE - batch_size = max_prefill_batch_size * 2 - logger.info(f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n") - # Decode warmup with bigger batch_size - try: - if ( - DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size - and batch_size <= max_decode_batch_size - ): - batches = [] - while batch_size <= max_decode_batch_size: - for i in range(int(batch_size / max_prefill_batch_size)): - logger.info( - f"Decode warmup for `batch_size={batch_size}`, this may take a while..." - ) - batch = self.generate_warmup_batch( - request, - PREFILL_WARMUP_SEQLEN_LIST[0] - 1, - max_prefill_batch_size, - is_warmup=True, - ) - _, prefill_batch, _ = self.generate_token( - [batch], is_warmup=True - ) - batches.append(prefill_batch) - - _, decode_batch, _ = self.generate_token(batches, is_warmup=True) - DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) - batch_size = batch_size * 2 - batches.clear() - - if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size: - max_decode_batch_size = math.floor(max_decode_batch_size / 2) * 2 - batch_size = max_decode_batch_size - for i in range(int(max_decode_batch_size / 2)): - batch = self.generate_warmup_batch( - request, - PREFILL_WARMUP_SEQLEN_LIST[0] - 1, - 2, - is_warmup=True, - ) - _, prefill_batch, _ = self.generate_token( - [batch], is_warmup=True - ) - batches.append(prefill_batch) - _, decode_batch, _ = self.generate_token(batches, is_warmup=True) - DECODE_WARMUP_BATCH_SIZE_LIST.append(max_decode_batch_size) - - except Exception: - raise RuntimeError( - f"Not enough memory to handle batch_size({batch_size}) decode warmup." - f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}" - f"max_decode_batch_size is {max_decode_batch_size}" - f"You need to decrease env `MAX_BATCH_SIZE` or '--max_batch_size'" - ) - - mem_stats = get_hpu_memory_stats(self.device) - logger.info(f"Decode warmup successful.\n" f"Memory stats: {mem_stats}") - - max_supported_total_tokens = MAX_BATCH_SIZE * MAX_TOTAL_TOKENS - max_input_tokens = max_input_tokens - max_total_tokens = MAX_TOTAL_TOKENS - - return max_supported_total_tokens, max_input_tokens, max_total_tokens diff --git a/backends/gaudi/server/text_generation_server/server.py b/backends/gaudi/server/text_generation_server/server.py index 5a7d2117..f5080ec3 100644 --- a/backends/gaudi/server/text_generation_server/server.py +++ b/backends/gaudi/server/text_generation_server/server.py @@ -23,26 +23,8 @@ from text_generation_server.models.globals import set_adapter_to_index from text_generation_server.utils.adapter import AdapterInfo from text_generation_server.utils.tokens import make_tokenizer_optional from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens +from text_generation_server.models import VLM_BATCH_TYPES -try: - from text_generation_server.models.pali_gemma import PaliGemmaBatch - from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch - from text_generation_server.models.vlm_causal_lm import ( - VlmCausalLMBatch, - ) - from text_generation_server.models.flash_vlm_causal_lm import ( - FlashVlmCausalLMBatch, - ) - - VLM_BATCH_TYPES = { - PaliGemmaBatch, - VlmCausalLMBatch, - FlashVlmCausalLMBatch, - FlashMllamaCausalLMBatch, - } -except (ImportError, NotImplementedError): - # These imports can fail on CPU/Non flash. - VLM_BATCH_TYPES = set() from text_generation_server.utils.version import ( is_driver_compatible, MIN_TGI_GAUDI_SYNAPSE_VERSION, @@ -224,6 +206,7 @@ def serve( quantize: Optional[str], speculate: Optional[int], dtype: Optional[str], + kv_cache_dtype: Optional[str], trust_remote_code: bool, uds_path: Path, max_input_tokens: int, @@ -236,6 +219,7 @@ def serve( quantize: Optional[str] = None, speculate: Optional[int] = None, dtype: Optional[str] = None, + kv_cache_dtype: Optional[str] = None, trust_remote_code: bool = False, ): if not is_driver_compatible(): @@ -279,6 +263,7 @@ def serve( quantize, speculate, data_type, + kv_cache_dtype, trust_remote_code, max_input_tokens, adapter_to_index, @@ -326,6 +311,7 @@ def serve( quantize, speculate, dtype, + kv_cache_dtype, trust_remote_code, ) ) diff --git a/backends/gaudi/server/text_generation_server/tgi_service.py b/backends/gaudi/server/text_generation_server/tgi_service.py index 18e88a7e..12317127 100644 --- a/backends/gaudi/server/text_generation_server/tgi_service.py +++ b/backends/gaudi/server/text_generation_server/tgi_service.py @@ -31,6 +31,7 @@ def main(args): trust_remote_code=args.trust_remote_code, uds_path=args.uds_path, max_input_tokens=args.max_input_tokens, + kv_cache_dtype="auto", ) diff --git a/backends/gaudi/server/text_generation_server/utils/debug.py b/backends/gaudi/server/text_generation_server/utils/debug.py index 8bbcad6a..690da54e 100644 --- a/backends/gaudi/server/text_generation_server/utils/debug.py +++ b/backends/gaudi/server/text_generation_server/utils/debug.py @@ -4,8 +4,8 @@ import os import glob import time -from optimum.habana.utils import to_gb_rounded import habana_frameworks.torch as htorch +import numpy as np START_TS = None DBG_TRACE_FILENAME = os.environ.get("DBG_TRACE_FILENAME") @@ -14,6 +14,19 @@ if "GRAPH_VISUALIZATION" in os.environ: os.remove(f) +def to_gb_rounded(mem: float) -> float: + """ + Rounds and converts to GB. + + Args: + mem (float): memory in bytes + + Returns: + float: memory in GB rounded to the second decimal + """ + return np.round(mem / 1024**3, 2) + + def count_hpu_graphs(): return len(glob.glob(".graph_dumps/*PreGraph*")) diff --git a/backends/gaudi/server/text_generation_server/utils/import_utils.py b/backends/gaudi/server/text_generation_server/utils/import_utils.py index 22560dd7..bdcfc9fa 100644 --- a/backends/gaudi/server/text_generation_server/utils/import_utils.py +++ b/backends/gaudi/server/text_generation_server/utils/import_utils.py @@ -1,18 +1,9 @@ import torch -from loguru import logger def get_hpu_free_memory(device, memory_fraction): - from habana_frameworks.torch.hpu import memory_stats - - device_id = device.index - mem_stats = memory_stats(device_id) - logger.info(f"mem_stats: {mem_stats}") - total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"] - free_memory = max( - 0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"]) - ) - return free_memory + free_hpu_memory, _ = torch.hpu.mem_get_info() + return free_hpu_memory def synchronize_hpu(device): diff --git a/backends/gaudi/server/text_generation_server/utils/quantization.py b/backends/gaudi/server/text_generation_server/utils/quantization.py index a8faf4a5..192963c4 100644 --- a/backends/gaudi/server/text_generation_server/utils/quantization.py +++ b/backends/gaudi/server/text_generation_server/utils/quantization.py @@ -1,7 +1,7 @@ import json import os from dataclasses import dataclass -from typing import Optional +from typing import Optional, List from huggingface_hub import hf_hub_download from text_generation_server.utils.weights import ( @@ -18,6 +18,8 @@ class _QuantizerConfig: groupsize: int quant_method: str sym: bool + weight_block_size: Optional[List[int]] + modules_to_not_convert: List[str] @dataclass @@ -25,7 +27,20 @@ class _FP8QuantizerConfig: activation_scale_ub: float -# We should probably do this with Pytantic JSON deserialization, +def _get_config_json(model_id: str, revision: Optional[str], filename: str): + if os.path.exists( + os.path.join( + model_id, + ) + ): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download(model_id, filename=filename, revision=revision) + with open(filename, "r") as f: + return json.load(f) + + +# We should probably do this with Pydantic JSON deserialization, # but for now we'll stay close to the old _set_gptq_params. def _get_quantizer_config(model_id, revision): bits = 4 @@ -34,21 +49,18 @@ def _get_quantizer_config(model_id, revision): checkpoint_format = None sym = False desc_act = False + weight_block_size = None + modules_to_not_convert = [] filename = "config.json" try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download(model_id, filename=filename, revision=revision) - with open(filename, "r") as f: - data = json.load(f) - + data = _get_config_json(model_id, revision, filename) # FP8 config if data["quantization_config"]["quant_method"] == "fbgemm_fp8": return _FP8QuantizerConfig( activation_scale_ub=data["quantization_config"]["activation_scale_ub"] ) + weight_block_size = data["quantization_config"].get("weight_block_size", None) if "zero_point" in data["quantization_config"]: sym = not data["quantization_config"]["zero_point"] @@ -61,18 +73,16 @@ def _get_quantizer_config(model_id, revision): # Order is important here, desc_act is missing on some real models quant_method = data["quantization_config"]["quant_method"] checkpoint_format = data["quantization_config"].get("checkpoint_format") - desc_act = data["quantization_config"]["desc_act"] + desc_act = data["quantization_config"].get("desc_act", False) + modules_to_not_convert = data["quantization_config"].get( + "modules_to_not_convert", [] + ) + if modules_to_not_convert is None: + modules_to_not_convert = [] except Exception: filename = "quantize_config.json" try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download( - model_id, filename=filename, revision=revision - ) - with open(filename, "r") as f: - data = json.load(f) + data = _get_config_json(model_id, revision, filename) bits = data["bits"] groupsize = data["group_size"] @@ -88,14 +98,7 @@ def _get_quantizer_config(model_id, revision): except Exception: filename = "quant_config.json" try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download( - model_id, filename=filename, revision=revision - ) - with open(filename, "r") as f: - data = json.load(f) + data = _get_config_json(model_id, revision, filename) bits = data["w_bit"] groupsize = data["q_group_size"] desc_act = data["desc_act"] @@ -111,12 +114,21 @@ def _get_quantizer_config(model_id, revision): checkpoint_format=checkpoint_format, sym=sym, desc_act=desc_act, + weight_block_size=weight_block_size, + modules_to_not_convert=modules_to_not_convert, ) def get_loader( quantize: Optional[str], model_id: str, revision: Optional[str] ) -> WeightsLoader: + if quantize == "compressed-tensors": + config = _get_config_json(model_id, revision, "config.json") + from text_generation_server.layers.compressed_tensors import ( + CompressedTensorsLoader, + ) + + return CompressedTensorsLoader(config) quantizer_config = _get_quantizer_config(model_id, revision) if quantize in {"awq", "gptq"}: from text_generation_server.layers.gptq import GPTQWeightsLoader @@ -134,6 +146,7 @@ def get_loader( quant_method=quantizer_config.quant_method, quantize=quantize, sym=quantizer_config.sym, + modules_to_not_convert=quantizer_config.modules_to_not_convert, ) elif quantize == "fp8" or quantize is None: from text_generation_server.layers.fp8 import HybridFP8UnquantLoader @@ -141,9 +154,14 @@ def get_loader( # Since the default for the quantize config is _QuantizerConfig, # we need to add this check to not get an attribute error activation_scale_ub = None + weight_block_size = quantizer_config.weight_block_size if isinstance(quantizer_config, _FP8QuantizerConfig): activation_scale_ub = quantizer_config.activation_scale_ub - return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8") + return HybridFP8UnquantLoader( + activation_scale_ub, + to_fp8=quantize == "fp8", + weight_block_size=weight_block_size, + ) else: raise ValueError(f"Unknown quantization method: {quantize}") diff --git a/backends/gaudi/server/text_generation_server/utils/version.py b/backends/gaudi/server/text_generation_server/utils/version.py index f54b6ae8..74b53dfa 100644 --- a/backends/gaudi/server/text_generation_server/utils/version.py +++ b/backends/gaudi/server/text_generation_server/utils/version.py @@ -1,5 +1,30 @@ -from optimum.habana.utils import get_driver_version from packaging.version import Version +from packaging import version +import subprocess + + +def get_driver_version(): + """ + Returns the driver version. + """ + # Enable console printing for `hl-smi` check + output = subprocess.run( + "hl-smi", + shell=True, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env={"ENABLE_CONSOLE": "true"}, + ) + if output.returncode == 0 and output.stdout: + return version.parse( + output.stdout.split("\n")[2] + .replace(" ", "") + .split(":")[1][:-1] + .split("-")[0] + ) + return None + MIN_TGI_GAUDI_SYNAPSE_VERSION = Version("1.19.0") diff --git a/backends/gaudi/server/text_generation_server/utils/weights.py b/backends/gaudi/server/text_generation_server/utils/weights.py index acd598d7..4edae0d4 100644 --- a/backends/gaudi/server/text_generation_server/utils/weights.py +++ b/backends/gaudi/server/text_generation_server/utils/weights.py @@ -62,6 +62,14 @@ class WeightsLoader(ABC): """ ... + @abstractmethod + def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int): + """ + Get the weights at the given prefixes, column-split them for tensor + parallelim, and then concatenate the weights along the given dimension. + """ + ... + @abstractmethod def get_weights_row(self, weights: "Weights", prefix: str): """ @@ -130,6 +138,10 @@ class DefaultWeightsLoader(WeightsLoader): weights.get_sharded(f"{prefix}.weight", dim=1), ) + def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int): + w = [weights.get_tensor(f"{p}.weight") for p in prefixes] + return self.weight_class(torch.cat(w, dim=dim)) + class Weights: def __init__( @@ -303,7 +315,7 @@ class Weights: world_size = self.process_group.size() rank = self.process_group.rank() - tensors = [] + tensors_slices = [] block_offset = 0 for block_size in block_sizes: assert ( @@ -312,15 +324,18 @@ class Weights: shard_block_size = block_size // world_size start = rank * shard_block_size stop = (rank + 1) * shard_block_size - if dim == 0: - tensor = slice_[block_offset + start : block_offset + stop] - elif dim == 1: - tensor = slice_[:, block_offset + start : block_offset + stop] - else: - raise NotImplementedError("Currently only dim=0 or dim=1 is supported") - tensors.append(tensor) + tensors_slices += range(block_offset + start, block_offset + stop) block_offset += block_size - tensor = torch.cat(tensors, dim=dim) + + if dim == 0: + tensor = slice_[tensors_slices, ...] + elif dim == 1 or dim == -2: + tensor = slice_[:, tensors_slices, ...] + elif dim == 2 or dim == -1: + tensor = slice_[..., tensors_slices] + else: + raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported") + tensor = tensor.to(device=self.device) # Avoid casting quantizer dtypes. @@ -390,6 +405,9 @@ class Weights: def get_weights_row(self, prefix: str): return self.weights_loader.get_weights_row(self, prefix) + def get_multi_weights(self, prefixes: List[str], dim: int): + return self.weights_loader.get_multi_weights(self, prefixes, dim) + @contextmanager def use_loader(self, weights_loader: WeightsLoader): """ diff --git a/backends/neuron/server/text_generation_server/generator.py b/backends/neuron/server/text_generation_server/generator.py index b3887e14..10a4d7a2 100644 --- a/backends/neuron/server/text_generation_server/generator.py +++ b/backends/neuron/server/text_generation_server/generator.py @@ -7,7 +7,8 @@ from typing import List, Optional, Tuple import torch from loguru import logger -from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase +from transformers import AutoTokenizer, PreTrainedTokenizerBase +from optimum.neuron.configuration_utils import NeuronConfig from transformers.generation import GenerationConfig from optimum.neuron import NeuronModelForCausalLM @@ -175,6 +176,12 @@ class Slot: self._generation_config.top_p = request.parameters.top_p if request.parameters.typical_p != 0: self._generation_config.typical_p = request.parameters.typical_p + else: + # Set the sampling parameters to emulate greedy decoding when using on-device sampling + self._generation_config.temperature = 1.0 + self._generation_config.top_k = 1 + self._generation_config.top_p = 1.0 + self._generation_config.typical_p = 1.0 if request.parameters.repetition_penalty != 0: self._generation_config.repetition_penalty = ( request.parameters.repetition_penalty @@ -211,19 +218,11 @@ class Slot: self._mask = attention_mask.clone() self._selector = selector - def pause(self, reset_on_pause: bool): + def pause(self): """Mark the current slot as paused for generation. Note that the KV cache for this slot will still be filled. """ - if reset_on_pause: - # Drop the last token as it will be added back when resuming the slot - self._generated_tokens -= 1 - # Since generated tokens are now part of the prefill, we need to reevaluate - # max_new_tokens for the next generation - self._generation_config.max_new_tokens = ( - self._max_new_tokens - self._generated_tokens - ) self._state = Slot.State.PAUSE def resume(self): @@ -340,16 +339,27 @@ class NeuronGenerator(Generator): tokenizer: PreTrainedTokenizerBase, ): self.model = model - self.rebuild_cache_on_prefill = not self.model.continuous_batching + if not isinstance(self.model, NeuronModelForCausalLM): + raise ValueError("The model must be a NeuronModelForCausalLM.") + if not model.neuron_config.continuous_batching: + raise ValueError( + "The neuron model must be compiled with continuous_batching=True." + ) # Specify padding and truncation options for decoder-only architecture tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = "left" tokenizer.truncation_side = "left" self.tokenizer = tokenizer self.special_tokens = self.tokenizer.all_special_ids - self.slots = [Slot(i, tokenizer) for i in range(self.model.batch_size)] + self.slots = [ + Slot(i, tokenizer) for i in range(self.model.neuron_config.batch_size) + ] self.batch_id = 0 + @property + def on_device_sampling(self) -> bool: + return getattr(self.model.neuron_config, "on_device_sampling", False) + @property def info(self) -> InfoResponse: """Returns the expected InfoResponse.""" @@ -371,14 +381,22 @@ class NeuronGenerator(Generator): The maximum number of tokens the model supports. """ # Just check that the warmup request parameters match the model capacity - batch_size = self.model.batch_size + batch_size = self.model.neuron_config.batch_size if len(batch.requests) > batch_size: raise ValueError( - f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE." + f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model.neuron_config.batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE." ) self.prefill(batch) self.clear() - return self.model.batch_size * self.model.max_length + return ( + self.model.neuron_config.batch_size + * self.model.neuron_config.sequence_length + ) + + def max_prefill_length(self) -> int: + if hasattr(self.model.neuron_config, "max_context_length"): + return self.model.neuron_config.max_context_length + return self.model.neuron_config.sequence_length def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: """Prefill new requests. @@ -398,7 +416,7 @@ class NeuronGenerator(Generator): if len(empty_slots) < len(batch.requests): raise ValueError( f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots." - f" Please align max_batch_size with the static batch size: {self.model.batch_size}." + f" Please align max_batch_size with the static batch size: {self.model.neuron_config.batch_size}." ) # Assign each request to an empty slot logger.debug( @@ -412,14 +430,8 @@ class NeuronGenerator(Generator): logger.debug( f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}" ) - if self.rebuild_cache_on_prefill: - # We will clear pending slots and prefill all slots - prefill_slots = self.slots - seq_ids = None - else: - # We only need to pass inputs for the new requests - prefill_slots = new_slots - seq_ids = torch.tensor([slot.id for slot in prefill_slots]) + prefill_slots = new_slots + seq_ids = torch.tensor([slot.id for slot in prefill_slots]) # Reconstruct the full inputs (without padding) as seen by the model. # This comprises: # - the inputs for new requests, @@ -431,8 +443,10 @@ class NeuronGenerator(Generator): inputs.append(slot.cached_text) # Apply truncation, making sure we fit into static dimensions if slot.truncate == 0: - max_length = self.model.max_length - elif slot.truncate > max_length and slot.truncate < self.model.max_length: + max_length = self.max_prefill_length() + elif ( + slot.truncate > max_length and slot.truncate < self.max_prefill_length() + ): max_length = slot.truncate # Tokenize with padding and truncation padded_inputs = self.tokenizer( @@ -444,13 +458,12 @@ class NeuronGenerator(Generator): ) input_ids = padded_inputs.input_ids attention_mask = padded_inputs.attention_mask + sampling_params = ( + torch.zeros(input_ids.shape[0], 3) if self.on_device_sampling else None + ) # Pause previously active slots during generation - next_tokens = [] for slot in active_slots: - slot.pause(reset_on_pause=self.rebuild_cache_on_prefill) - if self.rebuild_cache_on_prefill: - # The slot will be reset, so we need to store its next token - next_tokens.append(slot.next_token) + slot.pause() # Each slot must be reset with the padded inputs and masks for i, slot in enumerate(prefill_slots): if slot.state != slot.state.EMPTY: @@ -464,29 +477,33 @@ class NeuronGenerator(Generator): slot_input_ids, slot.generation_config, self.model, - self.model.max_length, + self.model.neuron_config.sequence_length, tokenizer=self.tokenizer, seed=slot.seed, ) slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64) slot_attention_mask = attention_mask[i] slot.reset(slot_input_ids, slot_attention_mask, selector) + if sampling_params is not None: + sampling_params[i, 0] = slot.generation_config.top_k + sampling_params[i, 1] = slot.generation_config.top_p + sampling_params[i, 2] = slot.generation_config.temperature # Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored, # as they have already been generated and sent back in the last decode. model_inputs = self.model.prepare_inputs_for_prefill( - input_ids, attention_mask, seq_ids + input_ids, + attention_mask=attention_mask, + seq_ids=seq_ids, + sampling_params=sampling_params, ) - logits = self.model(**model_inputs)[0] + tokens_or_logits = self.model(**model_inputs)[0] generation, next_batch = self._generate_token( - prefill_slots, self.batch_id, logits, input_ids + prefill_slots, self.batch_id, tokens_or_logits, input_ids ) self.batch_id += 1 # Reactivate previously active slots for the next decode for i, slot in enumerate(active_slots): slot.resume() - if self.rebuild_cache_on_prefill: - # Append back the next token - slot.append(next_tokens[i]) logger.debug("Model ready for decoding") if next_batch is not None: logger.debug( @@ -530,12 +547,8 @@ class NeuronGenerator(Generator): raise ValueError( "Unable to decode tokens for non-prefilled batches (probably due to a previous failure)" ) - if self.model.continuous_batching: - decode_slots = active_slots - seq_ids = torch.tensor([slot.id for slot in decode_slots]) - else: - decode_slots = self.slots - seq_ids = None + decode_slots = active_slots + seq_ids = torch.tensor([slot.id for slot in decode_slots]) # Reconstruct input_ids and attention_mask from decode slots n_slots = len(decode_slots) input_ids = torch.full( @@ -545,22 +558,32 @@ class NeuronGenerator(Generator): for slot in decode_slots: max_length = max(max_length, slot.attention_mask.size(-1)) attention_mask = torch.zeros([n_slots, max_length], dtype=torch.int64) + sampling_params = torch.zeros(n_slots, 3) if self.on_device_sampling else None for i, slot in enumerate(decode_slots): if slot.state != Slot.State.EMPTY: # input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached) input_ids[i, 0] = slot.next_token attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask + if sampling_params is not None: + sampling_params[i, 0] = slot.generation_config.top_k + sampling_params[i, 1] = slot.generation_config.top_p + sampling_params[i, 2] = slot.generation_config.temperature model_inputs = self.model.prepare_inputs_for_decode( - input_ids, attention_mask, seq_ids + input_ids, + attention_mask=attention_mask, + seq_ids=seq_ids, + sampling_params=sampling_params, + ) + tokens_or_logits = self.model(**model_inputs)[0] + return self._generate_token( + decode_slots, next_batch_id, tokens_or_logits, input_ids ) - logits = self.model(**model_inputs)[0] - return self._generate_token(decode_slots, next_batch_id, logits, input_ids) def _generate_token( self, slots: List[Slot], next_batch_id: int, - logits: torch.Tensor, + tokens_or_logits: torch.Tensor, input_ids: torch.LongTensor, ) -> Tuple[List[Generation], CachedBatch]: generations = [] @@ -569,9 +592,12 @@ class NeuronGenerator(Generator): if slot.state != Slot.State.READY: continue request_id = slot.request_id - next_token_logits = logits[i : i + 1, -1, :] slot_input_ids = input_ids[i : i + 1, :] - next_token = slot.select(slot_input_ids, next_token_logits) + if self.on_device_sampling: + next_token = tokens_or_logits[i] + else: + next_token_logits = tokens_or_logits[i : i + 1, -1, :] + next_token = slot.select(slot_input_ids, next_token_logits) next_token_text = slot.append(next_token) generated_text = None finish_reason = None @@ -622,7 +648,7 @@ class NeuronGenerator(Generator): def _cached_batch(self, batch_id: int, request_ids: List): size = len(request_ids) - max_tokens = size * self.model.max_length + max_tokens = size * self.model.neuron_config.sequence_length return CachedBatch( id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens ) @@ -671,8 +697,16 @@ class NeuronGenerator(Generator): Returns: A NeuronGenerator. """ - config = AutoConfig.from_pretrained(model_id) - neuron_config = getattr(config, "neuron", None) + try: + neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision) + except Exception as e: + logger.debug( + "NeuronConfig.from_pretrained failed for model %s, revision %s: %s", + model_id, + revision, + e, + ) + neuron_config = None start = time.time() if neuron_config is None: export_kwargs = get_export_kwargs_from_env() diff --git a/backends/neuron/server/text_generation_server/model.py b/backends/neuron/server/text_generation_server/model.py index 2151a921..d281b175 100644 --- a/backends/neuron/server/text_generation_server/model.py +++ b/backends/neuron/server/text_generation_server/model.py @@ -6,10 +6,12 @@ from typing import Optional from huggingface_hub import snapshot_download from huggingface_hub.constants import HF_HUB_CACHE from loguru import logger -from transformers import AutoConfig -from optimum.neuron import NeuronModelForCausalLM -from optimum.neuron.utils import get_hub_cached_entries +from optimum.neuron.cache import get_hub_cached_entries +from optimum.neuron.configuration_utils import NeuronConfig + + +from .tgi_env import check_env_and_neuron_config_compatibility def get_export_kwargs_from_env(): @@ -24,7 +26,6 @@ def get_export_kwargs_from_env(): num_cores = int(num_cores) auto_cast_type = os.environ.get("HF_AUTO_CAST_TYPE", None) return { - "task": "text-generation", "batch_size": batch_size, "sequence_length": sequence_length, "num_cores": num_cores, @@ -32,20 +33,15 @@ def get_export_kwargs_from_env(): } -def is_cached(model_id, neuron_config): +def is_cached(model_id): # Look for cached entries for the specified model in_cache = False - entries = get_hub_cached_entries(model_id, "inference") + entries = get_hub_cached_entries(model_id) # Look for compatible entries for entry in entries: - compatible = True - for key, value in neuron_config.items(): - # Only weights can be different - if key in ["checkpoint_id", "checkpoint_revision"]: - continue - if entry[key] != value: - compatible = False - if compatible: + if check_env_and_neuron_config_compatibility( + entry, check_compiler_version=True + ): in_cache = True break return in_cache @@ -87,8 +83,16 @@ def fetch_model( revision = None # Download the model from the Hub (HUGGING_FACE_HUB_TOKEN must be set for a private or gated model) # Note that the model may already be present in the cache. - config = AutoConfig.from_pretrained(model_id, revision=revision) - neuron_config = getattr(config, "neuron", None) + try: + neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision) + except Exception as e: + logger.debug( + "NeuronConfig.from_pretrained failed for model %s, revision %s: %s", + model_id, + revision, + e, + ) + neuron_config = None if neuron_config is not None: if os.path.isdir(model_id): return model_id @@ -99,16 +103,11 @@ def fetch_model( log_cache_size() return snapshot_download(model_id, revision=revision, ignore_patterns="*.bin") # Model needs to be exported: look for compatible cached entries on the hub - export_kwargs = get_export_kwargs_from_env() - export_config = NeuronModelForCausalLM.get_export_config( - model_id, config, revision=revision, **export_kwargs - ) - neuron_config = export_config.neuron - if not is_cached(model_id, neuron_config): + if not is_cached(model_id): hub_cache_url = "https://huggingface.co/aws-neuron/optimum-neuron-cache" neuron_export_url = "https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-neuronx-tgi" error_msg = ( - f"No cached version found for {model_id} with {neuron_config}." + f"No cached version found for {model_id} with {get_export_kwargs_from_env()}." f"You can start a discussion to request it on {hub_cache_url}" f"Alternatively, you can export your own neuron model as explained in {neuron_export_url}" ) @@ -121,8 +120,10 @@ def fetch_model( # Prefetch weights, tokenizer and generation config so that they are in cache log_cache_size() start = time.time() - snapshot_download(model_id, revision=revision, ignore_patterns="*.bin") + snapshot_path = snapshot_download( + model_id, revision=revision, ignore_patterns="*.bin" + ) end = time.time() logger.info(f"Model weights fetched in {end - start:.2f} s.") log_cache_size() - return model_id + return snapshot_path diff --git a/backends/neuron/tgi_env.py b/backends/neuron/server/text_generation_server/tgi_env.py old mode 100755 new mode 100644 similarity index 63% rename from backends/neuron/tgi_env.py rename to backends/neuron/server/text_generation_server/tgi_env.py index a7042130..ee97f180 --- a/backends/neuron/tgi_env.py +++ b/backends/neuron/server/text_generation_server/tgi_env.py @@ -6,12 +6,11 @@ import os import sys from typing import Any, Dict, List, Optional -from huggingface_hub import constants -from transformers import AutoConfig - from optimum.neuron.modeling_decoder import get_available_cores -from optimum.neuron.utils import get_hub_cached_entries +from optimum.neuron.cache import get_hub_cached_entries +from optimum.neuron.configuration_utils import NeuronConfig from optimum.neuron.utils.version_utils import get_neuronxcc_version +from optimum.neuron.utils import map_torch_dtype logger = logging.getLogger(__name__) @@ -24,15 +23,9 @@ tgi_router_env_vars = [ ] tgi_server_env_vars = ["HF_NUM_CORES", "HF_AUTO_CAST_TYPE"] -env_config_peering = [ - ("MAX_BATCH_SIZE", "batch_size"), - ("MAX_TOTAL_TOKENS", "sequence_length"), - ("HF_AUTO_CAST_TYPE", "auto_cast_type"), - ("HF_NUM_CORES", "num_cores"), -] # By the end of this script all env var should be specified properly -env_vars = tgi_server_env_vars + tgi_router_env_vars +tgi_env_vars = tgi_server_env_vars + tgi_router_env_vars available_cores = get_available_cores() neuronxcc_version = get_neuronxcc_version() @@ -93,9 +86,17 @@ def parse_cmdline_and_set_env(argv: List[str] = None) -> argparse.Namespace: def neuron_config_to_env(neuron_config): + if isinstance(neuron_config, NeuronConfig): + neuron_config = neuron_config.to_dict() with open(os.environ["ENV_FILEPATH"], "w") as f: - for env_var, config_key in env_config_peering: - f.write("export {}={}\n".format(env_var, neuron_config[config_key])) + f.write("export MAX_BATCH_SIZE={}\n".format(neuron_config["batch_size"])) + f.write("export MAX_TOTAL_TOKENS={}\n".format(neuron_config["sequence_length"])) + f.write("export HF_NUM_CORES={}\n".format(neuron_config["tp_degree"])) + config_key = ( + "auto_cast_type" if "auto_cast_type" in neuron_config else "torch_dtype" + ) + auto_cast_type = neuron_config[config_key] + f.write("export HF_AUTO_CAST_TYPE={}\n".format(auto_cast_type)) max_input_tokens = os.getenv("MAX_INPUT_TOKENS") if not max_input_tokens: max_input_tokens = int(neuron_config["sequence_length"]) // 2 @@ -111,7 +112,7 @@ def neuron_config_to_env(neuron_config): def sort_neuron_configs(dictionary): - return -dictionary["num_cores"], -dictionary["batch_size"] + return -dictionary["tp_degree"], -dictionary["batch_size"] def lookup_compatible_cached_model( @@ -119,7 +120,7 @@ def lookup_compatible_cached_model( ) -> Optional[Dict[str, Any]]: # Reuse the same mechanic as the one in use to configure the tgi server part # The only difference here is that we stay as flexible as possible on the compatibility part - entries = get_hub_cached_entries(model_id, "inference") + entries = get_hub_cached_entries(model_id) logger.debug( "Found %d cached entries for model %s, revision %s", @@ -155,15 +156,15 @@ def lookup_compatible_cached_model( def check_env_and_neuron_config_compatibility( - neuron_config: Dict[str, Any], check_compiler_version: bool + neuron_config_dict: Dict[str, Any], check_compiler_version: bool ) -> bool: logger.debug( "Checking the provided neuron config %s is compatible with the local setup and provided environment", - neuron_config, + neuron_config_dict, ) # Local setup compat checks - if neuron_config["num_cores"] > available_cores: + if neuron_config_dict["tp_degree"] > available_cores: logger.debug( "Not enough neuron cores available to run the provided neuron config" ) @@ -171,33 +172,65 @@ def check_env_and_neuron_config_compatibility( if ( check_compiler_version - and neuron_config["compiler_version"] != neuronxcc_version + and neuron_config_dict["neuronxcc_version"] != neuronxcc_version ): logger.debug( "Compiler version conflict, the local one (%s) differs from the one used to compile the model (%s)", neuronxcc_version, - neuron_config["compiler_version"], + neuron_config_dict["neuronxcc_version"], ) return False - for env_var, config_key in env_config_peering: - neuron_config_value = str(neuron_config[config_key]) - env_value = os.getenv(env_var, str(neuron_config_value)) + batch_size = os.getenv("MAX_BATCH_SIZE", None) + if batch_size is not None and neuron_config_dict["batch_size"] < int(batch_size): + logger.debug( + "The provided MAX_BATCH_SIZE (%s) is higher than the neuron config batch size (%s)", + os.getenv("MAX_BATCH_SIZE"), + neuron_config_dict["batch_size"], + ) + return False + max_total_tokens = os.getenv("MAX_TOTAL_TOKENS", None) + if max_total_tokens is not None and neuron_config_dict["sequence_length"] < int( + max_total_tokens + ): + logger.debug( + "The provided MAX_TOTAL_TOKENS (%s) is higher than the neuron config sequence length (%s)", + max_total_tokens, + neuron_config_dict["sequence_length"], + ) + return False + num_cores = os.getenv("HF_NUM_CORES", None) + if num_cores is not None and neuron_config_dict["tp_degree"] < int(num_cores): + logger.debug( + "The provided HF_NUM_CORES (%s) is higher than the neuron config tp degree (%s)", + num_cores, + neuron_config_dict["tp_degree"], + ) + return False + auto_cast_type = os.getenv("HF_AUTO_CAST_TYPE", None) + if auto_cast_type is not None: + config_key = ( + "auto_cast_type" + if "auto_cast_type" in neuron_config_dict + else "torch_dtype" + ) + neuron_config_value = map_torch_dtype(str(neuron_config_dict[config_key])) + env_value = map_torch_dtype(auto_cast_type) if env_value != neuron_config_value: logger.debug( - "The provided env var '%s' and the neuron config '%s' param differ (%s != %s)", - env_var, - config_key, + "The provided auto cast type and the neuron config param differ (%s != %s)", env_value, neuron_config_value, ) return False - max_input_tokens = int( os.getenv("MAX_INPUT_TOKENS", os.getenv("MAX_INPUT_LENGTH", 0)) ) if max_input_tokens > 0: - sequence_length = neuron_config["sequence_length"] + if hasattr(neuron_config_dict, "max_context_length"): + sequence_length = neuron_config_dict["max_context_length"] + else: + sequence_length = neuron_config_dict["sequence_length"] if max_input_tokens >= sequence_length: logger.debug( "Specified max input tokens is not compatible with config sequence length ( %s >= %s)", @@ -211,38 +244,29 @@ def check_env_and_neuron_config_compatibility( def get_env_dict() -> Dict[str, str]: d = {} - for k in env_vars: + for k in tgi_env_vars: d[k] = os.getenv(k) return d -def main(): - """ - This script determines proper default TGI env variables for the neuron precompiled models to - work properly - :return: - """ - args = parse_cmdline_and_set_env() - - for env_var in env_vars: - if not os.getenv(env_var): - break - else: - logger.info( - "All env vars %s already set, skipping, user know what they are doing", - env_vars, +def get_neuron_config_for_model( + model_name_or_path: str, revision: Optional[str] = None +) -> NeuronConfig: + try: + neuron_config = NeuronConfig.from_pretrained( + model_name_or_path, revision=revision ) - sys.exit(0) - - cache_dir = constants.HF_HUB_CACHE - - logger.info("Cache dir %s, model %s", cache_dir, args.model_id) - - config = AutoConfig.from_pretrained(args.model_id, revision=args.revision) - neuron_config = getattr(config, "neuron", None) + except Exception as e: + logger.debug( + "NeuronConfig.from_pretrained failed for model %s, revision %s: %s", + model_name_or_path, + revision, + e, + ) + neuron_config = None if neuron_config is not None: compatible = check_env_and_neuron_config_compatibility( - neuron_config, check_compiler_version=False + neuron_config.to_dict(), check_compiler_version=False ) if not compatible: env_dict = get_env_dict() @@ -252,17 +276,6 @@ def main(): logger.error(msg) raise Exception(msg) else: - neuron_config = lookup_compatible_cached_model(args.model_id, args.revision) + neuron_config = lookup_compatible_cached_model(model_name_or_path, revision) - if not neuron_config: - msg = ( - "No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}" - ).format(get_env_dict(), available_cores, neuronxcc_version) - logger.error(msg) - raise Exception(msg) - - neuron_config_to_env(neuron_config) - - -if __name__ == "__main__": - main() + return neuron_config diff --git a/backends/neuron/tests/fixtures/model.py b/backends/neuron/tests/fixtures/model.py index 4b6a1375..ad41fd10 100644 --- a/backends/neuron/tests/fixtures/model.py +++ b/backends/neuron/tests/fixtures/model.py @@ -4,14 +4,12 @@ import subprocess import sys from tempfile import TemporaryDirectory -import huggingface_hub +import os import pytest from transformers import AutoTokenizer -from optimum.neuron import NeuronModelForCausalLM -from optimum.neuron.utils import synchronize_hub_cache -from optimum.neuron.version import __sdk_version__ as sdk_version -from optimum.neuron.version import __version__ as version + +from optimum.neuron.cache import synchronize_hub_cache logging.basicConfig( @@ -21,30 +19,14 @@ logging.basicConfig( ) logger = logging.getLogger(__file__) + OPTIMUM_CACHE_REPO_ID = "optimum-internal-testing/neuron-testing-cache" + # All model configurations below will be added to the neuron_model_config fixture MODEL_CONFIGURATIONS = { - "gpt2": { - "model_id": "gpt2", - "export_kwargs": { - "batch_size": 4, - "sequence_length": 1024, - "num_cores": 2, - "auto_cast_type": "fp16", - }, - }, "llama": { - "model_id": "NousResearch/Hermes-2-Theta-Llama-3-8B", - "export_kwargs": { - "batch_size": 4, - "sequence_length": 2048, - "num_cores": 2, - "auto_cast_type": "fp16", - }, - }, - "mistral": { - "model_id": "optimum/mistral-1.1b-testing", + "model_id": "unsloth/Llama-3.2-1B-Instruct", "export_kwargs": { "batch_size": 4, "sequence_length": 4096, @@ -58,7 +40,7 @@ MODEL_CONFIGURATIONS = { "batch_size": 4, "sequence_length": 4096, "num_cores": 2, - "auto_cast_type": "fp16", + "auto_cast_type": "bf16", }, }, "granite": { @@ -73,12 +55,6 @@ MODEL_CONFIGURATIONS = { } -def get_hub_neuron_model_id(config_name: str): - return ( - f"optimum-internal-testing/neuron-testing-{version}-{sdk_version}-{config_name}" - ) - - def export_model(model_id, export_kwargs, neuron_model_path): export_command = [ "optimum-cli", @@ -104,57 +80,35 @@ def export_model(model_id, export_kwargs, neuron_model_path): def neuron_model_config(request): """Expose a pre-trained neuron model - The fixture first makes sure the following model artifacts are present on the hub: - - exported neuron model under optimum-internal-testing/neuron-testing--, - - cached artifacts under optimum-internal-testing/neuron-testing-cache. - If not, it will export the model and push it to the hub. - - It then fetches the model locally and return a dictionary containing: + The fixture exports a model locally and returns a dictionary containing: - a configuration name, - the original model id, - the export parameters, - - the neuron model id, - the neuron model local path. For each exposed model, the local directory is maintained for the duration of the test session and cleaned up afterwards. - The hub model artifacts are never cleaned up and persist accross sessions. - They must be cleaned up manually when the optimum-neuron version changes. """ config_name = request.param model_config = copy.deepcopy(MODEL_CONFIGURATIONS[request.param]) model_id = model_config["model_id"] export_kwargs = model_config["export_kwargs"] - neuron_model_id = get_hub_neuron_model_id(config_name) with TemporaryDirectory() as neuron_model_path: - hub = huggingface_hub.HfApi() - if hub.repo_exists(neuron_model_id): - logger.info(f"Fetching {neuron_model_id} from the HuggingFace hub") - hub.snapshot_download(neuron_model_id, local_dir=neuron_model_path) - else: - export_model(model_id, export_kwargs, neuron_model_path) - tokenizer = AutoTokenizer.from_pretrained(model_id) - tokenizer.save_pretrained(neuron_model_path) - del tokenizer - # Create the test model on the hub - hub.create_repo(neuron_model_id, private=True) - hub.upload_folder( - folder_path=neuron_model_path, - repo_id=neuron_model_id, - ignore_patterns=[NeuronModelForCausalLM.CHECKPOINT_DIR + "/*"], - ) - # Make sure it is cached - synchronize_hub_cache(cache_repo_id=OPTIMUM_CACHE_REPO_ID) + export_model(model_id, export_kwargs, neuron_model_path) + synchronize_hub_cache(cache_repo_id=OPTIMUM_CACHE_REPO_ID) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.save_pretrained(neuron_model_path) + del tokenizer # Add dynamic parameters to the model configuration model_config["neuron_model_path"] = neuron_model_path - model_config["neuron_model_id"] = neuron_model_id # Also add model configuration name to allow tests to adapt their expectations model_config["name"] = config_name # Yield instead of returning to keep a reference to the temporary directory. # It will go out of scope and be released only once all tests needing the fixture # have been completed. logger.info(f"{config_name} ready for testing ...") + os.environ["CUSTOM_CACHE_REPO"] = OPTIMUM_CACHE_REPO_ID yield model_config logger.info(f"Done with {config_name}") diff --git a/backends/neuron/tests/server/test_cached_model.py b/backends/neuron/tests/server/test_cached_model.py new file mode 100644 index 00000000..73622578 --- /dev/null +++ b/backends/neuron/tests/server/test_cached_model.py @@ -0,0 +1,42 @@ +import os +import pytest + +from text_generation_server.generator import NeuronGenerator +from text_generation_server.model import fetch_model, is_cached + + +@pytest.fixture(scope="module") +def cached_model_id(neuron_model_config) -> str: + """ + Fixture to provide a cached model ID for testing. + This assumes the model is already cached in the local environment. + """ + export_kwargs = neuron_model_config["export_kwargs"] + os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"]) + os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"]) + os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"] + os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"]) + yield neuron_model_config["model_id"] + os.environ.pop("MAX_BATCH_SIZE", None) + os.environ.pop("MAX_TOTAL_TOKENS", None) + os.environ.pop("HF_AUTO_CAST_TYPE", None) + os.environ.pop("HF_NUM_CORES", None) + + +def test_model_is_cached(cached_model_id): + assert is_cached(cached_model_id), f"Model {cached_model_id} is not cached" + + +def test_fetch_cached_model(cached_model_id: str): + model_path = fetch_model(cached_model_id) + assert os.path.exists( + model_path + ), f"Model {cached_model_id} was not fetched successfully" + assert os.path.isdir(model_path), f"Model {cached_model_id} is not a directory" + + +def test_generator_from_cached_model(cached_model_id: str): + generator = NeuronGenerator.from_pretrained(model_id=cached_model_id) + assert generator is not None, "Generator could not be created from cached model" + assert generator.model is not None, "Generator model is not initialized" + assert generator.tokenizer is not None, "Generator tokenizer is not initialized" diff --git a/backends/neuron/tests/server/test_continuous_batching.py b/backends/neuron/tests/server/test_continuous_batching.py index 48bb70cc..3d9ab509 100644 --- a/backends/neuron/tests/server/test_continuous_batching.py +++ b/backends/neuron/tests/server/test_continuous_batching.py @@ -9,13 +9,13 @@ def test_continuous_batching_two_requests(neuron_model_config): """ neuron_model_path = neuron_model_config["neuron_model_path"] generator = NeuronGenerator.from_pretrained(neuron_model_path) - assert generator.model.batch_size > 1 + assert generator.model.neuron_config.batch_size > 1 input_text = "Once upon a time" max_new_tokens = 20 # Prefill a single request, remembering the generated token tokens = {0: [], 1: []} request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens) - max_length = generator.model.max_length + max_length = generator.model.neuron_config.sequence_length batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length) generations, next_batch = generator.prefill(batch) assert next_batch.size == 1 diff --git a/backends/neuron/tests/server/test_decode.py b/backends/neuron/tests/server/test_decode.py index 9db5e3ab..b864e3ec 100644 --- a/backends/neuron/tests/server/test_decode.py +++ b/backends/neuron/tests/server/test_decode.py @@ -23,7 +23,7 @@ def _test_decode(config_name, generator, do_sample): request = create_request( id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample ) - max_length = generator.model.max_length + max_length = generator.model.neuron_config.sequence_length batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length) generations, next_batch = generator.prefill(batch) # We already generated one token: call decode max_new_tokens - 1 times @@ -40,19 +40,15 @@ def _test_decode(config_name, generator, do_sample): assert output.finish_reason == 0 if do_sample: expected_text = { - "gpt2": " The sun was set", - "llama": "George Orwell, 1984", - "mistral": "The sky was", - "qwen2": " A young woman with", + "llama": " I sat alone in the café", + "qwen2": " The air was so still", "granite": "1984, George Orwell", }[config_name] assert expected_text in output.text else: print(output.text) expected_text = { - "gpt2": '\n\n"I\'m going to go to bed," I said.\n\n"I\'m going', - "llama": " George Orwell’s classic dystopian novel, 1984, begins with this ominous sentence. The story", - "mistral": "\nThe clocks were striking thirteen.\nThe clocks were striking thirteen.", + "llama": " The world was holding its breath as the world's top scientists and engineers gathered at the secret underground facility", "qwen2": " I was sitting in my room, staring at the ceiling, when the door opened and in came a", "granite": "\n\nThis opening line from George Orwell's dystopian novel \"198", }[config_name] diff --git a/backends/neuron/tests/server/test_prefill.py b/backends/neuron/tests/server/test_prefill.py index c0155b1a..c9ecd1c8 100644 --- a/backends/neuron/tests/server/test_prefill.py +++ b/backends/neuron/tests/server/test_prefill.py @@ -9,7 +9,7 @@ def test_prefill(neuron_model_config): neuron_model_path = neuron_model_config["neuron_model_path"] generator = NeuronGenerator.from_pretrained(neuron_model_path) max_batch_size = 4 - assert generator.model.batch_size >= max_batch_size + assert generator.model.neuron_config.batch_size >= max_batch_size for num_requests in [1, max_batch_size]: for do_sample in [True, False]: mode = "sample" if do_sample else "greedy" @@ -34,7 +34,7 @@ def _test_prefill(config_name, generator, batch_size, do_sample): ) ) # Let's be pessimistic when estimating max_tokens - max_length = generator.model.max_length + max_length = generator.max_prefill_length() batch = Batch( id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length ) @@ -46,17 +46,13 @@ def _test_prefill(config_name, generator, batch_size, do_sample): assert len(generations) == batch_size if do_sample: expectations = { - "gpt2": [383, " The"], - "llama": [10058, " George"], - "mistral": [450, " The"], - "qwen2": [362, " A"], + "llama": [358, " I"], + "qwen2": [576, " The"], "granite": [308, " ("], }[config_name] else: expectations = { - "gpt2": [198, "\n"], - "llama": [10058, " George"], - "mistral": [13, "\n"], + "llama": [578, " The"], "qwen2": [358, " I"], "granite": [203, "\n"], }[config_name] @@ -70,7 +66,7 @@ def test_prefill_truncate(neuron_model_config): config_name = neuron_model_config["name"] neuron_model_path = neuron_model_config["neuron_model_path"] generator = NeuronGenerator.from_pretrained(neuron_model_path) - batch_size = generator.model.batch_size + batch_size = generator.model.neuron_config.batch_size # We apply truncation to all requests but the first one truncate = [ None, @@ -83,7 +79,7 @@ def test_prefill_truncate(neuron_model_config): requests = [] for i in range(batch_size): requests.append(create_request(id=i, inputs=input_text, truncate=truncate[i])) - max_length = generator.model.max_length + max_length = generator.max_prefill_length() batch = Batch( id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length ) @@ -91,12 +87,12 @@ def test_prefill_truncate(neuron_model_config): # Even if the input text is identical for all requests, the first generated token might # be different because of the truncation expectations = { - "gpt2": [" He", " He", "\n", " He"], - "llama": [" —", " The", " He", " He"], - "mistral": [" He", "\n", " He", " He"], + "llama": [" He", "iens", "\x08", " He"], "qwen2": [" He", " The", " He", " He"], "granite": ["\n", "\n", " I", " He"], }[config_name] for i, g in enumerate(generations): tokens = g.tokens - assert tokens.texts[0] == expectations[i] + assert ( + tokens.texts[0] == expectations[i] + ), f"Request {i} expected [{expectations[i]}], got [{tokens.texts[0]}]" diff --git a/backends/neuron/tests/test_entry_point.py b/backends/neuron/tests/test_entry_point.py new file mode 100644 index 00000000..d4ddf338 --- /dev/null +++ b/backends/neuron/tests/test_entry_point.py @@ -0,0 +1,63 @@ +import os +import pytest +from tempfile import TemporaryDirectory + +from optimum.neuron.models.inference.nxd.backend.config import NxDNeuronConfig +from optimum.neuron.utils import map_torch_dtype + +from text_generation_server.tgi_env import ( + get_neuron_config_for_model, + lookup_compatible_cached_model, + neuron_config_to_env, +) + + +def test_get_neuron_config_for_model(neuron_model_config): + neuron_model_path = neuron_model_config["neuron_model_path"] + export_kwargs = neuron_model_config["export_kwargs"] + os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"]) + os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"]) + os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"] + os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"]) + neuron_config = get_neuron_config_for_model(neuron_model_path) + assert neuron_config is not None + assert neuron_config.batch_size == export_kwargs["batch_size"] + assert neuron_config.sequence_length == export_kwargs["sequence_length"] + assert neuron_config.tp_degree == export_kwargs["num_cores"] + if isinstance(neuron_config, NxDNeuronConfig): + assert map_torch_dtype(neuron_config.torch_dtype) == map_torch_dtype( + export_kwargs["auto_cast_type"] + ) + else: + assert map_torch_dtype(neuron_config.auto_cast_type) == map_torch_dtype( + export_kwargs["auto_cast_type"] + ) + + +@pytest.mark.parametrize("model_id", ["unsloth/Llama-3.2-1B-Instruct"]) +def test_lookup_compatible_cached_model(model_id: str): + neuron_config = lookup_compatible_cached_model(model_id, None) + assert neuron_config is not None + + +def test_neuron_config_to_env(neuron_model_config) -> None: + neuron_model_path = neuron_model_config["neuron_model_path"] + neuron_config = get_neuron_config_for_model(neuron_model_path) + with TemporaryDirectory() as temp_dir: + os.environ["ENV_FILEPATH"] = os.path.join(temp_dir, "env.sh") + neuron_config_to_env(neuron_config) + with open(os.environ["ENV_FILEPATH"], "r") as env_file: + env_content = env_file.read() + assert f"export MAX_BATCH_SIZE={neuron_config.batch_size}" in env_content + assert ( + f"export MAX_TOTAL_TOKENS={neuron_config.sequence_length}" + in env_content + ) + assert f"export HF_NUM_CORES={neuron_config.tp_degree}" in env_content + if hasattr(neuron_config, "torch_dtype"): + auto_cast_type = str(map_torch_dtype(neuron_config.torch_dtype)).split( + "." + )[-1] + else: + auto_cast_type = neuron_config.auto_cast_type + assert f"export HF_AUTO_CAST_TYPE={auto_cast_type}" in env_content diff --git a/backends/neuron/tgi-entrypoint.sh b/backends/neuron/tgi-entrypoint.sh index b959a795..7965d1da 100755 --- a/backends/neuron/tgi-entrypoint.sh +++ b/backends/neuron/tgi-entrypoint.sh @@ -9,7 +9,7 @@ touch $ENV_FILEPATH SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) -${SCRIPT_DIR}/tgi_env.py $@ +${SCRIPT_DIR}/tgi_entry_point.py $@ source $ENV_FILEPATH diff --git a/backends/neuron/tgi_entry_point.py b/backends/neuron/tgi_entry_point.py new file mode 100755 index 00000000..7e81d0e4 --- /dev/null +++ b/backends/neuron/tgi_entry_point.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python + +import logging +import os +import sys + + +from text_generation_server.tgi_env import ( + available_cores, + get_env_dict, + get_neuron_config_for_model, + neuron_config_to_env, + neuronxcc_version, + parse_cmdline_and_set_env, + tgi_env_vars, +) + + +logger = logging.getLogger(__name__) + + +def main(): + """ + This script determines proper default TGI env variables for the neuron precompiled models to + work properly + :return: + """ + args = parse_cmdline_and_set_env() + + for env_var in tgi_env_vars: + if not os.getenv(env_var): + break + else: + logger.info( + "All env vars %s already set, skipping, user know what they are doing", + tgi_env_vars, + ) + sys.exit(0) + + neuron_config = get_neuron_config_for_model(args.model_id, args.revision) + + if not neuron_config: + msg = ( + "No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}" + ).format(get_env_dict(), available_cores, neuronxcc_version) + logger.error(msg) + raise Exception(msg) + + neuron_config_to_env(neuron_config) + + +if __name__ == "__main__": + main() diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index 1628a00b..c8b29204 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -162,6 +162,11 @@ impl Allocator for SimpleAllocator { tokens: u32, _prefill_tokens: Option>>, ) -> Option { + let mut tokens = tokens; + if self.is_hpu_device { + // need 1 slot for ping-pong optimization + tokens += 1; + } // Apply window size let (required_blocks, repeats) = { let (tokens, repeats) = match self.window_size { @@ -176,8 +181,7 @@ impl Allocator for SimpleAllocator { let required_blocks = tokens.div_ceil(self.block_size); (required_blocks, repeats) }; - - let mut tokens = tokens as usize; + let tokens = tokens as usize; if required_blocks > self.free_blocks.len() as u32 { None } else { @@ -189,8 +193,6 @@ impl Allocator for SimpleAllocator { .split_off(self.free_blocks.len() - required_blocks as usize); if self.is_hpu_device { blocks.sort(); - // need 1 slot for ping-pong optimization - tokens += 1; } let mut slots = Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize); diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index d3bf4b9c..8cfee3a5 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -8,6 +8,7 @@ use std::cmp::max; use std::collections::VecDeque; use text_generation_router::infer::InferError; use text_generation_router::infer::InferStreamResponse; +use text_generation_router::usage_stats::Env; use text_generation_router::validation::{ Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, @@ -185,6 +186,9 @@ struct State { /// Paged Attention Block Allocation block_allocator: Option, + + /// indicate if it's hpu device, the hpu device needs padding to generate first token. + is_hpu_device: bool, } impl State { @@ -214,6 +218,7 @@ impl State { speculate, support_chunking, block_allocator, + is_hpu_device: Env::new().is_hpu_device(), } } @@ -368,6 +373,21 @@ impl State { } } + if self.is_hpu_device { + //HPU needs to pad for the prefill + max_input_length = max_input_length.max(entry.request.input_length); + let actual_prefill_tokens_for_hpu = + (batch.len() + 1) as u32 * max_input_length; + + if actual_prefill_tokens_for_hpu > prefill_token_budget { + // Entry is over budget + // Add it back to the front + tracing::debug!("Over budget: prefill_tokens={actual_prefill_tokens_for_hpu} > {prefill_token_budget}"); + self.entries.push_front((id, entry)); + break 'entry_loop; + } + } + prefill_tokens += postfix_len; Some(block_allocation) diff --git a/docs/openapi.json b/docs/openapi.json index 5486413e..ff63c3da 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -10,7 +10,7 @@ "name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0" }, - "version": "3.3.0-dev0" + "version": "3.3.2-dev0" }, "paths": { "/": { diff --git a/docs/source/backends/gaudi.mdx b/docs/source/backends/gaudi.mdx index 33686966..49c6739d 100644 --- a/docs/source/backends/gaudi.mdx +++ b/docs/source/backends/gaudi.mdx @@ -20,7 +20,7 @@ hf_token=YOUR_HF_ACCESS_TOKEN docker run --runtime=habana --cap-add=sys_nice --ipc=host \ -p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \ - ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \ + ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \ --model-id $model ``` @@ -52,7 +52,7 @@ hf_token=YOUR_ACCESS_TOKEN docker run --runtime=habana --cap-add=sys_nice --ipc=host \ -p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \ - ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \ + ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \ --model-id $model ``` @@ -115,7 +115,7 @@ docker run -p 8080:80 \ -e BATCH_BUCKET_SIZE=256 \ -e PREFILL_BATCH_BUCKET_SIZE=4 \ -e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \ - ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \ + ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \ --model-id $model \ --sharded true --num-shard 8 \ --max-input-tokens 1024 --max-total-tokens 2048 \ @@ -141,7 +141,7 @@ docker run -p 8080:80 \ -v $volume:/data \ -e PREFILL_BATCH_BUCKET_SIZE=1 \ -e BATCH_BUCKET_SIZE=1 \ - ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \ + ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \ --model-id $model \ --max-input-tokens 4096 --max-batch-prefill-tokens 16384 \ --max-total-tokens 8192 --max-batch-size 4 @@ -208,7 +208,7 @@ docker run --runtime=habana --ipc=host --cap-add=sys_nice \ -e PROF_PATH=/tmp/hpu_profile \ -e PROF_RANKS=0 \ -e PROF_RECORD_SHAPES=True \ - ghcr.io/huggingface/text-generation-inference:3.3.0-gaudi \ + ghcr.io/huggingface/text-generation-inference:3.3.2-gaudi \ --model-id $model ``` diff --git a/docs/source/backends/neuron.md b/docs/source/backends/neuron.md index 5c4829bc..10c8a4fd 100644 --- a/docs/source/backends/neuron.md +++ b/docs/source/backends/neuron.md @@ -31,7 +31,7 @@ deployment instructions in the model card: The service is launched simply by running the text-generation-inference container with two sets of parameters: ``` -docker run ghcr.io/huggingface/text-generation-inference:3.3.0-neuron +docker run ghcr.io/huggingface/text-generation-inference:3.3.2-neuron ``` - system parameters are used to map ports, volumes and devices between the host and the service, diff --git a/docs/source/basic_tutorials/gated_model_access.md b/docs/source/basic_tutorials/gated_model_access.md index 35be7bab..50c71ab5 100644 --- a/docs/source/basic_tutorials/gated_model_access.md +++ b/docs/source/basic_tutorials/gated_model_access.md @@ -19,6 +19,6 @@ docker run --gpus all \ --shm-size 1g \ -e HF_TOKEN=$token \ -p 8080:80 \ - -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0 \ + -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2 \ --model-id $model ``` diff --git a/docs/source/conceptual/quantization.md b/docs/source/conceptual/quantization.md index 73c77d4b..a666a48a 100644 --- a/docs/source/conceptual/quantization.md +++ b/docs/source/conceptual/quantization.md @@ -19,7 +19,7 @@ bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models. In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇 ```bash -docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model --quantize bitsandbytes +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model --quantize bitsandbytes ``` 4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load. @@ -27,7 +27,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇 ```bash -docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model --quantize bitsandbytes-nf4 +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model --quantize bitsandbytes-nf4 ``` You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes). @@ -48,7 +48,7 @@ $$({\hat{W}_{l}}^{*} = argmin_{\hat{W_{l}}} ||W_{l}X-\hat{W}_{l}X||^{2}_{2})$$ TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇 ```bash -docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.0 --model-id $model --quantize gptq +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.2 --model-id $model --quantize gptq ``` Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI. diff --git a/docs/source/installation_amd.md b/docs/source/installation_amd.md index 38e67aac..19fbe8ba 100644 --- a/docs/source/installation_amd.md +++ b/docs/source/installation_amd.md @@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ --device=/dev/kfd --device=/dev/dri --group-add video \ --ipc=host --shm-size 256g --net host -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:3.3.0-rocm \ + ghcr.io/huggingface/text-generation-inference:3.3.2-rocm \ --model-id $model ``` diff --git a/docs/source/installation_intel.md b/docs/source/installation_intel.md index e29285c3..c1a2e867 100644 --- a/docs/source/installation_intel.md +++ b/docs/source/installation_intel.md @@ -12,7 +12,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading docker run --rm --privileged --cap-add=sys_nice \ --device=/dev/dri \ --ipc=host --shm-size 1g --net host -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:3.3.0-intel-xpu \ + ghcr.io/huggingface/text-generation-inference:3.3.2-intel-xpu \ --model-id $model --cuda-graphs 0 ``` @@ -29,7 +29,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading docker run --rm --privileged --cap-add=sys_nice \ --device=/dev/dri \ --ipc=host --shm-size 1g --net host -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:3.3.0-intel-cpu \ + ghcr.io/huggingface/text-generation-inference:3.3.2-intel-cpu \ --model-id $model --cuda-graphs 0 ``` diff --git a/docs/source/installation_nvidia.md b/docs/source/installation_nvidia.md index 56619bce..3aede5a9 100644 --- a/docs/source/installation_nvidia.md +++ b/docs/source/installation_nvidia.md @@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:3.3.0 \ + ghcr.io/huggingface/text-generation-inference:3.3.2 \ --model-id $model ``` diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md index 6a2d73c1..f1d2c92a 100644 --- a/docs/source/quicktour.md +++ b/docs/source/quicktour.md @@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ - ghcr.io/huggingface/text-generation-inference:3.3.0 \ + ghcr.io/huggingface/text-generation-inference:3.3.2 \ --model-id $model ``` @@ -96,7 +96,7 @@ curl 127.0.0.1:8080/generate \ To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more. ```bash -docker run ghcr.io/huggingface/text-generation-inference:3.3.0 --help +docker run ghcr.io/huggingface/text-generation-inference:3.3.2 --help ``` diff --git a/docs/source/reference/api_reference.md b/docs/source/reference/api_reference.md index 0fc8714d..5830f7b9 100644 --- a/docs/source/reference/api_reference.md +++ b/docs/source/reference/api_reference.md @@ -163,7 +163,7 @@ hub = { # create Hugging Face Model Class huggingface_model = HuggingFaceModel( - image_uri=get_huggingface_llm_image_uri("huggingface",version="3.3.0"), + image_uri=get_huggingface_llm_image_uri("huggingface",version="3.3.2"), env=hub, role=role, ) diff --git a/flake.lock b/flake.lock index 6e187510..e57990c8 100644 --- a/flake.lock +++ b/flake.lock @@ -102,7 +102,7 @@ "flake-parts": "flake-parts_3", "nix-test-runner": "nix-test-runner_3", "nixpkgs": [ - "tgi-nix", + "hf-nix", "nixpkgs" ], "pre-commit-hooks": "pre-commit-hooks_3" @@ -579,6 +579,26 @@ "type": "github" } }, + "hf-nix": { + "inputs": { + "flake-compat": "flake-compat_4", + "flake-utils": "flake-utils_7", + "nixpkgs": "nixpkgs_6" + }, + "locked": { + "lastModified": 1747919133, + "narHash": "sha256-VvF1naQOvv7yulQ5/cDiaxkNxlh1Y84QMZnderv1szk=", + "owner": "huggingface", + "repo": "hf-nix", + "rev": "9c71e026d6c7c8588ef85a5f7c77f57d598e038c", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "hf-nix", + "type": "github" + } + }, "nix-filter": { "locked": { "lastModified": 1731533336, @@ -718,16 +738,16 @@ }, "nixpkgs_6": { "locked": { - "lastModified": 1737453259, - "narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=", + "lastModified": 1747820358, + "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=", "owner": "danieldk", "repo": "nixpkgs", - "rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e", + "rev": "d3c1681180717528068082103bf323147de6ab0b", "type": "github" }, "original": { "owner": "danieldk", - "ref": "outlines-v0.1.4-tgi", + "ref": "cudatoolkit-12.9-kernel-builder", "repo": "nixpkgs", "type": "github" } @@ -836,19 +856,19 @@ "inputs": { "crate2nix": "crate2nix", "flake-utils": "flake-utils_6", + "hf-nix": "hf-nix", "nix-filter": "nix-filter", "nixpkgs": [ - "tgi-nix", + "hf-nix", "nixpkgs" ], - "rust-overlay": "rust-overlay", - "tgi-nix": "tgi-nix" + "rust-overlay": "rust-overlay" } }, "rust-overlay": { "inputs": { "nixpkgs": [ - "tgi-nix", + "hf-nix", "nixpkgs" ] }, @@ -970,26 +990,6 @@ "repo": "default", "type": "github" } - }, - "tgi-nix": { - "inputs": { - "flake-compat": "flake-compat_4", - "flake-utils": "flake-utils_7", - "nixpkgs": "nixpkgs_6" - }, - "locked": { - "lastModified": 1743931123, - "narHash": "sha256-MDQrbJkweLYsMYh44Gx+c1gAZOCR1fmZF1lkavAHDto=", - "owner": "huggingface", - "repo": "text-generation-inference-nix", - "rev": "1ad3feaadfdedca90278ee7676bca15019519189", - "type": "github" - }, - "original": { - "owner": "huggingface", - "repo": "text-generation-inference-nix", - "type": "github" - } } }, "root": "root", diff --git a/flake.nix b/flake.nix index c733cdd2..b5b13cad 100644 --- a/flake.nix +++ b/flake.nix @@ -2,15 +2,15 @@ inputs = { crate2nix = { url = "github:nix-community/crate2nix"; - inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; + inputs.nixpkgs.follows = "hf-nix/nixpkgs"; }; nix-filter.url = "github:numtide/nix-filter"; - tgi-nix.url = "github:huggingface/text-generation-inference-nix"; - nixpkgs.follows = "tgi-nix/nixpkgs"; + hf-nix.url = "github:huggingface/hf-nix"; + nixpkgs.follows = "hf-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; rust-overlay = { url = "github:oxalica/rust-overlay"; - inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; + inputs.nixpkgs.follows = "hf-nix/nixpkgs"; }; }; outputs = @@ -21,7 +21,7 @@ nixpkgs, flake-utils, rust-overlay, - tgi-nix, + hf-nix, }: flake-utils.lib.eachDefaultSystem ( system: @@ -33,10 +33,10 @@ }; pkgs = import nixpkgs { inherit system; - inherit (tgi-nix.lib) config; + inherit (hf-nix.lib) config; overlays = [ rust-overlay.overlays.default - tgi-nix.overlays.default + hf-nix.overlays.default (import nix/overlay.nix) ]; }; diff --git a/integration-tests/fixtures/neuron/export_models.py b/integration-tests/fixtures/neuron/export_models.py index 836402ec..d4d0f01c 100644 --- a/integration-tests/fixtures/neuron/export_models.py +++ b/integration-tests/fixtures/neuron/export_models.py @@ -28,15 +28,6 @@ logger = logging.getLogger(__file__) # All model configurations below will be added to the neuron_model_config fixture MODEL_CONFIGURATIONS = { - "gpt2": { - "model_id": "gpt2", - "export_kwargs": { - "batch_size": 4, - "sequence_length": 1024, - "num_cores": 2, - "auto_cast_type": "fp16", - }, - }, "llama": { "model_id": "unsloth/Llama-3.2-1B-Instruct", "export_kwargs": { @@ -46,15 +37,6 @@ MODEL_CONFIGURATIONS = { "auto_cast_type": "fp16", }, }, - "mistral": { - "model_id": "optimum/mistral-1.1b-testing", - "export_kwargs": { - "batch_size": 4, - "sequence_length": 4096, - "num_cores": 2, - "auto_cast_type": "bf16", - }, - }, "qwen2": { "model_id": "Qwen/Qwen2.5-0.5B", "export_kwargs": { diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_jpg.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_jpg.json index ae67e006..0c02702e 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_jpg.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_jpg.json @@ -5,7 +5,7 @@ "index": 0, "logprobs": null, "message": { - "content": "Okay, let's analyze the image.\n\nThe image is a solid, bright white color. There is nothing else visible within it. \n\nIt's essentially a blank white canvas or a completely white square. \n\nIs there anything specific you'd like me to do with this image, such as describe it further or imagine what it might represent?", + "content": "Okay, let's analyze the image.\n\nThe image is a solid, bright white color. There is nothing else visible within it. \n\nIt's essentially a blank white square or rectangle.", "name": null, "role": "assistant", "tool_calls": null @@ -13,14 +13,14 @@ "usage": null } ], - "created": 1741965894, + "created": 1747062956, "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.2.1-dev0-native", + "system_fingerprint": "3.3.2-dev0-native", "usage": { - "completion_tokens": 74, + "completion_tokens": 42, "prompt_tokens": 277, - "total_tokens": 351 + "total_tokens": 319 } } diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_png.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_png.json index 786ced6c..0bb67dfb 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_png.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_png.json @@ -1,11 +1,11 @@ { "choices": [ { - "finish_reason": "length", + "finish_reason": "stop", "index": 0, "logprobs": null, "message": { - "content": "Okay, let's analyze the image. \n\nThe image is entirely white, with a very subtle, faint outline of a stylized, cartoonish figure. It appears to be a simplified depiction of a person, likely a child, with a wide-eyed expression and a small, rounded body. \n\nIt's almost like a minimalist, iconic representation. \n\nDo you want me to try and describe it in more detail, or perhaps suggest what this image might represent (e.g", + "content": "Okay, let's analyze the image. \n\nThe image is a very plain, solid white square. That's it! \n\nIt's essentially a blank canvas. \n\nDo you want me to describe it in more detail, or are you interested in something else regarding this image?", "name": null, "role": "assistant", "tool_calls": null @@ -13,14 +13,14 @@ "usage": null } ], - "created": 1744396706, + "created": 1747062955, "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.3.0-dev0-native", + "system_fingerprint": "3.3.2-dev0-native", "usage": { - "completion_tokens": 100, + "completion_tokens": 62, "prompt_tokens": 277, - "total_tokens": 377 + "total_tokens": 339 } } diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgba.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgba.json index 571478ee..dc1309d2 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgba.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgba.json @@ -5,7 +5,7 @@ "index": 0, "logprobs": null, "message": { - "content": "Okay, let's analyze the image. \n\nThe transparent image reveals a stylized depiction of **a human head**. It's a minimalist, geometric representation, showing the basic shapes of the skull, eye sockets, and head outline. \n\nIf you'd like, you can give me more details about the image or ask me to focus on a specific aspect of it.", + "content": "Okay, let's analyze the image. \n\nThe transparent image reveals a stylized depiction of **a human head**. It's a minimalist, geometric representation, showing the basic shapes of the skull, eye sockets, and head outline. \n\nDo you want me to describe any specific element of the image in more detail?", "name": null, "role": "assistant", "tool_calls": null @@ -13,14 +13,14 @@ "usage": null } ], - "created": 1744396703, + "created": 1747062952, "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.3.0-dev0-native", + "system_fingerprint": "3.3.2-dev0-native", "usage": { - "completion_tokens": 78, + "completion_tokens": 67, "prompt_tokens": 277, - "total_tokens": 355 + "total_tokens": 344 } } diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow.json index 9fb0c4c5..7f7d0ef6 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow.json @@ -5,7 +5,7 @@ "index": 0, "logprobs": null, "message": { - "content": "Here's a description of what's shown in the image:\n\nThe image depicts a brown cow standing on a sandy beach. The beach has turquoise water and a distant island visible in the background. The sky is bright blue with some white clouds. \n\nIt's a quite a humorous and unusual scene – a cow enjoying a day at the beach!", + "content": "Here's a description of what's shown in the image:\n\nThe image depicts a brown cow standing on a sandy beach. The beach has turquoise water and a distant island visible in the background. The sky is bright blue with some white clouds. \n\nIt's a quite a humorous and unusual scene – a cow enjoying a beach day!", "name": null, "role": "assistant", "tool_calls": null @@ -13,14 +13,14 @@ "usage": null } ], - "created": 1744396699, + "created": 1747216083, "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.3.0-dev0-native", + "system_fingerprint": "3.3.2-dev0-native", "usage": { - "completion_tokens": 74, + "completion_tokens": 72, "prompt_tokens": 275, - "total_tokens": 349 + "total_tokens": 347 } } diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow_dog.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow_dog.json index 0ed2b1e1..35ca9cf0 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow_dog.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow_dog.json @@ -5,7 +5,7 @@ "index": 0, "logprobs": null, "message": { - "content": "That's a fantastic question! However, the image doesn't show a dog. It shows a **Brown Swiss cow** standing on a beach. \n\nBrown Swiss cows are known for their reddish-brown color and distinctive white markings. \n\nIf you'd like, you can send me another image and I’ll do my best to identify it!", + "content": "That's a fantastic question! However, the image doesn't show a dog. It shows a **Brown Swiss cow** standing on a beach. \n\nBrown Swiss cows are known for their beautiful reddish-brown coats and distinctive white markings. \n\nIf you'd like, you can send me another image, and I'll do my best to identify the animal in it!", "name": null, "role": "assistant", "tool_calls": null @@ -13,14 +13,14 @@ "usage": null } ], - "created": 1744396697, + "created": 1747216080, "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.3.0-dev0-native", + "system_fingerprint": "3.3.2-dev0-native", "usage": { - "completion_tokens": 75, + "completion_tokens": 80, "prompt_tokens": 279, - "total_tokens": 354 + "total_tokens": 359 } } diff --git a/integration-tests/models/__snapshots__/test_flash_pali_gemma2/test_flash_pali_gemma_image.json b/integration-tests/models/__snapshots__/test_flash_pali_gemma2/test_flash_pali_gemma_image.json index bc75bce4..90d787d4 100644 --- a/integration-tests/models/__snapshots__/test_flash_pali_gemma2/test_flash_pali_gemma_image.json +++ b/integration-tests/models/__snapshots__/test_flash_pali_gemma2/test_flash_pali_gemma_image.json @@ -8,126 +8,126 @@ "tokens": [ { "id": 108, - "logprob": -0.73046875, + "logprob": -0.48046875, "special": false, "text": "\n" }, { "id": 30234, - "logprob": -2.328125, + "logprob": -2.21875, "special": false, "text": "Brown" }, { "id": 108, - "logprob": -0.12060547, + "logprob": -0.119140625, "special": false, "text": "\n" }, { "id": 3726, - "logprob": -1.7734375, + "logprob": -1.703125, "special": false, "text": "Car" }, { "id": 108, - "logprob": -0.041503906, + "logprob": -0.0390625, "special": false, "text": "\n" }, { "id": 2915, - "logprob": -1.796875, + "logprob": -1.8203125, "special": false, "text": "Color" }, { "id": 108, - "logprob": -0.039794922, + "logprob": -0.035888672, "special": false, "text": "\n" }, { "id": 19178, - "logprob": -1.96875, + "logprob": -2.015625, "special": false, "text": "Cool" }, { "id": 108, - "logprob": -0.080566406, + "logprob": -0.08105469, "special": false, "text": "\n" }, { "id": 40544, - "logprob": -2.1875, + "logprob": -2.09375, "special": false, "text": "Decor" }, { "id": 108, - "logprob": -0.033935547, + "logprob": -0.038330078, "special": false, "text": "\n" }, { - "id": 13936, + "id": 108, + "logprob": -1.515625, + "special": false, + "text": "\n" + }, + { + "id": 108, + "logprob": -1.8671875, + "special": false, + "text": "\n" + }, + { + "id": 108, "logprob": -1.6328125, "special": false, - "text": "Green" + "text": "\n" }, { "id": 108, - "logprob": -0.16210938, + "logprob": -1.265625, "special": false, "text": "\n" }, - { - "id": 955, - "logprob": -2.015625, - "special": false, - "text": "..." - }, { "id": 108, - "logprob": -0.14746094, + "logprob": -1.0078125, "special": false, "text": "\n" }, - { - "id": 955, - "logprob": -0.73828125, - "special": false, - "text": "..." - }, { "id": 108, - "logprob": -0.051513672, + "logprob": -1.03125, "special": false, "text": "\n" }, { - "id": 955, - "logprob": -0.34765625, + "id": 235336, + "logprob": -1.2109375, "special": false, - "text": "..." + "text": "?" }, { "id": 108, - "logprob": -0.020141602, + "logprob": -0.29101562, "special": false, "text": "\n" }, { - "id": 955, - "logprob": -0.11767578, + "id": 235336, + "logprob": -0.08935547, "special": false, - "text": "..." + "text": "?" } ], "top_tokens": null }, - "generated_text": "\nBrown\nCar\nColor\nCool\nDecor\nGreen\n...\n...\n...\n..." + "generated_text": "\nBrown\nCar\nColor\nCool\nDecor\n\n\n\n\n\n\n?\n?" } diff --git a/integration-tests/models/__snapshots__/test_idefics/test_idefics_two_images.json b/integration-tests/models/__snapshots__/test_idefics/test_idefics_two_images.json index a4727707..46e67856 100644 --- a/integration-tests/models/__snapshots__/test_idefics/test_idefics_two_images.json +++ b/integration-tests/models/__snapshots__/test_idefics/test_idefics_two_images.json @@ -2,84 +2,90 @@ "details": { "best_of_sequences": null, "finish_reason": "eos_token", - "generated_tokens": 12, + "generated_tokens": 13, "prefill": [], "seed": null, "tokens": [ { "id": 450, - "logprob": -0.26342773, + "logprob": -0.2602539, "special": false, "text": " The" }, { "id": 21282, - "logprob": -0.01838684, + "logprob": -0.018463135, "special": false, "text": " cow" }, { "id": 322, - "logprob": -0.18041992, + "logprob": -0.1829834, "special": false, "text": " and" }, { "id": 521, - "logprob": -0.62841797, + "logprob": -0.62109375, "special": false, "text": " ch" }, { "id": 21475, - "logprob": -0.0037956238, + "logprob": -0.0037403107, "special": false, "text": "icken" }, { "id": 526, - "logprob": -0.018737793, + "logprob": -0.018920898, "special": false, "text": " are" }, + { + "id": 13407, + "logprob": -1.0732422, + "special": false, + "text": " standing" + }, { "id": 373, - "logprob": -1.0820312, + "logprob": -0.5292969, "special": false, "text": " on" }, { "id": 263, - "logprob": -0.5083008, + "logprob": -0.47070312, "special": false, "text": " a" }, { "id": 25695, - "logprob": -0.07128906, + "logprob": -0.25708008, "special": false, "text": " beach" }, { "id": 29889, - "logprob": -0.12573242, + "logprob": -0.17578125, "special": false, "text": "." }, { "id": 32002, - "logprob": -0.0029792786, + "logprob": -0.0023422241, "special": true, "text": "" }, { "id": 2, - "logprob": -0.00024962425, + "logprob": -0.00030851364, "special": true, "text": "" } ], "top_tokens": null }, - "generated_text": " The cow and chicken are on a beach." + "generated_text": " The cow and chicken are standing on a beach." } diff --git a/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_basic.json b/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_basic.json index 23fb8dda..c93f8a67 100644 --- a/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_basic.json +++ b/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_basic.json @@ -14,7 +14,7 @@ "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.3.0-dev0-native", + "system_fingerprint": "3.3.2-dev0-native", "usage": { "completion_tokens": 35, "prompt_tokens": 32, diff --git a/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_complex.json b/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_complex.json index e344a226..326d6702 100644 --- a/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_complex.json +++ b/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_complex.json @@ -14,7 +14,7 @@ "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.3.0-dev0-native", + "system_fingerprint": "3.3.2-dev0-native", "usage": { "completion_tokens": 44, "prompt_tokens": 37, diff --git a/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json b/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json index ef88926c..1c18cae9 100644 --- a/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json +++ b/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json @@ -64,7 +64,7 @@ }, { "id": 329, - "logprob": -2.171875, + "logprob": -2.296875, "special": false, "text": " A" }, @@ -81,19 +81,19 @@ "text": " of" }, { - "id": 1027, - "logprob": -1.5546875, + "id": 253, + "logprob": -0.86328125, "special": false, - "text": " different" + "text": " the" }, { "id": 3295, - "logprob": -0.97265625, + "logprob": -0.55078125, "special": false, "text": " color" } ], "top_tokens": null }, - "generated_text": "blue, red, yellow, \nand blue colors. A number of different color" + "generated_text": "blue, red, yellow, \nand blue colors. A number of the color" } diff --git a/integration-tests/models/__snapshots__/test_mllama/test_mllama_load.json b/integration-tests/models/__snapshots__/test_mllama/test_mllama_load.json index 9d4c98ef..682e10d4 100644 --- a/integration-tests/models/__snapshots__/test_mllama/test_mllama_load.json +++ b/integration-tests/models/__snapshots__/test_mllama/test_mllama_load.json @@ -14,11 +14,11 @@ "usage": null } ], - "created": 1746054921, + "created": 1747230173, "id": "", - "model": "meta-llama/Llama-3.2-11B-Vision-Instruct", + "model": "unsloth/Llama-3.2-11B-Vision-Instruct", "object": "chat.completion", - "system_fingerprint": "3.3.0-dev0-native", + "system_fingerprint": "3.3.2-dev0-native", "usage": { "completion_tokens": 10, "prompt_tokens": 45, @@ -40,11 +40,11 @@ "usage": null } ], - "created": 1746054921, + "created": 1747230173, "id": "", - "model": "meta-llama/Llama-3.2-11B-Vision-Instruct", + "model": "unsloth/Llama-3.2-11B-Vision-Instruct", "object": "chat.completion", - "system_fingerprint": "3.3.0-dev0-native", + "system_fingerprint": "3.3.2-dev0-native", "usage": { "completion_tokens": 10, "prompt_tokens": 45, diff --git a/integration-tests/models/__snapshots__/test_mllama/test_mllama_simpl.json b/integration-tests/models/__snapshots__/test_mllama/test_mllama_simpl.json index d1049ead..c3c5e76b 100644 --- a/integration-tests/models/__snapshots__/test_mllama/test_mllama_simpl.json +++ b/integration-tests/models/__snapshots__/test_mllama/test_mllama_simpl.json @@ -5,7 +5,7 @@ "index": 0, "logprobs": null, "message": { - "content": "A chicken stands on a pile of money, looking", + "content": "A chicken sits on a pile of money, looking", "name": null, "role": "assistant", "tool_calls": null @@ -13,11 +13,11 @@ "usage": null } ], - "created": 1746054919, + "created": 1747230171, "id": "", - "model": "meta-llama/Llama-3.2-11B-Vision-Instruct", + "model": "unsloth/Llama-3.2-11B-Vision-Instruct", "object": "chat.completion", - "system_fingerprint": "3.3.0-dev0-native", + "system_fingerprint": "3.3.2-dev0-native", "usage": { "completion_tokens": 10, "prompt_tokens": 45, diff --git a/integration-tests/models/__snapshots__/test_smolvlm/test_flash_smolvlm_next_simple_url.json b/integration-tests/models/__snapshots__/test_smolvlm/test_flash_smolvlm_next_simple_url.json index 17a69d0d..5a30fd80 100644 --- a/integration-tests/models/__snapshots__/test_smolvlm/test_flash_smolvlm_next_simple_url.json +++ b/integration-tests/models/__snapshots__/test_smolvlm/test_flash_smolvlm_next_simple_url.json @@ -8,49 +8,49 @@ "tokens": [ { "id": 330, - "logprob": -0.118652344, + "logprob": -0.107421875, "special": false, "text": " A" }, { "id": 11426, - "logprob": -0.28320312, + "logprob": -0.30078125, "special": false, "text": " bee" }, { "id": 335, - "logprob": -0.95703125, + "logprob": -0.9609375, "special": false, "text": " on" }, { "id": 253, - "logprob": -0.06982422, + "logprob": -0.0703125, "special": false, "text": " a" }, { "id": 11986, - "logprob": -0.49414062, + "logprob": -0.5, "special": false, "text": " pink" }, { "id": 8525, - "logprob": -0.07763672, + "logprob": -0.09716797, "special": false, "text": " flower" }, { "id": 30, - "logprob": -1.0703125, + "logprob": -1.078125, "special": false, "text": "." }, { "id": 49154, - "logprob": -0.092285156, + "logprob": -0.110839844, "special": true, "text": "" } diff --git a/integration-tests/models/test_flash_gemma3.py b/integration-tests/models/test_flash_gemma3.py index 5064f34d..cd9d98ea 100644 --- a/integration-tests/models/test_flash_gemma3.py +++ b/integration-tests/models/test_flash_gemma3.py @@ -53,9 +53,9 @@ async def test_flash_gemma3_image_cow_dog(flash_gemma3, response_snapshot): assert ( response.choices[0].message.content - == "That's a fantastic question! However, the image doesn't show a dog. It shows a **Brown Swiss cow** standing on a beach. \n\nBrown Swiss cows are known for their reddish-brown color and distinctive white markings. \n\nIf you'd like, you can send me another image and I’ll do my best to identify it!" + == "That's a fantastic question! However, the image doesn't show a dog. It shows a **Brown Swiss cow** standing on a beach. \n\nBrown Swiss cows are known for their beautiful reddish-brown coats and distinctive white markings. \n\nIf you'd like, you can send me another image, and I'll do my best to identify the animal in it!" ) - assert response.usage["completion_tokens"] == 75 + assert response.usage["completion_tokens"] == 80 assert response == response_snapshot @@ -76,9 +76,9 @@ async def test_flash_gemma3_image_cow(flash_gemma3, response_snapshot): ) assert ( response.choices[0].message.content - == "Here's a description of what's shown in the image:\n\nThe image depicts a brown cow standing on a sandy beach. The beach has turquoise water and a distant island visible in the background. The sky is bright blue with some white clouds. \n\nIt's a quite a humorous and unusual scene – a cow enjoying a day at the beach!" + == "Here's a description of what's shown in the image:\n\nThe image depicts a brown cow standing on a sandy beach. The beach has turquoise water and a distant island visible in the background. The sky is bright blue with some white clouds. \n\nIt's a quite a humorous and unusual scene – a cow enjoying a beach day!" ) - assert response.usage["completion_tokens"] == 74 + assert response.usage["completion_tokens"] == 72 assert response == response_snapshot diff --git a/integration-tests/models/test_flash_pali_gemma2.py b/integration-tests/models/test_flash_pali_gemma2.py index 23705385..bef9628d 100644 --- a/integration-tests/models/test_flash_pali_gemma2.py +++ b/integration-tests/models/test_flash_pali_gemma2.py @@ -22,8 +22,7 @@ async def test_flash_pali_gemma_image(flash_pali_gemma, response_snapshot): max_new_tokens=20, ) assert ( - response.generated_text - == "\nBrown\nCar\nColor\nCool\nDecor\nGreen\n...\n...\n...\n..." + response.generated_text == "\nBrown\nCar\nColor\nCool\nDecor\n\n\n\n\n\n\n?\n?" ) assert response == response_snapshot diff --git a/integration-tests/models/test_idefics.py b/integration-tests/models/test_idefics.py index e5d08bb7..848a6674 100644 --- a/integration-tests/models/test_idefics.py +++ b/integration-tests/models/test_idefics.py @@ -39,7 +39,7 @@ async def test_idefics_two_images(idefics, response_snapshot, chicken, cow_beach max_new_tokens=20, ) assert ( - response.generated_text == " The cow and chicken are on a beach." + response.generated_text == " The cow and chicken are standing on a beach." ), f"{repr(response.generated_text)}" assert response == response_snapshot diff --git a/integration-tests/models/test_mamba.py b/integration-tests/models/test_mamba.py index baa19643..ec0f90a1 100644 --- a/integration-tests/models/test_mamba.py +++ b/integration-tests/models/test_mamba.py @@ -47,7 +47,7 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): assert response.details.generated_tokens == 10 assert ( response.generated_text - == "blue, red, yellow, \nand blue colors. A number of different color" + == "blue, red, yellow, \nand blue colors. A number of the color" ) assert response == response_snapshot diff --git a/integration-tests/models/test_mllama.py b/integration-tests/models/test_mllama.py index 95ec1d99..fd58e4bf 100644 --- a/integration-tests/models/test_mllama.py +++ b/integration-tests/models/test_mllama.py @@ -5,7 +5,7 @@ import asyncio @pytest.fixture(scope="module") def mllama_handle(launcher): with launcher( - "meta-llama/Llama-3.2-11B-Vision-Instruct", + "unsloth/Llama-3.2-11B-Vision-Instruct", num_shard=2, ) as handle: yield handle @@ -48,7 +48,7 @@ async def test_mllama_simpl(mllama, response_snapshot): } assert ( response.choices[0].message.content - == "A chicken stands on a pile of money, looking" + == "A chicken sits on a pile of money, looking" ) assert response == response_snapshot diff --git a/integration-tests/neuron/test_generate.py b/integration-tests/neuron/test_generate.py index f0804356..9108ce0e 100644 --- a/integration-tests/neuron/test_generate.py +++ b/integration-tests/neuron/test_generate.py @@ -20,9 +20,7 @@ async def test_model_single_request(tgi_service): ) assert response.details.generated_tokens == 17 greedy_expectations = { - "gpt2": "\n\nDeep learning is a new field of research that has been around for a while", - "llama": " and How Does it Work?\nDeep learning is a subset of machine learning that uses artificial", - "mistral": "\nWhat is Deep Learning?\nDeep Learning is a type of machine learning that", + "llama": " and how does it work?\nDeep learning is a subset of machine learning that uses artificial", "qwen2": " - Part 1\n\nDeep Learning is a subset of Machine Learning that is based on", "granite": "\n\nDeep Learning is a subset of Machine Learning, which is a branch of Art", } @@ -79,9 +77,7 @@ async def test_model_multiple_requests(tgi_service, neuron_generate_load): assert len(responses) == 4 expectations = { - "gpt2": "Deep learning is a new field of research that has been around for a while", "llama": "Deep learning is a subset of machine learning that uses artificial", - "mistral": "Deep Learning is a type of machine learning that", "qwen2": "Deep Learning is a subset of Machine Learning that is based on", "granite": "Deep Learning is a subset of Machine Learning, which is a branch of Art", } diff --git a/launcher/src/env_runtime.rs b/launcher/src/env_runtime.rs index d9056e41..cd4ee290 100644 --- a/launcher/src/env_runtime.rs +++ b/launcher/src/env_runtime.rs @@ -27,10 +27,6 @@ impl Env { docker_label: option_env!("DOCKER_LABEL").unwrap_or("N/A"), } } - - pub fn should_start_a_single_hpu_shard(&self) -> bool { - self.hpu_env != "N/A" && std::env::var("ATTENTION").as_deref() != Ok("paged") - } } impl fmt::Display for Env { diff --git a/launcher/src/main.rs b/launcher/src/main.rs index a82ad12f..c727623c 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1263,7 +1263,23 @@ fn num_cuda_devices() -> Option { let devices = match env::var("CUDA_VISIBLE_DEVICES") { Ok(devices) => devices, Err(_) => match env::var("NVIDIA_VISIBLE_DEVICES") { - Ok(devices) => devices, + Ok(devices) => { + if devices.trim() == "all" { + // Count the number of all GPUs via nvidia-smi + let output = Command::new("nvidia-smi") + .args(["--query-gpu=uuid", "--format=csv,noheader"]) + .output() + .ok()?; + + String::from_utf8_lossy(&output.stdout) + .lines() + .filter(|line| !line.trim().is_empty()) + .count() + .to_string() + } else { + devices + } + } Err(_) => env::var("ZE_AFFINITY_MASK").ok()?, }, }; @@ -1574,11 +1590,6 @@ fn spawn_shards( ) -> Result<(), LauncherError> { // Start shard processes for rank in 0..num_shard { - if rank != 0 && env_runtime::Env::new().should_start_a_single_hpu_shard() { - tracing::info!("Running on HPU, the launcher will not do any sharding as actual sharding is done in the server"); - break; - } - let model_id = args.model_id.clone(); let revision = args.revision.clone(); let uds_path = args.shard_uds_path.clone(); @@ -1654,10 +1665,6 @@ fn spawn_shards( if shard_ready == num_shard { break; } - if env_runtime::Env::new().should_start_a_single_hpu_shard() { - tracing::info!("HPU detected, shard is ready"); - break; - } } Err(TryRecvError::Empty) => { sleep(Duration::from_millis(100)); diff --git a/nix/client.nix b/nix/client.nix index 351fd08a..be8e2fc7 100644 --- a/nix/client.nix +++ b/nix/client.nix @@ -1,6 +1,7 @@ { buildPythonPackage, poetry-core, + aiohttp, huggingface-hub, pydantic, }: @@ -15,6 +16,7 @@ buildPythonPackage { build-system = [ poetry-core ]; dependencies = [ + aiohttp huggingface-hub pydantic ]; diff --git a/nix/overlay.nix b/nix/overlay.nix index 069fdd80..0eb07c2a 100644 --- a/nix/overlay.nix +++ b/nix/overlay.nix @@ -13,26 +13,26 @@ final: prev: { ( python-self: python-super: with python-self; { # Python package override example: - transformers = python-super.transformers.overrideAttrs ( - _: _: { - src = final.fetchFromGitHub { - owner = "huggingface"; - repo = "transformers"; - rev = "v4.51.0"; - hash = "sha256-dnVpc6fm1SYGcx7FegpwVVxUY6XRlsxLs5WOxYv11y8="; - }; - } - ); - huggingface-hub = python-super.huggingface-hub.overrideAttrs ( - _: _: { - src = final.fetchFromGitHub { - owner = "huggingface"; - repo = "huggingface_hub"; - rev = "v0.30.0"; - hash = "sha256-sz+n1uoWrSQPqJFiG/qCT6b4r08kD9MsoPZXbfWNB2o="; - }; - } - ); + #transformers = python-super.transformers.overrideAttrs ( + # _: _: { + # src = final.fetchFromGitHub { + # owner = "huggingface"; + # repo = "transformers"; + # rev = "v4.51.0"; + # hash = "sha256-dnVpc6fm1SYGcx7FegpwVVxUY6XRlsxLs5WOxYv11y8="; + # }; + # } + #); + #huggingface-hub = python-super.huggingface-hub.overrideAttrs ( + # _: _: { + # src = final.fetchFromGitHub { + # owner = "huggingface"; + # repo = "huggingface_hub"; + # rev = "v0.30.0"; + # hash = "sha256-sz+n1uoWrSQPqJFiG/qCT6b4r08kD9MsoPZXbfWNB2o="; + # }; + # } + #); } ) ]; diff --git a/nix/server.nix b/nix/server.nix index e6493531..a45f39cc 100644 --- a/nix/server.nix +++ b/nix/server.nix @@ -31,7 +31,7 @@ peft, pillow, prometheus-client, - punica-kernels, + punica-sgmv, py-cpuinfo, pydantic, quantization, @@ -107,7 +107,7 @@ buildPythonPackage { peft pillow prometheus-client - punica-kernels + punica-sgmv py-cpuinfo pydantic quantization diff --git a/server/Makefile b/server/Makefile index cf6c7370..a95a4ae5 100644 --- a/server/Makefile +++ b/server/Makefile @@ -3,7 +3,6 @@ include Makefile-flash-att-v2 include Makefile-vllm include Makefile-awq include Makefile-selective-scan -include Makefile-lorax-punica include Makefile-exllamav2 include Makefile-flashinfer @@ -38,7 +37,7 @@ install: install-cuda echo "Installed server" install-cuda: install-server install-flash-attention-v2-cuda install-flash-attention - uv sync --inexact --extra attention --extra bnb --extra marlin --extra moe --active + uv sync --inexact --extra attention --extra bnb --active uv pip install nvidia-nccl-cu12==2.22.3 kernels download . @@ -46,6 +45,6 @@ install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm export-requirements: uv pip compile pyproject.toml --extra gen -o requirements_gen.txt --python-version 3.11 - uv pip compile pyproject.toml --extra attention --extra bnb --extra accelerate --extra compressed-tensors --extra marlin --extra moe --extra quantize --extra peft --extra outlines -o requirements_cuda.txt --python-version 3.11 + uv pip compile pyproject.toml --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines -o requirements_cuda.txt --python-version 3.11 uv pip compile pyproject.toml --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines -o requirements_intel.txt --python-version 3.11 uv pip compile pyproject.toml --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines -o requirements_rocm.txt --python-version 3.11 diff --git a/server/Makefile-lorax-punica b/server/Makefile-lorax-punica deleted file mode 100644 index 72f06f76..00000000 --- a/server/Makefile-lorax-punica +++ /dev/null @@ -1,12 +0,0 @@ -lorax_punica_commit := c71861a653412267dc27ec86013dd945ce3474bc - -build-lorax-punica: - if [ ! -d 'lorax-punica' ]; then \ - git clone --no-checkout https://github.com/predibase/lorax.git lorax-punica; \ - fi - cd lorax-punica && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit) - cd lorax-punica && git submodule update --init --recursive - cd lorax-punica/server/punica_kernels && python setup.py build - -install-lorax-punica: build-lorax-punica - cd lorax-punica/server/punica_kernels && python setup.py install diff --git a/server/kernels.lock b/server/kernels.lock index 9e11de68..a06cbff3 100644 --- a/server/kernels.lock +++ b/server/kernels.lock @@ -1,270 +1,468 @@ [ { "repo_id": "kernels-community/paged-attention", - "sha": "331b7e63a6b592799c8bc992f681bb1ee2c865a2", + "sha": "1e0a9708f0fe47009a3d292226c5492474353258", "variants": { "torch25-cxx11-cu118-x86_64-linux": { - "hash": "sha256-8e0aa39abab82f1d21b661d35e0470a24c3ebbdda38532ded805c18037a1ad1e", + "hash": "sha256-99710450ce815fdd0eeab3862ed0940c37a236c4f6cd49399e0112d66c9e40cb", "hash_type": "git_lfs_concat" }, "torch25-cxx11-cu121-x86_64-linux": { - "hash": "sha256-b0c3aef6c4c9aac627975cb1a2bfc46a70390763c8165575b89d1651d007c38a", + "hash": "sha256-bf136ffb4732e141e05738606a014fde18d3aa6d4345d6223858327c00eef2d1", "hash_type": "git_lfs_concat" }, "torch25-cxx11-cu124-x86_64-linux": { - "hash": "sha256-960fbc8998439d779adb47fb2a37cce68c7dc075d8a49893bd487be9ca2d1389", + "hash": "sha256-5ff343fc4feadf36ea38032d2a014a1cd6008fe22dea26191cd397745dbaf8ae", "hash_type": "git_lfs_concat" }, "torch25-cxx98-cu118-x86_64-linux": { - "hash": "sha256-9d6d60c411c55aa2f9d7c681c2be96f4262d56c96f592f3d4fb35ce4f4f1e18e", + "hash": "sha256-5db4fd37dcc6ec49ea71eba49415758b98fc21699155632902c76a545b36c47a", "hash_type": "git_lfs_concat" }, "torch25-cxx98-cu121-x86_64-linux": { - "hash": "sha256-98c0a305b2cc9b7be757fab923d9aa406c686dcd0460e462926f87d051ef3d19", + "hash": "sha256-995ff1a0cfe569639bc1644b5d6d823ea47ad0da33fe1cf398370ee70a203eb3", "hash_type": "git_lfs_concat" }, "torch25-cxx98-cu124-x86_64-linux": { - "hash": "sha256-71e586416213c96ffbdeae0d077ba97bfde5b00005f2746d4cba2320cb53bf87", + "hash": "sha256-1a00b021ea1273acb003ebd459699287ebf3d03f949befa31ae91899fa90b9e8", "hash_type": "git_lfs_concat" }, "torch26-cxx11-cu118-x86_64-linux": { - "hash": "sha256-2f559312c54d558b33a4082ffc3fcf923f51da40ced19bfc8920e998ba2b71bf", + "hash": "sha256-91e57835ae0f6e2df38c65c9e2eb47d33b609c7c117f6a86898740ad17653dba", "hash_type": "git_lfs_concat" }, "torch26-cxx11-cu124-x86_64-linux": { - "hash": "sha256-6033b41a0f8a9509887c6171f0b42d9aa738490903b3fd5ea2c52703c5fb8fc3", + "hash": "sha256-5435890298a7eca613c805c8aee08b5a4405a1a7ad38ad3bc43bba14b26683ae", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx11-cu126-aarch64-linux": { + "hash": "sha256-b3dffef571f4f813b727ce3b2fcb7b43ee9d2e793b594e6ccf3a694bac87280a", "hash_type": "git_lfs_concat" }, "torch26-cxx11-cu126-x86_64-linux": { - "hash": "sha256-3139f66a53f2bf0c314b4d309893095746bdc9c3914c904fc31adfdf553ed219", + "hash": "sha256-7ce5d58943f52959cc9643477e4dc211c7592628968cc53714e307092c95a769", "hash_type": "git_lfs_concat" }, "torch26-cxx98-cu118-x86_64-linux": { - "hash": "sha256-2173d77e384d8e2881fc38603992c09e8be7bcd9da4cafdd4f2a5ce0ce22caf4", + "hash": "sha256-c74c251ba84cf6ea4c0402ed6dec7dca92f46b101f299a0abb1bcab5c83d2165", "hash_type": "git_lfs_concat" }, "torch26-cxx98-cu124-x86_64-linux": { - "hash": "sha256-7b1aaef81e01ecce83e03c50872910680ff2953f7c6ffd3ff15e8d9497ca9239", + "hash": "sha256-44661e14516679bfa1788a4919c01014e9cd2402ad6231947bf7a6ca55002ecd", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx98-cu126-aarch64-linux": { + "hash": "sha256-e28ca88f80f95eede03eae610c08f83caabe579e15d110d9e070e46b6435770f", "hash_type": "git_lfs_concat" }, "torch26-cxx98-cu126-x86_64-linux": { - "hash": "sha256-818b160a88b12b8e871099e40f76aa436ee828e2e060ecc35502dbe34a6ebd3b", + "hash": "sha256-05eb63f56b6b665d0e25919a8f429c8c3b2e0e3fc55725885d0e68e9011ca283", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu118-x86_64-linux": { + "hash": "sha256-ef0c14844fd8df0ce765b85497c90ce1091b4a780642d86bf206799ba9d3c94a", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu126-aarch64-linux": { + "hash": "sha256-ab151aea475c6880ed15e8f9232bf8720f7f0f2b96acdac65a5bcb7e5ab727b1", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu126-x86_64-linux": { + "hash": "sha256-08345dd704dcea727b9c2c109664f1602f97908fed84522edb817d95eb859f74", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu128-aarch64-linux": { + "hash": "sha256-c2419e4057e26bd90360dacd30f1b51eea1fde2efed9bd4c7db034ffc2962a5a", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu128-x86_64-linux": { + "hash": "sha256-a85fa6b43d438380c9d064769d8dd509ebf5206327a326082c0c249c0704ca46", "hash_type": "git_lfs_concat" } } }, { "repo_id": "kernels-community/moe", - "sha": "605a216f507b9a97b543140dee8937a4622069a8", + "sha": "e3efab933893cde20c5417ba185fa3b7cc811b24", "variants": { "torch25-cxx11-cu118-x86_64-linux": { - "hash": "sha256-855d92f02be3bfba0758161fa1266159d76c172e7c5d43d30816d22cfba76074", + "hash": "sha256-719817bc2320f52d510e4a62bceef41a0ba8c58ea0e67d844db4225add3c5783", "hash_type": "git_lfs_concat" }, "torch25-cxx11-cu121-x86_64-linux": { - "hash": "sha256-e6e780230477bbbc26fc40cc7fcff50298155998af4fc77a026c9f815ec984b1", + "hash": "sha256-1b5973b5d9376e377ff223aed71936cc25f19367c8db7fcd9aa70960c15de290", "hash_type": "git_lfs_concat" }, "torch25-cxx11-cu124-x86_64-linux": { - "hash": "sha256-52c1fb337033c4d1d7a279c5cb28aebbc7389976f21dc5803aeb16b2f7aeb94c", + "hash": "sha256-69e1e5603c01227c3e2cbd67c09dd39fa7c0d4ecf3f736a2eb07227f6bb8935b", "hash_type": "git_lfs_concat" }, "torch25-cxx98-cu118-x86_64-linux": { - "hash": "sha256-1fb654e8d02dda2a2382d1fb3a3ca9738d292eea674b30b80030cdcdfb6a0035", + "hash": "sha256-91626ab4046b04e1a0967cc5c8a60a248e611b413e1cace3e4bdb0fc3a68a0e4", "hash_type": "git_lfs_concat" }, "torch25-cxx98-cu121-x86_64-linux": { - "hash": "sha256-0cf235f1de85d4ce7490c79aa64220f608f886f313b676d91c331a6a2fd67bbb", + "hash": "sha256-84dd628239aa3043bc048c51f513faf55042ccc3d372002bbc231b0aa6d6689f", "hash_type": "git_lfs_concat" }, "torch25-cxx98-cu124-x86_64-linux": { - "hash": "sha256-3def11fee9bf1ea9b1579206fd5f5ecbcaad47ac478e2c3aa7b2c9c7fd5db934", + "hash": "sha256-ffb9743f69aae59fba1cfed1fc9e2e0f90a9000121c2db5880f0e055a714931a", "hash_type": "git_lfs_concat" }, "torch26-cxx11-cu118-x86_64-linux": { - "hash": "sha256-3a49ee03f675190a79c7c74a45cc403d491eceb63a943f47d52064a11ca6ef6f", + "hash": "sha256-30560d5c091a9be1914fc8bf42d86767cfb07f1b7335f1ee88797e42f31e7856", "hash_type": "git_lfs_concat" }, "torch26-cxx11-cu124-x86_64-linux": { - "hash": "sha256-dbf20cb11db7d53e11147ab13641eefaa235f9ac2fde1beaf8f56f850c11bd54", + "hash": "sha256-6e2afd532fdc9cee8f532097a80e4c2139f47df8005c43c5cdac42204d6217e1", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx11-cu126-aarch64-linux": { + "hash": "sha256-93d46cc7701358cd5a4e5ae3fafde8120fdb765149b9a9224f52a802b7d48cf1", "hash_type": "git_lfs_concat" }, "torch26-cxx11-cu126-x86_64-linux": { - "hash": "sha256-8a07232ab316e8eab74747662cb7b86aac03f44ff158f275768fd59390df2525", + "hash": "sha256-e57c961ea9c1a411c5b348986e359b1e6f1102fa09cfaa82d20f96d09528098a", "hash_type": "git_lfs_concat" }, "torch26-cxx98-cu118-x86_64-linux": { - "hash": "sha256-cdd46301af997eeace5e016d8590969981b3a3f8647828d04baa5fa10c696746", + "hash": "sha256-946b982082c008220a667f44e4308c17933e0d4785cad72ececa35273275f09c", "hash_type": "git_lfs_concat" }, "torch26-cxx98-cu124-x86_64-linux": { - "hash": "sha256-c865188e9d2c17f3358f3d343fb40340232457572744bf85efd6b20af545d5f3", + "hash": "sha256-227be46b6cc468fadc237bb616d14e4747ad122bc0a2cd5bbef1a2b89a63d5bf", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx98-cu126-aarch64-linux": { + "hash": "sha256-d0dc0c8f34608f7c735e804c606dff029708349e68d5b9d9df7541b2498c1e8e", "hash_type": "git_lfs_concat" }, "torch26-cxx98-cu126-x86_64-linux": { - "hash": "sha256-2a8b09f3272ea80491e78a39ff886680471d99f7ba571581809adfe918013898", + "hash": "sha256-91b3df206bd4418e42d08608fdf652d65612342efc8f67958a66d68038179567", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu118-x86_64-linux": { + "hash": "sha256-4b0f4536cd8f24ef00f06e00dfa0123c03dada7de3394a6274ec5cfa3bbf31f6", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu126-aarch64-linux": { + "hash": "sha256-4c8468437ac977116f46be9a6871b0887f762ba44d3aea3c3ce2eb41637fb626", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu126-x86_64-linux": { + "hash": "sha256-9a0d84b8636a897e4a5abd243f48a71d7d470c2f8e28df6a6874a9d981105c0f", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu128-aarch64-linux": { + "hash": "sha256-11e6c4ce82a25d17664b4100af419f974fc312ac283195129c91519dac4d5812", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu128-x86_64-linux": { + "hash": "sha256-c49a6eda12752adf78690a5e985a55d3b85d6724be5d18db51cd03d5fc75cc9b", + "hash_type": "git_lfs_concat" + } + } + }, + { + "repo_id": "kernels-community/punica-sgmv", + "sha": "9ae1b469cb39c33df9ddd61657c6359acc423714", + "variants": { + "torch26-cxx11-cu118-x86_64-linux": { + "hash": "sha256-766062cd845bdebbe4e4391fda6f2663bebc2c110cbc2642d09c8c09ccf3f1d4", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx11-cu124-x86_64-linux": { + "hash": "sha256-c9cd76df7c84851aa566deb1c0d04ebddc1b1908a29df218344f2b3d53c4e683", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx11-cu126-aarch64-linux": { + "hash": "sha256-ae444bf53be3d469d4c9c58faef7d61a92e873e6104afe5aed2b2a1397333e99", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx11-cu126-x86_64-linux": { + "hash": "sha256-0706cc5ccf9cedae0bb6a938acdf2d5599a7b8f8b1fe46118b6ad61c0f3432af", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx98-cu118-x86_64-linux": { + "hash": "sha256-42cf390c6ae48b18041e201d4c67b4bf820b9f9cafe49a12c505f7920bae56ae", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx98-cu124-x86_64-linux": { + "hash": "sha256-75c97c23bfe32f65830341420d093a07df051828f385cbc5357b073c635f442f", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx98-cu126-aarch64-linux": { + "hash": "sha256-2ff5590ff6c298220c6e06142c971b08a686b98abb8d7dd1e6eb4539fa115cba", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx98-cu126-x86_64-linux": { + "hash": "sha256-70bcf04490865df6518c9d6a4c7eb2fee76b14642651f04a061c20ffa6fdb283", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu118-x86_64-linux": { + "hash": "sha256-727b8f5b22e4e91b956516235f26c39013a87ac6e196a0ce5f1897c2d959e69d", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu126-aarch64-linux": { + "hash": "sha256-bfddd19db7c9268a83e3cc5e281b007de80ab0fe611b3856ffd1691b400eca46", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu126-x86_64-linux": { + "hash": "sha256-940c68f5d4d8a2391b1eb3c7c5a56623428862f428aa5c6c1f7e62588c0e36fb", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu128-aarch64-linux": { + "hash": "sha256-781259a371b67bfbf744431c88a6ee847ab48459e73cb57264590de2728d6b3a", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu128-x86_64-linux": { + "hash": "sha256-8977a33d7884bebb9fb5e3d7daf157119206f0f18a22edb2b96ec593d5c81ae1", "hash_type": "git_lfs_concat" } } }, { "repo_id": "kernels-community/quantization", - "sha": "95272c71ca71b1ddbacb0105dab54e5d5240bd5c", + "sha": "6470f9b005797e00279eb9103463dfe0f8b7da00", "variants": { "torch25-cxx11-cu118-x86_64-linux": { - "hash": "sha256-2d0a274cf0117bf7880d6040adafa1b70fe8bff3a00ef2834ed5435a6b525a49", + "hash": "sha256-f52c9b1a7cd98fb389c6d2a0b22a293cb36eb96af3a624f5aec761735861c96d", "hash_type": "git_lfs_concat" }, "torch25-cxx11-cu121-x86_64-linux": { - "hash": "sha256-116458beac63ea5eeb1e7fba7edc68d160cd8ac28f55b926d79035551aac7d5f", + "hash": "sha256-e5f0da343363a562ce52f147a9534cd54a3efa90e70671f606cc2516f02a3876", "hash_type": "git_lfs_concat" }, "torch25-cxx11-cu124-x86_64-linux": { - "hash": "sha256-cace644c6fb04470384796c18987135cb051dfb90a14e902c51a3786fc07c599", + "hash": "sha256-caad9300c155faf79c26426f10951ba75f931a05e741a5b39a24b064daabc040", "hash_type": "git_lfs_concat" }, "torch25-cxx98-cu118-x86_64-linux": { - "hash": "sha256-104c6961cd3e1a74efdf14ea2172acc6647846852fccafe3698a27a6cf37941d", + "hash": "sha256-4fc87893de14a29ba4b55f5026ea05ec5901c0b52abd5ebae681ea0b791e858c", "hash_type": "git_lfs_concat" }, "torch25-cxx98-cu121-x86_64-linux": { - "hash": "sha256-cdc95b41aa91a803f11f8cd53001895c2b69550b5af2fb278d6f124381229d0b", + "hash": "sha256-72c975ea63fc524a38fcee5b2dbdb566eff0a0ea546ee5756441d04908e4e896", "hash_type": "git_lfs_concat" }, "torch25-cxx98-cu124-x86_64-linux": { - "hash": "sha256-d5388469cb6074f196f20b1e1e4805bb3c967a8147b31ca2c0461aa87b50604e", + "hash": "sha256-28c5510e3b07eae2b3846b880f6111da65df024e1f24f81077d187a97c015364", "hash_type": "git_lfs_concat" }, "torch26-cxx11-cu118-x86_64-linux": { - "hash": "sha256-70c4bb3792c4c3207d4963173d8d0ef3b2bda677151aef140662dd87bfa1b69f", + "hash": "sha256-8444cf77686578a6b0f7e2fd29bf2783ba120ebf7df41573f61d2521fd0acc10", "hash_type": "git_lfs_concat" }, "torch26-cxx11-cu124-x86_64-linux": { - "hash": "sha256-bcacbb2232f49345f27e07fa821b48a7e3df643c01af37281fcafc74c471f682", + "hash": "sha256-6ea8e00625b5fe799fbe407e7de0fc08228cac26f9bbed2d70a6500026fe3bab", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx11-cu126-aarch64-linux": { + "hash": "sha256-0b8b8afbdaf9aa533895cb9e884e3ad3e9a34d483f05a1bbde1b8902f9dbeb0f", "hash_type": "git_lfs_concat" }, "torch26-cxx11-cu126-x86_64-linux": { - "hash": "sha256-344d20964f7eb133e5ec6fda976fa5ee62807b739a4361f236aca5ae53beb9ac", + "hash": "sha256-e115e855d7ca4b97787f04c88e128432256c6b43d4823fb8889ab9985dc4cf36", "hash_type": "git_lfs_concat" }, "torch26-cxx98-cu118-x86_64-linux": { - "hash": "sha256-dfaec226550254fbce1a5c7e2f547e85700958a1a4087e1c873d22e6f71a5ceb", + "hash": "sha256-509f08c48a05584cc85c058607277fcbe3193e6cc61846dd2416d39e27c1d68e", "hash_type": "git_lfs_concat" }, "torch26-cxx98-cu124-x86_64-linux": { - "hash": "sha256-0abe6460d0a2202b0086e3663092595e5b93b9a9cbb85c10034180cc9bfebc6e", + "hash": "sha256-a10236bffd435296c736ae2762ab0836da2421297e46b377368a17b39d70c27b", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx98-cu126-aarch64-linux": { + "hash": "sha256-ca2cb56f3eea4c399a61e21ba9b577d718b250aa60a13f42f01019ddd5cd8b0c", "hash_type": "git_lfs_concat" }, "torch26-cxx98-cu126-x86_64-linux": { - "hash": "sha256-68e156f94c3c0c9523773b62eaeced93766e0d9ee67d8191fb9570fb5af30d5b", + "hash": "sha256-8fcd62d8243a30b63a03751cc0c15d24f6e00e43eae79f7281627f24e078bf9a", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu118-x86_64-linux": { + "hash": "sha256-60f5807ee3da937c57c1b6080c30632305aa4875ed5a52bf4e81968770b61b13", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu126-aarch64-linux": { + "hash": "sha256-64298b1713dc1d950915dc6569a06e2f541de3ed80aa5b32084246c1fdc7a958", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu126-x86_64-linux": { + "hash": "sha256-d9e219890dc28e8582ef21d6f81f2ebc361de218a86b742be63bc4714f102e5e", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu128-aarch64-linux": { + "hash": "sha256-d72549f51aefcf020bc74262bbbccb78094638c5ab9adc8667873d247c1cce86", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu128-x86_64-linux": { + "hash": "sha256-d31ac5f87d7c7f62c63c72946479193aed467c9417c0acead5137e0e1fa968f8", "hash_type": "git_lfs_concat" } } }, { "repo_id": "kernels-community/quantization-eetq", - "sha": "a80ce846d6270ddddeee109523ed947f594f246b", + "sha": "1aa83b1261b0c4cad890184a4d689e6330a110b5", "variants": { "torch25-cxx11-cu118-x86_64-linux": { - "hash": "sha256-e06beb00799b1e656583eb0496f09fc0bf1b26f75e9864a2fe19ebd5b62c3671", + "hash": "sha256-de257728ec38f48220d6c90b2fd960fed1f4c963e7cd6c204abfcf8607aedc20", "hash_type": "git_lfs_concat" }, "torch25-cxx11-cu121-x86_64-linux": { - "hash": "sha256-c128d3ef6558cfedf045c4a713891792708851b7f6f027de835d9083cb3b297d", + "hash": "sha256-9027918cf6e52591f97b2c621355e12d9adf0dfe833a763219813bfecd1ad1a3", "hash_type": "git_lfs_concat" }, "torch25-cxx11-cu124-x86_64-linux": { - "hash": "sha256-c7e2e14fc114788634b34a4f670f7bf4d27321e5ed40ff446f5a25eef70222c7", + "hash": "sha256-15cd0a56311897b27ee50617491cf69e698053a9f9af7bd37937cbca8da9db13", "hash_type": "git_lfs_concat" }, "torch25-cxx98-cu118-x86_64-linux": { - "hash": "sha256-58dad53cfbf1315af464f9d8ba7be9012089c839d4f06a8d2cf8ce0deaf5949a", + "hash": "sha256-ca35ccbb193c795587f4a0ea072fda6f0a0ac7f745f7a68e35c35012098f0a57", "hash_type": "git_lfs_concat" }, "torch25-cxx98-cu121-x86_64-linux": { - "hash": "sha256-6519af49c0f689744a7b49497ad2bea1524b69e4095446087d7ab622b898aa30", + "hash": "sha256-e7b12bd79163ee0f520b4a399f69c29e4a692667edf27f7d100f053434d8840c", "hash_type": "git_lfs_concat" }, "torch25-cxx98-cu124-x86_64-linux": { - "hash": "sha256-94e0731b58a9ba0e5e2f37b100c8d987c80b5d349008ef625917d020b6c52d25", + "hash": "sha256-f08e850e856faa42c992188affa898a9b5a7be9d64980c4193871b0ad999da78", "hash_type": "git_lfs_concat" }, "torch26-cxx11-cu118-x86_64-linux": { - "hash": "sha256-e5b04475538f49d7b4ffded080e4c9c86a658abc12667e3838ebcc410ab1eef4", + "hash": "sha256-9596f1c7cdbc7adf75898d18f370dc33ce0dfab2559301244411f5f4c4e581d4", "hash_type": "git_lfs_concat" }, "torch26-cxx11-cu124-x86_64-linux": { - "hash": "sha256-783c02db737a6ec9958b3090f164b87888d3b26e30a4fb6e1cd0c1a635753fab", + "hash": "sha256-90002710f9e59d12bff260ce288c2b2b954f988f94ef920c8384c97946b7782b", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx11-cu126-aarch64-linux": { + "hash": "sha256-d230dd53423cf29387350d2e28cc691785135613408edb73c79f5d965dbb30e5", "hash_type": "git_lfs_concat" }, "torch26-cxx11-cu126-x86_64-linux": { - "hash": "sha256-a3d81f82f9cfe9d8a6d46758758b3a1b3055d902f41917b4ef2976373db843d6", + "hash": "sha256-fb95eb2faee971ebc0ede12678816c7796b64c723e4fd787aea97397f1c7f5cd", "hash_type": "git_lfs_concat" }, "torch26-cxx98-cu118-x86_64-linux": { - "hash": "sha256-f1de67e17944a9816f778c72ae73bbbc90d795cb4885c2f9ee5e0b9a3c57583b", + "hash": "sha256-027930f857347a4f1524fa37244c41c53ffb8c1ebd4eeb72fa32eea4a28b8787", "hash_type": "git_lfs_concat" }, "torch26-cxx98-cu124-x86_64-linux": { - "hash": "sha256-789b50d767a5121a7e5a52eaf0c8e897bf1787f049ca08faffb220e5053a5f10", + "hash": "sha256-59ee042d58d57100c415f491a3db905671e094707f786f5f7e3260d5b827ad6a", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx98-cu126-aarch64-linux": { + "hash": "sha256-1f9d739bd8198c330b1f2893e0301740c54fa95272233fadb7a95c9b53a70383", "hash_type": "git_lfs_concat" }, "torch26-cxx98-cu126-x86_64-linux": { - "hash": "sha256-7c7fe57fea7b9be253085d506f01b2487b2306f22bdffe1de44397fc9f8a3613", + "hash": "sha256-f56c5ea702982b9f75dedeb3a8998550b1b38bcacd77590926234e221fcc571f", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu118-x86_64-linux": { + "hash": "sha256-9c6f2b7fea5327abee2920da86dd57878d5f35aacacc886875050649073d1565", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu126-aarch64-linux": { + "hash": "sha256-fba9bd51e4aa5515ed81193743512dec2129f38555a16a54710e650a717259a8", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu126-x86_64-linux": { + "hash": "sha256-990b615c4b5d2f96874e7f88767681544d84771f3a11443cf0c994759f5e5f75", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu128-aarch64-linux": { + "hash": "sha256-6ad809543e1099f91b022f1393fe9a4527957b854cdfe6c8f4a0632c5497cb9d", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu128-x86_64-linux": { + "hash": "sha256-90aaa73d93db015c693a4089f2574c2ec2d4943bcee5c9b0ede2834a2c72c370", "hash_type": "git_lfs_concat" } } }, { "repo_id": "kernels-community/rotary", - "sha": "4db658e027ec752840bb3f557ee076413b8db03f", + "sha": "804a326b61f181778b5eb4ebe27aecdb8fbcd845", "variants": { "torch25-cxx11-cu118-x86_64-linux": { - "hash": "sha256-907df2035267a65793985bb7f69fb2a975955fb08c2bbc78c58def43d02801da", + "hash": "sha256-198c67cc7330535da671086c3b6a0dd6189015381f25b409704b51224b25ae3c", "hash_type": "git_lfs_concat" }, "torch25-cxx11-cu121-x86_64-linux": { - "hash": "sha256-b614735ae61ee2c1825a3c823fa0cdd3aa07d0bb3f4106001b9e1a557c0ca9b9", + "hash": "sha256-c2e8233d79dd36fc778502c0d44e7399907c2ef064981c7d122fb0652c71eca5", "hash_type": "git_lfs_concat" }, "torch25-cxx11-cu124-x86_64-linux": { - "hash": "sha256-f2e98ec72faaebc1cae25f83ccdbb151868b6902fb5a0623e09d700a514c2a7e", + "hash": "sha256-452040cd5c335a3985da635a76db60a6fc0d9f8b1050fdf29f837d42ee2742ea", "hash_type": "git_lfs_concat" }, "torch25-cxx98-cu118-x86_64-linux": { - "hash": "sha256-421214c5a576fac2e0b7998395dccd7f66010f65a6fc647ce06b106ea91105d2", + "hash": "sha256-b627ad5946713c8893f2847eb28f87203f3caaa84f2f35bb9f7b54ea9c3c8a5d", "hash_type": "git_lfs_concat" }, "torch25-cxx98-cu121-x86_64-linux": { - "hash": "sha256-9d1c464cf7f391975afa48f2254a639f41582155ad1b50c25bb122418ce8db58", + "hash": "sha256-30311ae1858e29754a4c69e081466e78202ffe8522d08afa46f06350f54cfcd1", "hash_type": "git_lfs_concat" }, "torch25-cxx98-cu124-x86_64-linux": { - "hash": "sha256-82f8012d78304efaa7318f106907630294d10c8b5c9f56923c71df0b03e09f14", + "hash": "sha256-f988c59f5ac640c657f51c7a463f7bcc5ff789109275d8b14f524ad300f9ca55", "hash_type": "git_lfs_concat" }, "torch26-cxx11-cu118-x86_64-linux": { - "hash": "sha256-a3247919dcc392efc7e54725dfbce9ee8a796fe4ee53d113048b313de074d3da", + "hash": "sha256-58998893b9992e3ede276388e09c1c31da0b6175d68cf37bcb75bd6f69dba240", "hash_type": "git_lfs_concat" }, "torch26-cxx11-cu124-x86_64-linux": { - "hash": "sha256-a21c9734d15946f4cc967d0555d45d7effc6624990c6889fc49162af744fbbe9", + "hash": "sha256-2fdc356b7a5ce2f090dead00253180a750ec9ff72c0afc5f3f07c96e2e603916", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx11-cu126-aarch64-linux": { + "hash": "sha256-d82cd995be25b4b88b0a4086269dcdeb400d0720141fbbfa47bf88cd639ae7e1", "hash_type": "git_lfs_concat" }, "torch26-cxx11-cu126-x86_64-linux": { - "hash": "sha256-01cdda160425b29db0d9bb084874ade4ac081735f9717f272aaefe5bcb379ae1", + "hash": "sha256-a6cd702f278dcbd94f8412d51f79a2664844217b7344bdd24353760c72a789d5", "hash_type": "git_lfs_concat" }, "torch26-cxx98-cu118-x86_64-linux": { - "hash": "sha256-17be5b770418ad47101c49d8945b5aa32af9eb5a840bdffb0514d0e264edd860", + "hash": "sha256-c759c2e38a17ea61446afb881cfa2a152d82350e6d38efecbec8ebe1e27cf81f", "hash_type": "git_lfs_concat" }, "torch26-cxx98-cu124-x86_64-linux": { - "hash": "sha256-3cd4b9f63cc903e01325b7e5b204e40fc6600c0685f2e19e6f1fa604a599d82d", + "hash": "sha256-d81512fa75acbe8a124b9890bb041fdd1e447794ee210bbb5d01343bd5033eec", + "hash_type": "git_lfs_concat" + }, + "torch26-cxx98-cu126-aarch64-linux": { + "hash": "sha256-a81df695a1b980f899df3c05920a04ff15a89dd28c8cef4067e4e6579669292b", "hash_type": "git_lfs_concat" }, "torch26-cxx98-cu126-x86_64-linux": { - "hash": "sha256-c569f4a4f9b64792507c58d7cfa31dde1285b52125ef07cc98d9f23636af09ca", + "hash": "sha256-868a4b47368a251018bf8f67f3effd8685fed6b01e64725da7e653d38831b166", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu118-x86_64-linux": { + "hash": "sha256-21ae5790dcf3936b66cd74641f815280ea648dffdc5259b7e1dba3fa5a8fc70d", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu126-aarch64-linux": { + "hash": "sha256-93466448e31897ef7db0e84e7d6d36824661b15a9841e2476ff181e1eab155c2", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu126-x86_64-linux": { + "hash": "sha256-e0ce52422c82c2ce966c44e61e0d65c789b36feaaeca818f88c2e746201cde9b", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu128-aarch64-linux": { + "hash": "sha256-eb155e56df00ad7d6455f1549d072c39f14c2b7e355f729bf35cb3e62d087df9", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu128-x86_64-linux": { + "hash": "sha256-63b3f8fc56c940d824cdf06d3cc5b504d82c14e005c7d2ca5360e384a2b16af2", "hash_type": "git_lfs_concat" } } diff --git a/server/pyproject.toml b/server/pyproject.toml index 53347f52..7f2addb6 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -9,6 +9,8 @@ authors = [ {name = "Nicolas Patry", email = "nicolas@huggingface.co"}, ] dependencies = [ + # Remove explicit click dependency once typer/click are compatible again. + "click<8.2.0", "einops>=0.8.0", "grpc-interceptor>=0.15.4", "grpcio>=1.67.0", @@ -37,16 +39,16 @@ dependencies = [ ] [[tool.uv.index]] -name = "pytorch-cu124" -url = "https://download.pytorch.org/whl/cu124" +name = "pytorch-cu128" +url = "https://download.pytorch.org/whl/cu128" explicit = true [tool.uv.sources] torch = [ - { index = "pytorch-cu124", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] torchvision = [ - { index = "pytorch-cu124", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] [build-system] @@ -56,6 +58,7 @@ build-backend = "setuptools.build_meta" [tool.kernels.dependencies] "kernels-community/paged-attention" = ">=0.0.2" "kernels-community/moe" = ">=0.1.1" +"kernels-community/punica-sgmv" = ">=0.0.1" "kernels-community/quantization" = ">=0.0.3" "kernels-community/quantization-eetq" = ">=0.0.1" "kernels-community/rotary" = ">=0.0.1" @@ -92,8 +95,8 @@ gen = [ "mypy-protobuf>=3.6.0", ] torch = [ - "torch==2.6.0", - "torchvision==0.21.0", + "torch==2.7.0", + "torchvision==0.22.0", ] [tool.pytest.ini_options] diff --git a/server/requirements_cuda.txt b/server/requirements_cuda.txt index 0764bf92..a70d893b 100644 --- a/server/requirements_cuda.txt +++ b/server/requirements_cuda.txt @@ -1,43 +1,43 @@ # This file was autogenerated by uv via the following command: -# uv pip compile pyproject.toml --extra attention --extra bnb --extra accelerate --extra compressed-tensors --extra marlin --extra moe --extra quantize --extra peft --extra outlines -o requirements_cuda.txt --python-version 3.11 -accelerate==1.3.0 +# uv pip compile pyproject.toml --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines -o requirements_cuda.txt --python-version 3.11 +accelerate==1.6.0 # via # text-generation-server (pyproject.toml) # peft -aiohappyeyeballs==2.4.4 +aiohappyeyeballs==2.6.1 # via aiohttp -aiohttp==3.11.11 +aiohttp==3.11.18 # via # datasets # fsspec aiosignal==1.3.2 # via aiohttp -airportsdata==20241001 +airportsdata==20250224 # via outlines annotated-types==0.7.0 # via pydantic -attention-kernels @ https://github.com/danieldk/attention-kernels/releases/download/v0.2.0.post2/attention_kernels-0.2.0.post2+cu123torch2.5-cp39-abi3-linux_x86_64.whl - # via text-generation-server (pyproject.toml) -attrs==25.1.0 +attrs==25.3.0 # via # aiohttp # jsonschema # referencing -bitsandbytes==0.45.1 +bitsandbytes==0.45.5 # via text-generation-server (pyproject.toml) -certifi==2024.8.30 +certifi==2025.4.26 # via requests -charset-normalizer==3.4.0 +charset-normalizer==3.4.2 # via requests -click==8.1.7 - # via typer +click==8.1.8 + # via + # text-generation-server (pyproject.toml) + # typer cloudpickle==3.1.1 # via outlines -compressed-tensors==0.9.1 +compressed-tensors==0.9.4 # via text-generation-server (pyproject.toml) datasets==2.21.0 # via text-generation-server (pyproject.toml) -deprecated==1.2.14 +deprecated==1.2.18 # via # opentelemetry-api # opentelemetry-exporter-otlp-proto-grpc @@ -49,15 +49,15 @@ dill==0.3.8 # multiprocess diskcache==5.6.3 # via outlines -einops==0.8.0 +einops==0.8.1 # via text-generation-server (pyproject.toml) -filelock==3.16.1 +filelock==3.18.0 # via # datasets # huggingface-hub # torch # transformers -frozenlist==1.5.0 +frozenlist==1.6.0 # via # aiohttp # aiosignal @@ -68,30 +68,36 @@ fsspec==2024.6.1 # torch genson==1.3.0 # via outlines -googleapis-common-protos==1.65.0 +googleapis-common-protos==1.70.0 # via # grpcio-status # opentelemetry-exporter-otlp-proto-grpc # opentelemetry-exporter-otlp-proto-http grpc-interceptor==0.15.4 # via text-generation-server (pyproject.toml) -grpcio==1.68.0 +grpcio==1.71.0 # via # text-generation-server (pyproject.toml) # grpc-interceptor # grpcio-reflection # grpcio-status # opentelemetry-exporter-otlp-proto-grpc -grpcio-reflection==1.68.0 +grpcio-reflection==1.71.0 # via text-generation-server (pyproject.toml) -grpcio-status==1.68.0 +grpcio-status==1.71.0 # via text-generation-server (pyproject.toml) -hf-transfer==0.1.8 +hf-transfer==0.1.9 # via text-generation-server (pyproject.toml) -huggingface-hub==0.28.1 +hf-xet==1.1.0 # via + # text-generation-server (pyproject.toml) + # huggingface-hub +huggingface-hub==0.31.1 + # via + # text-generation-server (pyproject.toml) # accelerate # datasets + # kernels # peft # tokenizers # transformers @@ -99,13 +105,15 @@ idna==3.10 # via # requests # yarl -importlib-metadata==7.1.0 +importlib-metadata==8.6.1 # via opentelemetry-api interegular==0.3.3 # via # outlines # outlines-core -jinja2==3.1.5 +iso3166==2.1.1 + # via outlines +jinja2==3.1.6 # via # outlines # torch @@ -113,8 +121,10 @@ jsonschema==4.23.0 # via # outlines # outlines-core -jsonschema-specifications==2024.10.1 +jsonschema-specifications==2025.4.1 # via jsonschema +kernels==0.5.0 + # via text-generation-server (pyproject.toml) lark==1.2.2 # via outlines loguru==0.7.3 @@ -123,15 +133,11 @@ markdown-it-py==3.0.0 # via rich markupsafe==3.0.2 # via jinja2 -marlin-kernels @ https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp311-cp311-linux_x86_64.whl - # via text-generation-server (pyproject.toml) mdurl==0.1.2 # via markdown-it-py -moe-kernels @ https://github.com/danieldk/moe-kernels/releases/download/v0.8.2/moe_kernels-0.8.2+cu123torch2.5-cp39-abi3-linux_x86_64.whl - # via text-generation-server (pyproject.toml) mpmath==1.3.0 # via sympy -multidict==6.1.0 +multidict==6.4.3 # via # aiohttp # yarl @@ -141,7 +147,7 @@ nest-asyncio==1.6.0 # via outlines networkx==3.4.2 # via torch -numpy==1.26.4 +numpy==2.2.5 # via # text-generation-server (pyproject.toml) # accelerate @@ -152,43 +158,44 @@ numpy==1.26.4 # peft # scipy # transformers -nvidia-cublas-cu12==12.4.5.8 +nvidia-cublas-cu12==12.6.4.1 # via # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 # torch -nvidia-cuda-cupti-cu12==12.4.127 +nvidia-cuda-cupti-cu12==12.6.80 # via torch -nvidia-cuda-nvrtc-cu12==12.4.127 +nvidia-cuda-nvrtc-cu12==12.6.77 # via torch -nvidia-cuda-runtime-cu12==12.4.127 +nvidia-cuda-runtime-cu12==12.6.77 # via torch -nvidia-cudnn-cu12==9.1.0.70 +nvidia-cudnn-cu12==9.5.1.17 # via torch -nvidia-cufft-cu12==11.2.1.3 +nvidia-cufft-cu12==11.3.0.4 # via torch -nvidia-curand-cu12==10.3.5.147 +nvidia-cufile-cu12==1.11.1.6 # via torch -nvidia-cusolver-cu12==11.6.1.9 +nvidia-curand-cu12==10.3.7.77 # via torch -nvidia-cusparse-cu12==12.3.1.170 +nvidia-cusolver-cu12==11.7.1.2 + # via torch +nvidia-cusparse-cu12==12.5.4.2 # via # nvidia-cusolver-cu12 # torch -nvidia-cusparselt-cu12==0.6.2 +nvidia-cusparselt-cu12==0.6.3 # via torch -nvidia-ml-py==12.570.86 - # via moe-kernels -nvidia-nccl-cu12==2.21.5 +nvidia-nccl-cu12==2.26.2 # via torch -nvidia-nvjitlink-cu12==12.4.127 +nvidia-nvjitlink-cu12==12.6.85 # via + # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 # torch -nvidia-nvtx-cu12==12.4.127 +nvidia-nvtx-cu12==12.6.77 # via torch -opentelemetry-api==1.30.0 +opentelemetry-api==1.33.0 # via # text-generation-server (pyproject.toml) # opentelemetry-exporter-otlp-proto-grpc @@ -197,86 +204,85 @@ opentelemetry-api==1.30.0 # opentelemetry-instrumentation-grpc # opentelemetry-sdk # opentelemetry-semantic-conventions -opentelemetry-exporter-otlp==1.30.0 +opentelemetry-exporter-otlp==1.33.0 # via text-generation-server (pyproject.toml) -opentelemetry-exporter-otlp-proto-common==1.30.0 +opentelemetry-exporter-otlp-proto-common==1.33.0 # via # opentelemetry-exporter-otlp-proto-grpc # opentelemetry-exporter-otlp-proto-http -opentelemetry-exporter-otlp-proto-grpc==1.30.0 +opentelemetry-exporter-otlp-proto-grpc==1.33.0 # via opentelemetry-exporter-otlp -opentelemetry-exporter-otlp-proto-http==1.30.0 +opentelemetry-exporter-otlp-proto-http==1.33.0 # via opentelemetry-exporter-otlp -opentelemetry-instrumentation==0.51b0 +opentelemetry-instrumentation==0.54b0 # via opentelemetry-instrumentation-grpc -opentelemetry-instrumentation-grpc==0.51b0 +opentelemetry-instrumentation-grpc==0.54b0 # via text-generation-server (pyproject.toml) -opentelemetry-proto==1.30.0 +opentelemetry-proto==1.33.0 # via # opentelemetry-exporter-otlp-proto-common # opentelemetry-exporter-otlp-proto-grpc # opentelemetry-exporter-otlp-proto-http -opentelemetry-sdk==1.30.0 +opentelemetry-sdk==1.33.0 # via # opentelemetry-exporter-otlp-proto-grpc # opentelemetry-exporter-otlp-proto-http -opentelemetry-semantic-conventions==0.51b0 +opentelemetry-semantic-conventions==0.54b0 # via # opentelemetry-instrumentation # opentelemetry-instrumentation-grpc # opentelemetry-sdk -outlines==0.1.14 +outlines==0.2.3 # via text-generation-server (pyproject.toml) outlines-core==0.1.26 # via outlines -packaging==24.1 +packaging==25.0 # via # accelerate # datasets # huggingface-hub + # kernels # opentelemetry-instrumentation # peft # transformers pandas==2.2.3 # via datasets -peft==0.14.0 +peft==0.15.2 # via text-generation-server (pyproject.toml) -pillow==11.1.0 +pillow==11.2.1 # via text-generation-server (pyproject.toml) prometheus-client==0.21.1 # via text-generation-server (pyproject.toml) -propcache==0.2.1 +propcache==0.3.1 # via # aiohttp # yarl -protobuf==5.29.3 +protobuf==5.29.4 # via # text-generation-server (pyproject.toml) # googleapis-common-protos # grpcio-reflection # grpcio-status # opentelemetry-proto -psutil==6.1.1 +psutil==7.0.0 # via # accelerate # peft py-cpuinfo==9.0.0 # via text-generation-server (pyproject.toml) -pyarrow==19.0.0 +pyarrow==20.0.0 # via datasets -pycountry==24.6.1 - # via outlines -pydantic==2.10.6 +pydantic==2.11.4 # via # compressed-tensors # outlines -pydantic-core==2.27.2 +pydantic-core==2.33.2 # via pydantic -pygments==2.18.0 +pygments==2.19.1 # via rich python-dateutil==2.9.0.post0 # via pandas -pytz==2025.1 +pytz==2025.2 # via pandas pyyaml==6.0.2 # via @@ -290,7 +296,7 @@ referencing==0.36.2 # jsonschema # jsonschema-specifications # outlines -regex==2024.9.11 +regex==2024.11.6 # via transformers requests==2.32.3 # via @@ -299,65 +305,62 @@ requests==2.32.3 # opentelemetry-exporter-otlp-proto-http # outlines # transformers -rich==13.9.4 +rich==14.0.0 # via # text-generation-server (pyproject.toml) # typer -rpds-py==0.22.3 +rpds-py==0.24.0 # via # jsonschema # referencing -safetensors==0.4.5 +safetensors==0.5.3 # via # text-generation-server (pyproject.toml) # accelerate # peft # transformers -scipy==1.13.1 +scipy==1.15.3 # via text-generation-server (pyproject.toml) sentencepiece==0.2.0 # via text-generation-server (pyproject.toml) +setuptools==80.4.0 + # via triton shellingham==1.5.4 # via typer six==1.17.0 # via python-dateutil -sympy==1.13.1 +sympy==1.14.0 # via torch texttable==1.7.0 # via text-generation-server (pyproject.toml) -tokenizers==0.21.0 +tokenizers==0.21.1 # via # text-generation-server (pyproject.toml) # transformers -torch==2.6.0 +torch==2.7.0 # via # accelerate - # attention-kernels # bitsandbytes # compressed-tensors - # marlin-kernels - # moe-kernels # outlines # peft -tqdm==4.66.5 +tqdm==4.67.1 # via # datasets # huggingface-hub # outlines # peft # transformers -transformers==4.49 +transformers==4.51.3 # via # text-generation-server (pyproject.toml) # compressed-tensors # peft -triton==3.2.0 - # via - # moe-kernels - # torch -typer==0.15.1 +triton==3.3.0 + # via torch +typer==0.15.3 # via text-generation-server (pyproject.toml) -typing-extensions==4.12.2 +typing-extensions==4.13.2 # via # huggingface-hub # opentelemetry-sdk @@ -367,18 +370,21 @@ typing-extensions==4.12.2 # referencing # torch # typer -tzdata==2025.1 + # typing-inspection +typing-inspection==0.4.0 + # via pydantic +tzdata==2025.2 # via pandas -urllib3==2.2.3 +urllib3==2.4.0 # via requests -wrapt==1.16.0 +wrapt==1.17.2 # via # deprecated # opentelemetry-instrumentation # opentelemetry-instrumentation-grpc xxhash==3.5.0 # via datasets -yarl==1.18.3 +yarl==1.20.0 # via aiohttp -zipp==3.20.2 +zipp==3.21.0 # via importlib-metadata diff --git a/server/requirements_gen.txt b/server/requirements_gen.txt index 6d64a34b..0cee5d6c 100644 --- a/server/requirements_gen.txt +++ b/server/requirements_gen.txt @@ -1,33 +1,35 @@ # This file was autogenerated by uv via the following command: # uv pip compile pyproject.toml --extra gen -o requirements_gen.txt --python-version 3.11 -certifi==2025.1.31 +certifi==2025.4.26 # via requests -charset-normalizer==3.4.1 +charset-normalizer==3.4.2 # via requests click==8.1.8 - # via typer + # via + # text-generation-server (pyproject.toml) + # typer deprecated==1.2.18 # via # opentelemetry-api # opentelemetry-exporter-otlp-proto-grpc # opentelemetry-exporter-otlp-proto-http # opentelemetry-semantic-conventions -einops==0.8.0 +einops==0.8.1 # via text-generation-server (pyproject.toml) -filelock==3.17.0 +filelock==3.18.0 # via # huggingface-hub # transformers -fsspec==2025.2.0 +fsspec==2025.3.2 # via huggingface-hub -googleapis-common-protos==1.66.0 +googleapis-common-protos==1.70.0 # via # grpcio-status # opentelemetry-exporter-otlp-proto-grpc # opentelemetry-exporter-otlp-proto-http grpc-interceptor==0.15.4 # via text-generation-server (pyproject.toml) -grpcio==1.70.0 +grpcio==1.71.0 # via # text-generation-server (pyproject.toml) # grpc-interceptor @@ -35,22 +37,30 @@ grpcio==1.70.0 # grpcio-status # grpcio-tools # opentelemetry-exporter-otlp-proto-grpc -grpcio-reflection==1.70.0 +grpcio-reflection==1.71.0 # via text-generation-server (pyproject.toml) -grpcio-status==1.70.0 +grpcio-status==1.71.0 # via text-generation-server (pyproject.toml) -grpcio-tools==1.70.0 +grpcio-tools==1.71.0 # via text-generation-server (pyproject.toml) hf-transfer==0.1.9 # via text-generation-server (pyproject.toml) -huggingface-hub==0.28.1 +hf-xet==1.1.0 # via + # text-generation-server (pyproject.toml) + # huggingface-hub +huggingface-hub==0.31.1 + # via + # text-generation-server (pyproject.toml) + # kernels # tokenizers # transformers idna==3.10 # via requests -importlib-metadata==8.5.0 +importlib-metadata==8.6.1 # via opentelemetry-api +kernels==0.5.0 + # via text-generation-server (pyproject.toml) loguru==0.7.3 # via text-generation-server (pyproject.toml) markdown-it-py==3.0.0 @@ -59,12 +69,12 @@ mdurl==0.1.2 # via markdown-it-py mypy-protobuf==3.6.0 # via text-generation-server (pyproject.toml) -numpy==2.2.2 +numpy==2.2.5 # via # text-generation-server (pyproject.toml) # scipy # transformers -opentelemetry-api==1.30.0 +opentelemetry-api==1.33.0 # via # text-generation-server (pyproject.toml) # opentelemetry-exporter-otlp-proto-grpc @@ -73,44 +83,45 @@ opentelemetry-api==1.30.0 # opentelemetry-instrumentation-grpc # opentelemetry-sdk # opentelemetry-semantic-conventions -opentelemetry-exporter-otlp==1.30.0 +opentelemetry-exporter-otlp==1.33.0 # via text-generation-server (pyproject.toml) -opentelemetry-exporter-otlp-proto-common==1.30.0 +opentelemetry-exporter-otlp-proto-common==1.33.0 # via # opentelemetry-exporter-otlp-proto-grpc # opentelemetry-exporter-otlp-proto-http -opentelemetry-exporter-otlp-proto-grpc==1.30.0 +opentelemetry-exporter-otlp-proto-grpc==1.33.0 # via opentelemetry-exporter-otlp -opentelemetry-exporter-otlp-proto-http==1.30.0 +opentelemetry-exporter-otlp-proto-http==1.33.0 # via opentelemetry-exporter-otlp -opentelemetry-instrumentation==0.51b0 +opentelemetry-instrumentation==0.54b0 # via opentelemetry-instrumentation-grpc -opentelemetry-instrumentation-grpc==0.51b0 +opentelemetry-instrumentation-grpc==0.54b0 # via text-generation-server (pyproject.toml) -opentelemetry-proto==1.30.0 +opentelemetry-proto==1.33.0 # via # opentelemetry-exporter-otlp-proto-common # opentelemetry-exporter-otlp-proto-grpc # opentelemetry-exporter-otlp-proto-http -opentelemetry-sdk==1.30.0 +opentelemetry-sdk==1.33.0 # via # opentelemetry-exporter-otlp-proto-grpc # opentelemetry-exporter-otlp-proto-http -opentelemetry-semantic-conventions==0.51b0 +opentelemetry-semantic-conventions==0.54b0 # via # opentelemetry-instrumentation # opentelemetry-instrumentation-grpc # opentelemetry-sdk -packaging==24.2 +packaging==25.0 # via # huggingface-hub + # kernels # opentelemetry-instrumentation # transformers -pillow==11.1.0 +pillow==11.2.1 # via text-generation-server (pyproject.toml) prometheus-client==0.21.1 # via text-generation-server (pyproject.toml) -protobuf==5.29.3 +protobuf==5.29.4 # via # text-generation-server (pyproject.toml) # googleapis-common-protos @@ -134,23 +145,23 @@ requests==2.32.3 # huggingface-hub # opentelemetry-exporter-otlp-proto-http # transformers -rich==13.9.4 +rich==14.0.0 # via # text-generation-server (pyproject.toml) # typer -safetensors==0.5.2 +safetensors==0.5.3 # via # text-generation-server (pyproject.toml) # transformers -scipy==1.15.1 +scipy==1.15.3 # via text-generation-server (pyproject.toml) sentencepiece==0.2.0 # via text-generation-server (pyproject.toml) -setuptools==75.8.0 +setuptools==80.4.0 # via grpcio-tools shellingham==1.5.4 # via typer -tokenizers==0.21.0 +tokenizers==0.21.1 # via # text-generation-server (pyproject.toml) # transformers @@ -158,18 +169,18 @@ tqdm==4.67.1 # via # huggingface-hub # transformers -transformers==4.49 +transformers==4.51.3 # via text-generation-server (pyproject.toml) -typer==0.15.1 +typer==0.15.3 # via text-generation-server (pyproject.toml) -types-protobuf==5.29.1.20241207 +types-protobuf==6.30.2.20250506 # via mypy-protobuf -typing-extensions==4.12.2 +typing-extensions==4.13.2 # via # huggingface-hub # opentelemetry-sdk # typer -urllib3==2.3.0 +urllib3==2.4.0 # via requests wrapt==1.17.2 # via diff --git a/server/requirements_intel.txt b/server/requirements_intel.txt index c671199f..0cad583c 100644 --- a/server/requirements_intel.txt +++ b/server/requirements_intel.txt @@ -1,39 +1,41 @@ # This file was autogenerated by uv via the following command: # uv pip compile pyproject.toml --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines -o requirements_intel.txt --python-version 3.11 -accelerate==1.3.0 +accelerate==1.6.0 # via # text-generation-server (pyproject.toml) # peft -aiohappyeyeballs==2.4.4 +aiohappyeyeballs==2.6.1 # via aiohttp -aiohttp==3.11.11 +aiohttp==3.11.18 # via # datasets # fsspec aiosignal==1.3.2 # via aiohttp -airportsdata==20241001 +airportsdata==20250224 # via outlines annotated-types==0.7.0 # via pydantic -attrs==25.1.0 +attrs==25.3.0 # via # aiohttp # jsonschema # referencing -certifi==2024.8.30 +certifi==2025.4.26 # via requests -charset-normalizer==3.4.0 +charset-normalizer==3.4.2 # via requests -click==8.1.7 - # via typer +click==8.1.8 + # via + # text-generation-server (pyproject.toml) + # typer cloudpickle==3.1.1 # via outlines -compressed-tensors==0.9.1 +compressed-tensors==0.9.4 # via text-generation-server (pyproject.toml) datasets==2.21.0 # via text-generation-server (pyproject.toml) -deprecated==1.2.14 +deprecated==1.2.18 # via # opentelemetry-api # opentelemetry-exporter-otlp-proto-grpc @@ -45,15 +47,15 @@ dill==0.3.8 # multiprocess diskcache==5.6.3 # via outlines -einops==0.8.0 +einops==0.8.1 # via text-generation-server (pyproject.toml) -filelock==3.16.1 +filelock==3.18.0 # via # datasets # huggingface-hub # torch # transformers -frozenlist==1.5.0 +frozenlist==1.6.0 # via # aiohttp # aiosignal @@ -64,30 +66,36 @@ fsspec==2024.6.1 # torch genson==1.3.0 # via outlines -googleapis-common-protos==1.65.0 +googleapis-common-protos==1.70.0 # via # grpcio-status # opentelemetry-exporter-otlp-proto-grpc # opentelemetry-exporter-otlp-proto-http grpc-interceptor==0.15.4 # via text-generation-server (pyproject.toml) -grpcio==1.68.0 +grpcio==1.71.0 # via # text-generation-server (pyproject.toml) # grpc-interceptor # grpcio-reflection # grpcio-status # opentelemetry-exporter-otlp-proto-grpc -grpcio-reflection==1.68.0 +grpcio-reflection==1.71.0 # via text-generation-server (pyproject.toml) -grpcio-status==1.68.0 +grpcio-status==1.71.0 # via text-generation-server (pyproject.toml) -hf-transfer==0.1.8 +hf-transfer==0.1.9 # via text-generation-server (pyproject.toml) -huggingface-hub==0.28.1 +hf-xet==1.1.0 # via + # text-generation-server (pyproject.toml) + # huggingface-hub +huggingface-hub==0.31.1 + # via + # text-generation-server (pyproject.toml) # accelerate # datasets + # kernels # peft # tokenizers # transformers @@ -95,13 +103,15 @@ idna==3.10 # via # requests # yarl -importlib-metadata==7.1.0 +importlib-metadata==8.6.1 # via opentelemetry-api interegular==0.3.3 # via # outlines # outlines-core -jinja2==3.1.5 +iso3166==2.1.1 + # via outlines +jinja2==3.1.6 # via # outlines # torch @@ -109,8 +119,10 @@ jsonschema==4.23.0 # via # outlines # outlines-core -jsonschema-specifications==2024.10.1 +jsonschema-specifications==2025.4.1 # via jsonschema +kernels==0.5.0 + # via text-generation-server (pyproject.toml) lark==1.2.2 # via outlines loguru==0.7.3 @@ -123,7 +135,7 @@ mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy -multidict==6.1.0 +multidict==6.4.3 # via # aiohttp # yarl @@ -133,7 +145,7 @@ nest-asyncio==1.6.0 # via outlines networkx==3.4.2 # via torch -numpy==1.26.4 +numpy==2.2.5 # via # text-generation-server (pyproject.toml) # accelerate @@ -143,41 +155,44 @@ numpy==1.26.4 # peft # scipy # transformers -nvidia-cublas-cu12==12.4.5.8 +nvidia-cublas-cu12==12.6.4.1 # via # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 # torch -nvidia-cuda-cupti-cu12==12.4.127 +nvidia-cuda-cupti-cu12==12.6.80 # via torch -nvidia-cuda-nvrtc-cu12==12.4.127 +nvidia-cuda-nvrtc-cu12==12.6.77 # via torch -nvidia-cuda-runtime-cu12==12.4.127 +nvidia-cuda-runtime-cu12==12.6.77 # via torch -nvidia-cudnn-cu12==9.1.0.70 +nvidia-cudnn-cu12==9.5.1.17 # via torch -nvidia-cufft-cu12==11.2.1.3 +nvidia-cufft-cu12==11.3.0.4 # via torch -nvidia-curand-cu12==10.3.5.147 +nvidia-cufile-cu12==1.11.1.6 # via torch -nvidia-cusolver-cu12==11.6.1.9 +nvidia-curand-cu12==10.3.7.77 # via torch -nvidia-cusparse-cu12==12.3.1.170 +nvidia-cusolver-cu12==11.7.1.2 + # via torch +nvidia-cusparse-cu12==12.5.4.2 # via # nvidia-cusolver-cu12 # torch -nvidia-cusparselt-cu12==0.6.2 +nvidia-cusparselt-cu12==0.6.3 # via torch -nvidia-nccl-cu12==2.21.5 +nvidia-nccl-cu12==2.26.2 # via torch -nvidia-nvjitlink-cu12==12.4.127 +nvidia-nvjitlink-cu12==12.6.85 # via + # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 # torch -nvidia-nvtx-cu12==12.4.127 +nvidia-nvtx-cu12==12.6.77 # via torch -opentelemetry-api==1.30.0 +opentelemetry-api==1.33.0 # via # text-generation-server (pyproject.toml) # opentelemetry-exporter-otlp-proto-grpc @@ -186,86 +201,85 @@ opentelemetry-api==1.30.0 # opentelemetry-instrumentation-grpc # opentelemetry-sdk # opentelemetry-semantic-conventions -opentelemetry-exporter-otlp==1.30.0 +opentelemetry-exporter-otlp==1.33.0 # via text-generation-server (pyproject.toml) -opentelemetry-exporter-otlp-proto-common==1.30.0 +opentelemetry-exporter-otlp-proto-common==1.33.0 # via # opentelemetry-exporter-otlp-proto-grpc # opentelemetry-exporter-otlp-proto-http -opentelemetry-exporter-otlp-proto-grpc==1.30.0 +opentelemetry-exporter-otlp-proto-grpc==1.33.0 # via opentelemetry-exporter-otlp -opentelemetry-exporter-otlp-proto-http==1.30.0 +opentelemetry-exporter-otlp-proto-http==1.33.0 # via opentelemetry-exporter-otlp -opentelemetry-instrumentation==0.51b0 +opentelemetry-instrumentation==0.54b0 # via opentelemetry-instrumentation-grpc -opentelemetry-instrumentation-grpc==0.51b0 +opentelemetry-instrumentation-grpc==0.54b0 # via text-generation-server (pyproject.toml) -opentelemetry-proto==1.30.0 +opentelemetry-proto==1.33.0 # via # opentelemetry-exporter-otlp-proto-common # opentelemetry-exporter-otlp-proto-grpc # opentelemetry-exporter-otlp-proto-http -opentelemetry-sdk==1.30.0 +opentelemetry-sdk==1.33.0 # via # opentelemetry-exporter-otlp-proto-grpc # opentelemetry-exporter-otlp-proto-http -opentelemetry-semantic-conventions==0.51b0 +opentelemetry-semantic-conventions==0.54b0 # via # opentelemetry-instrumentation # opentelemetry-instrumentation-grpc # opentelemetry-sdk -outlines==0.1.14 +outlines==0.2.3 # via text-generation-server (pyproject.toml) outlines-core==0.1.26 # via outlines -packaging==24.1 +packaging==25.0 # via # accelerate # datasets # huggingface-hub + # kernels # opentelemetry-instrumentation # peft # transformers pandas==2.2.3 # via datasets -peft==0.14.0 +peft==0.15.2 # via text-generation-server (pyproject.toml) -pillow==11.1.0 +pillow==11.2.1 # via text-generation-server (pyproject.toml) prometheus-client==0.21.1 # via text-generation-server (pyproject.toml) -propcache==0.2.1 +propcache==0.3.1 # via # aiohttp # yarl -protobuf==5.29.3 +protobuf==5.29.4 # via # text-generation-server (pyproject.toml) # googleapis-common-protos # grpcio-reflection # grpcio-status # opentelemetry-proto -psutil==6.1.1 +psutil==7.0.0 # via # accelerate # peft py-cpuinfo==9.0.0 # via text-generation-server (pyproject.toml) -pyarrow==19.0.0 +pyarrow==20.0.0 # via datasets -pycountry==24.6.1 - # via outlines -pydantic==2.10.6 +pydantic==2.11.4 # via # compressed-tensors # outlines -pydantic-core==2.27.2 +pydantic-core==2.33.2 # via pydantic -pygments==2.18.0 +pygments==2.19.1 # via rich python-dateutil==2.9.0.post0 # via pandas -pytz==2025.1 +pytz==2025.2 # via pandas pyyaml==6.0.2 # via @@ -279,7 +293,7 @@ referencing==0.36.2 # jsonschema # jsonschema-specifications # outlines -regex==2024.9.11 +regex==2024.11.6 # via transformers requests==2.32.3 # via @@ -288,59 +302,61 @@ requests==2.32.3 # opentelemetry-exporter-otlp-proto-http # outlines # transformers -rich==13.9.4 +rich==14.0.0 # via # text-generation-server (pyproject.toml) # typer -rpds-py==0.22.3 +rpds-py==0.24.0 # via # jsonschema # referencing -safetensors==0.4.5 +safetensors==0.5.3 # via # text-generation-server (pyproject.toml) # accelerate # peft # transformers -scipy==1.13.1 +scipy==1.15.3 # via text-generation-server (pyproject.toml) sentencepiece==0.2.0 # via text-generation-server (pyproject.toml) +setuptools==80.4.0 + # via triton shellingham==1.5.4 # via typer six==1.17.0 # via python-dateutil -sympy==1.13.1 +sympy==1.14.0 # via torch texttable==1.7.0 # via text-generation-server (pyproject.toml) -tokenizers==0.21.0 +tokenizers==0.21.1 # via # text-generation-server (pyproject.toml) # transformers -torch==2.6.0 +torch==2.7.0 # via # accelerate # compressed-tensors # outlines # peft -tqdm==4.66.5 +tqdm==4.67.1 # via # datasets # huggingface-hub # outlines # peft # transformers -transformers==4.49 +transformers==4.51.3 # via # text-generation-server (pyproject.toml) # compressed-tensors # peft -triton==3.2.0 +triton==3.3.0 # via torch -typer==0.15.1 +typer==0.15.3 # via text-generation-server (pyproject.toml) -typing-extensions==4.12.2 +typing-extensions==4.13.2 # via # huggingface-hub # opentelemetry-sdk @@ -350,18 +366,21 @@ typing-extensions==4.12.2 # referencing # torch # typer -tzdata==2025.1 + # typing-inspection +typing-inspection==0.4.0 + # via pydantic +tzdata==2025.2 # via pandas -urllib3==2.2.3 +urllib3==2.4.0 # via requests -wrapt==1.16.0 +wrapt==1.17.2 # via # deprecated # opentelemetry-instrumentation # opentelemetry-instrumentation-grpc xxhash==3.5.0 # via datasets -yarl==1.18.3 +yarl==1.20.0 # via aiohttp -zipp==3.20.2 +zipp==3.21.0 # via importlib-metadata diff --git a/server/requirements_rocm.txt b/server/requirements_rocm.txt index fe7ca572..1f71c8e6 100644 --- a/server/requirements_rocm.txt +++ b/server/requirements_rocm.txt @@ -1,39 +1,41 @@ # This file was autogenerated by uv via the following command: # uv pip compile pyproject.toml --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines -o requirements_rocm.txt --python-version 3.11 -accelerate==1.3.0 +accelerate==1.6.0 # via # text-generation-server (pyproject.toml) # peft -aiohappyeyeballs==2.4.4 +aiohappyeyeballs==2.6.1 # via aiohttp -aiohttp==3.11.11 +aiohttp==3.11.18 # via # datasets # fsspec aiosignal==1.3.2 # via aiohttp -airportsdata==20241001 +airportsdata==20250224 # via outlines annotated-types==0.7.0 # via pydantic -attrs==25.1.0 +attrs==25.3.0 # via # aiohttp # jsonschema # referencing -certifi==2024.8.30 +certifi==2025.4.26 # via requests -charset-normalizer==3.4.0 +charset-normalizer==3.4.2 # via requests -click==8.1.7 - # via typer +click==8.1.8 + # via + # text-generation-server (pyproject.toml) + # typer cloudpickle==3.1.1 # via outlines -compressed-tensors==0.9.1 +compressed-tensors==0.9.4 # via text-generation-server (pyproject.toml) datasets==2.21.0 # via text-generation-server (pyproject.toml) -deprecated==1.2.14 +deprecated==1.2.18 # via # opentelemetry-api # opentelemetry-exporter-otlp-proto-grpc @@ -45,15 +47,15 @@ dill==0.3.8 # multiprocess diskcache==5.6.3 # via outlines -einops==0.8.0 +einops==0.8.1 # via text-generation-server (pyproject.toml) -filelock==3.16.1 +filelock==3.18.0 # via # datasets # huggingface-hub # torch # transformers -frozenlist==1.5.0 +frozenlist==1.6.0 # via # aiohttp # aiosignal @@ -64,30 +66,36 @@ fsspec==2024.6.1 # torch genson==1.3.0 # via outlines -googleapis-common-protos==1.65.0 +googleapis-common-protos==1.70.0 # via # grpcio-status # opentelemetry-exporter-otlp-proto-grpc # opentelemetry-exporter-otlp-proto-http grpc-interceptor==0.15.4 # via text-generation-server (pyproject.toml) -grpcio==1.68.0 +grpcio==1.71.0 # via # text-generation-server (pyproject.toml) # grpc-interceptor # grpcio-reflection # grpcio-status # opentelemetry-exporter-otlp-proto-grpc -grpcio-reflection==1.68.0 +grpcio-reflection==1.71.0 # via text-generation-server (pyproject.toml) -grpcio-status==1.68.0 +grpcio-status==1.71.0 # via text-generation-server (pyproject.toml) -hf-transfer==0.1.8 +hf-transfer==0.1.9 # via text-generation-server (pyproject.toml) -huggingface-hub==0.28.1 +hf-xet==1.1.0 # via + # text-generation-server (pyproject.toml) + # huggingface-hub +huggingface-hub==0.31.1 + # via + # text-generation-server (pyproject.toml) # accelerate # datasets + # kernels # peft # tokenizers # transformers @@ -95,13 +103,15 @@ idna==3.10 # via # requests # yarl -importlib-metadata==7.1.0 +importlib-metadata==8.6.1 # via opentelemetry-api interegular==0.3.3 # via # outlines # outlines-core -jinja2==3.1.5 +iso3166==2.1.1 + # via outlines +jinja2==3.1.6 # via # outlines # torch @@ -109,8 +119,10 @@ jsonschema==4.23.0 # via # outlines # outlines-core -jsonschema-specifications==2024.10.1 +jsonschema-specifications==2025.4.1 # via jsonschema +kernels==0.5.0 + # via text-generation-server (pyproject.toml) lark==1.2.2 # via outlines loguru==0.7.3 @@ -123,7 +135,7 @@ mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy -multidict==6.1.0 +multidict==6.4.3 # via # aiohttp # yarl @@ -133,7 +145,7 @@ nest-asyncio==1.6.0 # via outlines networkx==3.4.2 # via torch -numpy==1.26.4 +numpy==2.2.5 # via # text-generation-server (pyproject.toml) # accelerate @@ -143,41 +155,44 @@ numpy==1.26.4 # peft # scipy # transformers -nvidia-cublas-cu12==12.4.5.8 +nvidia-cublas-cu12==12.6.4.1 # via # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 # torch -nvidia-cuda-cupti-cu12==12.4.127 +nvidia-cuda-cupti-cu12==12.6.80 # via torch -nvidia-cuda-nvrtc-cu12==12.4.127 +nvidia-cuda-nvrtc-cu12==12.6.77 # via torch -nvidia-cuda-runtime-cu12==12.4.127 +nvidia-cuda-runtime-cu12==12.6.77 # via torch -nvidia-cudnn-cu12==9.1.0.70 +nvidia-cudnn-cu12==9.5.1.17 # via torch -nvidia-cufft-cu12==11.2.1.3 +nvidia-cufft-cu12==11.3.0.4 # via torch -nvidia-curand-cu12==10.3.5.147 +nvidia-cufile-cu12==1.11.1.6 # via torch -nvidia-cusolver-cu12==11.6.1.9 +nvidia-curand-cu12==10.3.7.77 # via torch -nvidia-cusparse-cu12==12.3.1.170 +nvidia-cusolver-cu12==11.7.1.2 + # via torch +nvidia-cusparse-cu12==12.5.4.2 # via # nvidia-cusolver-cu12 # torch -nvidia-cusparselt-cu12==0.6.2 +nvidia-cusparselt-cu12==0.6.3 # via torch -nvidia-nccl-cu12==2.21.5 +nvidia-nccl-cu12==2.26.2 # via torch -nvidia-nvjitlink-cu12==12.4.127 +nvidia-nvjitlink-cu12==12.6.85 # via + # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 # torch -nvidia-nvtx-cu12==12.4.127 +nvidia-nvtx-cu12==12.6.77 # via torch -opentelemetry-api==1.30.0 +opentelemetry-api==1.33.0 # via # text-generation-server (pyproject.toml) # opentelemetry-exporter-otlp-proto-grpc @@ -186,86 +201,85 @@ opentelemetry-api==1.30.0 # opentelemetry-instrumentation-grpc # opentelemetry-sdk # opentelemetry-semantic-conventions -opentelemetry-exporter-otlp==1.30.0 +opentelemetry-exporter-otlp==1.33.0 # via text-generation-server (pyproject.toml) -opentelemetry-exporter-otlp-proto-common==1.30.0 +opentelemetry-exporter-otlp-proto-common==1.33.0 # via # opentelemetry-exporter-otlp-proto-grpc # opentelemetry-exporter-otlp-proto-http -opentelemetry-exporter-otlp-proto-grpc==1.30.0 +opentelemetry-exporter-otlp-proto-grpc==1.33.0 # via opentelemetry-exporter-otlp -opentelemetry-exporter-otlp-proto-http==1.30.0 +opentelemetry-exporter-otlp-proto-http==1.33.0 # via opentelemetry-exporter-otlp -opentelemetry-instrumentation==0.51b0 +opentelemetry-instrumentation==0.54b0 # via opentelemetry-instrumentation-grpc -opentelemetry-instrumentation-grpc==0.51b0 +opentelemetry-instrumentation-grpc==0.54b0 # via text-generation-server (pyproject.toml) -opentelemetry-proto==1.30.0 +opentelemetry-proto==1.33.0 # via # opentelemetry-exporter-otlp-proto-common # opentelemetry-exporter-otlp-proto-grpc # opentelemetry-exporter-otlp-proto-http -opentelemetry-sdk==1.30.0 +opentelemetry-sdk==1.33.0 # via # opentelemetry-exporter-otlp-proto-grpc # opentelemetry-exporter-otlp-proto-http -opentelemetry-semantic-conventions==0.51b0 +opentelemetry-semantic-conventions==0.54b0 # via # opentelemetry-instrumentation # opentelemetry-instrumentation-grpc # opentelemetry-sdk -outlines==0.1.14 +outlines==0.2.3 # via text-generation-server (pyproject.toml) outlines-core==0.1.26 # via outlines -packaging==24.1 +packaging==25.0 # via # accelerate # datasets # huggingface-hub + # kernels # opentelemetry-instrumentation # peft # transformers pandas==2.2.3 # via datasets -peft==0.14.0 +peft==0.15.2 # via text-generation-server (pyproject.toml) -pillow==11.1.0 +pillow==11.2.1 # via text-generation-server (pyproject.toml) prometheus-client==0.21.1 # via text-generation-server (pyproject.toml) -propcache==0.2.1 +propcache==0.3.1 # via # aiohttp # yarl -protobuf==5.29.3 +protobuf==5.29.4 # via # text-generation-server (pyproject.toml) # googleapis-common-protos # grpcio-reflection # grpcio-status # opentelemetry-proto -psutil==6.1.1 +psutil==7.0.0 # via # accelerate # peft py-cpuinfo==9.0.0 # via text-generation-server (pyproject.toml) -pyarrow==19.0.0 +pyarrow==20.0.0 # via datasets -pycountry==24.6.1 - # via outlines -pydantic==2.10.6 +pydantic==2.11.4 # via # compressed-tensors # outlines -pydantic-core==2.27.2 +pydantic-core==2.33.2 # via pydantic -pygments==2.18.0 +pygments==2.19.1 # via rich python-dateutil==2.9.0.post0 # via pandas -pytz==2025.1 +pytz==2025.2 # via pandas pyyaml==6.0.2 # via @@ -279,7 +293,7 @@ referencing==0.36.2 # jsonschema # jsonschema-specifications # outlines -regex==2024.9.11 +regex==2024.11.6 # via transformers requests==2.32.3 # via @@ -288,59 +302,61 @@ requests==2.32.3 # opentelemetry-exporter-otlp-proto-http # outlines # transformers -rich==13.9.4 +rich==14.0.0 # via # text-generation-server (pyproject.toml) # typer -rpds-py==0.22.3 +rpds-py==0.24.0 # via # jsonschema # referencing -safetensors==0.4.5 +safetensors==0.5.3 # via # text-generation-server (pyproject.toml) # accelerate # peft # transformers -scipy==1.13.1 +scipy==1.15.3 # via text-generation-server (pyproject.toml) sentencepiece==0.2.0 # via text-generation-server (pyproject.toml) +setuptools==80.4.0 + # via triton shellingham==1.5.4 # via typer six==1.17.0 # via python-dateutil -sympy==1.13.1 +sympy==1.14.0 # via torch texttable==1.7.0 # via text-generation-server (pyproject.toml) -tokenizers==0.21.0 +tokenizers==0.21.1 # via # text-generation-server (pyproject.toml) # transformers -torch==2.6.0 +torch==2.7.0 # via # accelerate # compressed-tensors # outlines # peft -tqdm==4.66.5 +tqdm==4.67.1 # via # datasets # huggingface-hub # outlines # peft # transformers -transformers==4.49 +transformers==4.51.3 # via # text-generation-server (pyproject.toml) # compressed-tensors # peft -triton==3.2.0 +triton==3.3.0 # via torch -typer==0.15.1 +typer==0.15.3 # via text-generation-server (pyproject.toml) -typing-extensions==4.12.2 +typing-extensions==4.13.2 # via # huggingface-hub # opentelemetry-sdk @@ -350,18 +366,21 @@ typing-extensions==4.12.2 # referencing # torch # typer -tzdata==2025.1 + # typing-inspection +typing-inspection==0.4.0 + # via pydantic +tzdata==2025.2 # via pandas -urllib3==2.2.3 +urllib3==2.4.0 # via requests -wrapt==1.16.0 +wrapt==1.17.2 # via # deprecated # opentelemetry-instrumentation # opentelemetry-instrumentation-grpc xxhash==3.5.0 # via datasets -yarl==1.18.3 +yarl==1.20.0 # via aiohttp -zipp==3.20.2 +zipp==3.21.0 # via importlib-metadata diff --git a/server/text_generation_server/adapters/lora.py b/server/text_generation_server/adapters/lora.py index 782d66e4..c8eb48a2 100644 --- a/server/text_generation_server/adapters/lora.py +++ b/server/text_generation_server/adapters/lora.py @@ -13,21 +13,20 @@ from torch.distributed import ProcessGroup from text_generation_server.utils.log import log_master from text_generation_server.adapters.config import AdapterConfig, ModuleMap - +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.kernels import load_kernel from text_generation_server.adapters.weights import ( AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights, ) -from text_generation_server.utils.sgmv import ( - BGMV_MAX_RANK, - MAX_RANK_CUSTOM, - get_tmp_tensors, - orient_for_rank, - pad_rank, - use_cutlass_shrink, - has_sgmv, -) + +if SYSTEM == "cuda": + punica_sgmv = load_kernel( + module="punica_sgmv", repo_id="kernels-community/punica-sgmv" + ) +else: + punica_sgmv = None def get_start_stop_idxs_for_rank(offset, size, rank, world_size): @@ -129,11 +128,13 @@ class LoraWeights(AdapterWeights): self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1 self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1 - self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r) + self._use_cutlass_shrink = punica_sgmv.use_cutlass_shrink(self.lora_a_r) self._is_transposed = False # [num_layers, hidden_size, r] - weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a] + weights_a = [ + punica_sgmv.orient_for_rank(w, w.size(1)).contiguous() for w in weights_a + ] self._weights_a = torch.stack(weights_a) # [num_layers, r, hidden_size] @@ -244,8 +245,12 @@ class LoraWeights(AdapterWeights): lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale # pad lora ranks to be compatible with sgmv - lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list] - lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list] + lora_a_list = [ + punica_sgmv.pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list + ] + lora_b_list = [ + punica_sgmv.pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list + ] if lora_a_list: # update rank if it was padded @@ -293,7 +298,7 @@ class BatchLoraWeights(BatchAdapterWeights): def can_vectorize(self, pg: ProcessGroup) -> bool: return all( - rank_data.rank // pg.size() <= MAX_RANK_CUSTOM + rank_data.rank // pg.size() <= punica_sgmv.MAX_RANK_CUSTOM for rank_data in self.rank_data.values() ) @@ -337,8 +342,8 @@ class BatchLoraWeights(BatchAdapterWeights): ) use_sgmv = False - if prefill or max_rank > BGMV_MAX_RANK: - if has_sgmv(): + if prefill or max_rank > punica_sgmv.BGMV_MAX_RANK: + if punica_sgmv is not None: use_sgmv = True lora_a_ptr = torch.tensor( [ @@ -425,7 +430,7 @@ class BatchLoraWeights(BatchAdapterWeights): if use_sgmv: lora_a_ptr_indices = lora_a_ptr[indices] - tmp_shrink, tmp_expand = get_tmp_tensors( + tmp_shrink, tmp_expand = punica_sgmv.get_tmp_tensors( lora_a_ptr_indices.size(0), rank, device ) segment_starts = meta.adapter_segments[indices] diff --git a/server/text_generation_server/layers/lora.py b/server/text_generation_server/layers/lora.py index a4537b55..abfb097d 100644 --- a/server/text_generation_server/layers/lora.py +++ b/server/text_generation_server/layers/lora.py @@ -5,14 +5,16 @@ import torch.distributed from torch import nn from torch.distributed import ProcessGroup -from text_generation_server.utils.sgmv import ( - add_lora_a_bgmv, - add_lora_b_bgmv, - has_sgmv, - lora_a_sgmv_cutlass, - lora_b_sgmv_cutlass, - orient_for_rank, -) +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.kernels import load_kernel + +if SYSTEM == "cuda": + punica_sgmv = load_kernel( + module="punica_sgmv", repo_id="kernels-community/punica-sgmv" + ) +else: + punica_sgmv = None + if TYPE_CHECKING: from text_generation_server.adapters import AdapterBatchData @@ -41,7 +43,11 @@ class LoraLinear(nn.Module): return result data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type) - if has_sgmv() and data is not None and data.can_vectorize(self.process_group): + if ( + punica_sgmv is not None + and data is not None + and data.can_vectorize(self.process_group) + ): # In tensor-parallel configurations, each GPU processes a specific segment of the output. # The 'result' tensor represents the full output, which can vary in size based on # the layer type (e.g., attention vs. feed-forward layers). We define the current @@ -68,7 +74,7 @@ class LoraLinear(nn.Module): if data.use_sgmv: # Use SGMV for prefill - v = lora_a_sgmv_cutlass( + v = punica_sgmv.lora_a_sgmv_cutlass( input, rank_segments.tmp_shrink, lora_a_ptr, @@ -81,7 +87,7 @@ class LoraLinear(nn.Module): if self.process_group.size() > 1: v = self.collect_lora_a(v) - lora_b_sgmv_cutlass( + punica_sgmv.lora_b_sgmv_cutlass( proj, v, rank_segments.tmp_expand, @@ -96,7 +102,7 @@ class LoraLinear(nn.Module): (input.size(0), r), dtype=input.dtype, device=input.device ) # TODO: error with [-1, 0], but not [0, -1] - add_lora_a_bgmv( + punica_sgmv.add_lora_a_bgmv( v, input, lora_a_ptr, @@ -107,7 +113,7 @@ class LoraLinear(nn.Module): if self.process_group.size() > 1: v = self.collect_lora_a(v) - add_lora_b_bgmv( + punica_sgmv.add_lora_b_bgmv( proj, v, lora_b_ptr, @@ -142,7 +148,7 @@ class LoraLinear(nn.Module): lora_a = data.lora_a[adapter_index][self.layer_id, :, :] lora_b = data.lora_b[adapter_index][self.layer_id, :, :] - lora_a = orient_for_rank(lora_a, lora_b.size(0)) + lora_a = punica_sgmv.orient_for_rank(lora_a, lora_b.size(0)) a_out = input @ lora_a if self.process_group.size() > 1: diff --git a/server/text_generation_server/utils/sgmv.py b/server/text_generation_server/utils/sgmv.py deleted file mode 100644 index 2d0a73a5..00000000 --- a/server/text_generation_server/utils/sgmv.py +++ /dev/null @@ -1,252 +0,0 @@ -# Origin: https://github.com/predibase/lorax -# Path: lorax/server/lorax_server/utils/sgmv.py -# License: Apache License Version 2.0, January 2004 - -import os -import warnings -from functools import lru_cache -from typing import List, Tuple - -import torch -import torch.nn.functional as F - -try: - import punica_kernels as _kernels - - HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", "")) -except ImportError: - warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.") - _kernels = None - HAS_SGMV = False - - -MIN_SGMV_RANK = 8 -MIN_RANK_CUSTOM = 16 -MAX_RANK_CUSTOM = 128 -SGMV_BLOCK_SIZE = 16 -BGMV_MAX_RANK = 64 - - -def has_sgmv() -> bool: - return HAS_SGMV - - -def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor: - """Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size.""" - if not has_sgmv(): - return t - - # tensor parallelism will result in effective rank being divided by world_size, - # so we need to scale the min rank to offset that effect - min_rank = MIN_SGMV_RANK * world_size - - # if we're at or below the min rank, pad up to the min rank - # otherwise, pad to the nearest multiple of the block size - current_rank = t.size(dim) - target_rank = ( - min_rank - if current_rank <= min_rank - else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE - ) - if current_rank == target_rank: - return t - - pad_size = target_rank - current_rank - - # see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html - pad = [0, 0] * t.dim() - pad[(t.dim() - dim - 1) * 2 + 1] = pad_size - pad = tuple(pad) - - return F.pad(t, pad, mode="constant", value=0.0) - - -def use_cutlass_shrink(lora_rank: int) -> bool: - return lora_rank < MIN_RANK_CUSTOM - - -def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor: - if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM: - return t.transpose(0, 1) - return t - - -# Source: https://github.com/punica-ai/punica/blob/master/src/punica/ops/__init__.py -def add_lora_sgmv_cutlass( - y: torch.Tensor, - x: torch.Tensor, - wa_ptr: torch.Tensor, - wb_ptr: torch.Tensor, - s_start: torch.Tensor, - s_end: torch.Tensor, - layer_idx: int, - lora_rank: int, -): - """ - Semantics: - y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i]) - - Args: - y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. - x: Shape: `[B, H1]`. Input vectors. - wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\ - Weight matrix shape: `[num_layers, R, H1]`. - wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\ - Weight matrix shape: `[num_layers, R, H2]`. - s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices. - s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices. - layer_idx: Layer index of the weight matrices. - """ - if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM: - # Custom SGMV shrink only supports rank 16, 32, 64, 128 - _add_lora_sgmv_cutlass_legacy( - y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank - ) - return - - tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device) - tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0)) - tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device) - v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) - _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx) - _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx) - - -def _add_lora_sgmv_cutlass_legacy( - y: torch.Tensor, - x: torch.Tensor, - wa_ptr: torch.Tensor, - wb_ptr: torch.Tensor, - s_start: torch.IntTensor, - s_end: torch.IntTensor, - layer_idx: int, - lora_rank: int, -): - tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0)) - tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device) - v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) - _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) - _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx) - - -@lru_cache(maxsize=1) -def get_tmp_tensor(device: torch.device) -> torch.Tensor: - return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device) - - -@lru_cache(maxsize=32) -def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor: - tmp_size = _kernels.sgmv_cutlass_tmp_size(size) - return torch.empty((tmp_size,), dtype=torch.uint8, device=device) - - -def get_tmp_tensor_for_size_no_kernels(size: int, device: torch.device) -> torch.Tensor: - return torch.empty((size,), dtype=torch.uint8, device=device) - - -def get_tmp_expand_size(size: int) -> int: - return _kernels.sgmv_cutlass_tmp_size(size) - - -def get_tmp_tensors( - nsegments: int, lora_rank: int, device: torch.device -) -> Tuple[torch.Tensor, torch.Tensor]: - use_cutlass = use_cutlass_shrink(lora_rank) and has_sgmv() - has_sgmv_available = has_sgmv() - - if use_cutlass: - tmp = get_tmp_tensor_for_size(nsegments, device) - return tmp, tmp - elif has_sgmv_available: - return get_tmp_tensor(device), get_tmp_tensor_for_size(nsegments, device) - else: - tmp = get_tmp_tensor_for_size(nsegments, device) - return tmp, tmp - - -def lora_a_sgmv_cutlass( - x: torch.Tensor, - tmp: torch.Tensor, - wa_ptr: torch.Tensor, - s_start: torch.IntTensor, - s_end: torch.IntTensor, - layer_idx: int, - lora_rank: int, -) -> torch.Tensor: - v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) - if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM: - _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) - else: - _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) - return v - - -def lora_b_sgmv_cutlass( - y: torch.Tensor, - v: torch.Tensor, - tmp: torch.Tensor, - wb_ptr: torch.Tensor, - s_start: torch.IntTensor, - s_end: torch.IntTensor, - layer_idx: int, -): - _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx) - - -""" -Semantics: - y[i] += ( - x[i].unsqueeze(0) - @ wa_T_all[indices[i], layer_idx, :, :].transpose(-1, -2) - @ wb_T_all[indices[i], layer_idx, :, :].transpose(-1, -2) - * scale - ).squeeze(0) - -Args: - y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. - v: Shape: `[B, R]`. Temporary vector. - x: Shape: `[B, H1]`. Input vectors. - wa_T_all: Shape: `[None, L, R, H1]`. All of the transposed LoRA A matrices. - wb_T_all: Shape: `[None, L, H2, R]`. All of the transposed LoRA B matrices. - indicies: Shape: `[B]`. Indices of the LoRA weights. - layer_idx: Layer index of LoRA weights. - scale: Scaling factor. -""" - - -def add_lora_a_bgmv( - v: torch.Tensor, - x: torch.Tensor, - wa_T_all: torch.Tensor, - indicies: torch.LongTensor, - layer_idx: int, -): - _kernels.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0) - - -def add_lora_b_bgmv( - y: torch.Tensor, - v: torch.Tensor, - wb_T_all: torch.Tensor, - indicies: torch.LongTensor, - layer_idx: int, -): - _kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0) - - -def segmented_matmul( - y: torch.Tensor, - x: torch.Tensor, - w: List[torch.Tensor], - b: List[torch.Tensor], - s_start: torch.IntTensor, - s_end: torch.IntTensor, -): - for i in range(len(w)): - if s_end[i] - s_start[i] <= 0: - continue - - xi = x[s_start[i] : s_end[i]] - wi = w[i] - bi = b[i] - y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi) diff --git a/server/uv.lock b/server/uv.lock index b4a95c13..7e6f194a 100644 --- a/server/uv.lock +++ b/server/uv.lock @@ -2,13 +2,17 @@ version = 1 revision = 1 requires-python = ">=3.9" resolution-markers = [ - "(python_full_version >= '3.12' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform == 'win32')", + "(python_full_version >= '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform == 'win32')", + "python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.12' and sys_platform != 'linux' and sys_platform != 'win32'", - "(python_full_version == '3.11.*' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform == 'win32')", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform == 'win32')", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.11.*' and sys_platform != 'linux' and sys_platform != 'win32'", - "(python_full_version == '3.10.*' and sys_platform == 'linux') or (python_full_version == '3.10.*' and sys_platform == 'win32')", + "(python_full_version == '3.10.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.10.*' and sys_platform == 'win32')", + "python_full_version == '3.10.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.10.*' and sys_platform != 'linux' and sys_platform != 'win32'", - "(python_full_version < '3.10' and sys_platform == 'linux') or (python_full_version < '3.10' and sys_platform == 'win32')", + "(python_full_version < '3.10' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.10' and sys_platform == 'win32')", + "python_full_version < '3.10' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.10' and sys_platform != 'linux' and sys_platform != 'win32'", ] @@ -24,8 +28,8 @@ dependencies = [ { name = "psutil" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.7.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, + { name = "torch", version = "2.7.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/85/15/0fab0260ab4069e5224e637d2e400538bb27b0dfc36f17daf68db9770d78/accelerate-1.3.0.tar.gz", hash = "sha256:518631c0adb80bd3d42fb29e7e2dc2256bcd7c786b0ba9119bbaa08611b36d9c", size = 342758 } wheels = [ @@ -194,8 +198,8 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, - { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.7.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, + { name = "torch", version = "2.7.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/db/9d/9382259196d7ad7f3550702390081224e673a705e75b5660ee377b592fc0/bitsandbytes-0.45.2-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:ba3a720187f518b172ebce4081049c682ae3fd8284947e22499b256ff99a2bc3", size = 69680042 }, @@ -321,8 +325,8 @@ version = "0.9.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pydantic" }, - { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.7.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, + { name = "torch", version = "2.7.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "transformers" }, ] sdist = { url = "https://files.pythonhosted.org/packages/40/e0/d9529aae2d2425d214e5a50497df4532d3f9e21c8d2023037c701f8a37d3/compressed-tensors-0.9.1.tar.gz", hash = "sha256:3cf5cd637f0186c184dd5bbbbf941356b1225199b49c6a45bf0909d65907f686", size = 63060 } @@ -833,8 +837,8 @@ dependencies = [ { name = "huggingface-hub" }, { name = "packaging" }, { name = "tomli", marker = "python_full_version < '3.11'" }, - { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.7.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, + { name = "torch", version = "2.7.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/26/99/41af9dce502bb1682977fee1bc487a73fa8418cebbce16b8d27733947375/kernels-0.2.1.tar.gz", hash = "sha256:918942332819b28377b9d07070daddecfd8a5e7bab574dd3dc64a209ca6008b2", size = 9395 } @@ -1092,7 +1096,8 @@ name = "networkx" version = "3.2.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "(python_full_version < '3.10' and sys_platform == 'linux') or (python_full_version < '3.10' and sys_platform == 'win32')", + "(python_full_version < '3.10' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.10' and sys_platform == 'win32')", + "python_full_version < '3.10' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.10' and sys_platform != 'linux' and sys_platform != 'win32'", ] sdist = { url = "https://files.pythonhosted.org/packages/c4/80/a84676339aaae2f1cfdf9f418701dd634aef9cc76f708ef55c36ff39c3ca/networkx-3.2.1.tar.gz", hash = "sha256:9f1bb5cf3409bf324e0a722c20bdb4c20ee39bf1c30ce8ae499c8502b0b5e0c6", size = 2073928 } @@ -1105,11 +1110,14 @@ name = "networkx" version = "3.4.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "(python_full_version >= '3.12' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform == 'win32')", + "(python_full_version >= '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform == 'win32')", + "python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.12' and sys_platform != 'linux' and sys_platform != 'win32'", - "(python_full_version == '3.11.*' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform == 'win32')", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform == 'win32')", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.11.*' and sys_platform != 'linux' and sys_platform != 'win32'", - "(python_full_version == '3.10.*' and sys_platform == 'linux') or (python_full_version == '3.10.*' and sys_platform == 'win32')", + "(python_full_version == '3.10.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.10.*' and sys_platform == 'win32')", + "python_full_version == '3.10.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.10.*' and sys_platform != 'linux' and sys_platform != 'win32'", ] sdist = { url = "https://files.pythonhosted.org/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1", size = 2151368 } @@ -1122,7 +1130,8 @@ name = "numpy" version = "2.0.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "(python_full_version < '3.10' and sys_platform == 'linux') or (python_full_version < '3.10' and sys_platform == 'win32')", + "(python_full_version < '3.10' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.10' and sys_platform == 'win32')", + "python_full_version < '3.10' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.10' and sys_platform != 'linux' and sys_platform != 'win32'", ] sdist = { url = "https://files.pythonhosted.org/packages/a9/75/10dd1f8116a8b796cb2c737b674e02d02e80454bda953fa7e65d8c12b016/numpy-2.0.2.tar.gz", hash = "sha256:883c987dee1880e2a864ab0dc9892292582510604156762362d9326444636e78", size = 18902015 } @@ -1178,11 +1187,14 @@ name = "numpy" version = "2.2.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "(python_full_version >= '3.12' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform == 'win32')", + "(python_full_version >= '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform == 'win32')", + "python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.12' and sys_platform != 'linux' and sys_platform != 'win32'", - "(python_full_version == '3.11.*' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform == 'win32')", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform == 'win32')", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.11.*' and sys_platform != 'linux' and sys_platform != 'win32'", - "(python_full_version == '3.10.*' and sys_platform == 'linux') or (python_full_version == '3.10.*' and sys_platform == 'win32')", + "(python_full_version == '3.10.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.10.*' and sys_platform == 'win32')", + "python_full_version == '3.10.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.10.*' and sys_platform != 'linux' and sys_platform != 'win32'", ] sdist = { url = "https://files.pythonhosted.org/packages/ec/d0/c12ddfd3a02274be06ffc71f3efc6d0e457b0409c4481596881e748cb264/numpy-2.2.2.tar.gz", hash = "sha256:ed6906f61834d687738d25988ae117683705636936cc605be0bb208b23df4d8f", size = 20233295 } @@ -1245,120 +1257,128 @@ wheels = [ [[package]] name = "nvidia-cublas-cu12" -version = "12.4.5.8" +version = "12.8.3.14" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b", size = 363438805 }, + { url = "https://files.pythonhosted.org/packages/82/df/4b01f10069e23c641f116c62fc31e31e8dc361a153175d81561d15c8143b/nvidia_cublas_cu12-12.8.3.14-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44", size = 609620630 }, ] [[package]] name = "nvidia-cuda-cupti-cu12" -version = "12.4.127" +version = "12.8.57" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb", size = 13813957 }, + { url = "https://files.pythonhosted.org/packages/39/6f/3683ecf4e38931971946777d231c2df00dd5c1c4c2c914c42ad8f9f4dca6/nvidia_cuda_cupti_cu12-12.8.57-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950", size = 10237547 }, ] [[package]] name = "nvidia-cuda-nvrtc-cu12" -version = "12.4.127" +version = "12.8.61" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338", size = 24640306 }, + { url = "https://files.pythonhosted.org/packages/d4/22/32029d4583f7b19cfe75c84399cbcfd23f2aaf41c66fc8db4da460104fff/nvidia_cuda_nvrtc_cu12-12.8.61-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a0fa9c2a21583105550ebd871bd76e2037205d56f33f128e69f6d2a55e0af9ed", size = 88024585 }, ] [[package]] name = "nvidia-cuda-runtime-cu12" -version = "12.4.127" +version = "12.8.57" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5", size = 883737 }, + { url = "https://files.pythonhosted.org/packages/16/f6/0e1ef31f4753a44084310ba1a7f0abaf977ccd810a604035abb43421c057/nvidia_cuda_runtime_cu12-12.8.57-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be", size = 954762 }, ] [[package]] name = "nvidia-cudnn-cu12" -version = "9.1.0.70" +version = "9.7.1.26" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or sys_platform == 'win32'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, + { url = "https://files.pythonhosted.org/packages/25/dc/dc825c4b1c83b538e207e34f48f86063c88deaa35d46c651c7c181364ba2/nvidia_cudnn_cu12-9.7.1.26-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07", size = 726851421 }, ] [[package]] name = "nvidia-cufft-cu12" -version = "11.2.1.3" +version = "11.3.3.41" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or sys_platform == 'win32'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 }, + { url = "https://files.pythonhosted.org/packages/ac/26/b53c493c38dccb1f1a42e1a21dc12cba2a77fbe36c652f7726d9ec4aba28/nvidia_cufft_cu12-11.3.3.41-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a", size = 193118795 }, +] + +[[package]] +name = "nvidia-cufile-cu12" +version = "1.13.0.11" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/9c/1f3264d0a84c8a031487fb7f59780fc78fa6f1c97776233956780e3dc3ac/nvidia_cufile_cu12-1.13.0.11-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:483f434c541806936b98366f6d33caef5440572de8ddf38d453213729da3e7d4", size = 1197801 }, ] [[package]] name = "nvidia-curand-cu12" -version = "10.3.5.147" +version = "10.3.9.55" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b", size = 56305206 }, + { url = "https://files.pythonhosted.org/packages/bd/fc/7be5d0082507269bb04ac07cc614c84b78749efb96e8cf4100a8a1178e98/nvidia_curand_cu12-10.3.9.55-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8387d974240c91f6a60b761b83d4b2f9b938b7e0b9617bae0f0dafe4f5c36b86", size = 63618038 }, ] [[package]] name = "nvidia-cusolver-cu12" -version = "11.6.1.9" +version = "11.7.2.55" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or sys_platform == 'win32'" }, + { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or sys_platform == 'win32'" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or sys_platform == 'win32'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057 }, + { url = "https://files.pythonhosted.org/packages/c2/08/953675873a136d96bb12f93b49ba045d1107bc94d2551c52b12fa6c7dec3/nvidia_cusolver_cu12-11.7.2.55-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b", size = 260373342 }, ] [[package]] name = "nvidia-cusparse-cu12" -version = "12.3.1.170" +version = "12.5.7.53" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or sys_platform == 'win32'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 }, + { url = "https://files.pythonhosted.org/packages/c2/ab/31e8149c66213b846c082a3b41b1365b831f41191f9f40c6ddbc8a7d550e/nvidia_cusparse_cu12-12.5.7.53-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d", size = 292064180 }, ] [[package]] name = "nvidia-cusparselt-cu12" -version = "0.6.2" +version = "0.6.3" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/78/a8/bcbb63b53a4b1234feeafb65544ee55495e1bb37ec31b999b963cbccfd1d/nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:df2c24502fd76ebafe7457dbc4716b2fec071aabaed4fb7691a201cde03704d9", size = 150057751 }, + { url = "https://files.pythonhosted.org/packages/3b/9a/72ef35b399b0e183bc2e8f6f558036922d453c4d8237dab26c666a04244b/nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e5c8a26c36445dd2e6812f1177978a24e2d37cacce7e090f297a688d1ec44f46", size = 156785796 }, ] [[package]] name = "nvidia-nccl-cu12" -version = "2.21.5" +version = "2.26.2" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/df/99/12cd266d6233f47d00daf3a72739872bdc10267d0383508b0b9c84a18bb6/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0", size = 188654414 }, + { url = "https://files.pythonhosted.org/packages/67/ca/f42388aed0fddd64ade7493dbba36e1f534d4e6fdbdd355c6a90030ae028/nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:694cf3879a206553cc9d7dbda76b13efaf610fdb70a50cba303de1b0d1530ac6", size = 201319755 }, ] [[package]] name = "nvidia-nvjitlink-cu12" -version = "12.4.127" +version = "12.8.61" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57", size = 21066810 }, + { url = "https://files.pythonhosted.org/packages/03/f8/9d85593582bd99b8d7c65634d2304780aefade049b2b94d96e44084be90b/nvidia_nvjitlink_cu12-12.8.61-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17", size = 39243473 }, ] [[package]] name = "nvidia-nvtx-cu12" -version = "12.4.127" +version = "12.8.55" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144 }, + { url = "https://files.pythonhosted.org/packages/8d/cd/0e8c51b2ae3a58f054f2e7fe91b82d201abfb30167f2431e9bd92d532f42/nvidia_nvtx_cu12-12.8.55-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2dd0780f1a55c21d8e06a743de5bd95653de630decfff40621dbde78cc307102", size = 89896 }, ] [[package]] @@ -1525,8 +1545,8 @@ dependencies = [ { name = "pydantic" }, { name = "referencing" }, { name = "requests" }, - { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.7.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, + { name = "torch", version = "2.7.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "tqdm" }, { name = "typing-extensions" }, ] @@ -1649,8 +1669,8 @@ dependencies = [ { name = "psutil" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.7.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, + { name = "torch", version = "2.7.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "tqdm" }, { name = "transformers" }, ] @@ -2418,7 +2438,8 @@ name = "scipy" version = "1.13.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "(python_full_version < '3.10' and sys_platform == 'linux') or (python_full_version < '3.10' and sys_platform == 'win32')", + "(python_full_version < '3.10' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.10' and sys_platform == 'win32')", + "python_full_version < '3.10' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.10' and sys_platform != 'linux' and sys_platform != 'win32'", ] dependencies = [ @@ -2457,11 +2478,14 @@ name = "scipy" version = "1.15.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "(python_full_version >= '3.12' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform == 'win32')", + "(python_full_version >= '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform == 'win32')", + "python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version >= '3.12' and sys_platform != 'linux' and sys_platform != 'win32'", - "(python_full_version == '3.11.*' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform == 'win32')", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform == 'win32')", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.11.*' and sys_platform != 'linux' and sys_platform != 'win32'", - "(python_full_version == '3.10.*' and sys_platform == 'linux') or (python_full_version == '3.10.*' and sys_platform == 'win32')", + "(python_full_version == '3.10.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.10.*' and sys_platform == 'win32')", + "python_full_version == '3.10.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.10.*' and sys_platform != 'linux' and sys_platform != 'win32'", ] dependencies = [ @@ -2579,14 +2603,14 @@ wheels = [ [[package]] name = "sympy" -version = "1.13.1" +version = "1.14.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "mpmath" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ca/99/5a5b6f19ff9f083671ddf7b9632028436167cd3d33e11015754e41b249a4/sympy-1.13.1.tar.gz", hash = "sha256:9cebf7e04ff162015ce31c9c6c9144daa34a93bd082f54fd8f12deca4f47515f", size = 7533040 } +sdist = { url = "https://files.pythonhosted.org/packages/83/d3/803453b36afefb7c2bb238361cd4ae6125a569b4db67cd9e79846ba2d68c/sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517", size = 7793921 } wheels = [ - { url = "https://files.pythonhosted.org/packages/b2/fe/81695a1aa331a842b582453b605175f419fe8540355886031328089d840a/sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8", size = 6189177 }, + { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353 }, ] [[package]] @@ -2594,6 +2618,7 @@ name = "text-generation-server" version = "2.0.5.dev0" source = { editable = "." } dependencies = [ + { name = "click" }, { name = "einops" }, { name = "grpc-interceptor" }, { name = "grpcio" }, @@ -2652,16 +2677,18 @@ quantize = [ { name = "texttable" }, ] torch = [ - { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "torchvision", version = "0.21.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torchvision", version = "0.21.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torch", version = "2.7.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, + { name = "torch", version = "2.7.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torchvision", version = "0.22.0", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "torchvision", version = "0.22.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, + { name = "torchvision", version = "0.22.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or sys_platform == 'win32'" }, ] [package.metadata] requires-dist = [ { name = "accelerate", marker = "extra == 'accelerate'", specifier = ">=1.2.1,<2" }, { name = "bitsandbytes", marker = "extra == 'bnb'", specifier = ">=0.45.0" }, + { name = "click", specifier = "<8.2.0" }, { name = "compressed-tensors", marker = "extra == 'compressed-tensors'", specifier = ">=0.9.0" }, { name = "datasets", marker = "extra == 'quantize'", specifier = ">=2.21,<3" }, { name = "einops", specifier = ">=0.8.0" }, @@ -2694,10 +2721,10 @@ requires-dist = [ { name = "sentencepiece", specifier = ">=0.2.0" }, { name = "texttable", marker = "extra == 'quantize'", specifier = ">=1.6.7,<2" }, { name = "tokenizers", specifier = ">=0.20.3" }, - { name = "torch", marker = "(sys_platform == 'linux' and extra == 'torch') or (sys_platform == 'win32' and extra == 'torch')", specifier = "==2.6.0", index = "https://download.pytorch.org/whl/cu124" }, - { name = "torch", marker = "sys_platform != 'linux' and sys_platform != 'win32' and extra == 'torch'", specifier = "==2.6.0" }, - { name = "torchvision", marker = "(sys_platform == 'linux' and extra == 'torch') or (sys_platform == 'win32' and extra == 'torch')", specifier = "==0.21.0", index = "https://download.pytorch.org/whl/cu124" }, - { name = "torchvision", marker = "sys_platform != 'linux' and sys_platform != 'win32' and extra == 'torch'", specifier = "==0.21.0" }, + { name = "torch", marker = "(sys_platform == 'linux' and extra == 'torch') or (sys_platform == 'win32' and extra == 'torch')", specifier = "==2.7.0", index = "https://download.pytorch.org/whl/cu128" }, + { name = "torch", marker = "sys_platform != 'linux' and sys_platform != 'win32' and extra == 'torch'", specifier = "==2.7.0" }, + { name = "torchvision", marker = "(sys_platform == 'linux' and extra == 'torch') or (sys_platform == 'win32' and extra == 'torch')", specifier = "==0.22.0", index = "https://download.pytorch.org/whl/cu128" }, + { name = "torchvision", marker = "sys_platform != 'linux' and sys_platform != 'win32' and extra == 'torch'", specifier = "==0.22.0" }, { name = "transformers", specifier = ">=4.51.0" }, { name = "typer", specifier = ">=0.15.1" }, ] @@ -2778,7 +2805,7 @@ wheels = [ [[package]] name = "torch" -version = "2.6.0" +version = "2.7.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12' and sys_platform != 'linux' and sys_platform != 'win32'", @@ -2797,22 +2824,27 @@ dependencies = [ { name = "typing-extensions", marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/e5/16/ea1b7842413a7b8a5aaa5e99e8eaf3da3183cc3ab345ad025a07ff636301/torch-2.6.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:09e06f9949e1a0518c5b09fe95295bc9661f219d9ecb6f9893e5123e10696628", size = 66520221 }, - { url = "https://files.pythonhosted.org/packages/0b/fa/f33a4148c6fb46ca2a3f8de39c24d473822d5774d652b66ed9b1214da5f7/torch-2.6.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:94fc63b3b4bedd327af588696559f68c264440e2503cc9e6954019473d74ae21", size = 66530713 }, - { url = "https://files.pythonhosted.org/packages/81/b4/605ae4173aa37fb5aa14605d100ff31f4f5d49f617928c9f486bb3aaec08/torch-2.6.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:9a610afe216a85a8b9bc9f8365ed561535c93e804c2a317ef7fabcc5deda0989", size = 66532538 }, - { url = "https://files.pythonhosted.org/packages/88/8b/d60c0491ab63634763be1537ad488694d316ddc4a20eaadd639cedc53971/torch-2.6.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:ff96f4038f8af9f7ec4231710ed4549da1bdebad95923953a25045dcf6fd87e2", size = 66536783 }, - { url = "https://files.pythonhosted.org/packages/b3/17/41f681b87290a1d2f1394f943e470f8b0b3c2987b7df8dc078d8831fce5b/torch-2.6.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:265f70de5fd45b864d924b64be1797f86e76c8e48a02c2a3a6fc7ec247d2226c", size = 66520446 }, + { url = "https://files.pythonhosted.org/packages/dc/0b/b2b83f30b8e84a51bf4f96aa3f5f65fdf7c31c591cc519310942339977e2/torch-2.7.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:34e0168ed6de99121612d72224e59b2a58a83dae64999990eada7260c5dd582d", size = 68559462 }, + { url = "https://files.pythonhosted.org/packages/aa/3f/85b56f7e2abcfa558c5fbf7b11eb02d78a4a63e6aeee2bbae3bb552abea5/torch-2.7.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:0a8d43caa342b9986101ec5feb5bbf1d86570b5caa01e9cb426378311258fdde", size = 68569377 }, + { url = "https://files.pythonhosted.org/packages/ee/8d/b2939e5254be932db1a34b2bd099070c509e8887e0c5a90c498a917e4032/torch-2.7.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:30b7688a87239a7de83f269333651d8e582afffce6f591fff08c046f7787296e", size = 68574294 }, + { url = "https://files.pythonhosted.org/packages/28/fd/74ba6fde80e2b9eef4237fe668ffae302c76f0e4221759949a632ca13afa/torch-2.7.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:edad98dddd82220465b106506bb91ee5ce32bd075cddbcf2b443dfaa2cbd83bf", size = 68856166 }, + { url = "https://files.pythonhosted.org/packages/90/48/7e6477cf40d48cc0a61fa0d41ee9582b9a316b12772fcac17bc1a40178e7/torch-2.7.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:27f5007bdf45f7bb7af7f11d1828d5c2487e030690afb3d89a651fd7036a390e", size = 68575074 }, + { url = "https://files.pythonhosted.org/packages/85/11/571d6363d1aaee3033af46b40798a0238b24522e9b291b676446943cc8a9/torch-2.7.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:ccd7509141713997861b7a947ef0a717143cd7e9240addd168f38ba8fd23fd56", size = 68560465 }, ] [[package]] name = "torch" -version = "2.6.0+cu124" -source = { registry = "https://download.pytorch.org/whl/cu124" } +version = "2.7.0+cu128" +source = { registry = "https://download.pytorch.org/whl/cu128" } resolution-markers = [ - "(python_full_version >= '3.12' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform == 'win32')", - "(python_full_version == '3.11.*' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform == 'win32')", - "(python_full_version == '3.10.*' and sys_platform == 'linux') or (python_full_version == '3.10.*' and sys_platform == 'win32')", - "(python_full_version < '3.10' and sys_platform == 'linux') or (python_full_version < '3.10' and sys_platform == 'win32')", + "(python_full_version >= '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform == 'win32')", + "python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform == 'win32')", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version == '3.10.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.10.*' and sys_platform == 'win32')", + "python_full_version == '3.10.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version < '3.10' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.10' and sys_platform == 'win32')", + "python_full_version < '3.10' and platform_machine == 'aarch64' and sys_platform == 'linux'", ] dependencies = [ { name = "filelock", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -2826,6 +2858,7 @@ dependencies = [ { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, @@ -2835,26 +2868,58 @@ dependencies = [ { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "setuptools", marker = "(python_full_version >= '3.12' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform == 'win32')" }, { name = "sympy", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "triton", marker = "sys_platform == 'linux'" }, { name = "typing-extensions", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp310-cp310-linux_x86_64.whl", hash = "sha256:7f2ba7f7c0459320a521696f6b5bccc187f59890b23c9dfb6c49b0b87c6bfc97" }, - { url = "https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp310-cp310-win_amd64.whl", hash = "sha256:7cc45c5b39d74875cfafe908b7f55c544147cc16b01e795feb2fe766583efe78" }, - { url = "https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp311-cp311-linux_x86_64.whl", hash = "sha256:d4c3e9a8d31a7c0fcbb9da17c31a1917e1fac26c566a4cfbd8c9568ad7cade79" }, - { url = "https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp311-cp311-win_amd64.whl", hash = "sha256:6a1fb2714e9323f11edb6e8abf7aad5f79e45ad25c081cde87681a18d99c29eb" }, - { url = "https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp312-cp312-linux_x86_64.whl", hash = "sha256:a393b506844035c0dac2f30ea8478c343b8e95a429f06f3b3cadfc7f53adb597" }, - { url = "https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp312-cp312-win_amd64.whl", hash = "sha256:3313061c1fec4c7310cf47944e84513dcd27b6173b72a349bb7ca68d0ee6e9c0" }, - { url = "https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp313-cp313-linux_x86_64.whl", hash = "sha256:0f3bc53c988ce9568cd876a2a5316761e84a8704135ec8068f5f81b4417979cb" }, - { url = "https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp313-cp313-win_amd64.whl", hash = "sha256:519330eef09534acad8110b6f423d2fe58c1d8e9ada999ed077a637a0021f908" }, - { url = "https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp313-cp313t-linux_x86_64.whl", hash = "sha256:35cba404c0d742406cdcba1609085874bc60facdfbc50e910c47a92405fef44c" }, - { url = "https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp39-cp39-linux_x86_64.whl", hash = "sha256:e661267cd0242462ab100bdd67f651988aa9f67eb31609d6909afcac891df612" }, - { url = "https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp39-cp39-win_amd64.whl", hash = "sha256:c2eb62b99161d87be486c88fd82441274cc892bce8c48dbc28c055cb147732ce" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp310-cp310-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:ac1849553ee673dfafb44c610c60cb60a2890f0e117f43599a526cf777eb8b8c" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp310-cp310-win_amd64.whl", hash = "sha256:c52c4b869742f00b12cb34521d1381be6119fa46244791704b00cc4a3cb06850" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp311-cp311-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:c4bbc0b4be60319ba1cefc90be9557b317f0b3c261eeceb96ca6e0343eec56bf" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp311-cp311-win_amd64.whl", hash = "sha256:bf88f647d76d79da9556ca55df49e45aff1d66c12797886364343179dd09a36c" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp312-cp312-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:7c0f08d1c44a02abad389373dddfce75904b969a410be2f4e5109483dd3dc0ce" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp312-cp312-win_amd64.whl", hash = "sha256:1704e5dd66c9221e4e8b6ae2d80cbf54e129571e643f5fa9ca78cc6d2096403a" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp313-cp313-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:d2f69f909da5dc52113ec66a851d62079f3d52c83184cf64beebdf12ca2f705c" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp313-cp313-win_amd64.whl", hash = "sha256:58c749f52ddc9098155c77d6c74153bb13d8978fd6e1063b5d7b41d4644f5af5" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp313-cp313t-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:78e13c26c38ae92d6841cf9ce760d7e9d52bca3e3183de371812e84274b054dc" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp313-cp313t-win_amd64.whl", hash = "sha256:3559e98be824c2b12ab807319cd61c6174d73a524c9961317de8e8a44133c5c5" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp39-cp39-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:f446f97b20cb070747b103fb640df941b88cb68c8d3b01538287d05d56a7e874" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.7.0%2Bcu128-cp39-cp39-win_amd64.whl", hash = "sha256:8614a167d6a163273fb130f586802f3243479862b53ee2843941c10cc5761da6" }, ] [[package]] name = "torchvision" -version = "0.21.0" +version = "0.22.0" +source = { registry = "https://download.pytorch.org/whl/cu128" } +resolution-markers = [ + "python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.10.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.10' and platform_machine == 'aarch64' and sys_platform == 'linux'", +] +dependencies = [ + { name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10' and platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10' and platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "pillow", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "torch", version = "2.7.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, +] +wheels = [ + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.22.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:566224d7b4f00bc6366bed1d62f834ca80f8e57fe41e10e4a5636bfa3ffb984e" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.22.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6be714bcdd8849549571f6acfaa2dfa9e00676f042bda517432745fb116f7904" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.22.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:6e9752b48c1cdd7f6428bcd30c3d198b30ecea348d16afb651f95035e5252506" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.22.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:e4d4d5a14225875d9bf8c5221d43d8be97786adc498659493799bdeff52c54cf" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.22.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:e50ff5bbae11f57fd3af8e6f2185c136f32e8b94324613428228dd27eba6a4f6" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.22.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:7a398fad02f4ac6b7d18bea9a08dc14163ffc5a368618f29ceb0e53dfa91f69e" }, +] + +[[package]] +name = "torchvision" +version = "0.22.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12' and sys_platform != 'linux' and sys_platform != 'win32'", @@ -2866,43 +2931,46 @@ dependencies = [ { name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10' and sys_platform != 'linux' and sys_platform != 'win32'" }, { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10' and sys_platform != 'linux' and sys_platform != 'win32'" }, { name = "pillow", marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, - { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, + { name = "torch", version = "2.7.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/8e/0d/143bd264876fad17c82096b6c2d433f1ac9b29cdc69ee45023096976ee3d/torchvision-0.21.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:044ea420b8c6c3162a234cada8e2025b9076fa82504758cd11ec5d0f8cd9fa37", size = 1784140 }, - { url = "https://files.pythonhosted.org/packages/29/88/00c69db213ee2443ada8886ec60789b227e06bb869d85ee324578221a7f7/torchvision-0.21.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:110d115333524d60e9e474d53c7d20f096dbd8a080232f88dddb90566f90064c", size = 1784141 }, - { url = "https://files.pythonhosted.org/packages/6e/1b/28f527b22d5e8800184d0bc847f801ae92c7573a8c15979d92b7091c0751/torchvision-0.21.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:97a5814a93c793aaf0179cfc7f916024f4b63218929aee977b645633d074a49f", size = 1784140 }, - { url = "https://files.pythonhosted.org/packages/f9/56/47d456b61c3bbce7bed4af3925c83d405bb87468e659fd3cf3d9840c3b51/torchvision-0.21.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:659b76c86757cb2ee4ca2db245e0740cfc3081fef46f0f1064d11adb4a8cee31", size = 1784141 }, - { url = "https://files.pythonhosted.org/packages/49/d5/d18c5d89cbe32015b033f1fa06918c7cdd5c0af0c03e55d72a3cc2d768f8/torchvision-0.21.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5c22caeaae8b3c36d93459f1a5294e6f43306cff856ed243189a229331a404b4", size = 1784154 }, + { url = "https://files.pythonhosted.org/packages/eb/03/a514766f068b088180f273913e539d08e830be3ae46ef8577ea62584a27c/torchvision-0.22.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:72256f1d7ff510b16c9fb4dd488584d0693f40c792f286a9620674438a81ccca", size = 1947829 }, + { url = "https://files.pythonhosted.org/packages/b1/43/28bc858b022f6337326d75f4027d2073aad5432328f01ee1236d847f1b82/torchvision-0.22.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:191ea28321fc262d8aa1a7fe79c41ff2848864bf382f9f6ea45c41dde8313792", size = 1947828 }, + { url = "https://files.pythonhosted.org/packages/cb/ea/887d1d61cf4431a46280972de665f350af1898ce5006cd046326e5d0a2f2/torchvision-0.22.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:31c3165418fe21c3d81fe3459e51077c2f948801b8933ed18169f54652796a0f", size = 1947826 }, + { url = "https://files.pythonhosted.org/packages/e1/2a/9b34685599dcb341d12fc2730055155623db7a619d2415a8d31f17050952/torchvision-0.22.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ece17995857dd328485c9c027c0b20ffc52db232e30c84ff6c95ab77201112c5", size = 1947823 }, + { url = "https://files.pythonhosted.org/packages/6f/a7/f43e9c8d13118b4ffbaebea664c9338ab20fa115a908125afd2238ff16e7/torchvision-0.22.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:cdc96daa4658b47ce9384154c86ed1e70cba9d972a19f5de6e33f8f94a626790", size = 2137621 }, + { url = "https://files.pythonhosted.org/packages/3a/6e/eb662050a22a75a85b3b5e5f33dddfdc487c10ffcd20b82a8c2a4a6cd56c/torchvision-0.22.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2ef38a397f1b9cf62846fb20659cb99101f9d361de8c45d79284ee45c6f40d50", size = 1947880 }, ] [[package]] name = "torchvision" -version = "0.21.0+cu124" -source = { registry = "https://download.pytorch.org/whl/cu124" } +version = "0.22.0+cu128" +source = { registry = "https://download.pytorch.org/whl/cu128" } resolution-markers = [ - "(python_full_version >= '3.12' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform == 'win32')", - "(python_full_version == '3.11.*' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform == 'win32')", - "(python_full_version == '3.10.*' and sys_platform == 'linux') or (python_full_version == '3.10.*' and sys_platform == 'win32')", - "(python_full_version < '3.10' and sys_platform == 'linux') or (python_full_version < '3.10' and sys_platform == 'win32')", + "(python_full_version >= '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform == 'win32')", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform == 'win32')", + "(python_full_version == '3.10.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.10.*' and sys_platform == 'win32')", + "(python_full_version < '3.10' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.10' and sys_platform == 'win32')", ] dependencies = [ - { name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.10' and sys_platform == 'linux') or (python_full_version < '3.10' and sys_platform == 'win32')" }, - { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.10' and sys_platform == 'linux') or (python_full_version >= '3.10' and sys_platform == 'win32')" }, - { name = "pillow", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.10' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.10' and sys_platform == 'win32')" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.10' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.10' and sys_platform == 'win32')" }, + { name = "pillow", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or sys_platform == 'win32'" }, + { name = "torch", version = "2.7.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or sys_platform == 'win32'" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/cu124/torchvision-0.21.0%2Bcu124-cp310-cp310-linux_x86_64.whl", hash = "sha256:3d3e74018eaa7837c73e3764dad3b7792b7544401c25a42977e9744303731bd3" }, - { url = "https://download.pytorch.org/whl/cu124/torchvision-0.21.0%2Bcu124-cp310-cp310-win_amd64.whl", hash = "sha256:0c6aefb70ab2b312065240c804e459ac7b0e449867afd469b38d2fd47f9391a7" }, - { url = "https://download.pytorch.org/whl/cu124/torchvision-0.21.0%2Bcu124-cp311-cp311-linux_x86_64.whl", hash = "sha256:137376805aca5ba57bd2c7a3ecb8569df961dbe82b128aac9b3b0a7125ef9385" }, - { url = "https://download.pytorch.org/whl/cu124/torchvision-0.21.0%2Bcu124-cp311-cp311-win_amd64.whl", hash = "sha256:000a013584ad2304ab30496318145f284ac364622addb5ee3a5abd2769ba146f" }, - { url = "https://download.pytorch.org/whl/cu124/torchvision-0.21.0%2Bcu124-cp312-cp312-linux_x86_64.whl", hash = "sha256:efb53ea0af7bf09b7b53e2a18b9be6d245f7d46a90b51d5cf97f37e9b929a991" }, - { url = "https://download.pytorch.org/whl/cu124/torchvision-0.21.0%2Bcu124-cp312-cp312-win_amd64.whl", hash = "sha256:ec63c2ee792757492da40590e34b14f2fceda29050558c215f0c1f3b08149c0f" }, - { url = "https://download.pytorch.org/whl/cu124/torchvision-0.21.0%2Bcu124-cp313-cp313-linux_x86_64.whl", hash = "sha256:4b70acf3b4b96a0ceb1374116626c9bef9e8be016b57b1284e482260ca1896d6" }, - { url = "https://download.pytorch.org/whl/cu124/torchvision-0.21.0%2Bcu124-cp313-cp313-win_amd64.whl", hash = "sha256:8fcf55321b206de70ff8e01c884fa42e57a60b1cb749341b96e0f22c8a7c9ec7" }, - { url = "https://download.pytorch.org/whl/cu124/torchvision-0.21.0%2Bcu124-cp39-cp39-linux_x86_64.whl", hash = "sha256:6afb21a22f5497e08ea4dbd4544472330d8249bf09dafd239302552cad6906b2" }, - { url = "https://download.pytorch.org/whl/cu124/torchvision-0.21.0%2Bcu124-cp39-cp39-win_amd64.whl", hash = "sha256:579b6a7fffc34a860c57a7131221ef125831f5961431f8da15760ab1ef752d44" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.22.0%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:59df5a550113a80ce523047066eaaedb168c69482da88c3ab246716ab45ba092" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.22.0%2Bcu128-cp310-cp310-win_amd64.whl", hash = "sha256:cdd90b768b01b0d638cb06a6c211b550b275c0c207b5210b7cbb5cea8dde11db" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.22.0%2Bcu128-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:f3ac527d58b4c2043eb8d9e29fc56cd1751f36f2aaa6dc75e34ec54c951bcb9c" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.22.0%2Bcu128-cp311-cp311-win_amd64.whl", hash = "sha256:f5dae1307c34813425c0b753530c035e1cc72af0bded395d1ba64dcb2872889f" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.22.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:06c101f40e1ff94869be14487c91fd5352e376f202fdeafb8f53c58cee2fbeb5" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.22.0%2Bcu128-cp312-cp312-win_amd64.whl", hash = "sha256:a87393c86649b7e56b4bf859fe95922ee6ec1c1f3b430246fb1a5b51f8aee37a" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.22.0%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:ee4fa6d4052d9ae25c1233289947fbfa4b88d23710254ab1772b108c1fc5fb4d" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.22.0%2Bcu128-cp313-cp313-win_amd64.whl", hash = "sha256:17d50ffb1df6320da16b85395f1078bf369250ea144f3bb405088aca3d5f030f" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.22.0%2Bcu128-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:209c29d78cf2003cf4e22c9b651790f57171334998ee3125594d130526aeaa50" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.22.0%2Bcu128-cp313-cp313t-win_amd64.whl", hash = "sha256:03b454b867f7a0aa9861a463042141448c4f15bec784def19eed39a57fac217b" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.22.0%2Bcu128-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:c92a353ff82db3312644b5b26d410b586b72969b535948d584c247569f75605c" }, + { url = "https://download.pytorch.org/whl/cu128/torchvision-0.22.0%2Bcu128-cp39-cp39-win_amd64.whl", hash = "sha256:90a0dacad36b1ea8de912af8583cbe780b4a1bdf9cb85870fe548fdec212ab31" }, ] [[package]] @@ -2941,19 +3009,23 @@ wheels = [ [[package]] name = "triton" -version = "3.2.0" +version = "3.3.0" source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "setuptools", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] wheels = [ - { url = "https://files.pythonhosted.org/packages/01/65/3ffa90e158a2c82f0716eee8d26a725d241549b7d7aaf7e4f44ac03ebd89/triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3e54983cd51875855da7c68ec05c05cf8bb08df361b1d5b69e05e40b0c9bd62", size = 253090354 }, - { url = "https://files.pythonhosted.org/packages/a7/2e/757d2280d4fefe7d33af7615124e7e298ae7b8e3bc4446cdb8e88b0f9bab/triton-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8009a1fb093ee8546495e96731336a33fb8856a38e45bb4ab6affd6dbc3ba220", size = 253157636 }, - { url = "https://files.pythonhosted.org/packages/06/00/59500052cb1cf8cf5316be93598946bc451f14072c6ff256904428eaf03c/triton-3.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d9b215efc1c26fa7eefb9a157915c92d52e000d2bf83e5f69704047e63f125c", size = 253159365 }, - { url = "https://files.pythonhosted.org/packages/c7/30/37a3384d1e2e9320331baca41e835e90a3767303642c7a80d4510152cbcf/triton-3.2.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5dfa23ba84541d7c0a531dfce76d8bcd19159d50a4a8b14ad01e91734a5c1b0", size = 253154278 }, - { url = "https://files.pythonhosted.org/packages/bc/74/9f12bdedeb110242d8bb1bd621f6605e753ee0cbf73cf7f3a62b8173f190/triton-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:30ceed0eff2c4a73b14eb63e052992f44bbdf175f3fad21e1ac8097a772de7ee", size = 253057866 }, + { url = "https://files.pythonhosted.org/packages/76/04/d54d3a6d077c646624dc9461b0059e23fd5d30e0dbe67471e3654aec81f9/triton-3.3.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fad99beafc860501d7fcc1fb7045d9496cbe2c882b1674640304949165a916e7", size = 156441993 }, + { url = "https://files.pythonhosted.org/packages/3c/c5/4874a81131cc9e934d88377fbc9d24319ae1fb540f3333b4e9c696ebc607/triton-3.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3161a2bf073d6b22c4e2f33f951f3e5e3001462b2570e6df9cd57565bdec2984", size = 156528461 }, + { url = "https://files.pythonhosted.org/packages/11/53/ce18470914ab6cfbec9384ee565d23c4d1c55f0548160b1c7b33000b11fd/triton-3.3.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b68c778f6c4218403a6bd01be7484f6dc9e20fe2083d22dd8aef33e3b87a10a3", size = 156504509 }, + { url = "https://files.pythonhosted.org/packages/7d/74/4bf2702b65e93accaa20397b74da46fb7a0356452c1bb94dbabaf0582930/triton-3.3.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:47bc87ad66fa4ef17968299acacecaab71ce40a238890acc6ad197c3abe2b8f1", size = 156516468 }, + { url = "https://files.pythonhosted.org/packages/0a/93/f28a696fa750b9b608baa236f8225dd3290e5aff27433b06143adc025961/triton-3.3.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ce4700fc14032af1e049005ae94ba908e71cd6c2df682239aed08e49bc71b742", size = 156580729 }, + { url = "https://files.pythonhosted.org/packages/f0/9c/315d25590fc309e2d28bb67953526238fac5d54548a16ceca992c76441bc/triton-3.3.0-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1f41403bfa0cbb3e24fd958ca7fee04e9681e55e539296db9aca30c42acae693", size = 156439372 }, ] [[package]] name = "typer" -version = "0.15.1" +version = "0.15.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, @@ -2961,9 +3033,9 @@ dependencies = [ { name = "shellingham" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/cb/ce/dca7b219718afd37a0068f4f2530a727c2b74a8b6e8e0c0080a4c0de4fcd/typer-0.15.1.tar.gz", hash = "sha256:a0588c0a7fa68a1978a069818657778f86abe6ff5ea6abf472f940a08bfe4f0a", size = 99789 } +sdist = { url = "https://files.pythonhosted.org/packages/98/1a/5f36851f439884bcfe8539f6a20ff7516e7b60f319bbaf69a90dc35cc2eb/typer-0.15.3.tar.gz", hash = "sha256:818873625d0569653438316567861899f7e9972f2e6e0c16dab608345ced713c", size = 101641 } wheels = [ - { url = "https://files.pythonhosted.org/packages/d0/cc/0a838ba5ca64dc832aa43f727bd586309846b0ffb2ce52422543e6075e8a/typer-0.15.1-py3-none-any.whl", hash = "sha256:7994fb7b8155b64d3402518560648446072864beefd44aa2dc36972a5972e847", size = 44908 }, + { url = "https://files.pythonhosted.org/packages/48/20/9d953de6f4367163d23ec823200eb3ecb0050a2609691e512c8b95827a9b/typer-0.15.3-py3-none-any.whl", hash = "sha256:c86a65ad77ca531f03de08d1b9cb67cd09ad02ddddf4b34745b5008f43b239bd", size = 45253 }, ] [[package]]