Merge branch 'main' into gaudi_backend_pa

This commit is contained in:
Wang, Yi A 2025-03-19 18:15:08 -07:00
commit d5b78ba16f
102 changed files with 7043 additions and 12285 deletions

View File

@ -124,6 +124,15 @@ jobs:
export extra_pytest="--neuron"
export target=""
;;
gaudi)
export dockerfile="Dockerfile_gaudi"
export label_extension="-gaudi"
export docker_volume="/mnt/cache"
export docker_devices=""
export runs_on="ubuntu-latest"
export platform=""
export extra_pytest=""
export target=""
esac
echo $dockerfile
echo "Dockerfile=${dockerfile}"
@ -224,7 +233,12 @@ jobs:
- name: Final
id: final
run: |
echo "docker_image=docker.io/huggingface/text-generation-inference-ci:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
if [ "${{ github.event_name }}" = "pull_request" ]; then
echo "docker_image=docker.io/huggingface/text-generation-inference-ci:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
else
echo "docker_image=ghcr.io/huggingface/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
fi
echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT"
echo "docker_volume=${{ env.DOCKER_VOLUME }}" >> "$GITHUB_OUTPUT"
echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT"

View File

@ -21,6 +21,7 @@ on:
- "Dockerfile_amd"
- "Dockerfile_intel"
- "Dockerfile.neuron"
- "Dockerfile_gaudi"
branches:
- "main"
workflow_dispatch:
@ -38,7 +39,7 @@ jobs:
# fail-fast is true by default
fail-fast: false
matrix:
hardware: ["cuda", "cuda-trtllm", "rocm", "intel-xpu", "intel-cpu", "neuron"]
hardware: ["cuda", "cuda-trtllm", "rocm", "intel-xpu", "intel-cpu", "neuron", "gaudi"]
uses: ./.github/workflows/build.yaml # calls the one above ^
permissions:
contents: write

53
.github/workflows/nix_build.yaml vendored Normal file
View File

@ -0,0 +1,53 @@
name: "Nix Build Docker image"
on:
pull_request:
push:
branches:
- 'main'
tags:
- 'v*'
concurrency:
group: nix-image-${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
build_nix_image:
runs-on:
group: aws-highmemory-32-plus-priv
steps:
- uses: actions/checkout@v4
- uses: cachix/install-nix-action@v27
with:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
with:
name: text-generation-inference
# If you chose signing key for write access
authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
env:
USER: github_runner
- name: Build
run: nix build .#dockerImage
- name: Initialize Docker Buildx
uses: docker/setup-buildx-action@v3
with:
install: true
buildkitd-config: /tmp/buildkitd.toml
- name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1
- name: Login to internal Container Registry
# if: github.event_name != 'pull_request'
uses: docker/login-action@v3
with:
username: ${{ secrets.REGISTRY_USERNAME }}
password: ${{ secrets.REGISTRY_PASSWORD }}
registry: registry.internal.huggingface.tech
- name: Push to docker
run: |
if [ "${{ github.event_name }}" = "pull_request" ]; then
export TAG=nix-sha-${{ env.GITHUB_SHA_SHORT }}
else
export TAG=nix-${{ github.ref_name }}
fi
export IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:$TAG
nix-shell -p skopeo --command "skopeo --insecure-policy copy docker-archive:$(readlink -f ./result) docker://$IMAGE --dest-compress-format zstd"

View File

@ -46,7 +46,7 @@ jobs:
- name: Download locked kernels
run: |
source ./.venv/bin/activate
hf-kernels download server
kernels download server
- name: Run server tests
run: |
source ./.venv/bin/activate

1
.gitignore vendored
View File

@ -28,3 +28,4 @@ server/fbgemmm
hl-smi_log*.txt
.graph_dumps
out
hqt_output

16
Cargo.lock generated
View File

@ -4617,7 +4617,7 @@ dependencies = [
[[package]]
name = "text-generation-backends-trtllm"
version = "3.1.2-dev0"
version = "3.2.1-dev0"
dependencies = [
"async-trait",
"clap 4.5.30",
@ -4638,7 +4638,7 @@ dependencies = [
[[package]]
name = "text-generation-benchmark"
version = "3.1.2-dev0"
version = "3.2.1-dev0"
dependencies = [
"average",
"clap 4.5.30",
@ -4658,7 +4658,7 @@ dependencies = [
[[package]]
name = "text-generation-client"
version = "3.1.2-dev0"
version = "3.2.1-dev0"
dependencies = [
"async-trait",
"base64 0.22.1",
@ -4676,7 +4676,7 @@ dependencies = [
[[package]]
name = "text-generation-launcher"
version = "3.1.2-dev0"
version = "3.2.1-dev0"
dependencies = [
"clap 4.5.30",
"ctrlc",
@ -4697,7 +4697,7 @@ dependencies = [
[[package]]
name = "text-generation-router"
version = "3.1.2-dev0"
version = "3.2.1-dev0"
dependencies = [
"anyhow",
"async-stream",
@ -4749,7 +4749,7 @@ dependencies = [
[[package]]
name = "text-generation-router-llamacpp"
version = "3.1.2-dev0"
version = "3.2.1-dev0"
dependencies = [
"async-trait",
"bindgen 0.71.1",
@ -4767,7 +4767,7 @@ dependencies = [
[[package]]
name = "text-generation-router-v2"
version = "3.1.2-dev0"
version = "3.2.1-dev0"
dependencies = [
"async-stream",
"async-trait",
@ -4816,7 +4816,7 @@ dependencies = [
[[package]]
name = "text-generation-router-v3"
version = "3.1.2-dev0"
version = "3.2.1-dev0"
dependencies = [
"async-stream",
"async-trait",

View File

@ -21,7 +21,7 @@ default-members = [
resolver = "2"
[workspace.package]
version = "3.1.2-dev0"
version = "3.2.1-dev0"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"
@ -29,7 +29,7 @@ homepage = "https://github.com/huggingface/text-generation-inference"
[workspace.dependencies]
base64 = "0.22.0"
tokenizers = { version = "0.20.0", features = ["http"] }
hf-hub = { version = "0.4.1", features = ["tokio"] }
hf-hub = { version = "0.4.2", features = ["tokio"] }
metrics = { version = "0.23.0" }
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
minijinja = { version = "2.2.0", features = ["json"] }

View File

@ -183,12 +183,12 @@ COPY server server
COPY server/Makefile server/Makefile
ENV HF_KERNELS_CACHE=/kernels
RUN cd server && \
uv sync --frozen --extra gen --extra attention --extra bnb --extra accelerate --extra compressed-tensors --extra marlin --extra moe --extra quantize --extra peft --extra outlines --no-install-project --active && \
uv sync --frozen --extra gen --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --no-install-project --active && \
make gen-server-raw && \
hf-kernels download .
kernels download .
RUN cd server && \
uv sync --frozen --extra gen --extra attention --extra bnb --extra accelerate --extra compressed-tensors --extra marlin --extra moe --extra quantize --extra peft --extra outlines --active --python=${PYTHON_VERSION} && \
uv sync --frozen --extra gen --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --active --python=${PYTHON_VERSION} && \
uv pip install nvidia-nccl-cu12==2.25.1 && \
pwd && \
text-generation-server --help

View File

@ -5,7 +5,7 @@ RUN mkdir -p /tgi
# Fetch the optimum-neuron sources directly to avoid relying on pypi deployments
FROM alpine AS optimum-neuron
RUN mkdir -p /optimum-neuron
ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.0.28.tar.gz /optimum-neuron/sources.tar.gz
ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.1.0.tar.gz /optimum-neuron/sources.tar.gz
RUN tar -C /optimum-neuron -xf /optimum-neuron/sources.tar.gz --strip-components=1
# Build cargo components (adapted from TGI original Dockerfile)
@ -108,10 +108,10 @@ RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEU
# Install neuronx packages
RUN apt-get update -y \
&& apt-get install -y --no-install-recommends \
aws-neuronx-dkms=2.18.20.0 \
aws-neuronx-collectives=2.22.33.0-d2128d1aa \
aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 \
aws-neuronx-tools=2.19.0.0 \
aws-neuronx-dkms=2.19.64.0 \
aws-neuronx-collectives=2.23.135.0-3e70920f2 \
aws-neuronx-runtime-lib=2.23.112.0-9b5179492 \
aws-neuronx-tools=2.20.204.0 \
libxml2 \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
@ -120,16 +120,16 @@ ENV PATH="/opt/bin/:/opt/aws/neuron/bin:${PATH}"
# Install manually torch CPU version to avoid pulling CUDA
RUN pip3 install \
torch==2.1.2 \
torchvision==0.16.2 \
torch==2.5.1 \
torchvision==0.20.1 \
--index-url https://download.pytorch.org/whl/cpu
RUN pip3 install \
neuronx-cc==2.15.143.0 \
torch-neuronx==2.1.2.2.3.2 \
transformers-neuronx==0.12.313 \
neuronx-distributed==0.9.0 \
libneuronxla==2.0.5347.0 \
neuronx-cc==2.16.372.0 \
torch-neuronx==2.5.1.2.4.0 \
transformers-neuronx==0.13.322 \
neuronx-distributed==0.10.1 \
libneuronxla==2.1.681.0 \
--extra-index-url=https://pip.repos.neuron.amazonaws.com
# Install HuggingFace packages

View File

@ -1,6 +1,6 @@
# Those arguments are required to build the image
ARG HABANA_VERSION
ARG PYTORCH_VERSION
ARG HABANA_VERSION=1.20.0
ARG PYTORCH_VERSION=2.6.0
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.85.0 AS chef
@ -92,7 +92,6 @@ RUN cd server && \
make gen-server && \
pip install --no-deps -r requirements.txt && \
bash ./dill-0.3.8-patch.sh && \
pip install outlines~=0.0.34 && \
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
pip install . --no-cache-dir

View File

@ -45,7 +45,7 @@ RUN cargo build --profile release-opt --frozen
# Text Generation Inference base image for Intel
FROM intel/oneapi-basekit:2024.2.1-0-devel-ubuntu22.04 AS xpu
FROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04 AS xpu
USER root
@ -87,7 +87,7 @@ RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https:/
RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-pti-dev-0.9
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc
# Text Generation Inference base env
ENV HF_HOME=/data \
@ -96,13 +96,11 @@ ENV HF_HOME=/data \
WORKDIR /usr/src
RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/torch-2.5.0a0%2Bgite84e33f-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/torchaudio-2.5.0a0%2B56bc006-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/torchvision-0.20.0a0%2B8e8a208-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/oneccl_bind_pt-2.5.0%2Bxpu-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install triton-xpu==3.0.0b2 --no-cache-dir
WORKDIR /usr/src
RUN pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/test/xpu
RUN pip install triton-xpu==3.2.0b1 --no-cache-dir
# Install server
COPY proto proto
@ -114,15 +112,14 @@ RUN cd server && \
pip install -U pip uv && \
uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/pti/0.9/lib:/opt/conda/lib
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib
ENV CCL_ZE_IPC_EXCHANGE=sockets
#ENV TORCH_LLM_ALLREDUCE=1
#ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
ENV TORCH_LLM_ALLREDUCE=1
ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout 1ccf72b2d11cd00b47aef6d6cd054c088aa6f083
RUN cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc,ats-m150' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.6.0%2Bxpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.6.10%2Bxpu-cp311-cp311-linux_x86_64.whl
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router

View File

@ -53,3 +53,6 @@ run-falcon-7b-instruct-quantize:
clean:
rm -rf target aml
preview_doc:
doc-builder preview text-generation-inference docs/source --not_python_module

View File

@ -84,7 +84,7 @@ model=HuggingFaceH4/zephyr-7b-beta
volume=$PWD/data
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.1.1 --model-id $model
ghcr.io/huggingface/text-generation-inference:3.2.1 --model-id $model
```
And then you can make requests like
@ -121,7 +121,7 @@ curl localhost:8080/v1/chat/completions \
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.1-rocm --model-id $model` instead of the command above.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.2.1-rocm --model-id $model` instead of the command above.
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
```
@ -152,7 +152,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
token=<your cli READ token>
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.1.1 --model-id $model
ghcr.io/huggingface/text-generation-inference:3.2.1 --model-id $model
```
### A note on Shared Memory (shm)

View File

@ -2,8 +2,8 @@ mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
mkfile_dir := $(dir $(mkfile_path))
root_dir := "${mkfile_dir}/../.."
HABANA_VERSION := 1.19.0
PYTORCH_VERSION := 2.5.1
HABANA_VERSION := 1.20.0
PYTORCH_VERSION := 2.6.0
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install

View File

@ -0,0 +1,283 @@
# Examples of Docker Commands for Gaudi Backend
This page gives a list of examples of docker run commands for some of the most popular models.
> **Note:** The parameters are chosen for Gaudi2 hardware to maximize performance on this given hardware, please adjust the parameters based on your hardware. For example, if you are using Gaudi3, you may want to increase the batch size.
## Default Precision (BF16)
### Llama3.1-8B on 1 card (BF16)
```bash
model=meta-llama/Meta-Llama-3.1-8B-Instruct
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e PREFILL_BATCH_BUCKET_SIZE=2 \
-e BATCH_BUCKET_SIZE=32 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
```
### Llama3.1-70B 8 cards (BF16)
```bash
model=meta-llama/Meta-Llama-3.1-70B-Instruct
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e BATCH_BUCKET_SIZE=256 \
-e PREFILL_BATCH_BUCKET_SIZE=4 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
```
### Llama2-7B on 1 Card (BF16)
```bash
model=meta-llama/Llama-2-7b-chat-hf
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e PREFILL_BATCH_BUCKET_SIZE=2 \
-e BATCH_BUCKET_SIZE=32 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
```
### Llama2-70B on 8 cards (BF16)
```bash
model=meta-llama/Llama-2-70b-chat-hf
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e BATCH_BUCKET_SIZE=256 \
-e PREFILL_BATCH_BUCKET_SIZE=4 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
```
### Llava-v1.6-Mistral-7B on 1 card (BF16)
```bash
model=llava-hf/llava-v1.6-mistral-7b-hf
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e PREFILL_BATCH_BUCKET_SIZE=1 \
-e BATCH_BUCKET_SIZE=1 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
--max-total-tokens 8192 --max-batch-size 4
```
## FP8 Precision
Please refer to the [FP8 Precision](https://huggingface.co/docs/text-generation-inference/backends/gaudi_new#how-to-use-different-precision-formats) section for more details. You need to measure the statistics of the model first before running the model in FP8 precision.
## Llama3.1-8B on 1 Card (FP8)
```bash
model=meta-llama/Meta-Llama-3.1-8B-Instruct
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e PREFILL_BATCH_BUCKET_SIZE=2 \
-e BATCH_BUCKET_SIZE=32 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
```
## Llama3.1-70B on 8 cards (FP8)
```bash
model=meta-llama/Meta-Llama-3.1-70B-Instruct
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e BATCH_BUCKET_SIZE=256 \
-e PREFILL_BATCH_BUCKET_SIZE=4 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
```
## Llama2-7B on 1 Card (FP8)
```bash
model=meta-llama/Llama-2-7b-chat-hf
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e PREFILL_BATCH_BUCKET_SIZE=2 \
-e BATCH_BUCKET_SIZE=32 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
```
## Llama2-70B on 8 Cards (FP8)
```bash
model=meta-llama/Llama-2-70b-chat-hf
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e BATCH_BUCKET_SIZE=256 \
-e PREFILL_BATCH_BUCKET_SIZE=4 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
```
## Llava-v1.6-Mistral-7B on 1 Card (FP8)
```bash
model=llava-hf/llava-v1.6-mistral-7b-hf
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e PREFILL_BATCH_BUCKET_SIZE=1 \
-e BATCH_BUCKET_SIZE=1 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
--max-total-tokens 8192 --max-batch-size 4
```
## Llava-v1.6-Mistral-7B on 8 Cards (FP8)
```bash
model=llava-hf/llava-v1.6-mistral-7b-hf
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e PREFILL_BATCH_BUCKET_SIZE=1 \
-e BATCH_BUCKET_SIZE=1 \
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
--model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
--max-total-tokens 8192 --max-batch-size 4
```

View File

@ -22,7 +22,7 @@ opentelemetry-instrumentation-grpc = "^0.36b0"
hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97"
peft = "^0.10"
optimum-habana = "1.15.0"
optimum-habana = "1.16.0"
transformers = "4.45.2"
numpy = "1.26.4"
accelerate = "0.33.0"

View File

@ -46,7 +46,7 @@ opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_versi
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
optimum-habana==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
optimum-habana==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
optimum==1.23.2 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
pandas==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
@ -87,3 +87,18 @@ wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
xxhash==3.5.0 ; python_version >= "3.9" and python_version < "3.13"
yarl==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
outlines==0.0.34 ; python_version >= "3.9" and python_version < "3.13"
interegular==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
lark==1.2.2 ; python_version >= "3.9" and python_version < "3.13"
cloudpickle==3.1.0 ; python_version >= "3.9" and python_version < "3.13"
diskcache==5.6.3 ; python_version >= "3.9" and python_version < "3.13"
numba==0.60.0 ; python_version >= "3.9" and python_version < "3.13"
llvmlite==0.43.0 ; python_version >= "3.9" and python_version < "3.13"
jsonschema==4.23.0 ; python_version >= "3.9" and python_version < "3.13"
annotated-types==0.7.0 ; python_version >= "3.9" and python_version < "3.13"
jsonschema-specifications==2024.10.1 ; python_version >= "3.9" and python_version < "3.13"
nest-asyncio==1.6.0; python_version >= "3.9" and python_version < "3.13"
pydantic==2.10.6; python_version >= "3.9" and python_version < "3.13"
pydantic-core==2.27.2 ; python_version >= "3.9" and python_version < "3.13"
referencing==0.36.2 ; python_version >= "3.9" and python_version < "3.13"
rpds-py==0.22.3 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -20,8 +20,9 @@ from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.bloom import BLOOM
from text_generation_server.models.starcoder import StarCoder
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
from text_generation_server.models.custom_modeling.mllama import (
MllamaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
@ -30,9 +31,6 @@ from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import
)
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
# from text_generation_server.models.custom_modeling.mllama import (
# MllamaForConditionalGeneration,
# )
from text_generation_server.utils.adapter import (
AdapterParameters,
build_layer_weight_lookup,
@ -329,6 +327,7 @@ __GLOBALS = locals()
for data in ModelType:
__GLOBALS[data.name] = data.value["type"]
SDP_ON_BF16 = int(os.environ.get("SDP_ON_BF16", 0))
# Disable gradients
torch.set_grad_enabled(False)
@ -849,6 +848,8 @@ def get_model(
trust_remote_code=trust_remote_code,
)
adapt_transformers_to_gaudi()
if SDP_ON_BF16 == 1:
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
if model_type == "gpt_bigcode":
return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
if model_type == "bloom":
@ -871,6 +872,17 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == "mllama":
return VlmCausalLM(
model_class=MllamaForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=None,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
model_id,

View File

@ -704,6 +704,9 @@ class CausalLM(Model):
htorch.core.hpu_set_env()
if world_size > 1:
os.environ.setdefault(
"DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1"
)
model = self.get_deepspeed_model(model_id, dtype, revision)
model = hq_env.prepare_model_for_quantization(model)
else:

View File

@ -14,25 +14,18 @@
# limitations under the License.
""" PyTorch Llava-NeXT model."""
from typing import List, Optional, Tuple
from typing import List, Optional, Union
import torch
import torch.utils.checkpoint
from torch import nn
import numpy as np
from transformers.activations import ACT2FN
from transformers.models.llava_next.modeling_llava_next import (
unpad_image,
)
from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
)
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
@ -40,7 +33,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
Args:
image_size (`tuple`):
The size of the input image in the format (height, width).
The size of the input image in the format (width, height).
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
@ -48,7 +41,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (height, width).
tuple: The shape of the image patch grid in the format (width, height).
"""
if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists")
@ -57,100 +50,53 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return height // patch_size, width // patch_size
def unpad_image(tensor, original_size):
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
"""
Unpads a PyTorch tensor of a padded and resized image.
Calculate the number of patches after the preprocessing for images of any resolution.
Args:
tensor (`torch.Tensor`):
The image tensor, assumed to be of shape (num_channels, height, width).
original_size (`tuple`):
The original size of the image (height, width).
image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`):
The size of the input image in the format (height, width). ?
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
`torch.Tensor`: The unpadded image tensor.
int: the number of patches
"""
original_height, original_width = original_size
current_height, current_width = tensor.shape[1:]
if not isinstance(grid_pinpoints, list):
raise TypeError("grid_pinpoints should be a list of tuples or lists")
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
if not isinstance(image_size, (list, tuple)):
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
raise TypeError(
f"image_size invalid type {type(image_size)} with value {image_size}"
)
image_size = image_size.tolist()
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
return unpadded_tensor
best_resolution = select_best_resolution(image_size, grid_pinpoints)
height, width = best_resolution
num_patches = 0
# consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
num_patches += 1
# add the base patch
num_patches += 1
return num_patches
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
class LlavaNextMultiModalProjector(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.linear_1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class LlavaNextForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = config.quantize
vision_config = config.vision_config
# Instead of selecting in hidden_states[-2].
# Instead compute only the n -2 + 1 layers and don't pool
if config.vision_feature_layer < 0:
vision_config.num_hidden_layers += config.vision_feature_layer + 1
else:
vision_config.num_hidden_layers = config.vision_feature_layer + 1
self.vision_tower = load_vision_model(
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config,
weights=weights,
)
self.multi_modal_projector = LlavaNextMultiModalProjector(
prefix="multi_modal_projector", config=config, weights=weights
)
self.image_newline = weights.get_tensor("image_newline")
self.vocab_size = config.text_config.vocab_size
self.config = config
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
self.text_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config,
weights=weights,
)
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
def _merge_input_ids_with_image_features(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
input_ids: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
mask = input_ids == self.config.image_token_index
@ -165,126 +111,315 @@ class LlavaNextForConditionalGeneration(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
# Unused for this model
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[int] = None,
vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = True,
flash_attention_recompute: Optional[bool] = True,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
# 1. Extract the input embeddings
# 2. Merge text and images
num_images, num_patches, channels, height, width = pixel_values.shape
pixel_values = pixel_values.view(
num_images * num_patches, channels, height, width
if token_idx is not None:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
image_features = self.vision_tower(pixel_values)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
# Already done within the clip model
selected_image_feature = image_features.last_hidden_state
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
token_idx=token_idx,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)
if self.config.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.config.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise RuntimeError(
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
logits = outputs[0]
if not return_dict:
output = (logits,) + outputs[1:]
return output
return outputs
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479
def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
and are of shape `(num_patches, image_length, embed_dim)`).
"""
# ! infer image_num_patches from image_sizes
image_num_patches = [
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=self.config.image_grid_pinpoints,
patch_size=self.config.vision_config.image_size,
)
for imsize in image_sizes
]
if pixel_values.dim() == 5:
# stacked if input is (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [
pix_val[:num_patch]
for pix_val, num_patch in zip(pixel_values, image_num_patches)
]
pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
raise ValueError(
f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions"
)
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
if isinstance(vision_feature_layer, int):
selected_image_feature = image_features.hidden_states[vision_feature_layer]
else:
hs_pool = [
image_features.hidden_states[layer_idx]
for layer_idx in vision_feature_layer
]
selected_image_feature = torch.cat(hs_pool, dim=-1)
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
image_features = self.multi_modal_projector(selected_image_feature)
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
pixel_values=None,
image_sizes=None,
attention_mask=None,
**kwargs,
):
"""
Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635
The only differences are:
- add new args token_idx
- add the process of merging images into inputs_embeds
"""
token_idx = kwargs.get("token_idx", None)
if token_idx is None:
return super().prepare_inputs_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
pixel_values=pixel_values,
image_sizes=image_sizes,
attention_mask=attention_mask,
**kwargs,
)
else:
use_flash_attention = kwargs.get("use_flash_attention", True)
flash_attention_recompute = kwargs.get("flash_attention_recompute", True)
position_ids = kwargs.get("position_ids", None)
labels = kwargs.get("labels", None)
if (
past_key_values is None
and pixel_values is not None
and input_ids.shape[1] != 1
):
vision_feature_select_strategy = kwargs.get(
"vision_feature_select_strategy", None
)
vision_feature_layer = kwargs.get("vision_feature_layer", None)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
vision_feature_layer = (
vision_feature_layer
if vision_feature_layer is not None
else self.config.vision_feature_layer
)
image_features = self.multi_modal_projector(selected_image_feature)
# 1. Extract the input embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text and images
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
# split up image_features for each of the individual images
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# if we assume each image has 5 image features (base image + 4 patches)
split_sizes = [num_patches] * num_images
image_features = torch.split(image_features, split_sizes, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
)
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
if height * width != base_image_feature.shape[0]:
raise ValueError(
"The number of patches is not consistent with the image size."
)
if height * width != base_image_feature.shape[0]:
raise ValueError(
"The number of patches is not consistent with the image size."
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx].tolist(),
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
# Dimensions are intentionally swapped to be bug-compatible with
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(
4, 0, 2, 1, 3
).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(
image_feature, image_sizes[image_idx]
)
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
),
),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature, self.image_newline[None]), dim=0
)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature, self.image_newline[None]), dim=0
)
new_image_features.append(image_feature)
image_features = torch.cat(new_image_features, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(
inputs_embeds, image_features, input_ids
)
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
# generation with cache
elif past_key_values is not None:
seq_len = input_ids.shape[1]
pad_len = seq_len - token_idx.item()
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(
first_layer_past_key_value.float().sum(-2) == 0
)
# Get the target length
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_features
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = extended_attention_mask
attention_mask[:, -pad_len:] = 0
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
if token_idx is not None:
position_ids = (
torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
)
else:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"token_idx": token_idx,
"labels": labels,
"use_flash_attention": use_flash_attention,
"flash_attention_recompute": flash_attention_recompute,
}
)
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=None,
adapter_data=adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits
return model_inputs

View File

@ -25,6 +25,7 @@ image:
--ulimit nofile=100000:100000 \
--build-arg VERSION=$(VERSION) \
-t text-generation-inference:$(VERSION)-neuron ${root_dir}
docker tag text-generation-inference:$(VERSION)-neuron text-generation-inference:latest-neuron
install_server:
make -C ${mkfile_dir}/server install VERSION:=${VERSION}

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
},
"version": "3.1.2-dev0"
"version": "3.2.1-dev0"
},
"paths": {
"/": {
@ -2148,9 +2148,6 @@
},
"StreamOptions": {
"type": "object",
"required": [
"include_usage"
],
"properties": {
"include_usage": {
"type": "boolean",

View File

@ -52,6 +52,8 @@
- sections:
- local: backends/neuron
title: Neuron
- local: backends/gaudi
title: Gaudi
- local: backends/trtllm
title: TensorRT-LLM
- local: backends/llamacpp

View File

@ -0,0 +1,317 @@
# Gaudi Backend for Text Generation Inference
## Overview
Text Generation Inference (TGI) has been optimized to run on Gaudi hardware via the Gaudi backend for TGI.
## Supported Hardware
- **Gaudi1**: Available on [AWS EC2 DL1 instances](https://aws.amazon.com/ec2/instance-types/dl1/)
- **Gaudi2**: Available on [Intel Cloud](https://console.cloud.intel.com/docs/reference/ai_instances.html)
- **Gaudi3**: Available on [Intel Cloud](https://console.cloud.intel.com/docs/reference/ai_instances.html)
## Tutorial: Getting Started with TGI on Gaudi
### Basic Usage
The easiest way to run TGI on Gaudi is to use the official Docker image:
```bash
model=meta-llama/Meta-Llama-3.1-8B-Instruct
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
hf_token=YOUR_HF_ACCESS_TOKEN
docker run --runtime=habana --cap-add=sys_nice --ipc=host \
-p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
ghcr.io/huggingface/text-generation-inference:3.2.1-gaudi \
--model-id $model
```
Once you see the `connected` log, the server is ready to accept requests:
> 2024-05-22T19:31:48.302239Z INFO text_generation_router: router/src/main.rs:378: Connected
You can find your `YOUR_HF_ACCESS_TOKEN` at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens). This is necessary to access gated models like llama3.1.
### Making Your First Request
You can send a request from a separate terminal:
```bash
curl 127.0.0.1:8080/generate \
-X POST \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":32}}' \
-H 'Content-Type: application/json'
```
## How-to Guides
### How to Run Specific Models
The following models have been validated on Gaudi2:
| Model | Model ID | BF16 | | FP8 | |
|-----------------------|----------------------------------------|-------------|------------|-------------|------------|
| | | Single Card | Multi-Card | Single Card | Multi-Card |
| Llama2-7B | meta-llama/Llama-2-7b-chat-hf | ✔ | ✔ | ✔ | ✔ |
| Llama2-70B | meta-llama/Llama-2-70b-chat-hf | | ✔ | | ✔ |
| Llama3-8B | meta-llama/Meta-Llama-3.1-8B-Instruct | ✔ | ✔ | ✔ | ✔ |
| Llama3-70B | meta-llama/Meta-Llama-3-70B-Instruct | | ✔ | | ✔ |
| Llama3.1-8B | meta-llama/Meta-Llama-3.1-8B-Instruct | ✔ | ✔ | ✔ | ✔ |
| Llama3.1-70B | meta-llama/Meta-Llama-3.1-70B-Instruct | | ✔ | | ✔ |
| CodeLlama-13B | codellama/CodeLlama-13b-hf | ✔ | ✔ | ✔ | ✔ |
| Mixtral-8x7B | mistralai/Mixtral-8x7B-Instruct-v0.1 | ✔ | ✔ | ✔ | ✔ |
| Mistral-7B | mistralai/Mistral-7B-Instruct-v0.3 | ✔ | ✔ | ✔ | ✔ |
| Falcon-180B | tiiuae/falcon-180B-chat | | ✔ | | ✔ |
| Qwen2-72B | Qwen/Qwen2-72B-Instruct | | ✔ | | ✔ |
| Starcoder2-3b | bigcode/starcoder2-3b | ✔ | ✔ | ✔ | |
| Starcoder2-15b | bigcode/starcoder2-15b | ✔ | ✔ | ✔ | |
| Starcoder | bigcode/starcoder | ✔ | ✔ | ✔ | ✔ |
| Gemma-7b | google/gemma-7b-it | ✔ | ✔ | ✔ | ✔ |
| Llava-v1.6-Mistral-7B | llava-hf/llava-v1.6-mistral-7b-hf | ✔ | ✔ | ✔ | ✔ |
To run any of these models:
```bash
model=MODEL_ID_THAT_YOU_WANT_TO_RUN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
hf_token=YOUR_ACCESS_TOKEN
docker run --runtime=habana --cap-add=sys_nice --ipc=host \
-p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
ghcr.io/huggingface/text-generation-inference:3.2.1-gaudi \
--model-id $model
<text-generation-inference-launcher-arguments>
```
For the full list of service parameters, refer to the [launcher-arguments page](https://huggingface.co/docs/text-generation-inference/reference/launcher).
The validated docker commands can be found in the [examples/docker_commands folder](https://github.com/huggingface/text-generation-inference/tree/main/backends/gaudi/examples/docker_commands).
> Note: `--runtime=habana --cap-add=sys_nice --ipc=host ` is required to enable docker to use the Gaudi hardware (more details [here](https://docs.habana.ai/en/latest/Installation_Guide/Additional_Installation/Docker_Installation.html)).
### How to Enable Multi-Card Inference (Sharding)
TGI-Gaudi supports sharding for multi-card inference, allowing you to distribute the load across multiple Gaudi cards.
For example, on a machine with 8 Gaudi cards, you can run:
```bash
docker run --runtime=habana --ipc=host --cap-add=sys_nice \
-p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
tgi-gaudi \
--model-id $model --sharded true --num-shard 8
```
<Tip>
We recommend always using sharding when running on a multi-card machine.
</Tip>
### How to Use Different Precision Formats
#### BF16 Precision (Default)
By default, all models run with BF16 precision on Gaudi hardware.
#### FP8 Precision
TGI-Gaudi supports FP8 precision inference with [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html).
To run FP8 Inference:
1. Measure statistics using [Optimum Habana measurement script](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation#running-with-fp8)
2. Run the model in TGI with QUANT_CONFIG setting - e.g. `-e QUANT_CONFIG=./quantization_config/maxabs_quant.json`.
The following commmand example for FP8 inference is based on the assumption that measurement is done via the first step above.
Example for Llama3.1-70B on 8 cards with FP8 precision:
```bash
model=meta-llama/Meta-Llama-3.1-70B-Instruct
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-v $PWD/quantization_config:/usr/src/quantization_config \
-v $PWD/hqt_output:/usr/src/hqt_output \
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
-e HF_TOKEN=$hf_token \
-e MAX_TOTAL_TOKENS=2048 \
-e BATCH_BUCKET_SIZE=256 \
-e PREFILL_BATCH_BUCKET_SIZE=4 \
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
ghcr.io/huggingface/text-generation-inference:3.2.1-gaudi \
--model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
```
### How to Run Vision-Language Models (VLMs)
Gaudi supports VLM inference.
Example for Llava-v1.6-Mistral-7B on 1 card:
Start the TGI server via the following command:
```bash
model=llava-hf/llava-v1.6-mistral-7b-hf
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e PREFILL_BATCH_BUCKET_SIZE=1 \
-e BATCH_BUCKET_SIZE=1 \
ghcr.io/huggingface/text-generation-inference:3.2.1-gaudi \
--model-id $model \
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
--max-total-tokens 8192 --max-batch-size 4
```
You can then send a request to the server via the following command:
```bash
curl -N 127.0.0.1:8080/generate \
-X POST \
-d '{"inputs":"![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)What is this a picture of?\n\n","parameters":{"max_new_tokens":32}}' \
-H 'Content-Type: application/json'
```
> Note: In Llava-v1.6-Mistral-7B, an image usually accounts for 2000 input tokens. For example, an image of size 512x512 is represented by 2800 tokens. Thus, `max-input-tokens` must be larger than the number of tokens associated with the image. Otherwise the image may be truncated. We set `BASE_IMAGE_TOKENS=2048` as the default image token value. This is the minimum value of `max-input-tokens`. You can override the environment variable `BASE_IMAGE_TOKENS` to change this value. The warmup will generate graphs with input length from `BASE_IMAGE_TOKENS` to `max-input-tokens`. For Llava-v1.6-Mistral-7B, the value of `max-batch-prefill-tokens` is 16384, which is calcualted as follows: `prefill_batch_size` = `max-batch-prefill-tokens` / `max-input-tokens`.
### How to Benchmark Performance
We recommend using the [inference-benchmarker tool](https://github.com/huggingface/inference-benchmarker) to benchmark performance on Gaudi hardware.
This benchmark tool simulates user requests and measures the performance of the model on realistic scenarios.
To run it on the same machine, you can do the following:
```bash
MODEL=meta-llama/Llama-3.1-8B-Instruct
HF_TOKEN=<your HF READ token>
# run a benchmark to evaluate the performance of the model for chat use case
# we mount results to the current directory
docker run \
--rm \
-it \
--net host \
-v $(pwd):/opt/inference-benchmarker/results \
-e "HF_TOKEN=$HF_TOKEN" \
ghcr.io/huggingface/inference-benchmarker:latest \
inference-benchmarker \
--tokenizer-name "$MODEL" \
--url http://localhost:8080 \
--profile chat
```
Please refer to the [inference-benchmarker README](https://github.com/huggingface/inference-benchmarker) for more details.
### How to Profile Performance
To collect performance profiling, you need to set the following environment variables:
| Name | Value(s) | Default | Description |
|--------------------| :--------- | :--------------- | :------------------------------------------------------- |
| PROF_WAITSTEP | integer | 0 | Control profile wait steps |
| PROF_WARMUPSTEP | integer | 0 | Control profile warmup steps |
| PROF_STEP | integer | 0 | Enable/disable profile, control profile active steps |
| PROF_PATH | string | /tmp/hpu_profile | Define profile folder |
| PROF_RANKS | string | 0 | Comma-separated list of ranks to profile |
| PROF_RECORD_SHAPES | True/False | False | Control record_shapes option in the profiler |
To use these environment variables, add them to your docker run command with the -e flag. For example:
```bash
docker run --runtime=habana --ipc=host --cap-add=sys_nice \
-p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
-e PROF_WAITSTEP=10 \
-e PROF_WARMUPSTEP=10 \
-e PROF_STEP=1 \
-e PROF_PATH=/tmp/hpu_profile \
-e PROF_RANKS=0 \
-e PROF_RECORD_SHAPES=True \
ghcr.io/huggingface/text-generation-inference:3.2.1-gaudi \
--model-id $model
```
## Explanation: Understanding TGI on Gaudi
### The Warmup Process
To ensure optimal performance, warmup is performed at the beginning of each server run. This process creates queries with various input shapes based on provided parameters and runs basic TGI operations (prefill, decode, concatenate).
Note: Model warmup can take several minutes, especially for FP8 inference. For faster subsequent runs, refer to [Disk Caching Eviction Policy](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#disk-caching-eviction-policy).
### Understanding Parameter Tuning
#### Sequence Length Parameters
- `--max-input-tokens` is the maximum possible input prompt length. Default value is `4095`.
- `--max-total-tokens` is the maximum possible total length of the sequence (input and output). Default value is `4096`.
#### Batch Size Parameters
- For prefill operation, please set `--max-batch-prefill-tokens` as `bs * max-input-tokens`, where `bs` is your expected maximum prefill batch size.
- For decode operation, please set `--max-batch-size` as `bs`, where `bs` is your expected maximum decode batch size.
- Please note that batch size will be always padded to the nearest multiplication of `BATCH_BUCKET_SIZE` and `PREFILL_BATCH_BUCKET_SIZE`.
#### Performance and Memory Parameters
- `PAD_SEQUENCE_TO_MULTIPLE_OF` determines sizes of input length buckets. Since warmup creates several graphs for each bucket, it's important to adjust that value proportionally to input sequence length. Otherwise, some out of memory issues can be observed.
- `ENABLE_HPU_GRAPH` enables HPU graphs usage, which is crucial for performance results. Recommended value to keep is `true`.
#### Sequence Length Parameters
- `--max-input-tokens`: Maximum possible input prompt length (default: 4095)
- `--max-total-tokens`: Maximum possible total sequence length (input + output) (default: 4096)
#### Batch Size Parameters
- `--max-batch-prefill-tokens`: Set as `bs * max-input-tokens` where `bs` is your expected maximum prefill batch size
- `--max-batch-size`: Set as `bs` where `bs` is your expected maximum decode batch size
- Note: Batch sizes are padded to the nearest multiple of `BATCH_BUCKET_SIZE` and `PREFILL_BATCH_BUCKET_SIZE`
## Reference
This section contains reference information about the Gaudi backend.
### Environment Variables
The following table contains the environment variables that can be used to configure the Gaudi backend:
| Name | Value(s) | Default | Description | Usage |
|-----------------------------| :--------- | :--------------- | :------------------------------------------------------------------------------------------------------------------------------- | :--------------------------- |
| ENABLE_HPU_GRAPH | True/False | True | Enable hpu graph or not | add -e in docker run command |
| LIMIT_HPU_GRAPH | True/False | True | Skip HPU graph usage for prefill to save memory, set to `True` for large sequence/decoding lengths(e.g. 300/212) | add -e in docker run command |
| BATCH_BUCKET_SIZE | integer | 8 | Batch size for decode operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
| PREFILL_BATCH_BUCKET_SIZE | integer | 4 | Batch size for prefill operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
| PAD_SEQUENCE_TO_MULTIPLE_OF | integer | 128 | For prefill operation, sequences will be padded to a multiple of provided value. | add -e in docker run command |
| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command |
| WARMUP_ENABLED | True/False | True | Enable warmup during server initialization to recompile all graphs. This can increase TGI setup time. | add -e in docker run command |
| QUEUE_THRESHOLD_MS | integer | 120 | Controls the threshold beyond which the request are considered overdue and handled with priority. Shorter requests are prioritized otherwise. | add -e in docker run command |
| USE_FLASH_ATTENTION | True/False | True | Whether to enable Habana Flash Attention, provided that the model supports it. Please refer to https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html?highlight=fusedsdpa#using-fused-scaled-dot-product-attention-fusedsdpa | add -e in docker run command |
| FLASH_ATTENTION_RECOMPUTE | True/False | True | Whether to enable Habana Flash Attention in recompute mode on first token generation. | add -e in docker run command |
## Contributing
Contributions to the TGI-Gaudi project are welcome. Please refer to the [contributing guide](https://github.com/huggingface/text-generation-inference/blob/main/CONTRIBUTING.md).
**Guidelines for contributing to Gaudi on TGI:** All changes should be made within the `backends/gaudi` folder. In general, you should avoid modifying the router, launcher, or benchmark to accommodate Gaudi hardware, as all Gaudi-specific logic should be contained within the `backends/gaudi` folder.
### Building the Docker Image from Source
To build the Docker image from source:
```bash
make -C backends/gaudi image
```
This builds the image and saves it as `tgi-gaudi`. You can then run TGI-Gaudi with this image:
```bash
model=meta-llama/Meta-Llama-3.1-8B-Instruct
volume=$PWD/data
hf_token=YOUR_ACCESS_TOKEN
docker run --runtime=habana --ipc=host --cap-add=sys_nice \
-p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
tgi-gaudi \
--model-id $model
```
For more details, see the [README of the Gaudi backend](https://github.com/huggingface/text-generation-inference/blob/main/backends/gaudi/README.md) and the [Makefile of the Gaudi backend](https://github.com/huggingface/text-generation-inference/blob/main/backends/gaudi/Makefile).

View File

@ -31,7 +31,7 @@ deployment instructions in the model card:
The service is launched simply by running the text-generation-inference container with two sets of parameters:
```
docker run <system_parameters> ghcr.io/huggingface/text-generation-inference:3.1.1-neuron <service_parameters>
docker run <system_parameters> ghcr.io/huggingface/text-generation-inference:3.2.1-neuron <service_parameters>
```
- system parameters are used to map ports, volumes and devices between the host and the service,

View File

@ -19,6 +19,6 @@ docker run --gpus all \
--shm-size 1g \
-e HF_TOKEN=$token \
-p 8080:80 \
-v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.1 \
-v $volume:/data ghcr.io/huggingface/text-generation-inference:3.2.1 \
--model-id $model
```

View File

@ -19,7 +19,7 @@ bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models.
In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇
```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.1 --model-id $model --quantize bitsandbytes
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.2.1 --model-id $model --quantize bitsandbytes
```
4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.
@ -27,7 +27,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf
In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇
```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.1 --model-id $model --quantize bitsandbytes-nf4
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.2.1 --model-id $model --quantize bitsandbytes-nf4
```
You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
@ -48,7 +48,7 @@ $$({\hat{W}_{l}}^{*} = argmin_{\hat{W_{l}}} ||W_{l}X-\hat{W}_{l}X||^{2}_{2})$$
TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇
```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.1 --model-id $model --quantize gptq
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.2.1 --model-id $model --quantize gptq
```
Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI.

View File

@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--device=/dev/kfd --device=/dev/dri --group-add video \
--ipc=host --shm-size 256g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.1.1-rocm \
ghcr.io/huggingface/text-generation-inference:3.2.1-rocm \
--model-id $model
```

View File

@ -1,3 +1,3 @@
# Using TGI with Intel Gaudi
Check out this [repository](https://github.com/huggingface/tgi-gaudi) to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index).
You can use TGI on Intel Gaudi using the [TGI gaudi backend](https://huggingface.co/docs/text-generation-inference/backends/gaudi).

View File

@ -12,7 +12,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.1.1-intel-xpu \
ghcr.io/huggingface/text-generation-inference:3.2.1-intel-xpu \
--model-id $model --cuda-graphs 0
```
@ -29,7 +29,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.1.1-intel-cpu \
ghcr.io/huggingface/text-generation-inference:3.2.1-intel-cpu \
--model-id $model --cuda-graphs 0
```

View File

@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.1.1 \
ghcr.io/huggingface/text-generation-inference:3.2.1 \
--model-id $model
```

View File

@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.1.1 \
ghcr.io/huggingface/text-generation-inference:3.2.1 \
--model-id $model
```
@ -96,7 +96,7 @@ curl 127.0.0.1:8080/generate \
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
```bash
docker run ghcr.io/huggingface/text-generation-inference:3.1.1 --help
docker run ghcr.io/huggingface/text-generation-inference:3.2.1 --help
```
</Tip>

View File

@ -163,7 +163,7 @@ hub = {
# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
image_uri=get_huggingface_llm_image_uri("huggingface",version="3.1.1"),
image_uri=get_huggingface_llm_image_uri("huggingface",version="3.2.1"),
env=hub,
role=role,
)

View File

@ -14,6 +14,8 @@ Text Generation Inference enables serving optimized models. The following sectio
- [Gemma](https://huggingface.co/google/gemma-7b)
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
- [Gemma3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d)
- [Gemma3 Text](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d)
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct)
- [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj)

View File

@ -978,16 +978,16 @@
"nixpkgs": "nixpkgs_6"
},
"locked": {
"lastModified": 1740049068,
"narHash": "sha256-heYzYOt+TSnRKHIV24s74yEjLkTbBfjNCWHdQEX++eI=",
"lastModified": 1741617161,
"narHash": "sha256-cwKYAsIVSLtoLbG48+oi3NkSrvuZRLYs8lkJmpDsTw0=",
"owner": "huggingface",
"repo": "text-generation-inference-nix",
"rev": "143e8451efa22b120f97e6698508e9a0aed82769",
"rev": "5946021ec6cb6aae18158a9dc27f893cfbab2925",
"type": "github"
},
"original": {
"owner": "huggingface",
"ref": "hub-rotary",
"ref": "kernels-0.2.0",
"repo": "text-generation-inference-nix",
"type": "github"
}

View File

@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
};
nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:huggingface/text-generation-inference-nix/hub-rotary";
tgi-nix.url = "github:huggingface/text-generation-inference-nix/kernels-0.2.0";
nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
rust-overlay = {
@ -176,11 +176,15 @@
'';
};
dockerImage = pkgs.callPackage nix/docker.nix {
# Use plain nixpkgs without overlays for dockerTools. dockerTools
# uses a Python package for computing the layers from the transitive
# closure. However, this needs a lot of rebuilds due to our overlay.
dockerImage = nixpkgs.legacyPackages.${system}.callPackage nix/docker.nix {
text-generation-inference = default;
};
dockerImageStreamed = pkgs.callPackage nix/docker.nix {
dockerImageStreamed = nixpkgs.legacyPackages.${system}.callPackage nix/docker.nix {
text-generation-inference = default;
stream = true;
};

View File

@ -0,0 +1,109 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 16,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 506,
"logprob": -1.3984375,
"special": false,
"text": " the"
},
{
"id": 1331,
"logprob": -1.6953125,
"special": false,
"text": " people"
},
{
"id": 236764,
"logprob": -0.23535156,
"special": false,
"text": ","
},
{
"id": 532,
"logprob": -0.24316406,
"special": false,
"text": " and"
},
{
"id": 506,
"logprob": -0.12109375,
"special": false,
"text": " the"
},
{
"id": 2780,
"logprob": -1.1640625,
"special": false,
"text": " food"
},
{
"id": 236761,
"logprob": -0.21386719,
"special": false,
"text": "."
},
{
"id": 108,
"logprob": -0.64453125,
"special": false,
"text": "\n\n"
},
{
"id": 2094,
"logprob": -0.77734375,
"special": false,
"text": "This"
},
{
"id": 563,
"logprob": -0.040283203,
"special": false,
"text": " is"
},
{
"id": 496,
"logprob": -0.03125,
"special": false,
"text": " a"
},
{
"id": 6290,
"logprob": -0.03515625,
"special": false,
"text": " nice"
},
{
"id": 1977,
"logprob": -0.0020751953,
"special": false,
"text": " place"
},
{
"id": 236761,
"logprob": -0.0079956055,
"special": false,
"text": "."
},
{
"id": 107,
"logprob": -0.9921875,
"special": false,
"text": "\n"
},
{
"id": 106,
"logprob": -0.45507812,
"special": true,
"text": "<end_of_turn>"
}
],
"top_tokens": null
},
"generated_text": " the people, and the food.\n\nThis is a nice place.\n"
}

View File

@ -0,0 +1,613 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 100,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 1331,
"logprob": -0.34960938,
"special": false,
"text": " people"
},
{
"id": 8390,
"logprob": -0.14746094,
"special": false,
"text": " died"
},
{
"id": 528,
"logprob": -1.2265625,
"special": false,
"text": " in"
},
{
"id": 506,
"logprob": -0.47070312,
"special": false,
"text": " the"
},
{
"id": 3640,
"logprob": -0.5859375,
"special": false,
"text": " United"
},
{
"id": 4184,
"logprob": -0.0027770996,
"special": false,
"text": " States"
},
{
"id": 236761,
"logprob": -0.34765625,
"special": false,
"text": "."
},
{
"id": 108,
"logprob": -0.0859375,
"special": false,
"text": "\n\n"
},
{
"id": 818,
"logprob": -1.1640625,
"special": false,
"text": "The"
},
{
"id": 6816,
"logprob": -1.890625,
"special": false,
"text": " generally"
},
{
"id": 10951,
"logprob": -0.14648438,
"special": false,
"text": " accepted"
},
{
"id": 10967,
"logprob": -0.90625,
"special": false,
"text": " estimate"
},
{
"id": 563,
"logprob": -0.49414062,
"special": false,
"text": " is"
},
{
"id": 600,
"logprob": -0.65234375,
"special": false,
"text": " that"
},
{
"id": 236743,
"logprob": -1.2109375,
"special": false,
"text": " "
},
{
"id": 236825,
"logprob": -0.00088119507,
"special": false,
"text": "6"
},
{
"id": 236832,
"logprob": -6.580353e-05,
"special": false,
"text": "7"
},
{
"id": 236810,
"logprob": -5.2690506e-05,
"special": false,
"text": "5"
},
{
"id": 236764,
"logprob": -0.0001745224,
"special": false,
"text": ","
},
{
"id": 236771,
"logprob": -1.180172e-05,
"special": false,
"text": "0"
},
{
"id": 236771,
"logprob": -1.7881393e-06,
"special": false,
"text": "0"
},
{
"id": 236771,
"logprob": 0.0,
"special": false,
"text": "0"
},
{
"id": 1331,
"logprob": -0.44921875,
"special": false,
"text": " people"
},
{
"id": 8390,
"logprob": -0.011474609,
"special": false,
"text": " died"
},
{
"id": 528,
"logprob": -0.084472656,
"special": false,
"text": " in"
},
{
"id": 506,
"logprob": -0.00034713745,
"special": false,
"text": " the"
},
{
"id": 3640,
"logprob": -0.028564453,
"special": false,
"text": " United"
},
{
"id": 4184,
"logprob": -0.00012207031,
"special": false,
"text": " States"
},
{
"id": 236761,
"logprob": -1.15625,
"special": false,
"text": "."
},
{
"id": 3153,
"logprob": -0.103027344,
"special": false,
"text": " However"
},
{
"id": 236764,
"logprob": -0.009155273,
"special": false,
"text": ","
},
{
"id": 1070,
"logprob": -0.92578125,
"special": false,
"text": " some"
},
{
"id": 61806,
"logprob": -0.91796875,
"special": false,
"text": " historians"
},
{
"id": 4646,
"logprob": -1.3828125,
"special": false,
"text": " believe"
},
{
"id": 506,
"logprob": -0.65234375,
"special": false,
"text": " the"
},
{
"id": 5396,
"logprob": -0.8046875,
"special": false,
"text": " actual"
},
{
"id": 1548,
"logprob": -0.04321289,
"special": false,
"text": " number"
},
{
"id": 1451,
"logprob": -0.66015625,
"special": false,
"text": " could"
},
{
"id": 577,
"logprob": -0.091308594,
"special": false,
"text": " be"
},
{
"id": 618,
"logprob": -0.57421875,
"special": false,
"text": " as"
},
{
"id": 1494,
"logprob": -0.00036239624,
"special": false,
"text": " high"
},
{
"id": 618,
"logprob": -0.0001335144,
"special": false,
"text": " as"
},
{
"id": 236743,
"logprob": -0.0009689331,
"special": false,
"text": " "
},
{
"id": 236770,
"logprob": -0.26367188,
"special": false,
"text": "1"
},
{
"id": 236771,
"logprob": -0.17773438,
"special": false,
"text": "0"
},
{
"id": 3625,
"logprob": -0.012084961,
"special": false,
"text": " million"
},
{
"id": 236761,
"logprob": -0.21289062,
"special": false,
"text": "."
},
{
"id": 108,
"logprob": -0.37304688,
"special": false,
"text": "\n\n"
},
{
"id": 236777,
"logprob": -1.078125,
"special": false,
"text": "I"
},
{
"id": 1006,
"logprob": -1.3203125,
"special": false,
"text": " am"
},
{
"id": 3182,
"logprob": -1.078125,
"special": false,
"text": " looking"
},
{
"id": 573,
"logprob": -0.035888672,
"special": false,
"text": " for"
},
{
"id": 919,
"logprob": -1.25,
"special": false,
"text": " more"
},
{
"id": 1938,
"logprob": -1.2421875,
"special": false,
"text": " information"
},
{
"id": 580,
"logprob": -0.7734375,
"special": false,
"text": " on"
},
{
"id": 672,
"logprob": -0.73046875,
"special": false,
"text": " this"
},
{
"id": 59725,
"logprob": -0.75,
"special": false,
"text": " discrepancy"
},
{
"id": 532,
"logprob": -0.83984375,
"special": false,
"text": " and"
},
{
"id": 506,
"logprob": -0.7109375,
"special": false,
"text": " the"
},
{
"id": 5872,
"logprob": -1.2734375,
"special": false,
"text": " factors"
},
{
"id": 600,
"logprob": -0.22851562,
"special": false,
"text": " that"
},
{
"id": 19263,
"logprob": -1.1640625,
"special": false,
"text": " contributed"
},
{
"id": 531,
"logprob": -0.0010757446,
"special": false,
"text": " to"
},
{
"id": 506,
"logprob": -0.18945312,
"special": false,
"text": " the"
},
{
"id": 5777,
"logprob": -1.2734375,
"special": false,
"text": " wide"
},
{
"id": 2644,
"logprob": -0.01940918,
"special": false,
"text": " range"
},
{
"id": 529,
"logprob": -0.14550781,
"special": false,
"text": " of"
},
{
"id": 14287,
"logprob": -0.032470703,
"special": false,
"text": " estimates"
},
{
"id": 236761,
"logprob": -0.010375977,
"special": false,
"text": "."
},
{
"id": 108,
"logprob": -0.06591797,
"special": false,
"text": "\n\n"
},
{
"id": 8291,
"logprob": -0.8046875,
"special": false,
"text": "Here"
},
{
"id": 236789,
"logprob": -0.23828125,
"special": false,
"text": "'"
},
{
"id": 236751,
"logprob": -1.0728836e-06,
"special": false,
"text": "s"
},
{
"id": 496,
"logprob": -0.17480469,
"special": false,
"text": " a"
},
{
"id": 25890,
"logprob": -0.087402344,
"special": false,
"text": " breakdown"
},
{
"id": 529,
"logprob": -0.0021209717,
"special": false,
"text": " of"
},
{
"id": 506,
"logprob": -0.19140625,
"special": false,
"text": " the"
},
{
"id": 5872,
"logprob": -1.0078125,
"special": false,
"text": " factors"
},
{
"id": 20894,
"logprob": -0.26367188,
"special": false,
"text": " contributing"
},
{
"id": 531,
"logprob": -9.250641e-05,
"special": false,
"text": " to"
},
{
"id": 506,
"logprob": -0.008666992,
"special": false,
"text": " the"
},
{
"id": 5777,
"logprob": -0.6171875,
"special": false,
"text": " wide"
},
{
"id": 2644,
"logprob": -0.0023956299,
"special": false,
"text": " range"
},
{
"id": 529,
"logprob": -0.016723633,
"special": false,
"text": " of"
},
{
"id": 14287,
"logprob": -0.011352539,
"special": false,
"text": " estimates"
},
{
"id": 573,
"logprob": -0.30664062,
"special": false,
"text": " for"
},
{
"id": 506,
"logprob": -0.21386719,
"special": false,
"text": " the"
},
{
"id": 236743,
"logprob": -0.35351562,
"special": false,
"text": " "
},
{
"id": 236770,
"logprob": -3.5762787e-07,
"special": false,
"text": "1"
},
{
"id": 236819,
"logprob": 0.0,
"special": false,
"text": "9"
},
{
"id": 236770,
"logprob": 0.0,
"special": false,
"text": "1"
},
{
"id": 236828,
"logprob": 0.0,
"special": false,
"text": "8"
},
{
"id": 7745,
"logprob": -0.70703125,
"special": false,
"text": " flu"
},
{
"id": 10248,
"logprob": -0.015258789,
"special": false,
"text": " pandemic"
},
{
"id": 4355,
"logprob": -0.83203125,
"special": false,
"text": " death"
},
{
"id": 25363,
"logprob": -7.43866e-05,
"special": false,
"text": " toll"
},
{
"id": 528,
"logprob": -0.08496094,
"special": false,
"text": " in"
},
{
"id": 506,
"logprob": -6.67572e-06,
"special": false,
"text": " the"
},
{
"id": 3640,
"logprob": -0.0059509277,
"special": false,
"text": " United"
},
{
"id": 4184,
"logprob": 0.0,
"special": false,
"text": " States"
}
],
"top_tokens": null
},
"generated_text": " people died in the United States.\n\nThe generally accepted estimate is that 675,000 people died in the United States. However, some historians believe the actual number could be as high as 10 million.\n\nI am looking for more information on this discrepancy and the factors that contributed to the wide range of estimates.\n\nHere's a breakdown of the factors contributing to the wide range of estimates for the 1918 flu pandemic death toll in the United States"
}

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "Okay, let's analyze the image.\n\nThe image is a solid, bright white color. There is nothing else visible within it. \n\nIt's essentially a blank white canvas or a completely white square. \n\nIs there anything specific you'd like me to do with this image, such as describe it further or imagine what it might represent?",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1741965894,
"id": "",
"model": "google/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.2.1-dev0-native",
"usage": {
"completion_tokens": 74,
"prompt_tokens": 277,
"total_tokens": 351
}
}

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "Okay, let's analyze the image. \n\nThe image is entirely white, with a very subtle, faint outline of a stylized, cartoonish figure. It appears to be a simplified depiction of a person, likely a child, with a wide-eyed expression and a small, rounded body. \n\nIt's almost like a minimalist, iconic representation. \n\nDo you want me to try and describe it in more detail or perhaps speculate about the context of the image?",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1741965892,
"id": "",
"model": "google/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.2.1-dev0-native",
"usage": {
"completion_tokens": 98,
"prompt_tokens": 277,
"total_tokens": 375
}
}

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "Okay, let's analyze the image. \n\nThe transparent image reveals a stylized depiction of **a human head**. It's a minimalist, geometric representation, showing the basic shapes of the skull, eye sockets, and head outline. \n\nDo you want me to describe any specific element of the image in more detail?",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1741966313,
"id": "",
"model": "google/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.2.1-dev0-native",
"usage": {
"completion_tokens": 67,
"prompt_tokens": 277,
"total_tokens": 344
}
}

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "Here's a description of what's shown in the image:\n\nThe image depicts a brown cow standing on a sandy beach. The beach has turquoise water and a distant island visible in the background. The sky is bright blue with some white clouds. \n\nIt's a quite a humorous and unusual scene a cow enjoying a day at the beach!",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1741964480,
"id": "",
"model": "google/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.2.1-dev0-native",
"usage": {
"completion_tokens": 74,
"prompt_tokens": 275,
"total_tokens": 349
}
}

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "That's a fantastic question! However, the image doesn't show a dog. It shows a **Brown Swiss cow** standing on a beach. \n\nBrown Swiss cows are known for their reddish-brown color and distinctive white markings. \n\nIf you'd like, you can send me another image and Ill do my best to identify it!",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1741964477,
"id": "",
"model": "google/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.2.1-dev0-native",
"usage": {
"completion_tokens": 75,
"prompt_tokens": 279,
"total_tokens": 354
}
}

View File

@ -10,7 +10,7 @@
"tool_calls": [
{
"function": {
"arguments": "{\"format\":\"fahrenheit\",\"location\":\"Brooklyn, NY\"}",
"arguments": "{\"location\":\"Brooklyn, NY\",\"format\":\"fahrenheit\"}",
"description": null,
"name": "get_current_weather"
},
@ -21,7 +21,7 @@
}
}
],
"created": 1741263682,
"created": 1741372434,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion",

View File

@ -10,7 +10,7 @@
"tool_calls": [
{
"function": {
"arguments": "{\"format\":\"fahrenheit\",\"location\":\"Brooklyn, NY\"}",
"arguments": "{\"location\":\"Brooklyn, NY\",\"format\":\"fahrenheit\"}",
"description": null,
"name": "get_current_weather"
},
@ -21,7 +21,7 @@
}
}
],
"created": 1741263684,
"created": 1741372657,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion",

View File

@ -8,10 +8,10 @@
"tool_calls": [
{
"function": {
"arguments": "{\"",
"name": null
"arguments": "{",
"name": "get_current_weather"
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -22,187 +22,7 @@
"logprobs": null
}
],
"created": 1741263685,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "function",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741263685,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "\":",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741263685,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": " {\"",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741263685,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "_",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741263685,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "name",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741263685,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "\":",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741263685,
"created": 1741688515,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -221,7 +41,7 @@
"arguments": " \"",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -232,157 +52,7 @@
"logprobs": null
}
],
"created": 1741263685,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "get",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741263685,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "_current",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741263685,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "_weather",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741263685,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "\",",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741263685,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": " \"",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741263685,
"created": 1741688515,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -401,7 +71,7 @@
"arguments": "location",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -412,7 +82,7 @@
"logprobs": null
}
],
"created": 1741263685,
"created": 1741688515,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -431,7 +101,7 @@
"arguments": "\":",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -442,7 +112,7 @@
"logprobs": null
}
],
"created": 1741263685,
"created": 1741688515,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -461,7 +131,7 @@
"arguments": " \"",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -472,7 +142,7 @@
"logprobs": null
}
],
"created": 1741263685,
"created": 1741688515,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -488,10 +158,10 @@
"tool_calls": [
{
"function": {
"arguments": "Paris",
"arguments": "Bro",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -502,7 +172,37 @@
"logprobs": null
}
],
"created": 1741263685,
"created": 1741688515,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "oklyn",
"name": null
},
"id": "0",
"index": 0,
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741688515,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -521,7 +221,7 @@
"arguments": ",",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -532,7 +232,7 @@
"logprobs": null
}
],
"created": 1741263685,
"created": 1741688515,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -548,10 +248,10 @@
"tool_calls": [
{
"function": {
"arguments": " France",
"arguments": " NY",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -562,7 +262,7 @@
"logprobs": null
}
],
"created": 1741263685,
"created": 1741688515,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -581,7 +281,7 @@
"arguments": "\",",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -592,7 +292,7 @@
"logprobs": null
}
],
"created": 1741263685,
"created": 1741688515,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -611,7 +311,7 @@
"arguments": " \"",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -622,7 +322,7 @@
"logprobs": null
}
],
"created": 1741263685,
"created": 1741688515,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -641,7 +341,7 @@
"arguments": "format",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -652,7 +352,7 @@
"logprobs": null
}
],
"created": 1741263685,
"created": 1741688515,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -671,7 +371,7 @@
"arguments": "\":",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -682,7 +382,7 @@
"logprobs": null
}
],
"created": 1741263685,
"created": 1741688515,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -701,7 +401,7 @@
"arguments": " \"",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -712,7 +412,7 @@
"logprobs": null
}
],
"created": 1741263685,
"created": 1741688515,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -728,10 +428,10 @@
"tool_calls": [
{
"function": {
"arguments": "c",
"arguments": "f",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -742,7 +442,7 @@
"logprobs": null
}
],
"created": 1741263685,
"created": 1741688515,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -758,10 +458,10 @@
"tool_calls": [
{
"function": {
"arguments": "elsius",
"arguments": "ahrenheit",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -772,7 +472,7 @@
"logprobs": null
}
],
"created": 1741263685,
"created": 1741688515,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -788,10 +488,10 @@
"tool_calls": [
{
"function": {
"arguments": "\"}}",
"arguments": "\"}",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -802,37 +502,7 @@
"logprobs": null
}
],
"created": 1741263685,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "<|eot_id|>",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": "stop",
"index": 0,
"logprobs": null
}
],
"created": 1741263685,
"created": 1741688515,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",

View File

@ -5,20 +5,20 @@
"index": 0,
"logprobs": null,
"message": {
"content": "I am a helpful assistant!",
"content": "I'm an artificial intelligence model known as a large language model (LLM) or conversational AI",
"role": "assistant",
"tool_calls": null
}
}
],
"created": 1741263686,
"created": 1741693957,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.1.2-dev0-native",
"usage": {
"completion_tokens": 23,
"prompt_tokens": 494,
"total_tokens": 517
"completion_tokens": 12,
"prompt_tokens": 53,
"total_tokens": 65
}
}

View File

@ -12,7 +12,7 @@
"logprobs": null
}
],
"created": 1741263687,
"created": 1741694017,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -23,7 +23,7 @@
"choices": [
{
"delta": {
"content": " am",
"content": "'m",
"role": "assistant",
"tool_calls": null
},
@ -32,7 +32,127 @@
"logprobs": null
}
],
"created": 1741263687,
"created": 1741694017,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": " an",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741694017,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": " artificial",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741694017,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": " intelligence",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741694017,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": " model",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741694017,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": " known",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741694017,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": " as",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741694017,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -52,7 +172,7 @@
"logprobs": null
}
],
"created": 1741263687,
"created": 1741694017,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -63,7 +183,7 @@
"choices": [
{
"delta": {
"content": " helpful",
"content": " large",
"role": "assistant",
"tool_calls": null
},
@ -72,7 +192,7 @@
"logprobs": null
}
],
"created": 1741263687,
"created": 1741694017,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -83,7 +203,7 @@
"choices": [
{
"delta": {
"content": " assistant",
"content": " language",
"role": "assistant",
"tool_calls": null
},
@ -92,7 +212,187 @@
"logprobs": null
}
],
"created": 1741263687,
"created": 1741694017,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": " model",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741694017,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": " (",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741694017,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": "LL",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741694017,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": "M",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741694017,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": ")",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741694017,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": " or",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741694017,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": " convers",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741694017,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": "ational",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741694017,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": " AI",
"role": "assistant",
"tool_calls": null
},
"finish_reason": "length",
"index": 0,
"logprobs": null
}
],
"created": 1741694017,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",

View File

@ -10,7 +10,7 @@
"tool_calls": [
{
"function": {
"arguments": "{\"format\":\"fahrenheit\",\"location\":\"Brooklyn, NY\"}",
"arguments": "{\"location\":\"Brooklyn, NY\",\"format\":\"fahrenheit\"}",
"description": null,
"name": "get_current_weather"
},
@ -21,7 +21,7 @@
}
}
],
"created": 1741263680,
"created": 1741372335,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion",

View File

@ -10,10 +10,10 @@
"tool_calls": [
{
"function": {
"arguments": "{\"",
"name": null
"arguments": "{",
"name": "get_current_weather"
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -24,205 +24,7 @@
"logprobs": null
}
],
"created": 1741263681,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "function",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741263681,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "\":",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741263681,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": " {\"",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741263681,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "_",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741263681,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "name",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741263681,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "\":",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741263681,
"created": 1741689423,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -244,7 +46,7 @@
"arguments": " \"",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -255,172 +57,7 @@
"logprobs": null
}
],
"created": 1741263681,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "get",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741263681,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "_current",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741263681,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "_weather",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741263681,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "\",",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741263681,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": " \"",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1741263681,
"created": 1741689423,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -442,7 +79,7 @@
"arguments": "location",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -453,7 +90,7 @@
"logprobs": null
}
],
"created": 1741263681,
"created": 1741689423,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -475,7 +112,7 @@
"arguments": "\":",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -486,7 +123,7 @@
"logprobs": null
}
],
"created": 1741263681,
"created": 1741689423,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -508,7 +145,7 @@
"arguments": " \"",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -519,7 +156,7 @@
"logprobs": null
}
],
"created": 1741263681,
"created": 1741689423,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -541,7 +178,7 @@
"arguments": "Bro",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -552,7 +189,7 @@
"logprobs": null
}
],
"created": 1741263681,
"created": 1741689423,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -574,7 +211,7 @@
"arguments": "oklyn",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -585,7 +222,7 @@
"logprobs": null
}
],
"created": 1741263681,
"created": 1741689423,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -607,7 +244,7 @@
"arguments": ",",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -618,7 +255,7 @@
"logprobs": null
}
],
"created": 1741263681,
"created": 1741689423,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -640,7 +277,7 @@
"arguments": " NY",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -651,7 +288,7 @@
"logprobs": null
}
],
"created": 1741263681,
"created": 1741689423,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -673,7 +310,7 @@
"arguments": "\",",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -684,7 +321,7 @@
"logprobs": null
}
],
"created": 1741263681,
"created": 1741689423,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -706,7 +343,7 @@
"arguments": " \"",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -717,7 +354,7 @@
"logprobs": null
}
],
"created": 1741263681,
"created": 1741689423,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -739,7 +376,7 @@
"arguments": "format",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -750,7 +387,7 @@
"logprobs": null
}
],
"created": 1741263681,
"created": 1741689423,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -772,7 +409,7 @@
"arguments": "\":",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -783,7 +420,7 @@
"logprobs": null
}
],
"created": 1741263681,
"created": 1741689423,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -805,7 +442,7 @@
"arguments": " \"",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -816,7 +453,7 @@
"logprobs": null
}
],
"created": 1741263681,
"created": 1741689423,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -838,7 +475,7 @@
"arguments": "f",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -849,7 +486,7 @@
"logprobs": null
}
],
"created": 1741263681,
"created": 1741689423,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -871,7 +508,7 @@
"arguments": "ahrenheit",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -882,7 +519,7 @@
"logprobs": null
}
],
"created": 1741263681,
"created": 1741689423,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -901,10 +538,10 @@
"tool_calls": [
{
"function": {
"arguments": "\"}}",
"arguments": "\"}",
"name": null
},
"id": "",
"id": "0",
"index": 0,
"type": "function"
}
@ -915,40 +552,7 @@
"logprobs": null
}
],
"created": 1741263681,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": null,
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "<|eot_id|>",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
]
},
"finish_reason": "stop",
"index": 0,
"logprobs": null
}
],
"created": 1741263681,
"created": 1741689423,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",

View File

@ -0,0 +1,170 @@
import base64
from io import BytesIO
from PIL import Image
import pytest
@pytest.fixture(scope="module")
def flash_gemma3_handle(launcher):
with launcher("google/gemma-3-4b-it", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_gemma3(flash_gemma3_handle):
await flash_gemma3_handle.health(300)
return flash_gemma3_handle.client
async def test_flash_gemma3(flash_gemma3, response_snapshot):
response = await flash_gemma3.generate(
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
seed=42,
max_new_tokens=100,
)
assert (
response.generated_text
== " people died in the United States.\n\nThe generally accepted estimate is that 675,000 people died in the United States. However, some historians believe the actual number could be as high as 10 million.\n\nI am looking for more information on this discrepancy and the factors that contributed to the wide range of estimates.\n\nHere's a breakdown of the factors contributing to the wide range of estimates for the 1918 flu pandemic death toll in the United States"
)
assert response.details.generated_tokens == 100
assert response == response_snapshot
async def test_flash_gemma3_image_cow_dog(flash_gemma3, response_snapshot):
image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
response = await flash_gemma3.chat(
seed=42,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{
"type": "text",
"text": "What is the breed of the dog in the image?",
},
],
},
],
max_tokens=100,
)
assert (
response.choices[0].message.content
== "That's a fantastic question! However, the image doesn't show a dog. It shows a **Brown Swiss cow** standing on a beach. \n\nBrown Swiss cows are known for their reddish-brown color and distinctive white markings. \n\nIf you'd like, you can send me another image and Ill do my best to identify it!"
)
assert response.usage["completion_tokens"] == 75
assert response == response_snapshot
async def test_flash_gemma3_image_cow(flash_gemma3, response_snapshot):
image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
response = await flash_gemma3.chat(
seed=42,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": "What is shown in this image?"},
],
},
],
max_tokens=100,
)
assert (
response.choices[0].message.content
== "Here's a description of what's shown in the image:\n\nThe image depicts a brown cow standing on a sandy beach. The beach has turquoise water and a distant island visible in the background. The sky is bright blue with some white clouds. \n\nIt's a quite a humorous and unusual scene a cow enjoying a day at the beach!"
)
assert response.usage["completion_tokens"] == 74
assert response == response_snapshot
async def test_exceed_window(flash_gemma3, response_snapshot):
response = await flash_gemma3.generate(
"This is a nice place. " * 800 + "I really enjoy the scenery,",
seed=42,
max_new_tokens=20,
)
assert (
response.generated_text
== " the people, and the food.\n\nThis is a nice place.\n"
)
assert response.details.generated_tokens == 16
assert response == response_snapshot
# Helper function to convert a Pillow image to a base64 data URL
def image_to_data_url(img: Image.Image, fmt: str) -> str:
buffer = BytesIO()
img.save(buffer, format=fmt)
img_data = buffer.getvalue()
b64_str = base64.b64encode(img_data).decode("utf-8")
mime_type = "image/png" if fmt.upper() == "PNG" else "image/jpeg"
return f"data:{mime_type};base64,{b64_str}"
async def test_flash_gemma3_image_base64_rgba(flash_gemma3, response_snapshot):
# Create an empty 100x100 PNG image with alpha (transparent background)
img = Image.new("RGBA", (100, 100), (0, 0, 0, 0))
data_url = image_to_data_url(img, "PNG")
response = await flash_gemma3.chat(
seed=42,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_url}},
{
"type": "text",
"text": "What do you see in this transparent image?",
},
],
},
],
max_tokens=100,
)
assert response == response_snapshot
async def test_flash_gemma3_image_base64_rgb_png(flash_gemma3, response_snapshot):
# Create an empty 100x100 PNG image without alpha (white background)
img = Image.new("RGB", (100, 100), (255, 255, 255))
data_url = image_to_data_url(img, "PNG")
response = await flash_gemma3.chat(
seed=42,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_url}},
{"type": "text", "text": "What do you see in this plain image?"},
],
},
],
max_tokens=100,
)
assert response == response_snapshot
async def test_flash_gemma3_image_base64_rgb_jpg(flash_gemma3, response_snapshot):
# Create an empty 100x100 JPEG image (white background)
img = Image.new("RGB", (100, 100), (255, 255, 255))
data_url = image_to_data_url(img, "JPEG")
response = await flash_gemma3.chat(
seed=42,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_url}},
{"type": "text", "text": "What do you see in this JPEG image?"},
],
},
],
max_tokens=100,
)
assert response == response_snapshot

View File

@ -108,7 +108,7 @@ async def test_flash_llama_grammar_tools_nostream(
function=ChatCompletionOutputFunctionDefinition(
description=None,
name="get_current_weather",
arguments='{"format":"fahrenheit","location":"Brooklyn, NY"}',
arguments='{"location":"Brooklyn, NY","format":"fahrenheit"}',
),
)
]
@ -142,14 +142,15 @@ async def test_flash_llama_grammar_tools_openai(
chunks = []
tool = ""
name = ""
for chunk in stream:
if chunk.choices[0].delta.tool_calls[0].function.name:
name += chunk.choices[0].delta.tool_calls[0].function.name
tool += chunk.choices[0].delta.tool_calls[0].function.arguments
chunks.append(chunk)
assert (
tool
== '{"function": {"_name": "get_current_weather", "location": "Brooklyn, NY", "format": "fahrenheit"}}<|eot_id|>'
)
assert name == "get_current_weather"
assert tool == '{ "location": "Brooklyn, NY", "format": "fahrenheit"}'
assert chunks == response_snapshot
@ -184,7 +185,7 @@ async def test_flash_llama_grammar_tools_auto_nostream(
function=ChatCompletionOutputFunctionDefinition(
description=None,
name="get_current_weather",
arguments='{"format":"fahrenheit","location":"Brooklyn, NY"}',
arguments='{"location":"Brooklyn, NY","format":"fahrenheit"}',
),
)
]
@ -223,7 +224,7 @@ async def test_flash_llama_grammar_tools_choice_nostream(
function=ChatCompletionOutputFunctionDefinition(
description=None,
name="get_current_weather",
arguments='{"format":"fahrenheit","location":"Brooklyn, NY"}',
arguments='{"location":"Brooklyn, NY","format":"fahrenheit"}',
),
)
]
@ -250,23 +251,24 @@ async def test_flash_llama_grammar_tools_choice_stream(
},
{
"role": "user",
"content": "What is the weather like in Paris, France?",
"content": "What is the weather like in Brooklyn, New York?",
},
],
stream=True,
)
tool_calls_generated = ""
arguments = ""
chunks = []
name = ""
for chunk in stream:
tool_calls_generated += chunk.choices[0].delta.tool_calls[0].function.arguments
if chunk.choices[0].delta.tool_calls[0].function.name:
name += chunk.choices[0].delta.tool_calls[0].function.name
arguments += chunk.choices[0].delta.tool_calls[0].function.arguments
assert chunk.choices[0].delta.content is None
chunks.append(chunk)
assert (
tool_calls_generated
== '{"function": {"_name": "get_current_weather", "location": "Paris, France", "format": "celsius"}}<|eot_id|>'
)
assert name == "get_current_weather"
assert arguments == '{ "location": "Brooklyn, NY", "format": "fahrenheit"}'
assert chunks == response_snapshot
@ -277,7 +279,7 @@ async def test_flash_llama_grammar_tools_insufficient_information_nostream(
):
client = InferenceClient(base_url=f"{flash_llama_grammar_tools.base_url}/v1")
response = client.chat_completion(
max_tokens=100,
max_tokens=20,
seed=24,
tools=tools,
tool_choice="auto",
@ -297,9 +299,10 @@ async def test_flash_llama_grammar_tools_insufficient_information_nostream(
content_generated = response.choices[0].message.content
assert response.choices[0].message.tool_calls is None
######## FIXME before MERGE ############################
# TODO This is different from the streaming case, this is NOT normal.
assert content_generated == "I am a helpful assistant!"
assert (
content_generated
== "I'm an artificial intelligence model known as a large language model (LLM) or conversational AI"
)
assert response == response_snapshot
@ -310,7 +313,7 @@ async def test_flash_llama_grammar_tools_insufficient_information_stream(
):
client = InferenceClient(base_url=f"{flash_llama_grammar_tools.base_url}/v1")
stream = client.chat_completion(
max_tokens=100,
max_tokens=20,
seed=24,
tools=tools,
tool_choice="auto",
@ -334,7 +337,11 @@ async def test_flash_llama_grammar_tools_insufficient_information_stream(
chunks.append(chunk)
assert chunk.choices[0].delta.tool_calls is None
assert content_generated == "I am a helpful assistant"
######## This is exactly the same as the non streaming case
assert (
content_generated
== "I'm an artificial intelligence model known as a large language model (LLM) or conversational AI"
)
assert chunks == response_snapshot
@ -345,7 +352,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_auto(
):
client = InferenceClient(base_url=f"{flash_llama_grammar_tools.base_url}/v1")
stream = client.chat_completion(
max_tokens=100,
max_tokens=20,
seed=24,
tools=tools,
tool_choice="auto",
@ -371,7 +378,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_auto(
assert (
content_generated
== "There was a wise old octopus named Oracle. He lived in a cozy little cave beneath the waves with his best friend, a curious seahorse named Finley. One day, Finley met a playful dolphin named Daisy, and the three became inseparable. They spent their days exploring the ocean, playing hide-and-seek, and learning about the wonders of the sea from Oracle"
== "Once upon a time, in a vibrant ocean filled with coral reefs and schools of shimmering fish,"
)
assert chunks == response_snapshot
@ -401,14 +408,18 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_required(
)
tool_calls_generated = ""
name = ""
chunks = []
for chunk in stream:
assert chunk.choices[0].delta.content is None
if chunk.choices[0].delta.tool_calls[0].function.name:
name += chunk.choices[0].delta.tool_calls[0].function.name
tool_calls_generated += chunk.choices[0].delta.tool_calls[0].function.arguments
assert name == "get_n_day_weather_forecast"
assert (
tool_calls_generated
== '{"function": {"_name": "get_n_day_weather_forecast", "location": "San Francisco, CA", "format": "fahrenheit", "num_days":3}}<|eot_id|>'
== '{ "location": "San Francisco, CA", "format": "fahrenheit", "num_days":3}'
)
assert chunks == response_snapshot
@ -479,12 +490,17 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
)
chunks = []
tool_calls_generated = ""
name = ""
for chunk in stream:
assert chunk.choices[0].delta.content is None
if chunk.choices[0].delta.tool_calls[0].function.name:
name += chunk.choices[0].delta.tool_calls[0].function.name
tool_calls_generated += chunk.choices[0].delta.tool_calls[0].function.arguments
chunks.append(chunk)
assert name == "get_n_day_weather_forecast"
assert (
tool_calls_generated
== '{"function": {"_name": "get_n_day_weather_forecast", "location": "San Francisco, CA", "format": "celsius", "num_days": 3}}<|eot_id|>'
== '{ "location": "San Francisco, CA", "format": "celsius", "num_days": 3}'
)
assert chunks == response_snapshot

View File

@ -49,17 +49,11 @@ async def test_model_single_request(tgi_service):
max_new_tokens=128,
seed=42,
)
sample_expectations = {
"gpt2": "Deep Learning",
"llama": "Deep Learning",
"mistral": "Deep learning",
"qwen2": "Deep Learning",
"granite": "Deep learning",
}
assert sample_expectations[service_name] in response
# The response must be different
assert not response.startswith(greedy_expectations[service_name])
# Sampling with stop sequence
stop_sequence = sample_expectations[service_name][-5:]
# Sampling with stop sequence (using one of the words returned from the previous test)
stop_sequence = response.split(" ")[-5]
response = await tgi_service.client.text_generation(
"What is Deep Learning?",
do_sample=True,

View File

@ -15,6 +15,7 @@ dependencies = [
"numpy>=2.0",
"openai>=1.65",
"huggingface_hub>=0.29",
"pillow>=11.1.0",
]
[tool.isort]

View File

@ -1,8 +1,8 @@
# This file was autogenerated by uv via the following command:
# uv pip compile pyproject.toml -o requirements.txt
aiohappyeyeballs==2.4.6
# uv pip compile pyproject.toml
aiohappyeyeballs==2.6.1
# via aiohttp
aiohttp==3.11.12
aiohttp==3.11.13
# via text-generation
aiosignal==1.3.2
# via aiohttp
@ -12,7 +12,7 @@ anyio==4.8.0
# via
# httpx
# openai
attrs==25.1.0
attrs==25.3.0
# via aiohttp
certifi==2025.1.31
# via
@ -25,13 +25,13 @@ distro==1.9.0
# via openai
docker==7.1.0
# via text-generation-integration-tests (pyproject.toml)
filelock==3.17.0
filelock==3.18.0
# via huggingface-hub
frozenlist==1.5.0
# via
# aiohttp
# aiosignal
fsspec==2025.2.0
fsspec==2025.3.0
# via huggingface-hub
h11==0.14.0
# via httpcore
@ -39,7 +39,7 @@ httpcore==1.0.7
# via httpx
httpx==0.28.1
# via openai
huggingface-hub==0.29.0
huggingface-hub==0.29.3
# via
# text-generation-integration-tests (pyproject.toml)
# text-generation
@ -51,7 +51,7 @@ idna==3.10
# yarl
iniconfig==2.0.0
# via pytest
jiter==0.8.2
jiter==0.9.0
# via openai
multidict==6.1.0
# via
@ -59,15 +59,17 @@ multidict==6.1.0
# yarl
numpy==2.2.3
# via text-generation-integration-tests (pyproject.toml)
openai==1.65.3
openai==1.66.3
# via text-generation-integration-tests (pyproject.toml)
packaging==24.2
# via
# huggingface-hub
# pytest
pillow==11.1.0
# via text-generation-integration-tests (pyproject.toml)
pluggy==1.5.0
# via pytest
propcache==0.2.1
propcache==0.3.0
# via
# aiohttp
# yarl
@ -78,7 +80,7 @@ pydantic==2.10.6
# text-generation
pydantic-core==2.27.2
# via pydantic
pytest==8.3.4
pytest==8.3.5
# via
# text-generation-integration-tests (pyproject.toml)
# pytest-asyncio
@ -95,7 +97,7 @@ sniffio==1.3.1
# via
# anyio
# openai
syrupy==4.8.1
syrupy==4.9.0
# via text-generation-integration-tests (pyproject.toml)
text-generation==0.7.0
# via text-generation-integration-tests (pyproject.toml)

View File

@ -97,6 +97,21 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643 },
]
[[package]]
name = "anyio"
version = "4.8.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
{ name = "idna" },
{ name = "sniffio" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a3/73/199a98fc2dae33535d6b8e8e6ec01f8c1d76c9adb096c6b7d64823038cde/anyio-4.8.0.tar.gz", hash = "sha256:1d9fe889df5212298c0c0723fa20479d1b94883a2df44bd3897aa91083316f7a", size = 181126 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/46/eb/e7f063ad1fec6b3178a3cd82d1a3c4de82cccf283fc42746168188e1cdd5/anyio-4.8.0-py3-none-any.whl", hash = "sha256:b5011f270ab5eb0abf13385f851315585cc37ef330dd88e27ec3d34d651fd47a", size = 96041 },
]
[[package]]
name = "async-timeout"
version = "5.0.1"
@ -181,6 +196,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 },
]
[[package]]
name = "distro"
version = "1.9.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/fc/f8/98eea607f65de6527f8a2e8885fc8015d3e6f5775df186e443e0964a11c3/distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed", size = 60722 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277 },
]
[[package]]
name = "docker"
version = "7.1.0"
@ -276,6 +300,43 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e2/94/758680531a00d06e471ef649e4ec2ed6bf185356a7f9fbfbb7368a40bd49/fsspec-2025.2.0-py3-none-any.whl", hash = "sha256:9de2ad9ce1f85e1931858535bc882543171d197001a0a5eb2ddc04f1781ab95b", size = 184484 },
]
[[package]]
name = "h11"
version = "0.14.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f5/38/3af3d3633a34a3316095b39c8e8fb4853a28a536e55d347bd8d8e9a14b03/h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d", size = 100418 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761", size = 58259 },
]
[[package]]
name = "httpcore"
version = "1.0.7"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "certifi" },
{ name = "h11" },
]
sdist = { url = "https://files.pythonhosted.org/packages/6a/41/d7d0a89eb493922c37d343b607bc1b5da7f5be7e383740b4753ad8943e90/httpcore-1.0.7.tar.gz", hash = "sha256:8551cb62a169ec7162ac7be8d4817d561f60e08eaa485234898414bb5a8a0b4c", size = 85196 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/87/f5/72347bc88306acb359581ac4d52f23c0ef445b57157adedb9aee0cd689d2/httpcore-1.0.7-py3-none-any.whl", hash = "sha256:a3fff8f43dc260d5bd363d9f9cf1830fa3a458b332856f34282de498ed420edd", size = 78551 },
]
[[package]]
name = "httpx"
version = "0.28.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
{ name = "certifi" },
{ name = "httpcore" },
{ name = "idna" },
]
sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517 },
]
[[package]]
name = "huggingface-hub"
version = "0.29.0"
@ -312,6 +373,50 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 },
]
[[package]]
name = "jiter"
version = "0.9.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/1e/c2/e4562507f52f0af7036da125bb699602ead37a2332af0788f8e0a3417f36/jiter-0.9.0.tar.gz", hash = "sha256:aadba0964deb424daa24492abc3d229c60c4a31bfee205aedbf1acc7639d7893", size = 162604 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b0/82/39f7c9e67b3b0121f02a0b90d433626caa95a565c3d2449fea6bcfa3f5f5/jiter-0.9.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:816ec9b60fdfd1fec87da1d7ed46c66c44ffec37ab2ef7de5b147b2fce3fd5ad", size = 314540 },
{ url = "https://files.pythonhosted.org/packages/01/07/7bf6022c5a152fca767cf5c086bb41f7c28f70cf33ad259d023b53c0b858/jiter-0.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9b1d3086f8a3ee0194ecf2008cf81286a5c3e540d977fa038ff23576c023c0ea", size = 321065 },
{ url = "https://files.pythonhosted.org/packages/6c/b2/de3f3446ecba7c48f317568e111cc112613da36c7b29a6de45a1df365556/jiter-0.9.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1339f839b91ae30b37c409bf16ccd3dc453e8b8c3ed4bd1d6a567193651a4a51", size = 341664 },
{ url = "https://files.pythonhosted.org/packages/13/cf/6485a4012af5d407689c91296105fcdb080a3538e0658d2abf679619c72f/jiter-0.9.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ffba79584b3b670fefae66ceb3a28822365d25b7bf811e030609a3d5b876f538", size = 364635 },
{ url = "https://files.pythonhosted.org/packages/0d/f7/4a491c568f005553240b486f8e05c82547340572d5018ef79414b4449327/jiter-0.9.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5cfc7d0a8e899089d11f065e289cb5b2daf3d82fbe028f49b20d7b809193958d", size = 406288 },
{ url = "https://files.pythonhosted.org/packages/d3/ca/f4263ecbce7f5e6bded8f52a9f1a66540b270c300b5c9f5353d163f9ac61/jiter-0.9.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e00a1a2bbfaaf237e13c3d1592356eab3e9015d7efd59359ac8b51eb56390a12", size = 397499 },
{ url = "https://files.pythonhosted.org/packages/ac/a2/522039e522a10bac2f2194f50e183a49a360d5f63ebf46f6d890ef8aa3f9/jiter-0.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1d9870561eb26b11448854dce0ff27a9a27cb616b632468cafc938de25e9e51", size = 352926 },
{ url = "https://files.pythonhosted.org/packages/b1/67/306a5c5abc82f2e32bd47333a1c9799499c1c3a415f8dde19dbf876f00cb/jiter-0.9.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9872aeff3f21e437651df378cb75aeb7043e5297261222b6441a620218b58708", size = 384506 },
{ url = "https://files.pythonhosted.org/packages/0f/89/c12fe7b65a4fb74f6c0d7b5119576f1f16c79fc2953641f31b288fad8a04/jiter-0.9.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:1fd19112d1049bdd47f17bfbb44a2c0001061312dcf0e72765bfa8abd4aa30e5", size = 520621 },
{ url = "https://files.pythonhosted.org/packages/c4/2b/d57900c5c06e6273fbaa76a19efa74dbc6e70c7427ab421bf0095dfe5d4a/jiter-0.9.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6ef5da104664e526836070e4a23b5f68dec1cc673b60bf1edb1bfbe8a55d0678", size = 512613 },
{ url = "https://files.pythonhosted.org/packages/89/05/d8b90bfb21e58097d5a4e0224f2940568366f68488a079ae77d4b2653500/jiter-0.9.0-cp310-cp310-win32.whl", hash = "sha256:cb12e6d65ebbefe5518de819f3eda53b73187b7089040b2d17f5b39001ff31c4", size = 206613 },
{ url = "https://files.pythonhosted.org/packages/2c/1d/5767f23f88e4f885090d74bbd2755518050a63040c0f59aa059947035711/jiter-0.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:c43ca669493626d8672be3b645dbb406ef25af3f4b6384cfd306da7eb2e70322", size = 208371 },
{ url = "https://files.pythonhosted.org/packages/23/44/e241a043f114299254e44d7e777ead311da400517f179665e59611ab0ee4/jiter-0.9.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:6c4d99c71508912a7e556d631768dcdef43648a93660670986916b297f1c54af", size = 314654 },
{ url = "https://files.pythonhosted.org/packages/fb/1b/a7e5e42db9fa262baaa9489d8d14ca93f8663e7f164ed5e9acc9f467fc00/jiter-0.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8f60fb8ce7df529812bf6c625635a19d27f30806885139e367af93f6e734ef58", size = 320909 },
{ url = "https://files.pythonhosted.org/packages/60/bf/8ebdfce77bc04b81abf2ea316e9c03b4a866a7d739cf355eae4d6fd9f6fe/jiter-0.9.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51c4e1a4f8ea84d98b7b98912aa4290ac3d1eabfde8e3c34541fae30e9d1f08b", size = 341733 },
{ url = "https://files.pythonhosted.org/packages/a8/4e/754ebce77cff9ab34d1d0fa0fe98f5d42590fd33622509a3ba6ec37ff466/jiter-0.9.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5f4c677c424dc76684fea3e7285a7a2a7493424bea89ac441045e6a1fb1d7b3b", size = 365097 },
{ url = "https://files.pythonhosted.org/packages/32/2c/6019587e6f5844c612ae18ca892f4cd7b3d8bbf49461ed29e384a0f13d98/jiter-0.9.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2221176dfec87f3470b21e6abca056e6b04ce9bff72315cb0b243ca9e835a4b5", size = 406603 },
{ url = "https://files.pythonhosted.org/packages/da/e9/c9e6546c817ab75a1a7dab6dcc698e62e375e1017113e8e983fccbd56115/jiter-0.9.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3c7adb66f899ffa25e3c92bfcb593391ee1947dbdd6a9a970e0d7e713237d572", size = 396625 },
{ url = "https://files.pythonhosted.org/packages/be/bd/976b458add04271ebb5a255e992bd008546ea04bb4dcadc042a16279b4b4/jiter-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c98d27330fdfb77913c1097a7aab07f38ff2259048949f499c9901700789ac15", size = 351832 },
{ url = "https://files.pythonhosted.org/packages/07/51/fe59e307aaebec9265dbad44d9d4381d030947e47b0f23531579b9a7c2df/jiter-0.9.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:eda3f8cc74df66892b1d06b5d41a71670c22d95a1ca2cbab73654745ce9d0419", size = 384590 },
{ url = "https://files.pythonhosted.org/packages/db/55/5dcd2693794d8e6f4889389ff66ef3be557a77f8aeeca8973a97a7c00557/jiter-0.9.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:dd5ab5ddc11418dce28343123644a100f487eaccf1de27a459ab36d6cca31043", size = 520690 },
{ url = "https://files.pythonhosted.org/packages/54/d5/9f51dc90985e9eb251fbbb747ab2b13b26601f16c595a7b8baba964043bd/jiter-0.9.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:42f8a68a69f047b310319ef8e2f52fdb2e7976fb3313ef27df495cf77bcad965", size = 512649 },
{ url = "https://files.pythonhosted.org/packages/a6/e5/4e385945179bcf128fa10ad8dca9053d717cbe09e258110e39045c881fe5/jiter-0.9.0-cp311-cp311-win32.whl", hash = "sha256:a25519efb78a42254d59326ee417d6f5161b06f5da827d94cf521fed961b1ff2", size = 206920 },
{ url = "https://files.pythonhosted.org/packages/4c/47/5e0b94c603d8e54dd1faab439b40b832c277d3b90743e7835879ab663757/jiter-0.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:923b54afdd697dfd00d368b7ccad008cccfeb1efb4e621f32860c75e9f25edbd", size = 210119 },
{ url = "https://files.pythonhosted.org/packages/af/d7/c55086103d6f29b694ec79156242304adf521577530d9031317ce5338c59/jiter-0.9.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:7b46249cfd6c48da28f89eb0be3f52d6fdb40ab88e2c66804f546674e539ec11", size = 309203 },
{ url = "https://files.pythonhosted.org/packages/b0/01/f775dfee50beb420adfd6baf58d1c4d437de41c9b666ddf127c065e5a488/jiter-0.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:609cf3c78852f1189894383cf0b0b977665f54cb38788e3e6b941fa6d982c00e", size = 319678 },
{ url = "https://files.pythonhosted.org/packages/ab/b8/09b73a793714726893e5d46d5c534a63709261af3d24444ad07885ce87cb/jiter-0.9.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d726a3890a54561e55a9c5faea1f7655eda7f105bd165067575ace6e65f80bb2", size = 341816 },
{ url = "https://files.pythonhosted.org/packages/35/6f/b8f89ec5398b2b0d344257138182cc090302854ed63ed9c9051e9c673441/jiter-0.9.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2e89dc075c1fef8fa9be219e249f14040270dbc507df4215c324a1839522ea75", size = 364152 },
{ url = "https://files.pythonhosted.org/packages/9b/ca/978cc3183113b8e4484cc7e210a9ad3c6614396e7abd5407ea8aa1458eef/jiter-0.9.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:04e8ffa3c353b1bc4134f96f167a2082494351e42888dfcf06e944f2729cbe1d", size = 406991 },
{ url = "https://files.pythonhosted.org/packages/13/3a/72861883e11a36d6aa314b4922125f6ae90bdccc225cd96d24cc78a66385/jiter-0.9.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:203f28a72a05ae0e129b3ed1f75f56bc419d5f91dfacd057519a8bd137b00c42", size = 395824 },
{ url = "https://files.pythonhosted.org/packages/87/67/22728a86ef53589c3720225778f7c5fdb617080e3deaed58b04789418212/jiter-0.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fca1a02ad60ec30bb230f65bc01f611c8608b02d269f998bc29cca8619a919dc", size = 351318 },
{ url = "https://files.pythonhosted.org/packages/69/b9/f39728e2e2007276806d7a6609cda7fac44ffa28ca0d02c49a4f397cc0d9/jiter-0.9.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:237e5cee4d5d2659aaf91bbf8ec45052cc217d9446070699441a91b386ae27dc", size = 384591 },
{ url = "https://files.pythonhosted.org/packages/eb/8f/8a708bc7fd87b8a5d861f1c118a995eccbe6d672fe10c9753e67362d0dd0/jiter-0.9.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:528b6b71745e7326eed73c53d4aa57e2a522242320b6f7d65b9c5af83cf49b6e", size = 520746 },
{ url = "https://files.pythonhosted.org/packages/95/1e/65680c7488bd2365dbd2980adaf63c562d3d41d3faac192ebc7ef5b4ae25/jiter-0.9.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:9f48e86b57bc711eb5acdfd12b6cb580a59cc9a993f6e7dcb6d8b50522dcd50d", size = 512754 },
{ url = "https://files.pythonhosted.org/packages/78/f3/fdc43547a9ee6e93c837685da704fb6da7dba311fc022e2766d5277dfde5/jiter-0.9.0-cp312-cp312-win32.whl", hash = "sha256:699edfde481e191d81f9cf6d2211debbfe4bd92f06410e7637dffb8dd5dfde06", size = 207075 },
{ url = "https://files.pythonhosted.org/packages/cd/9d/742b289016d155f49028fe1bfbeb935c9bf0ffeefdf77daf4a63a42bb72b/jiter-0.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:099500d07b43f61d8bd780466d429c45a7b25411b334c60ca875fa775f68ccb0", size = 207999 },
]
[[package]]
name = "multidict"
version = "6.1.0"
@ -411,6 +516,25 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/17/7f/d322a4125405920401450118dbdc52e0384026bd669939484670ce8b2ab9/numpy-2.2.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:783145835458e60fa97afac25d511d00a1eca94d4a8f3ace9fe2043003c678e4", size = 12839607 },
]
[[package]]
name = "openai"
version = "1.66.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
{ name = "distro" },
{ name = "httpx" },
{ name = "jiter" },
{ name = "pydantic" },
{ name = "sniffio" },
{ name = "tqdm" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a3/77/5172104ca1df35ed2ed8fb26dbc787f721c39498fc51d666c4db07756a0c/openai-1.66.3.tar.gz", hash = "sha256:8dde3aebe2d081258d4159c4cb27bdc13b5bb3f7ea2201d9bd940b9a89faf0c9", size = 397244 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/78/5a/e20182f7b6171642d759c548daa0ba20a1d3ac10d2bd0a13fd75704a9ac3/openai-1.66.3-py3-none-any.whl", hash = "sha256:a427c920f727711877ab17c11b95f1230b27767ba7a01e5b66102945141ceca9", size = 567400 },
]
[[package]]
name = "packaging"
version = "24.2"
@ -420,6 +544,54 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", size = 65451 },
]
[[package]]
name = "pillow"
version = "11.1.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f3/af/c097e544e7bd278333db77933e535098c259609c4eb3b85381109602fb5b/pillow-11.1.0.tar.gz", hash = "sha256:368da70808b36d73b4b390a8ffac11069f8a5c85f29eff1f1b01bcf3ef5b2a20", size = 46742715 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/50/1c/2dcea34ac3d7bc96a1fd1bd0a6e06a57c67167fec2cff8d95d88229a8817/pillow-11.1.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:e1abe69aca89514737465752b4bcaf8016de61b3be1397a8fc260ba33321b3a8", size = 3229983 },
{ url = "https://files.pythonhosted.org/packages/14/ca/6bec3df25e4c88432681de94a3531cc738bd85dea6c7aa6ab6f81ad8bd11/pillow-11.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c640e5a06869c75994624551f45e5506e4256562ead981cce820d5ab39ae2192", size = 3101831 },
{ url = "https://files.pythonhosted.org/packages/d4/2c/668e18e5521e46eb9667b09e501d8e07049eb5bfe39d56be0724a43117e6/pillow-11.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a07dba04c5e22824816b2615ad7a7484432d7f540e6fa86af60d2de57b0fcee2", size = 4314074 },
{ url = "https://files.pythonhosted.org/packages/02/80/79f99b714f0fc25f6a8499ecfd1f810df12aec170ea1e32a4f75746051ce/pillow-11.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e267b0ed063341f3e60acd25c05200df4193e15a4a5807075cd71225a2386e26", size = 4394933 },
{ url = "https://files.pythonhosted.org/packages/81/aa/8d4ad25dc11fd10a2001d5b8a80fdc0e564ac33b293bdfe04ed387e0fd95/pillow-11.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:bd165131fd51697e22421d0e467997ad31621b74bfc0b75956608cb2906dda07", size = 4353349 },
{ url = "https://files.pythonhosted.org/packages/84/7a/cd0c3eaf4a28cb2a74bdd19129f7726277a7f30c4f8424cd27a62987d864/pillow-11.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:abc56501c3fd148d60659aae0af6ddc149660469082859fa7b066a298bde9482", size = 4476532 },
{ url = "https://files.pythonhosted.org/packages/8f/8b/a907fdd3ae8f01c7670dfb1499c53c28e217c338b47a813af8d815e7ce97/pillow-11.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:54ce1c9a16a9561b6d6d8cb30089ab1e5eb66918cb47d457bd996ef34182922e", size = 4279789 },
{ url = "https://files.pythonhosted.org/packages/6f/9a/9f139d9e8cccd661c3efbf6898967a9a337eb2e9be2b454ba0a09533100d/pillow-11.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:73ddde795ee9b06257dac5ad42fcb07f3b9b813f8c1f7f870f402f4dc54b5269", size = 4413131 },
{ url = "https://files.pythonhosted.org/packages/a8/68/0d8d461f42a3f37432203c8e6df94da10ac8081b6d35af1c203bf3111088/pillow-11.1.0-cp310-cp310-win32.whl", hash = "sha256:3a5fe20a7b66e8135d7fd617b13272626a28278d0e578c98720d9ba4b2439d49", size = 2291213 },
{ url = "https://files.pythonhosted.org/packages/14/81/d0dff759a74ba87715509af9f6cb21fa21d93b02b3316ed43bda83664db9/pillow-11.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:b6123aa4a59d75f06e9dd3dac5bf8bc9aa383121bb3dd9a7a612e05eabc9961a", size = 2625725 },
{ url = "https://files.pythonhosted.org/packages/ce/1f/8d50c096a1d58ef0584ddc37e6f602828515219e9d2428e14ce50f5ecad1/pillow-11.1.0-cp310-cp310-win_arm64.whl", hash = "sha256:a76da0a31da6fcae4210aa94fd779c65c75786bc9af06289cd1c184451ef7a65", size = 2375213 },
{ url = "https://files.pythonhosted.org/packages/dd/d6/2000bfd8d5414fb70cbbe52c8332f2283ff30ed66a9cde42716c8ecbe22c/pillow-11.1.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:e06695e0326d05b06833b40b7ef477e475d0b1ba3a6d27da1bb48c23209bf457", size = 3229968 },
{ url = "https://files.pythonhosted.org/packages/d9/45/3fe487010dd9ce0a06adf9b8ff4f273cc0a44536e234b0fad3532a42c15b/pillow-11.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:96f82000e12f23e4f29346e42702b6ed9a2f2fea34a740dd5ffffcc8c539eb35", size = 3101806 },
{ url = "https://files.pythonhosted.org/packages/e3/72/776b3629c47d9d5f1c160113158a7a7ad177688d3a1159cd3b62ded5a33a/pillow-11.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a3cd561ded2cf2bbae44d4605837221b987c216cff94f49dfeed63488bb228d2", size = 4322283 },
{ url = "https://files.pythonhosted.org/packages/e4/c2/e25199e7e4e71d64eeb869f5b72c7ddec70e0a87926398785ab944d92375/pillow-11.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f189805c8be5ca5add39e6f899e6ce2ed824e65fb45f3c28cb2841911da19070", size = 4402945 },
{ url = "https://files.pythonhosted.org/packages/c1/ed/51d6136c9d5911f78632b1b86c45241c712c5a80ed7fa7f9120a5dff1eba/pillow-11.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:dd0052e9db3474df30433f83a71b9b23bd9e4ef1de13d92df21a52c0303b8ab6", size = 4361228 },
{ url = "https://files.pythonhosted.org/packages/48/a4/fbfe9d5581d7b111b28f1d8c2762dee92e9821bb209af9fa83c940e507a0/pillow-11.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:837060a8599b8f5d402e97197d4924f05a2e0d68756998345c829c33186217b1", size = 4484021 },
{ url = "https://files.pythonhosted.org/packages/39/db/0b3c1a5018117f3c1d4df671fb8e47d08937f27519e8614bbe86153b65a5/pillow-11.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:aa8dd43daa836b9a8128dbe7d923423e5ad86f50a7a14dc688194b7be5c0dea2", size = 4287449 },
{ url = "https://files.pythonhosted.org/packages/d9/58/bc128da7fea8c89fc85e09f773c4901e95b5936000e6f303222490c052f3/pillow-11.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0a2f91f8a8b367e7a57c6e91cd25af510168091fb89ec5146003e424e1558a96", size = 4419972 },
{ url = "https://files.pythonhosted.org/packages/5f/bb/58f34379bde9fe197f51841c5bbe8830c28bbb6d3801f16a83b8f2ad37df/pillow-11.1.0-cp311-cp311-win32.whl", hash = "sha256:c12fc111ef090845de2bb15009372175d76ac99969bdf31e2ce9b42e4b8cd88f", size = 2291201 },
{ url = "https://files.pythonhosted.org/packages/3a/c6/fce9255272bcf0c39e15abd2f8fd8429a954cf344469eaceb9d0d1366913/pillow-11.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:fbd43429d0d7ed6533b25fc993861b8fd512c42d04514a0dd6337fb3ccf22761", size = 2625686 },
{ url = "https://files.pythonhosted.org/packages/c8/52/8ba066d569d932365509054859f74f2a9abee273edcef5cd75e4bc3e831e/pillow-11.1.0-cp311-cp311-win_arm64.whl", hash = "sha256:f7955ecf5609dee9442cbface754f2c6e541d9e6eda87fad7f7a989b0bdb9d71", size = 2375194 },
{ url = "https://files.pythonhosted.org/packages/95/20/9ce6ed62c91c073fcaa23d216e68289e19d95fb8188b9fb7a63d36771db8/pillow-11.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2062ffb1d36544d42fcaa277b069c88b01bb7298f4efa06731a7fd6cc290b81a", size = 3226818 },
{ url = "https://files.pythonhosted.org/packages/b9/d8/f6004d98579a2596c098d1e30d10b248798cceff82d2b77aa914875bfea1/pillow-11.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a85b653980faad27e88b141348707ceeef8a1186f75ecc600c395dcac19f385b", size = 3101662 },
{ url = "https://files.pythonhosted.org/packages/08/d9/892e705f90051c7a2574d9f24579c9e100c828700d78a63239676f960b74/pillow-11.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9409c080586d1f683df3f184f20e36fb647f2e0bc3988094d4fd8c9f4eb1b3b3", size = 4329317 },
{ url = "https://files.pythonhosted.org/packages/8c/aa/7f29711f26680eab0bcd3ecdd6d23ed6bce180d82e3f6380fb7ae35fcf3b/pillow-11.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7fdadc077553621911f27ce206ffcbec7d3f8d7b50e0da39f10997e8e2bb7f6a", size = 4412999 },
{ url = "https://files.pythonhosted.org/packages/c8/c4/8f0fe3b9e0f7196f6d0bbb151f9fba323d72a41da068610c4c960b16632a/pillow-11.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:93a18841d09bcdd774dcdc308e4537e1f867b3dec059c131fde0327899734aa1", size = 4368819 },
{ url = "https://files.pythonhosted.org/packages/38/0d/84200ed6a871ce386ddc82904bfadc0c6b28b0c0ec78176871a4679e40b3/pillow-11.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:9aa9aeddeed452b2f616ff5507459e7bab436916ccb10961c4a382cd3e03f47f", size = 4496081 },
{ url = "https://files.pythonhosted.org/packages/84/9c/9bcd66f714d7e25b64118e3952d52841a4babc6d97b6d28e2261c52045d4/pillow-11.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3cdcdb0b896e981678eee140d882b70092dac83ac1cdf6b3a60e2216a73f2b91", size = 4296513 },
{ url = "https://files.pythonhosted.org/packages/db/61/ada2a226e22da011b45f7104c95ebda1b63dcbb0c378ad0f7c2a710f8fd2/pillow-11.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:36ba10b9cb413e7c7dfa3e189aba252deee0602c86c309799da5a74009ac7a1c", size = 4431298 },
{ url = "https://files.pythonhosted.org/packages/e7/c4/fc6e86750523f367923522014b821c11ebc5ad402e659d8c9d09b3c9d70c/pillow-11.1.0-cp312-cp312-win32.whl", hash = "sha256:cfd5cd998c2e36a862d0e27b2df63237e67273f2fc78f47445b14e73a810e7e6", size = 2291630 },
{ url = "https://files.pythonhosted.org/packages/08/5c/2104299949b9d504baf3f4d35f73dbd14ef31bbd1ddc2c1b66a5b7dfda44/pillow-11.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:a697cd8ba0383bba3d2d3ada02b34ed268cb548b369943cd349007730c92bddf", size = 2626369 },
{ url = "https://files.pythonhosted.org/packages/37/f3/9b18362206b244167c958984b57c7f70a0289bfb59a530dd8af5f699b910/pillow-11.1.0-cp312-cp312-win_arm64.whl", hash = "sha256:4dd43a78897793f60766563969442020e90eb7847463eca901e41ba186a7d4a5", size = 2375240 },
{ url = "https://files.pythonhosted.org/packages/fa/c5/389961578fb677b8b3244fcd934f720ed25a148b9a5cc81c91bdf59d8588/pillow-11.1.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:8c730dc3a83e5ac137fbc92dfcfe1511ce3b2b5d7578315b63dbbb76f7f51d90", size = 3198345 },
{ url = "https://files.pythonhosted.org/packages/c4/fa/803c0e50ffee74d4b965229e816af55276eac1d5806712de86f9371858fd/pillow-11.1.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:7d33d2fae0e8b170b6a6c57400e077412240f6f5bb2a342cf1ee512a787942bb", size = 3072938 },
{ url = "https://files.pythonhosted.org/packages/dc/67/2a3a5f8012b5d8c63fe53958ba906c1b1d0482ebed5618057ef4d22f8076/pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a8d65b38173085f24bc07f8b6c505cbb7418009fa1a1fcb111b1f4961814a442", size = 3400049 },
{ url = "https://files.pythonhosted.org/packages/e5/a0/514f0d317446c98c478d1872497eb92e7cde67003fed74f696441e647446/pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:015c6e863faa4779251436db398ae75051469f7c903b043a48f078e437656f83", size = 3422431 },
{ url = "https://files.pythonhosted.org/packages/cd/00/20f40a935514037b7d3f87adfc87d2c538430ea625b63b3af8c3f5578e72/pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:d44ff19eea13ae4acdaaab0179fa68c0c6f2f45d66a4d8ec1eda7d6cecbcc15f", size = 3446208 },
{ url = "https://files.pythonhosted.org/packages/28/3c/7de681727963043e093c72e6c3348411b0185eab3263100d4490234ba2f6/pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d3d8da4a631471dfaf94c10c85f5277b1f8e42ac42bade1ac67da4b4a7359b73", size = 3509746 },
{ url = "https://files.pythonhosted.org/packages/41/67/936f9814bdd74b2dfd4822f1f7725ab5d8ff4103919a1664eb4874c58b2f/pillow-11.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:4637b88343166249fe8aa94e7c4a62a180c4b3898283bb5d3d2fd5fe10d8e4e0", size = 2626353 },
]
[[package]]
name = "pluggy"
version = "1.5.0"
@ -656,6 +828,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928 },
]
[[package]]
name = "sniffio"
version = "1.3.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 },
]
[[package]]
name = "syrupy"
version = "4.8.1"
@ -688,7 +869,10 @@ version = "2.0.1"
source = { virtual = "." }
dependencies = [
{ name = "docker" },
{ name = "huggingface-hub" },
{ name = "numpy" },
{ name = "openai" },
{ name = "pillow" },
{ name = "pydantic" },
{ name = "pytest" },
{ name = "pytest-asyncio" },
@ -699,7 +883,10 @@ dependencies = [
[package.metadata]
requires-dist = [
{ name = "docker", specifier = ">=7" },
{ name = "huggingface-hub", specifier = ">=0.29" },
{ name = "numpy", specifier = ">=2.0" },
{ name = "openai", specifier = ">=1.65" },
{ name = "pillow", specifier = ">=11.1.0" },
{ name = "pydantic", specifier = ">2,<3" },
{ name = "pytest", specifier = ">=8.3.0" },
{ name = "pytest-asyncio", specifier = ">=0.23.1" },
@ -741,7 +928,7 @@ name = "tqdm"
version = "4.67.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "colorama", marker = "platform_system == 'Windows'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 }
wheels = [

View File

@ -97,11 +97,10 @@ fn get_config(
let filename = if !path.exists() {
// Assume it's a hub id
let mut builder = if let Ok(token) = std::env::var("HF_TOKEN") {
let mut builder = ApiBuilder::from_env();
if let Ok(token) = std::env::var("HF_TOKEN") {
// env variable has precedence over on file token.
ApiBuilder::new().with_token(Some(token))
} else {
ApiBuilder::new()
builder = builder.with_token(Some(token))
};
if let Ok(origin) = env::var("HF_HUB_USER_AGENT_ORIGIN") {
builder = builder.with_user_agent("origin", origin.as_str());
@ -152,7 +151,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
"flashdecoding"
};
match config.head_dim {
match config.get_head_dim() {
Some(h) if h == 64 || h == 128 || h == 256 => {
if lora_adapters.is_some() && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching because of lora adapters");
@ -214,6 +213,7 @@ struct RawConfig {
num_key_value_heads: Option<usize>,
num_hidden_layers: Option<usize>,
head_dim: Option<usize>,
text_config: Option<TextConfig>,
vision_config: Option<VisionConfig>,
is_encoder_decoder: Option<bool>,
#[serde(rename = "num_experts_per_tok")]
@ -233,6 +233,11 @@ struct QuantizationConfig {
#[derive(Debug, Deserialize)]
struct VisionConfig {}
#[derive(Debug, Deserialize)]
struct TextConfig {
head_dim: Option<usize>,
}
#[derive(Debug, Deserialize)]
struct Config {
max_position_embeddings: Option<usize>,
@ -244,6 +249,7 @@ struct Config {
intermediate_size: Option<usize>,
hidden_size: Option<usize>,
model_type: Option<String>,
text_config: Option<TextConfig>,
vision_config: Option<VisionConfig>,
is_encoder_decoder: bool,
num_experts_per_token: usize,
@ -253,6 +259,14 @@ struct Config {
}
impl Config {
fn get_head_dim(&self) -> Option<usize> {
self.head_dim.or_else(|| {
self.text_config
.as_ref()
.and_then(|text_config| text_config.head_dim)
})
}
fn flop(&self) -> Option<u64> {
if self.vision_config.is_some() {
// VLM are much harder to predict and VRAM requirements
@ -261,7 +275,7 @@ impl Config {
}
let num_heads = self.num_heads? as u64;
let num_kv_heads = self.num_kv_heads? as u64;
let head_dim = self.head_dim? as u64;
let head_dim = self.get_head_dim()? as u64;
let hidden_size = self.hidden_size? as u64;
let intermediate_size = (self.intermediate_size?
* (self.num_experts_per_token + self.num_shared_experts))
@ -289,7 +303,7 @@ impl Config {
}
// 2 for key and values
// 2 for f16 dtype?
Some(self.num_kv_heads? * 2 * self.head_dim? * 2 * self.num_layers?)
Some(self.num_kv_heads? * 2 * self.get_head_dim()? * 2 * self.num_layers?)
}
fn mlp_vram_per_tok(&self) -> Option<usize> {
@ -310,8 +324,8 @@ impl Config {
}
fn model_vram(&self) -> Option<usize> {
let attn_vram = (self.num_heads? + 2 * self.num_kv_heads?) * self.head_dim?;
let o_vram = self.num_heads? * self.head_dim? * self.hidden_size?;
let attn_vram = (self.num_heads? + 2 * self.num_kv_heads?) * self.get_head_dim()?;
let o_vram = self.num_heads? * self.get_head_dim()? * self.hidden_size?;
// gate + up + down = 3
let mlp_vram = 3 * self.intermediate_size? * self.num_experts * self.hidden_size?;
let layer_vram = mlp_vram + attn_vram + o_vram;
@ -349,6 +363,7 @@ impl From<RawConfig> for Config {
let num_kv_heads = other.num_key_value_heads.or(other.num_attention_heads);
let intermediate_size = other.intermediate_size;
let model_type = other.model_type;
let text_config = other.text_config;
let vision_config = other.vision_config;
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
let num_experts_per_token = other.num_experts_per_token.unwrap_or(1);
@ -360,6 +375,7 @@ impl From<RawConfig> for Config {
quantize,
head_dim,
model_type,
text_config,
vision_config,
is_encoder_decoder,
hidden_size,
@ -2067,6 +2083,7 @@ fn main() -> Result<(), LauncherError> {
let default_optimal = match config {
Some(ref config) => match config.model_type.as_deref() {
Some("qwen2_vl") | Some("qwen2_5_vl") => 10_000,
Some("gemma3") => 8000,
_ => 4096,
},
None => 4096,

View File

@ -1,4 +1,5 @@
{
stdenv,
dockerTools,
cacert,
text-generation-inference,
@ -11,13 +12,25 @@ in
build {
name = "tgi-docker";
tag = "latest";
compressor = "zstd";
config = {
EntryPoint = [ "${text-generation-inference}/bin/text-generation-inference" ];
Env = [
"HF_HOME=/data"
"PORT=80"
# The CUDA container toolkit will mount the driver shim into the
# container. We just have to ensure that the dynamic loader finds
# the libraries.
"LD_LIBRARY_PATH=/usr/lib64"
];
};
contents = [ cacert ];
extraCommands = ''
mkdir -p tmp
chmod -R 1777 tmp
'';
contents = [
cacert
stdenv.cc
];
}

View File

@ -16,8 +16,8 @@
grpcio-reflection,
grpcio-status,
grpcio-tools,
hf-kernels,
hf-transfer,
kernels,
loguru,
mamba-ssm,
moe,
@ -91,8 +91,8 @@ buildPythonPackage {
grpcio-reflection
grpcio-status
grpcio-tools
hf-kernels
hf-transfer
kernels
loguru
mamba-ssm
moe

700
router/src/chat.rs Normal file
View File

@ -0,0 +1,700 @@
use crate::{
infer::InferError, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
ChatCompletionLogprobs, CompletionType, DeltaToolCall, Function, FunctionDefinition,
StreamOptions, StreamResponse, TextMessage, ToolCallDelta, Usage,
};
use serde::Deserialize;
use serde_json::Value;
#[derive(Debug, Deserialize)]
struct ToolCall {
_name: String,
#[serde(flatten, default)]
/// Using Map to preserve order
arguments: serde_json::Map<String, Value>,
}
#[derive(Debug, Deserialize)]
struct Call {
function: ToolCall,
}
#[cfg_attr(test, derive(Debug))]
pub(crate) enum ChatEvent {
NoTool,
Events(Vec<CompletionType>),
}
#[cfg_attr(test, derive(Debug))]
pub(crate) enum ChatChoice {
NoTool,
ToolCalls(Vec<crate::ToolCall>),
}
pub(crate) fn parse_output(generated_text: &str) -> Result<ChatChoice, InferError> {
let call: Call = serde_json::from_str(generated_text).map_err(|e| {
InferError::ToolError(format!(
"Failed to parse generated text: {} {:?}",
e, generated_text
))
})?;
let name = call.function._name;
match &name[..] {
"no_tool" => {
// parse the content message
Ok(ChatChoice::NoTool)
}
name => {
let tool_calls = vec![crate::ToolCall {
id: "0".to_string(),
r#type: "function".to_string(),
function: FunctionDefinition {
description: None,
name: name.to_string(),
arguments: serde_json::to_value(call.function.arguments).map_err(|err| {
InferError::ToolError(format!(
"Could not convert arguments to JSON map {err}"
))
})?,
},
}];
Ok(ChatChoice::ToolCalls(tool_calls))
}
}
}
/// Convert a StreamResponse into an Event to be sent over SSE
fn create_event_from_stream_token(
stream_token: &StreamResponse,
logprobs: bool,
inner_using_tools: bool,
system_fingerprint: String,
model_id: String,
function_name: Option<String>,
id: String,
) -> CompletionType {
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let logprobs = logprobs.then(|| {
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens.clone()))
});
// replace the content with the tool calls if grammar is present
let content = if !stream_token.token.special {
Some(stream_token.token.text.clone())
} else {
None
};
let (content, tool_calls) = if inner_using_tools {
// Cast into a vec
(None, content)
} else {
(content, None)
};
let finish_reason = stream_token
.details
.as_ref()
.map(|details| details.finish_reason.format(true));
let delta = match (content, tool_calls) {
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
role: "assistant".to_string(),
content: delta,
..Default::default()
}),
(None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {
role: "assistant".to_string(),
tool_calls: vec![DeltaToolCall {
index: 0,
id,
r#type: "function".to_string(),
function: Function {
name: function_name,
arguments: tool_calls,
},
}],
}),
(None, None) => ChatCompletionDelta::Chat(TextMessage {
role: "assistant".to_string(),
content: "".to_string(),
..Default::default()
}),
};
let choices = vec![ChatCompletionChoice {
index: 0,
delta,
logprobs,
finish_reason,
}];
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
model_id,
system_fingerprint,
current_time,
choices,
None,
))
}
#[derive(Debug)]
enum StreamState {
/// Before the tools was parsed
Buffering,
/// We detected a tool call here
Tool,
/// This is without tool calling
Content,
}
pub struct ChatState {
state: StreamState,
text: String,
options: StreamOptions,
model_id: String,
fingerprint: String,
logprobs: bool,
id: String,
}
impl ChatState {
pub fn new(
using_tools: bool,
options: StreamOptions,
fingerprint: String,
model_id: String,
logprobs: bool,
id: String,
) -> Self {
let state = if using_tools {
StreamState::Buffering
} else {
StreamState::Content
};
let text = String::new();
Self {
state,
text,
options,
fingerprint,
model_id,
logprobs,
id,
}
}
pub fn push(&mut self, mut stream_token: StreamResponse) -> ChatEvent {
let mut events = vec![];
let token_text = &stream_token.token.text;
match self.state {
StreamState::Buffering => {
self.text.push_str(token_text);
tracing::info!("Current text {:?}", self.text);
let partial = &self.text;
let partial =
partial.trim_end_matches(|c: char| c.is_whitespace() || c == ',' || c == '}');
if let Ok(call) = serde_json::from_str::<Call>(&format!("{}}}}}", partial)) {
// This can be no_tool before the content has been emitted
if call.function._name != "no_tool" {
stream_token.token.text = "{".to_string();
let chat_complete = create_event_from_stream_token(
&stream_token,
self.logprobs,
true,
self.fingerprint.clone(),
self.model_id.clone(),
Some(call.function._name),
self.id.clone(),
);
events.push(chat_complete);
self.state = StreamState::Tool;
} else {
return ChatEvent::NoTool;
}
}
}
StreamState::Tool => {
self.text.push_str(token_text);
if serde_json::from_str::<Call>(&self.text).is_ok() {
self.state = StreamState::Buffering;
let mut text = stream_token.token.text.trim_end();
// Effectively trimming only the last closing brace
if text.ends_with('}') {
text = &text[..text.len() - 1];
}
stream_token.token.text = text.to_string();
let chat_complete = create_event_from_stream_token(
&stream_token,
self.logprobs,
true,
self.fingerprint.clone(),
self.model_id.clone(),
None,
self.id.clone(),
);
events.push(chat_complete);
} else {
let chat_complete = create_event_from_stream_token(
&stream_token,
self.logprobs,
true,
self.fingerprint.clone(),
self.model_id.clone(),
None,
self.id.clone(),
);
events.push(chat_complete);
}
}
StreamState::Content => {
let chat_complete = create_event_from_stream_token(
&stream_token,
self.logprobs,
false,
self.fingerprint.clone(),
self.model_id.clone(),
None,
self.id.clone(),
);
events.push(chat_complete);
}
}
if self.options.include_usage {
if let Some(details) = stream_token.details {
let completion_tokens = details.generated_tokens;
let prompt_tokens = details.input_length;
let total_tokens = prompt_tokens + completion_tokens;
let usage = Usage {
completion_tokens,
prompt_tokens,
total_tokens,
};
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk {
id: String::new(),
created: current_time,
model: self.model_id.clone(),
system_fingerprint: self.fingerprint.clone(),
choices: vec![],
usage: Some(Usage {
prompt_tokens: usage.prompt_tokens,
completion_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
}),
});
events.push(chat_complete);
}
}
ChatEvent::Events(events)
}
}
#[cfg(test)]
mod tests {
use crate::{
ChatCompletionChoice, ChatCompletionDelta, FinishReason, StreamDetails, TextMessage, Token,
};
use super::*;
fn get_tool_call_content(event: &CompletionType) -> (Option<&String>, &String) {
match event {
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
assert_eq!(choices.len(), 1);
if let ChatCompletionChoice {
delta: ChatCompletionDelta::Tool(ToolCallDelta { tool_calls, .. }),
..
} = &choices[0]
{
assert_eq!(tool_calls.len(), 1);
let DeltaToolCall {
index,
id,
r#type,
function,
} = &tool_calls[0];
assert_eq!(*index, 0);
assert_eq!(id, "0");
assert_eq!(r#type, "function");
(function.name.as_ref(), &function.arguments)
} else {
panic!("Expected plain message");
}
}
_ => panic!("Unexpected chunk"),
}
}
#[test]
fn test_chat_stream() {
let mut chat_state = ChatState::new(
false,
StreamOptions {
include_usage: false,
},
"fingerprint".to_string(),
"model_id".to_string(),
false,
"0".to_string(),
);
let events = chat_state.push(StreamResponse {
generated_text: None,
token: Token {
id: 42,
text: "Hi".to_string(),
logprob: 0.0,
special: false,
},
top_tokens: vec![],
index: 0,
details: None,
});
if let ChatEvent::Events(events) = events {
assert_eq!(events.len(), 1);
match &events[0] {
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
assert_eq!(
choices,
&[ChatCompletionChoice {
index: 0,
delta: ChatCompletionDelta::Chat(TextMessage {
role: "assistant".to_string(),
content: "Hi".to_string(),
tool_call_id: None,
}),
logprobs: None,
finish_reason: None,
}]
);
}
_ => panic!("Unexpected chunk"),
}
} else {
panic!("Expected chat events");
}
}
#[test]
fn test_chat_stream_usage() {
let mut chat_state = ChatState::new(
false,
StreamOptions {
include_usage: true,
},
"fingerprint".to_string(),
"model_id".to_string(),
false,
"0".to_string(),
);
let events = chat_state.push(StreamResponse {
generated_text: None,
token: Token {
id: 42,
text: "Hi".to_string(),
logprob: 0.0,
special: false,
},
top_tokens: vec![],
index: 0,
details: Some(StreamDetails {
input_length: 2,
generated_tokens: 10,
seed: None,
finish_reason: FinishReason::Length,
}),
});
if let ChatEvent::Events(events) = events {
assert_eq!(events.len(), 2);
match &events[0] {
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
assert_eq!(
choices,
&[ChatCompletionChoice {
index: 0,
delta: ChatCompletionDelta::Chat(TextMessage {
role: "assistant".to_string(),
content: "Hi".to_string(),
tool_call_id: None,
}),
logprobs: None,
// HAS A FINISH REASON
finish_reason: Some("length".to_string()),
}]
);
}
_ => panic!("Unexpected chunk"),
}
match &events[1] {
CompletionType::ChatCompletionChunk(ChatCompletionChunk { usage, .. }) => {
assert_eq!(
*usage,
Some(Usage {
prompt_tokens: 2,
completion_tokens: 10,
total_tokens: 12,
})
);
}
_ => panic!("Unexpected chunk"),
}
} else {
panic!("Expected chat events");
}
}
#[test]
fn test_chat_stream_tool_no_tool_simple() {
let mut chat_state = ChatState::new(
true,
StreamOptions {
include_usage: true,
},
"fingerprint".to_string(),
"model_id".to_string(),
false,
"0".to_string(),
);
let tokens = vec![
"{\"".to_string(),
"function".to_string(),
"\":".to_string(),
" {\"".to_string(),
"_".to_string(),
"name".to_string(),
"\":".to_string(),
" \"".to_string(),
"no".to_string(),
"_tool".to_string(),
"\",".to_string(),
" \"".to_string(),
"content".to_string(),
"\":".to_string(),
" \"".to_string(), // Token 14
"I".to_string(), // Event 1
" am".to_string(), // Event 2
" a".to_string(), // Event 3
" helpful".to_string(), // Event 4
" assistant".to_string(), // Event 5
"!\"".to_string(), // Event 6 (with trailing quore removed)
"}".to_string(),
"}".to_string(),
];
let tokens: Vec<_> = tokens
.into_iter()
.map(|text| StreamResponse {
generated_text: None,
token: Token {
id: 42,
text: text.to_string(),
logprob: 0.0,
special: false,
},
top_tokens: vec![],
index: 0,
details: None,
})
.collect();
// Initial ignored output
for token in &tokens[..10] {
let events = chat_state.push(token.clone());
if let ChatEvent::Events(events) = events {
assert_eq!(events.len(), 0, "{events:?}");
} else {
panic!("Expected chat events");
}
}
// No tool output
let events = chat_state.push(tokens[10].clone());
if let ChatEvent::NoTool = events {
assert!(true);
} else {
panic!("Expected chat events");
}
}
#[test]
fn test_chat_stream_tool_no_tool_empty() {
let mut chat_state = ChatState::new(
true,
StreamOptions {
include_usage: true,
},
"fingerprint".to_string(),
"model_id".to_string(),
false,
"0".to_string(),
);
let tokens = vec![
"{\"".to_string(),
"function".to_string(),
"\":".to_string(),
" {\"".to_string(),
"_".to_string(),
"name".to_string(),
"\":".to_string(),
" \"".to_string(),
"no".to_string(),
"_tool".to_string(),
"\",".to_string(),
" \"".to_string(),
"content".to_string(),
"\":\"".to_string(),
"\"}".to_string(), // Token 13
"}".to_string(), // Event 1
];
let tokens: Vec<_> = tokens
.into_iter()
.map(|text| StreamResponse {
generated_text: None,
token: Token {
id: 42,
text: text.to_string(),
logprob: 0.0,
special: false,
},
top_tokens: vec![],
index: 0,
details: None,
})
.collect();
// Initial ignored output
for token in &tokens[..10] {
let events = chat_state.push(token.clone());
if let ChatEvent::Events(events) = events {
assert_eq!(events.len(), 0, "{events:?}");
} else {
panic!("Expected chat events");
}
}
// No tool output
let events = chat_state.push(tokens[10].clone());
if let ChatEvent::NoTool = events {
assert!(true);
} else {
panic!("Expected chat events");
}
}
#[test]
fn test_chat_stream_tool_get_weather() {
let mut chat_state = ChatState::new(
true,
StreamOptions {
include_usage: true,
},
"fingerprint".to_string(),
"model_id".to_string(),
false,
"0".to_string(),
);
let tokens = vec![
"{\"".to_string(),
"function".to_string(),
"\":".to_string(),
" {\"".to_string(),
"_".to_string(),
"name".to_string(),
"\":".to_string(),
" \"".to_string(),
"get".to_string(),
"_current".to_string(),
"_weather".to_string(),
"\",".to_string(),
// Event 1 is the function name
// Event 2 is the start of the arguments "{"
" \"".to_string(), // Event 3
"location".to_string(), // Event 4
"\":".to_string(), // Event 5
" \"".to_string(), // Event 6
"San".to_string(), // Event 7
" Francisco".to_string(), // Event 8
",".to_string(), // Event 9
" CA".to_string(), // Event 10
"\",".to_string(), // Event 11
" \"".to_string(), // Event 12
"format".to_string(), // Event 13
"\":".to_string(), // Event 14
" \"".to_string(), // Event 15
"c".to_string(), // Event 16
"elsius".to_string(), // Event 17
"\"}}".to_string(), // Event 18 retained (trailing brace removed)
];
let tokens: Vec<_> = tokens
.into_iter()
.map(|text| StreamResponse {
generated_text: None,
token: Token {
id: 42,
text: text.to_string(),
logprob: 0.0,
special: false,
},
top_tokens: vec![],
index: 0,
details: None,
})
.collect();
// Initial ignored output
for token in &tokens[..11] {
let events = chat_state.push(token.clone());
if let ChatEvent::Events(events) = events {
assert_eq!(events.len(), 0, "{events:?}");
} else {
panic!("Expected chat events");
}
}
// No tool output
let mut output = String::new();
let mut output_name = String::new();
for token in &tokens[11..11 + 17] {
let events = chat_state.push(token.clone());
if let ChatEvent::Events(events) = events {
assert_eq!(events.len(), 1);
let (name, arguments) = get_tool_call_content(&events[0]);
if let Some(name) = name {
assert_eq!(name, "get_current_weather");
output_name.push_str(&name);
}
output.push_str(arguments);
} else {
panic!("Expected chat events");
}
}
assert_eq!(output_name, "get_current_weather");
assert_eq!(
output,
"{ \"location\": \"San Francisco, CA\", \"format\": \"celsius\"}"
);
// No tool finish
for token in &tokens[11 + 17..] {
let events = chat_state.push(token.clone());
if let ChatEvent::Events(events) = events {
assert_eq!(events.len(), 0, "{events:?}");
} else {
panic!("Expected chat events");
}
}
}
}

View File

@ -216,6 +216,19 @@ impl Qwen2_5Vl {
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Gemma3VisionConfig {
pub(crate) image_size: usize,
pub(crate) patch_size: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Gemma3 {
vision_config: Gemma3VisionConfig,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
@ -249,6 +262,8 @@ pub enum Config {
Paligemma(Paligemma),
Gemma,
Gemma2,
Gemma3(Gemma3),
Gemma3Text,
Cohere,
Drbx,
Falcon,

View File

@ -16,7 +16,7 @@ pub(crate) fn strftime_now(format_str: String) -> Result<String, minijinja::Erro
Ok(Local::now().format(&format_str).to_string())
}
#[derive(Clone)]
#[derive(Debug, Clone)]
pub(crate) struct ChatTemplate {
template: Template<'static, 'static>,
bos_token: Option<String>,
@ -33,7 +33,16 @@ impl ChatTemplate {
let mut env = Box::new(Environment::new());
// enable things like .strip() or .capitalize()
env.set_unknown_method_callback(pycompat::unknown_method_callback);
let template_str = template.into_boxed_str();
// TODO: replace with better solution
// hack to adjust gemma3 template for debug
// replace 'messages[0]['content'][0]['text']' with 'messages[0]['content']'
let mutated_template = template.replace(
"messages[0]['content'][0]['text']",
"messages[0]['content']",
);
let template_str = mutated_template.into_boxed_str();
env.add_function("raise_exception", raise_exception);
env.add_function("strftime_now", strftime_now);
tracing::debug!("Loading template: {}", template_str);
@ -123,8 +132,8 @@ mod tests {
use crate::infer::chat_template::{raise_exception, strftime_now};
use crate::infer::ChatTemplate;
use crate::{
ChatTemplateInputs, Message, MessageBody, MessageContent, TextMessage,
TokenizerConfigToken, Tool,
ChatTemplateInputs, Message, MessageBody, MessageChunk, MessageContent, TextMessage,
TokenizerConfigToken, Tool, Url,
};
use chrono::Local;
use minijinja::Environment;
@ -1230,4 +1239,98 @@ TOOL CALL ID: 0
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": \"{\\\"type\\\":\\\"object\\\",\\\"properties\\\":{\\\"location\\\":{\\\"type\\\":\\\"string\\\",\\\"description\\\":\\\"The city and state, e.g. San Francisco, CA\\\"},\\\"format\\\":{\\\"type\\\":\\\"string\\\",\\\"enum\\\":[\\\"celsius\\\",\\\"fahrenheit\\\"],\\\"description\\\":\\\"The temperature unit to use. Infer this from the users location.\\\"}},\\\"required\\\":[\\\"location\\\",\\\"format\\\"]}\",\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
assert_eq!(result.unwrap(), expected);
}
#[test]
fn test_chat_template_with_special_system_prompt() {
// chat template from gemma3
let ct = ChatTemplate::new(
r#"{{ bos_token }}
{%- if messages[0]['role'] == 'system' -%}
{%- set first_user_prefix = messages[0]['content'][0]['text'] + '
' -%}
{%- set loop_messages = messages[1:] -%}
{%- else -%}
{%- set first_user_prefix = "" -%}
{%- set loop_messages = messages -%}
{%- endif -%}
{%- for message in loop_messages -%}
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
{{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
{%- endif -%}
{%- if (message['role'] == 'assistant') -%}
{%- set role = "model" -%}
{%- else -%}
{%- set role = message['role'] -%}
{%- endif -%}
{{ '<start_of_turn>' + role + '
' + (first_user_prefix if loop.first else "") }}
{%- if message['content'] is string -%}
{{ message['content'] | trim }}
{%- elif message['content'] is iterable -%}
{%- for item in message['content'] -%}
{%- if item['type'] == 'image' -%}
{{ '<start_of_image>' }}
{%- elif item['type'] == 'text' -%}
{{ item['text'] | trim }}
{%- endif -%}
{%- endfor -%}
{%- else -%}
{{ raise_exception("Invalid content type") }}
{%- endif -%}
{{ '<end_of_turn>
' }}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{'<start_of_turn>model
'}}
{%- endif -%}
"#
.to_string(),
Some(TokenizerConfigToken::String("<bos>".to_string())),
Some(TokenizerConfigToken::String("</eos>".to_string())),
);
let msgs: Vec<Message> = vec![
Message {
name: None,
role: "system".to_string(),
body: MessageBody::Content {
content: MessageContent::MultipleChunks(vec![MessageChunk::Text {
text: "You are a helpful assistant.".to_string(),
}]),
},
},
Message {
name: None,
role: "user".to_string(),
body: MessageBody::Content {
content: MessageContent::MultipleChunks(vec![
MessageChunk::Text {
text: "I'm already using this supplement ".to_string(),
},
MessageChunk::ImageUrl {
image_url: Url {
url: "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_3018.JPG".to_string()
},
},
MessageChunk::Text {
text: "and I want to use this one too ".to_string()
},
MessageChunk::ImageUrl {
image_url: Url {
url: "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_3015.jpg".to_string()
},
},
MessageChunk::Text {
text: " what are cautions?".to_string()
},
]),
},
},
];
let result = ct.apply(msgs, None);
let expected = "<bos><start_of_turn>user\nYou are a helpful assistant.\n\nI'm already using this supplement ![](https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_3018.JPG)and I want to use this one too ![](https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_3015.jpg) what are cautions?<end_of_turn>\n<start_of_turn>model\n".to_string();
assert_eq!(result.unwrap(), expected);
}
}

View File

@ -52,7 +52,7 @@ pub struct Infer {
/// Request backend
backend: Arc<dyn Backend + Send + Sync>,
/// Chat template
chat_template: Option<ChatTemplate>,
pub(crate) chat_template: Option<ChatTemplate>,
/// Inference limit
limit_concurrent_requests: Arc<Semaphore>,
/// Backend health

View File

@ -40,13 +40,13 @@ impl ToolGrammar {
),
arguments: json!({
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "The response content",
}
},
"required": ["content"]
// "properties": {
// "content": {
// "type": "string",
// "description": "The response content",
// }
// },
// "required": ["content"]
}),
},
}))

View File

@ -8,6 +8,7 @@ pub mod validation;
mod kserve;
pub mod logging;
mod chat;
mod sagemaker;
pub mod usage_stats;
mod vertex;
@ -20,6 +21,7 @@ use serde::{Deserialize, Serialize};
use tokenizers::Encoding;
use tracing::warn;
use utoipa::ToSchema;
use uuid::Uuid;
use validation::Validation;
#[allow(clippy::large_enum_variant)]
@ -150,6 +152,11 @@ impl HubTokenizerConfig {
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ChatTemplateStandalone {
pub chat_template: ChatTemplateVersions,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
#[serde(untagged)]
pub enum TokenizerConfigToken {
@ -171,6 +178,7 @@ impl TokenizerConfigToken {
pub enum HubPreprocessorConfig {
Idefics2Processor(Idefics2Preprocessor),
Idefics3Processor(Idefics2Preprocessor),
Gemma3Processor(Gemma3Processor),
}
impl HubPreprocessorConfig {
@ -186,6 +194,12 @@ pub struct Idefics2Preprocessor {
do_image_splitting: bool,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Gemma3Processor {
#[serde(default)]
do_image_splitting: bool,
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct HubProcessorConfig {
pub chat_template: Option<ChatTemplateVersions>,
@ -541,6 +555,7 @@ pub(crate) struct Chunk {
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
#[cfg_attr(test, derive(Debug))]
pub(crate) struct ChatCompletion {
pub id: String,
#[schema(example = "1706270835")]
@ -553,6 +568,7 @@ pub(crate) struct ChatCompletion {
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
#[cfg_attr(test, derive(Debug))]
pub(crate) struct ChatCompletionComplete {
pub index: u32,
pub message: OutputMessage,
@ -561,6 +577,7 @@ pub(crate) struct ChatCompletionComplete {
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
#[cfg_attr(test, derive(Debug, PartialEq))]
pub(crate) struct ChatCompletionLogprobs {
content: Vec<ChatCompletionLogprob>,
}
@ -619,6 +636,7 @@ impl From<(Vec<Token>, Vec<Vec<Token>>)> for ChatCompletionLogprobs {
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
#[cfg_attr(test, derive(Debug, PartialEq))]
pub(crate) struct ChatCompletionLogprob {
token: String,
logprob: f32,
@ -626,12 +644,14 @@ pub(crate) struct ChatCompletionLogprob {
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
#[cfg_attr(test, derive(Debug, PartialEq))]
pub(crate) struct ChatCompletionTopLogprob {
token: String,
logprob: f32,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
#[cfg_attr(test, derive(Debug, PartialEq))]
pub(crate) struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
@ -640,6 +660,7 @@ pub(crate) struct Usage {
#[derive(Clone, Serialize, ToSchema)]
#[serde(tag = "object")]
#[cfg_attr(test, derive(Debug))]
enum CompletionType {
#[serde(rename = "chat.completion.chunk")]
ChatCompletionChunk(ChatCompletionChunk),
@ -707,6 +728,7 @@ impl ChatCompletion {
}
}
#[derive(Clone, Serialize, ToSchema)]
#[cfg_attr(test, derive(Debug))]
pub(crate) struct ChatCompletionChunk {
pub id: String,
#[schema(example = "1706270978")]
@ -719,6 +741,7 @@ pub(crate) struct ChatCompletionChunk {
}
#[derive(Clone, Serialize, ToSchema)]
#[cfg_attr(test, derive(Debug, PartialEq))]
pub(crate) struct ChatCompletionChoice {
pub index: u32,
pub delta: ChatCompletionDelta,
@ -735,6 +758,7 @@ pub struct ToolCallDelta {
#[derive(Clone, Debug, Serialize, ToSchema)]
#[serde(untagged)]
#[cfg_attr(test, derive(PartialEq))]
enum ChatCompletionDelta {
Chat(TextMessage),
Tool(ToolCallDelta),
@ -759,48 +783,17 @@ impl ChatCompletionChunk {
pub(crate) fn new(
model: String,
system_fingerprint: String,
delta: Option<String>,
tool_calls: Option<Vec<String>>,
created: u64,
logprobs: Option<ChatCompletionLogprobs>,
finish_reason: Option<String>,
choices: Vec<ChatCompletionChoice>,
usage: Option<Usage>,
) -> Self {
let delta = match (delta, tool_calls) {
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
role: "assistant".to_string(),
content: delta,
..Default::default()
}),
(None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {
role: "assistant".to_string(),
tool_calls: vec![DeltaToolCall {
index: 0,
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: None,
arguments: tool_calls[0].to_string(),
},
}],
}),
(None, None) => ChatCompletionDelta::Chat(TextMessage {
role: "assistant".to_string(),
content: "".to_string(),
..Default::default()
}),
};
Self {
id: String::new(),
created,
model,
system_fingerprint,
choices: vec![ChatCompletionChoice {
index: 0,
delta,
logprobs,
finish_reason,
}],
usage: None,
choices,
usage,
}
}
}
@ -915,7 +908,7 @@ pub(crate) struct ChatRequest {
/// Options for streaming response. Only set this when you set stream: true.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub stream_options: Option<StreamOptions>,
pub stream_options: StreamOptions,
}
impl ChatRequest {
@ -1015,13 +1008,37 @@ impl ChatRequest {
using_tools,
))
}
fn next_int_id(&self) -> Result<String, Box<dyn std::error::Error>> {
let mut id: usize = 0;
for message in &self.messages {
if let MessageBody::Tool { tool_calls } = &message.body {
for tool_call in tool_calls {
let new_id: usize = tool_call.id.parse()?;
id = std::cmp::max(id, new_id + 1);
}
}
}
Ok(id.to_string())
}
/// Try to have linearly increasing id
/// or resort to using Uuid if the initial
/// scheme is not understood
fn next_tool_call_id(&self) -> String {
self.next_int_id().unwrap_or_else(|_| {
let uid = Uuid::new_v4().to_string();
uid.to_string()
})
}
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]
#[derive(Clone, Deserialize, ToSchema, Serialize, Default)]
#[cfg_attr(test, derive(Debug, PartialEq))]
struct StreamOptions {
/// If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.
#[schema(example = "true")]
#[serde(default)]
include_usage: bool,
}
@ -1445,7 +1462,7 @@ pub(crate) struct ChatTokenizeResponse {
#[serde(transparent)]
pub(crate) struct TokenizeResponse(Vec<SimpleToken>);
#[derive(Serialize, ToSchema)]
#[derive(Serialize, ToSchema, Clone)]
pub(crate) struct StreamDetails {
#[schema(example = "length")]
pub finish_reason: FinishReason,
@ -1457,7 +1474,7 @@ pub(crate) struct StreamDetails {
pub input_length: u32,
}
#[derive(Serialize, ToSchema)]
#[derive(Serialize, ToSchema, Clone)]
pub(crate) struct StreamResponse {
pub index: u32,
pub token: Token,
@ -1700,9 +1717,25 @@ mod tests {
assert!(matches!(
request.stream_options,
Some(StreamOptions {
StreamOptions {
include_usage: true
})
}
));
let json = json!({
"model": "",
"messages": [{
"role": "user",
"content": "Hello"
}]
});
let request: ChatRequest = serde_json::from_str(json.to_string().as_str()).unwrap();
assert!(matches!(
request.stream_options,
StreamOptions {
include_usage: false
}
));
}

View File

@ -1,3 +1,4 @@
use crate::chat::{ChatChoice, ChatEvent, ChatState};
/// HTTP Server logic
use crate::config::Config;
use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse};
@ -47,8 +48,6 @@ use http::header::AUTHORIZATION;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use pyo3::prelude::*;
use pyo3::types::IntoPyDict;
use regex::Regex;
use serde_json::Value;
use std::convert::Infallible;
use std::fs::File;
use std::io::BufReader;
@ -1114,62 +1113,6 @@ pub(crate) async fn completions(
}
}
enum StreamState {
Buffering,
BufferTrailing,
Content { skip_close_quote: bool },
}
/// Convert a StreamResponse into an Event to be sent over SSE
fn create_event_from_stream_token(
stream_token: &StreamResponse,
logprobs: bool,
inner_using_tools: bool,
system_fingerprint: String,
model_id: String,
) -> Event {
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let logprobs = logprobs.then(|| {
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens.clone()))
});
// replace the content with the tool calls if grammar is present
let (content, tool_calls) = if inner_using_tools {
(None, Some(vec![stream_token.token.text.clone()]))
} else {
let content = if !stream_token.token.special {
Some(stream_token.token.text.clone())
} else {
None
};
(content, None)
};
let finish_reason = stream_token
.details
.as_ref()
.map(|details| details.finish_reason.format(true));
let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
content,
tool_calls,
current_time,
logprobs,
finish_reason,
));
event.json_data(chat_complete).unwrap_or_else(|e| {
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
Event::default()
})
}
/// Generate tokens
#[utoipa::path(
post,
@ -1208,7 +1151,7 @@ pub(crate) async fn chat_completions(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>,
Json(chat): Json<ChatRequest>,
Json(mut chat): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
metrics::counter!("tgi_request_count").increment(1);
@ -1219,8 +1162,11 @@ pub(crate) async fn chat_completions(
logprobs,
..
} = chat.clone();
tracing::debug!("Got chat_template {:?}", infer.chat_template);
let id = chat.next_tool_call_id();
let (generate_request, using_tools): (GenerateRequest, bool) =
chat.try_into_generate(&infer)?;
chat.clone().try_into_generate(&infer)?;
span.record("parameters", format!("{:?}", generate_request.parameters));
let logprobs = logprobs.unwrap_or_default();
@ -1232,167 +1178,41 @@ pub(crate) async fn chat_completions(
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
// switch on stream
if stream {
let (headers, response_stream) =
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
// regex to match any function name
let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
Ok(regex) => regex,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Failed to compile regex: {}", e),
error_type: "regex".to_string(),
}),
))
}
};
let (headers, response_stream) = generate_stream_internal(
infer.clone(),
compute_type.clone(),
Json(generate_request),
span.clone(),
)
.await;
let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream);
let mut buffer = Vec::new();
let mut json_buffer = String::new();
let mut state = if using_tools {
StreamState::Buffering
} else {
StreamState::Content {
skip_close_quote: false,
}
};
let mut response_as_tool = using_tools;
let mut state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
while let Some(result) = response_stream.next().await {
match result{
Ok(stream_token) => {
let token_text = &stream_token.token.text.clone();
let usage = stream_token.details.as_ref().map(|details| {
let completion_tokens = details.generated_tokens;
let prompt_tokens = details.input_length;
let total_tokens = prompt_tokens + completion_tokens;
Usage {
completion_tokens,
prompt_tokens,
total_tokens,
let events = state.push(stream_token);
match events{
ChatEvent::NoTool => {
chat.tools = None;
chat.response_format = None;
let (generate_request, using_tools): (GenerateRequest, bool) =
chat.clone().try_into_generate(&infer).unwrap();
assert!(!using_tools);
let (_headers, response_stream2) =
generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await;
state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
response_stream = Box::pin(response_stream2);
}
});
match state {
StreamState::Buffering => {
json_buffer.push_str(&token_text.replace(" ", ""));
buffer.push(stream_token);
if let Some(captures) = function_regex.captures(&json_buffer) {
let function_name = captures[1].to_string();
if function_name == "no_tool" {
state = StreamState::BufferTrailing;
response_as_tool = false;
buffer.clear();
json_buffer.clear();
} else {
state = StreamState::Content {
skip_close_quote: false,
};
// send all the buffered messages
for stream_token in &buffer {
let event = create_event_from_stream_token(
stream_token,
logprobs,
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
}
}
}
}
// if we skipped sending the buffer we need to avoid sending the following json key and quotes
StreamState::BufferTrailing => {
let infix_text = "\"content\":\"";
json_buffer.push_str(&token_text.replace(" ", ""));
// keep capturing until we find the infix text
match json_buffer.find(infix_text) {
Some(content_key_index) => {
json_buffer =
json_buffer[content_key_index + infix_text.len()..].to_string();
}
None => {
continue;
}
}
// if there is leftover text after removing the infix text, we need to send it
if !json_buffer.is_empty() {
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let chat_complete =
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
Some(json_buffer.clone()),
None,
current_time,
None,
None,
));
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
InferError::StreamSerializationError(e.to_string()).into()
ChatEvent::Events(events) => {
for chat_complete in events{
yield Ok(Event::default().json_data(chat_complete).unwrap_or_else(|e| {
tracing::error!("Failed to serialize ChatCompletionChunk: {:?}", e);
Event::default()
}));
}
// cleanup the buffers
buffer.clear();
json_buffer.clear();
state = StreamState::Content {
skip_close_quote: true,
};
}
StreamState::Content { skip_close_quote } => {
if skip_close_quote && token_text.contains('"') {
break;
}
// send the content
let event = create_event_from_stream_token(
&stream_token,
logprobs,
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
}
}
let should_send_usage = usage.is_some()
&& stream_options
.as_ref()
.is_some_and(|opts| opts.include_usage);
if should_send_usage {
let usage_data = usage.unwrap();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk {
id: String::new(),
created: current_time,
model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(),
choices: vec![],
usage: Some(Usage {
prompt_tokens: usage_data.prompt_tokens,
completion_tokens: usage_data.completion_tokens,
total_tokens: usage_data.total_tokens,
}),
});
yield Ok(Event::default()
.json_data(chat_complete)
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into()));
}
}
Err(err) => yield Ok(err.into_openai_event())
@ -1404,8 +1224,13 @@ pub(crate) async fn chat_completions(
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response())
} else {
let (headers, input_length, Json(generation)) =
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
let (mut headers, mut input_length, Json(generation)) = generate_internal(
Extension(infer.clone()),
compute_type.clone(),
Json(generate_request),
span.clone(),
)
.await?;
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
@ -1413,55 +1238,25 @@ pub(crate) async fn chat_completions(
.as_secs();
let (tool_calls, output) = if using_tools {
let gen_text_value: Value =
serde_json::from_str(&generation.generated_text).map_err(|e| {
InferError::ToolError(format!(
"Failed to parse generated text: {} {:?}",
e, generation.generated_text
))
})?;
let function = gen_text_value.get("function").ok_or(InferError::ToolError(
"No function found in generated text".to_string(),
))?;
let name = function
.get("_name")
.and_then(Value::as_str)
.ok_or(InferError::ToolError(
"No _name found in generated text".to_string(),
))?
.to_string();
let mut arguments = function.clone();
if let Value::Object(ref mut props) = arguments {
props.remove("_name");
}
match name.as_str() {
"no_tool" => {
// parse the content message
let content_message = arguments
.get("content")
.and_then(Value::as_str)
.ok_or_else(|| {
InferError::ToolError(
"No `content` found in generated text".to_string(),
)
})?
.to_string();
(None, Some(content_message))
}
_ => {
let tool_calls = vec![ToolCall {
id: "0".to_string(),
r#type: "function".to_string(),
function: FunctionDefinition {
description: None,
name,
arguments,
},
}];
(Some(tool_calls), None)
match crate::chat::parse_output(&generation.generated_text)? {
ChatChoice::NoTool => {
chat.tools = None;
chat.response_format = None;
let (generate_request, using_tools): (GenerateRequest, bool) =
chat.clone().try_into_generate(&infer)?;
assert!(!using_tools);
let (headers_final, input_length_final, Json(generation)) = generate_internal(
Extension(infer),
compute_type,
Json(generate_request),
span,
)
.await?;
headers = headers_final;
input_length = input_length_final;
(None, Some(generation.generated_text))
}
ChatChoice::ToolCalls(tool_calls) => (Some(tool_calls), None),
}
} else {
(None, Some(generation.generated_text))
@ -1727,7 +1522,7 @@ pub async fn run(
// Shared API builder initialization
let api_builder = || {
let mut builder = ApiBuilder::new().with_progress(false);
let mut builder = ApiBuilder::from_env().with_progress(false);
if let Some(token) = authorization_token {
builder = builder.with_token(Some(token));
}
@ -1781,6 +1576,7 @@ pub async fn run(
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
chat_template_filename,
model_info,
) = match api {
Type::None => (
@ -1788,6 +1584,7 @@ pub async fn run(
Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")),
Some(local_path.join("chat_template.json")),
None,
),
Type::Api(api) => {
@ -1801,6 +1598,7 @@ pub async fn run(
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
let chat_template_filename = api_repo.get("chat_template.json").await.ok();
let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {
Some(model_info)
@ -1813,10 +1611,12 @@ pub async fn run(
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
chat_template_filename,
model_info,
)
}
Type::Cache(cache) => {
tracing::info!("Cache {cache:?}");
let repo = cache.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
@ -1827,23 +1627,41 @@ pub async fn run(
repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"),
repo.get("processor_config.json"),
repo.get("chat_template.json"),
None,
)
}
};
// if chat_template_filename is present, load the chat template
let chat_template: Option<crate::ChatTemplateVersions> = chat_template_filename
.and_then(|f| std::fs::read_to_string(f).ok())
.and_then(|c| {
let res = serde_json::from_str::<crate::ChatTemplateStandalone>(&c);
if let Err(e) = &res {
tracing::warn!("Could not parse chat template {e:?}");
}
res.ok().map(|t| t.chat_template)
});
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
tracing::warn!("Tokenizer_config {tokenizer_config_path:?} - {tokenizer_config_filename:?}");
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{
HubTokenizerConfig::from_file(filename)
} else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
};
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
let mut tokenizer_config = tokenizer_config.unwrap_or_else(|| {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
});
if chat_template.is_some() {
tracing::info!("Using chat template from chat_template.json");
tokenizer_config.chat_template = chat_template;
}
let tokenizer: Result<Tokenizer, WebServerError> = {
use pyo3::prelude::*;
Python::with_gil(|py| -> PyResult<()> {

View File

@ -18,6 +18,7 @@ use std::sync::Arc;
use thiserror::Error;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tracing::warn;
use tracing::{instrument, Span};
use {once_cell::sync::Lazy, regex::Regex};
@ -694,6 +695,14 @@ fn image_tokens(
"<|vision_start|>{:?}<|vision_end|>",
"<|image_pad|>".repeat(config.get_number_of_features(height, width))
),
Gemma3(_config) => {
// TODO: prefer using the config to determine the number of features
let num_mm_soft_tokens_per_image = 256;
format!(
"\n\n<start_of_image>{}<end_of_image>\n\n",
"<image_soft_token>".repeat(num_mm_soft_tokens_per_image)
)
}
_ => unimplemented!("Images tokens are not supported for this model configuration"),
}
}
@ -721,8 +730,8 @@ fn prepare_input<T: TokenizerTrait>(
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config {
Some(
config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Paligemma(_) | LlavaNext(_)
| Qwen2Vl(_) | Qwen2_5Vl(_)),
config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Gemma3(_) | Paligemma(_)
| LlavaNext(_) | Qwen2Vl(_) | Qwen2_5Vl(_)),
) => {
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());

View File

@ -39,6 +39,7 @@ install: install-cuda
install-cuda: install-server install-flash-attention-v2-cuda install-flash-attention
uv pip install -e ".[attention,bnb,marlin,moe]"
uv pip install nvidia-nccl-cu12==2.22.3
kernels download .
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm

File diff suppressed because it is too large Load Diff

272
server/kernels.lock Normal file
View File

@ -0,0 +1,272 @@
[
{
"repo_id": "kernels-community/paged-attention",
"sha": "331b7e63a6b592799c8bc992f681bb1ee2c865a2",
"variants": {
"torch25-cxx11-cu118-x86_64-linux": {
"hash": "sha256-8e0aa39abab82f1d21b661d35e0470a24c3ebbdda38532ded805c18037a1ad1e",
"hash_type": "git_lfs_concat"
},
"torch25-cxx11-cu121-x86_64-linux": {
"hash": "sha256-b0c3aef6c4c9aac627975cb1a2bfc46a70390763c8165575b89d1651d007c38a",
"hash_type": "git_lfs_concat"
},
"torch25-cxx11-cu124-x86_64-linux": {
"hash": "sha256-960fbc8998439d779adb47fb2a37cce68c7dc075d8a49893bd487be9ca2d1389",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu118-x86_64-linux": {
"hash": "sha256-9d6d60c411c55aa2f9d7c681c2be96f4262d56c96f592f3d4fb35ce4f4f1e18e",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu121-x86_64-linux": {
"hash": "sha256-98c0a305b2cc9b7be757fab923d9aa406c686dcd0460e462926f87d051ef3d19",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu124-x86_64-linux": {
"hash": "sha256-71e586416213c96ffbdeae0d077ba97bfde5b00005f2746d4cba2320cb53bf87",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu118-x86_64-linux": {
"hash": "sha256-2f559312c54d558b33a4082ffc3fcf923f51da40ced19bfc8920e998ba2b71bf",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu124-x86_64-linux": {
"hash": "sha256-6033b41a0f8a9509887c6171f0b42d9aa738490903b3fd5ea2c52703c5fb8fc3",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu126-x86_64-linux": {
"hash": "sha256-3139f66a53f2bf0c314b4d309893095746bdc9c3914c904fc31adfdf553ed219",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu118-x86_64-linux": {
"hash": "sha256-2173d77e384d8e2881fc38603992c09e8be7bcd9da4cafdd4f2a5ce0ce22caf4",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu124-x86_64-linux": {
"hash": "sha256-7b1aaef81e01ecce83e03c50872910680ff2953f7c6ffd3ff15e8d9497ca9239",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu126-x86_64-linux": {
"hash": "sha256-818b160a88b12b8e871099e40f76aa436ee828e2e060ecc35502dbe34a6ebd3b",
"hash_type": "git_lfs_concat"
}
}
},
{
"repo_id": "kernels-community/moe",
"sha": "605a216f507b9a97b543140dee8937a4622069a8",
"variants": {
"torch25-cxx11-cu118-x86_64-linux": {
"hash": "sha256-855d92f02be3bfba0758161fa1266159d76c172e7c5d43d30816d22cfba76074",
"hash_type": "git_lfs_concat"
},
"torch25-cxx11-cu121-x86_64-linux": {
"hash": "sha256-e6e780230477bbbc26fc40cc7fcff50298155998af4fc77a026c9f815ec984b1",
"hash_type": "git_lfs_concat"
},
"torch25-cxx11-cu124-x86_64-linux": {
"hash": "sha256-52c1fb337033c4d1d7a279c5cb28aebbc7389976f21dc5803aeb16b2f7aeb94c",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu118-x86_64-linux": {
"hash": "sha256-1fb654e8d02dda2a2382d1fb3a3ca9738d292eea674b30b80030cdcdfb6a0035",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu121-x86_64-linux": {
"hash": "sha256-0cf235f1de85d4ce7490c79aa64220f608f886f313b676d91c331a6a2fd67bbb",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu124-x86_64-linux": {
"hash": "sha256-3def11fee9bf1ea9b1579206fd5f5ecbcaad47ac478e2c3aa7b2c9c7fd5db934",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu118-x86_64-linux": {
"hash": "sha256-3a49ee03f675190a79c7c74a45cc403d491eceb63a943f47d52064a11ca6ef6f",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu124-x86_64-linux": {
"hash": "sha256-dbf20cb11db7d53e11147ab13641eefaa235f9ac2fde1beaf8f56f850c11bd54",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu126-x86_64-linux": {
"hash": "sha256-8a07232ab316e8eab74747662cb7b86aac03f44ff158f275768fd59390df2525",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu118-x86_64-linux": {
"hash": "sha256-cdd46301af997eeace5e016d8590969981b3a3f8647828d04baa5fa10c696746",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu124-x86_64-linux": {
"hash": "sha256-c865188e9d2c17f3358f3d343fb40340232457572744bf85efd6b20af545d5f3",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu126-x86_64-linux": {
"hash": "sha256-2a8b09f3272ea80491e78a39ff886680471d99f7ba571581809adfe918013898",
"hash_type": "git_lfs_concat"
}
}
},
{
"repo_id": "kernels-community/quantization",
"sha": "95272c71ca71b1ddbacb0105dab54e5d5240bd5c",
"variants": {
"torch25-cxx11-cu118-x86_64-linux": {
"hash": "sha256-2d0a274cf0117bf7880d6040adafa1b70fe8bff3a00ef2834ed5435a6b525a49",
"hash_type": "git_lfs_concat"
},
"torch25-cxx11-cu121-x86_64-linux": {
"hash": "sha256-116458beac63ea5eeb1e7fba7edc68d160cd8ac28f55b926d79035551aac7d5f",
"hash_type": "git_lfs_concat"
},
"torch25-cxx11-cu124-x86_64-linux": {
"hash": "sha256-cace644c6fb04470384796c18987135cb051dfb90a14e902c51a3786fc07c599",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu118-x86_64-linux": {
"hash": "sha256-104c6961cd3e1a74efdf14ea2172acc6647846852fccafe3698a27a6cf37941d",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu121-x86_64-linux": {
"hash": "sha256-cdc95b41aa91a803f11f8cd53001895c2b69550b5af2fb278d6f124381229d0b",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu124-x86_64-linux": {
"hash": "sha256-d5388469cb6074f196f20b1e1e4805bb3c967a8147b31ca2c0461aa87b50604e",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu118-x86_64-linux": {
"hash": "sha256-70c4bb3792c4c3207d4963173d8d0ef3b2bda677151aef140662dd87bfa1b69f",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu124-x86_64-linux": {
"hash": "sha256-bcacbb2232f49345f27e07fa821b48a7e3df643c01af37281fcafc74c471f682",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu126-x86_64-linux": {
"hash": "sha256-344d20964f7eb133e5ec6fda976fa5ee62807b739a4361f236aca5ae53beb9ac",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu118-x86_64-linux": {
"hash": "sha256-dfaec226550254fbce1a5c7e2f547e85700958a1a4087e1c873d22e6f71a5ceb",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu124-x86_64-linux": {
"hash": "sha256-0abe6460d0a2202b0086e3663092595e5b93b9a9cbb85c10034180cc9bfebc6e",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu126-x86_64-linux": {
"hash": "sha256-68e156f94c3c0c9523773b62eaeced93766e0d9ee67d8191fb9570fb5af30d5b",
"hash_type": "git_lfs_concat"
}
}
},
{
"repo_id": "kernels-community/quantization-eetq",
"sha": "a80ce846d6270ddddeee109523ed947f594f246b",
"variants": {
"torch25-cxx11-cu118-x86_64-linux": {
"hash": "sha256-e06beb00799b1e656583eb0496f09fc0bf1b26f75e9864a2fe19ebd5b62c3671",
"hash_type": "git_lfs_concat"
},
"torch25-cxx11-cu121-x86_64-linux": {
"hash": "sha256-c128d3ef6558cfedf045c4a713891792708851b7f6f027de835d9083cb3b297d",
"hash_type": "git_lfs_concat"
},
"torch25-cxx11-cu124-x86_64-linux": {
"hash": "sha256-c7e2e14fc114788634b34a4f670f7bf4d27321e5ed40ff446f5a25eef70222c7",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu118-x86_64-linux": {
"hash": "sha256-58dad53cfbf1315af464f9d8ba7be9012089c839d4f06a8d2cf8ce0deaf5949a",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu121-x86_64-linux": {
"hash": "sha256-6519af49c0f689744a7b49497ad2bea1524b69e4095446087d7ab622b898aa30",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu124-x86_64-linux": {
"hash": "sha256-94e0731b58a9ba0e5e2f37b100c8d987c80b5d349008ef625917d020b6c52d25",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu118-x86_64-linux": {
"hash": "sha256-e5b04475538f49d7b4ffded080e4c9c86a658abc12667e3838ebcc410ab1eef4",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu124-x86_64-linux": {
"hash": "sha256-783c02db737a6ec9958b3090f164b87888d3b26e30a4fb6e1cd0c1a635753fab",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu126-x86_64-linux": {
"hash": "sha256-a3d81f82f9cfe9d8a6d46758758b3a1b3055d902f41917b4ef2976373db843d6",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu118-x86_64-linux": {
"hash": "sha256-f1de67e17944a9816f778c72ae73bbbc90d795cb4885c2f9ee5e0b9a3c57583b",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu124-x86_64-linux": {
"hash": "sha256-789b50d767a5121a7e5a52eaf0c8e897bf1787f049ca08faffb220e5053a5f10",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu126-x86_64-linux": {
"hash": "sha256-7c7fe57fea7b9be253085d506f01b2487b2306f22bdffe1de44397fc9f8a3613",
"hash_type": "git_lfs_concat"
}
}
},
{
"repo_id": "kernels-community/rotary",
"sha": "4db658e027ec752840bb3f557ee076413b8db03f",
"variants": {
"torch25-cxx11-cu118-x86_64-linux": {
"hash": "sha256-907df2035267a65793985bb7f69fb2a975955fb08c2bbc78c58def43d02801da",
"hash_type": "git_lfs_concat"
},
"torch25-cxx11-cu121-x86_64-linux": {
"hash": "sha256-b614735ae61ee2c1825a3c823fa0cdd3aa07d0bb3f4106001b9e1a557c0ca9b9",
"hash_type": "git_lfs_concat"
},
"torch25-cxx11-cu124-x86_64-linux": {
"hash": "sha256-f2e98ec72faaebc1cae25f83ccdbb151868b6902fb5a0623e09d700a514c2a7e",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu118-x86_64-linux": {
"hash": "sha256-421214c5a576fac2e0b7998395dccd7f66010f65a6fc647ce06b106ea91105d2",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu121-x86_64-linux": {
"hash": "sha256-9d1c464cf7f391975afa48f2254a639f41582155ad1b50c25bb122418ce8db58",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu124-x86_64-linux": {
"hash": "sha256-82f8012d78304efaa7318f106907630294d10c8b5c9f56923c71df0b03e09f14",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu118-x86_64-linux": {
"hash": "sha256-a3247919dcc392efc7e54725dfbce9ee8a796fe4ee53d113048b313de074d3da",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu124-x86_64-linux": {
"hash": "sha256-a21c9734d15946f4cc967d0555d45d7effc6624990c6889fc49162af744fbbe9",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu126-x86_64-linux": {
"hash": "sha256-01cdda160425b29db0d9bb084874ade4ac081735f9717f272aaefe5bcb379ae1",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu118-x86_64-linux": {
"hash": "sha256-17be5b770418ad47101c49d8945b5aa32af9eb5a840bdffb0514d0e264edd860",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu124-x86_64-linux": {
"hash": "sha256-3cd4b9f63cc903e01325b7e5b204e40fc6600c0685f2e19e6f1fa604a599d82d",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu126-x86_64-linux": {
"hash": "sha256-c569f4a4f9b64792507c58d7cfa31dde1285b52125ef07cc98d9f23636af09ca",
"hash_type": "git_lfs_concat"
}
}
}
]

View File

@ -14,7 +14,7 @@ dependencies = [
"grpcio>=1.67.0",
"grpcio-reflection>=1.67.0",
"grpcio-status>=1.67.0",
"hf-kernels>=0.1.5",
"kernels>=0.2.1",
"hf-transfer>=0.1.8",
"loguru>=0.7.3",
"numpy>=1.26,<3",
@ -36,7 +36,7 @@ dependencies = [
]
[build-system]
requires = ["hf-kernels>=0.1.2", "setuptools"]
requires = ["kernels>=0.1.7", "setuptools"]
build-backend = "setuptools.build_meta"
[tool.kernels.dependencies]

View File

@ -205,7 +205,6 @@ class LoraWeights(AdapterWeights):
lora_a_list = [None] * nlayers
lora_b_list = [None] * nlayers
# import ipdb; ipdb.set_trace()
for layer_id in range(nlayers):
key = (layer_id, layer_type)
if key not in target_to_layer:

View File

@ -38,6 +38,7 @@ def paged_attention(
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
window_size_left: Optional[int] = -1,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
@ -79,12 +80,15 @@ def paged_attention(
sm_scale=softmax_scale,
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
window_left=window_size_left,
)
elif ATTENTION == "flashdecoding":
max_q = 1
max_k = max_s
import flash_attn_2_cuda
window_size_right = -1 if window_size_left == -1 else 0
# TODO fixme when flash contains the fix.
# Number of splits is not correctly handled
# by the current path
@ -109,8 +113,8 @@ def paged_attention(
softmax_scale,
False, # zero_tensors
True, # causal
-1, # Window_left
-1, # Window right
window_size_left, # Window_left
window_size_right, # Window right
softcap,
False, # return softmax
None, # generator
@ -253,6 +257,7 @@ def attention(
sm_scale=softmax_scale,
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
window_left=window_size_left,
)
# If we are using flashdecoding or paged, we always use flash-attn for

View File

@ -52,7 +52,6 @@ def use_prefill_with_paged_kv_state(
page_size: int,
kv_dtype: torch.dtype,
q_dtype: torch.dtype,
window_left: int,
):
"""
Context manager to set the active flashinfer prefill state to the given
@ -95,7 +94,6 @@ def use_prefill_with_paged_kv_state(
kv_data_type=kv_dtype,
q_data_type=q_dtype,
page_size=page_size,
window_left=-1 if window_left is None else window_left,
)
yield
finally:
@ -172,7 +170,6 @@ def use_decode_state(
page_size: int,
kv_cache_dtype: torch.dtype,
q_dtype: torch.dtype,
window_left: int,
):
"""
Context manager to set the active flashinfer decoding state to the given
@ -209,7 +206,6 @@ def use_decode_state(
page_size=page_size,
data_type=kv_cache_dtype,
q_data_type=q_dtype,
window_left=-1 if window_left is None else window_left,
)
yield
finally:

View File

@ -78,6 +78,7 @@ def paged_attention(
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
window_size_left: Optional[int] = -1,
):
if softcap is not None:
raise NotImplementedError("softcap is not available in IPEX")

View File

@ -59,6 +59,7 @@ def paged_attention(
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
window_size_left: Optional[int] = -1,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
@ -82,6 +83,8 @@ def paged_attention(
max_k = max_s
import flash_attn_2_cuda
window_size_right = -1 if window_size_left == -1 else 0
if softcap is None:
softcap = 0.0
out = flash_attn_2_cuda.varlen_fwd(
@ -101,8 +104,8 @@ def paged_attention(
softmax_scale,
False, # zero_tensors
True, # causal
-1, # Window_left
-1, # Window right
window_size_left, # Window_left
window_size_right, # Window right
softcap,
False, # return softmax
None, # generator

View File

@ -106,6 +106,17 @@ try:
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gemma3_modeling import (
FlashGemma3ForCausalLM,
Gemma3ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.gemma3.processing_gemma3 import (
Gemma3Processor,
)
from text_generation_server.models.custom_modeling.gemma3.configuration_gemma3 import (
Gemma3Config,
Gemma3TextConfig,
)
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
FlashDbrxForCausalLM,
DbrxConfig,
@ -258,6 +269,16 @@ class ModelType(enum.Enum):
"name": "Gemma2",
"url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
}
GEMMA3 = {
"type": "gemma3",
"name": "Gemma3",
"url": "https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d",
}
GEMMA3_TEXT = {
"type": "gemma3_text",
"name": "Gemma3 Text",
"url": "https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d",
}
COHERE = {
"type": "cohere",
"name": "Cohere",
@ -1094,6 +1115,83 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == GEMMA3_TEXT:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGemma3ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# TODO: once implemented in transformers, use the config class
# and processor class from there.
config_class=Gemma3TextConfig,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma3"))
else:
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == GEMMA3:
if FLASH_ATTENTION:
# TODO: Use VlmCausalLM when image support is added.
return VlmCausalLM(
model_id=model_id,
model_class=Gemma3ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# TODO: once implemented in transformers, use the config class
# and processor class from there.
config_class=Gemma3Config,
processor_class=Gemma3Processor,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma3"))
else:
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == COHERE:
if FLASH_ATTENTION:

View File

@ -287,6 +287,7 @@ class FlashGemma2Attention(torch.nn.Module):
max_s,
softcap=self.softcap,
kv_scales=self.kv_scales,
window_size_left=self.window_size,
)
return self.o_proj(

View File

@ -0,0 +1,902 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from typing import Optional, List, Tuple
import copy
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
get_linear,
#
SpeculativeHead,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
import torch
import torch.nn.functional as F
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.utils.weights import UnquantizedWeight
from transformers.activations import ACT2FN
from text_generation_server.layers.attention import (
paged_attention,
attention,
Seqlen,
)
ATTENTION_TYPE_GLOBAL = "global"
ATTENTION_TYPE_LOCAL = "local_sliding"
class Gemma3FastRMSNorm(FastRMSNorm):
@classmethod
def load(cls, prefix: str, weights, eps=1e-6):
dtype = weights.dtype
weights.dtype = torch.float32
weight = weights.get_tensor(f"{prefix}.weight") + 1
weights.dtype = dtype
new = cls(weight, eps)
new.dtype = dtype
return new
# perform the multiplication in full precision and downcast after
def forward(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = hidden_states * self.weight
return hidden_states.to(self.dtype), residual
def load_attention(config, prefix: str, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
def _load_gqa(config, prefix: str, weights):
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
)
if isinstance(weight, UnquantizedWeight):
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.head_dim
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias=None))
class FlashGemma3Attention(torch.nn.Module):
def __init__(
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
):
super().__init__()
self.num_heads = config.num_attention_heads
self.head_size = config.head_dim
self.causal = causal
if is_sliding:
self.window_size = config.sliding_window
# TODO: remove this hack to support local sliding window
config = copy.deepcopy(config)
config.rope_scaling = dict(rope_type="default")
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=config.head_dim,
base=config.rope_local_base_freq,
device=weights.device,
)
else:
self.window_size = -1
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=config.head_dim,
base=config.rope_theta,
device=weights.device,
)
self.softmax_scale = (
config.query_pre_attn_scalar**-0.5
if config.query_pre_attn_scalar is not None
else None
)
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.softcap = None # config.attn_logit_softcapping
query_key_value = load_attention(config, prefix, weights)
self.query_key_value = TensorParallelMultiAdapterLinear.load(
query_key_value,
layer_id,
["q_proj", "k_proj", "v_proj"],
sizes=[
self.head_size * config.num_attention_heads,
self.head_size * config.num_key_value_heads,
self.head_size * config.num_key_value_heads,
],
process_group=weights.process_group,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
layer_id,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
self.q_norm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps
)
self.k_norm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps
)
self.enable_gqa = self.num_heads != self.num_key_value_heads
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
attention_mask,
):
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
kv = kv.view(-1, 2, self.num_key_value_heads * self.head_size)
key = kv[:, 0]
value = kv[:, 1]
query = query.reshape(-1, self.head_size)
key = key.reshape(-1, self.head_size)
query, _ = self.q_norm(query.contiguous())
key, _ = self.k_norm(key.contiguous())
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_key_value_heads, self.head_size)
value = value.view(-1, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, key, cos, sin)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
if attention_mask is None:
# flash attention
attn_output = attention(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
window_size_left=self.window_size,
softcap=self.softcap,
)
else:
lengths = cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]
# Split tensors using vectorized split
query_list = torch.split(query, lengths.tolist(), dim=0)
key_list = torch.split(key, lengths.tolist(), dim=0)
value_list = torch.split(value, lengths.tolist(), dim=0)
padded_query = torch.nn.utils.rnn.pad_sequence(
query_list, batch_first=True
)
padded_key = torch.nn.utils.rnn.pad_sequence(key_list, batch_first=True)
padded_value = torch.nn.utils.rnn.pad_sequence(
value_list, batch_first=True
)
padded_query = padded_query.transpose(1, 2).contiguous()
padded_key = padded_key.transpose(1, 2).contiguous()
padded_value = padded_value.transpose(1, 2).contiguous()
# Compute attention
attn_output = F.scaled_dot_product_attention(
padded_query,
padded_key,
padded_value,
attn_mask=attention_mask,
scale=self.softmax_scale,
enable_gqa=self.enable_gqa,
)
attn_output = attn_output.transpose(
1, 2
) # [batch_size, seq_len, num_heads, head_dim]
max_seq_len = padded_query.size(2)
seq_range = torch.arange(
max_seq_len, device=padded_query.device
).unsqueeze(0)
lengths_tensor = torch.tensor(
lengths, device=padded_query.device
).unsqueeze(1)
mask = seq_range < lengths_tensor # [batch, max_seq_len]
attn_output = attn_output[mask] # [total_seq_len, num_heads, head_dim]
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
max_s,
softcap=self.softcap,
kv_scales=self.kv_scales,
window_size_left=self.window_size,
)
return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class Gemma3MLP(nn.Module):
def __init__(self, prefix, config, weights, layer_id):
super().__init__()
act = config.hidden_activation
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
# Fuse gate and up proj
gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
layer_id,
["gate_proj", "up_proj"],
sizes=[
config.intermediate_size,
config.intermediate_size,
],
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
layer_id,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
def forward(self, hidden_states, adapter_data):
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class FlashGemma3Layer(nn.Module):
def __init__(
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
):
super().__init__()
self.self_attn = FlashGemma3Attention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
layer_id=layer_id,
causal=causal,
is_sliding=is_sliding,
)
self.mlp = Gemma3MLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
)
self.input_layernorm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.pre_feedforward_layernorm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.pre_feedforward_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.post_feedforward_layernorm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.post_feedforward_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
attention_mask,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
attention_mask,
)
# faster post attention rms norm
normed_attn_res_output, _ = self.post_attention_layernorm(attn_output)
normed_attn_res_output = normed_attn_res_output + res
res = normed_attn_res_output
pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output)
mlp_output = self.mlp(pre_normed, adapter_data)
post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output)
return post_hidden_states, normed_attn_res_output
class FlashGemma3Model(torch.nn.Module):
def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.layers = nn.ModuleList(
[
FlashGemma3Layer(
prefix=f"{prefix}.layers.{layer_id}",
config=config,
weights=weights,
layer_id=layer_id,
causal=causal,
is_sliding=bool((layer_id + 1) % config.sliding_window_pattern),
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
adapter_data: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
attention_mask_local: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
residual = None
for i, layer in enumerate(self.layers):
cos, sin = self.layers[i].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
max_s,
adapter_data,
(
attention_mask
if self.layers[i].self_attn.window_size == -1
else attention_mask_local
),
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashGemma3ForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights, *, causal: bool = True):
super().__init__()
embed_norm = config.hidden_size**0.5
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.embed_tokens.weight *= embed_norm
self.model = FlashGemma3Model(
prefix=prefix, config=config, weights=weights, causal=causal
)
self.lm_head = SpeculativeHead.load(
prefix=(
f"{prefix}.embed_tokens"
if config.tie_word_embeddings
else f"{prefix}.lm_head"
),
config=config,
weights=weights,
)
# self.softcap = config.attn_logit_softcapping
# assert isinstance(self.softcap, float)
self.softcap = None
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
input_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
input_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits
class Gemma3MultimodalInputProjection(torch.nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.mm_input_projection_weight = weights.get_tensor(
"multi_modal_projector.mm_input_projection_weight"
)
self.mm_soft_emb_norm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.mm_soft_emb_norm",
weights=weights,
eps=config.vision_config.layer_norm_eps,
)
self.patches_per_image = int(
config.vision_config.image_size // config.vision_config.patch_size
)
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
self.kernel_size = self.patches_per_image // self.tokens_per_side
self.avg_pool = nn.AvgPool2d(
kernel_size=self.kernel_size, stride=self.kernel_size
)
def forward(self, vision_outputs: torch.Tensor):
batch_size, _, seq_length = vision_outputs.shape
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
batch_size, seq_length, self.patches_per_image, self.patches_per_image
)
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
normed_vision_outputs, _ = self.mm_soft_emb_norm(pooled_vision_outputs)
projected_vision_outputs = torch.matmul(
normed_vision_outputs, self.mm_input_projection_weight
)
return projected_vision_outputs.type_as(vision_outputs)
class Gemma3ForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
if config.vision_config is not None:
config.vision_config.quantize = config.quantize
self.post_vision_model_layernorm = nn.LayerNorm.load(
prefix="vision_tower.vision_model.post_layernorm",
weights=weights,
eps=config.vision_config.layer_norm_eps,
)
self.multimodal_projector = Gemma3MultimodalInputProjection(
prefix="multi_modal_projector",
config=config,
weights=weights,
)
text_config = config.text_config
text_config.speculator = config.speculator
text_config.quantize = config.quantize
self.vision_model = load_vision_model(
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config,
weights=weights,
)
self.text_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config,
weights=weights,
)
else:
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
self.text_model = load_text_model(
prefix=prefix,
config=config.text_config,
weights=weights,
)
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
def get_attention_mask(
self, input_ids, max_s, cu_seqlen_prefill, dtype, image_token_mask
):
device = input_ids.device
min_dtype = torch.finfo(dtype).min
lengths = (cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]).tolist()
batch_size = len(lengths)
sequence_length = max(lengths)
target_length = sequence_length
# Create the padding mask from the computed lengths.
# pad_mask: [batch, sequence_length] where True indicates valid tokens.
seq_range = torch.arange(sequence_length, device=device).unsqueeze(0)
lengths_tensor = torch.tensor(lengths, device=device).unsqueeze(1)
pad_mask = seq_range < lengths_tensor # shape: [batch, sequence_length]
# Build the base causal mask (for non-image tokens):
causal_mask = torch.tril(
torch.ones(
(sequence_length, sequence_length), dtype=torch.bool, device=device
)
)
base_mask = pad_mask.unsqueeze(2) & pad_mask.unsqueeze(
1
) # [batch, sequence_length, sequence_length]
base_mask = base_mask & causal_mask.unsqueeze(0) # apply causal constraint
image_token_mask = torch.nn.utils.rnn.pad_sequence(
torch.split(image_token_mask, lengths), batch_first=True, padding_value=0
)
bidirectional_mask = image_token_mask.unsqueeze(2) & image_token_mask.unsqueeze(
1
)
# Combine the causal base mask and the bidirectional mask.
combined_mask = torch.logical_or(
base_mask.unsqueeze(1), bidirectional_mask.unsqueeze(1)
).to(device)
# combined_mask now has shape [batch, 1, sequence_length, sequence_length]
full_attention_mask = torch.zeros(
(batch_size, 1, sequence_length, target_length),
device=device,
dtype=torch.bool,
)
full_attention_mask[:, :, :, :sequence_length] = combined_mask
final_attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(device)
return final_attention_mask
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
# Unused here
attention_mask: Optional[torch.BoolTensor] = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids)
if cu_seqlen_prefill is not None:
max_s += 1
position_ids += 1
if pixel_values is not None:
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
image_outputs = self.vision_model(pixel_values)
vision_outputs = self.post_vision_model_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multimodal_projector(vision_outputs)
image_token_mask = (input_ids == self.config.image_token_index).to(
input_ids.device
)
inputs_embeds[image_token_mask] = image_features.view(
-1, image_features.shape[-1]
)
attention_mask = self.get_attention_mask(
input_ids,
max_s,
cu_seqlen_prefill,
inputs_embeds.dtype,
image_token_mask,
)
# Use flash attention for text-only input
# else:
# if cu_seqlen_prefill is not None:
# min_dtype = torch.finfo(inputs_embeds.dtype).min
# lengths = (cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]).tolist()
# # Determine the maximum sequence length (after padding) from query.
# sequence_length = max(lengths)
# target_length = sequence_length
# # Create the padding mask from the computed lengths.
# # pad_mask: [batch, sequence_length] where True indicates valid tokens.
# seq_range = torch.arange(
# sequence_length, device=input_ids.device
# ).unsqueeze(0)
# lengths_tensor = torch.tensor(
# lengths, device=input_ids.device
# ).unsqueeze(1)
# pad_mask = seq_range < lengths_tensor # shape: [batch, sequence_length]
# # Build the base causal mask (for non-image tokens):
# causal_mask = torch.tril(
# torch.ones(
# (sequence_length, sequence_length),
# dtype=torch.bool,
# device=input_ids.device,
# )
# )
# base_mask = pad_mask.unsqueeze(2) & pad_mask.unsqueeze(
# 1
# ) # [batch, sequence_length, sequence_length]
# base_mask = base_mask & causal_mask.unsqueeze(0)
# attention_mask = base_mask.unsqueeze(
# 1
# ) # [batch, 1, sequence_length, sequence_length]
# full_attention_mask = torch.zeros(
# (len(lengths), 1, sequence_length, target_length),
# device=input_ids.device,
# dtype=torch.bool,
# )
# full_attention_mask[:, :, :, :sequence_length] = attention_mask
# attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(
# input_ids.device
# )
if attention_mask is not None:
min_dtype = torch.finfo(inputs_embeds.dtype).min
# prefill may be larger than sliding window
effective_seq_len = max(
position_ids.shape[0], self.config.text_config.sliding_window
)
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool),
diagonal=-self.config.text_config.sliding_window,
)
attention_mask_local = torch.where(
sliding_window_mask, min_dtype, attention_mask
)
offset = max(0, position_ids.shape[0] - effective_seq_len)
attention_mask_local = attention_mask_local[
:, :, :, offset : offset + effective_seq_len
]
else:
attention_mask_local = None
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
attention_mask=attention_mask,
attention_mask_local=attention_mask_local,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.text_model.lm_head(hidden_states)
# pad logit with 1 zero logit for the image token
if pixel_values is not None:
logits = torch.cat(
[logits, torch.zeros(logits.size(0), 1, device=logits.device)], dim=1
)
if speculative_logits is not None:
speculative_logits = torch.cat(
[
speculative_logits,
torch.zeros(
speculative_logits.size(0),
1,
device=speculative_logits.device,
),
],
dim=1,
)
return logits, speculative_logits

View File

@ -242,6 +242,7 @@ class MistralAttention(torch.nn.Module):
seqlen,
max_s,
kv_scales=self.kv_scales,
window_size_left=self.max_past,
)
return self.o_proj(

View File

@ -290,6 +290,7 @@ class MixtralAttention(torch.nn.Module):
seqlen,
max_s,
kv_scales=self.kv_scales,
window_size_left=self.max_past,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -74,7 +74,7 @@ class Qwen2Attention(torch.nn.Module):
weights,
):
super().__init__()
self.max_past = (
self.window_size = (
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads
@ -172,7 +172,7 @@ class Qwen2Attention(torch.nn.Module):
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past,
window_size_left=self.window_size,
)
# Decode
else:
@ -185,6 +185,7 @@ class Qwen2Attention(torch.nn.Module):
seqlen,
max_s,
kv_scales=self.kv_scales,
window_size_left=self.window_size,
)
return self.o_proj(
@ -405,10 +406,10 @@ class Qwen2ForCausalLM(torch.nn.Module):
weights=weights,
)
self.max_past = config.sliding_window
self.max_past_tensor = (
self.window_size = config.sliding_window
self.window_size_tensor = (
torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None
if self.window_size is not None
else None
)
@ -430,10 +431,10 @@ class Qwen2ForCausalLM(torch.nn.Module):
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
elif self.window_size is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)
seqlen = seqlen.clamp(max=self.window_size_tensor)
inputs_embeds = self.embed_tokens(input_ids)

View File

@ -291,6 +291,7 @@ class Starcoder2Attention(torch.nn.Module):
seqlen,
max_s,
kv_scales=self.kv_scales,
window_size_left=self.max_past,
)
return self.o_proj(

View File

@ -0,0 +1,313 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_gemma3.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
from transformers import SiglipVisionConfig
logger = logging.get_logger(__name__)
class Gemma3TextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Gemma3Model`]. It is used to instantiate a Gemma3
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma3-4B.
e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 262144):
Vocabulary size of the Gemma3 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Gemma3Model`]
hidden_size (`int`, *optional*, defaults to 2304):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 9216):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 26):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 4):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
head_dim (`int`, *optional*, defaults to 256):
The attention head dimension.
sliding_window (`int`, *optional*, defaults to 4096): in Gemma3, every other layer uses sliding window
attention. This is the size of the sliding window.
query_pre_attn_scalar (`float`, *optional*):
The scaling factor used on the attention scores, not that
rope_theta (`float`, *optional*, defaults to 1000000.0):
The base period of the RoPE embeddings used for global attention.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
rope_local_base_freq (float, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings for local attention.
sliding_window_pattern (`int`, *optional*, defaults to 6):
Pattern for the sliding window attention.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the decoder. Will default to
`"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"`
activation function.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
eos_token_id (`int`, *optional*, defaults to 1):
End of stream token id.
bos_token_id (`int`, *optional*, defaults to 2):
Beginning of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie weight embeddings
max_position_embeddings (`int`, *optional*, defaults to 131072):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
final_logit_softcapping (`bool`, *optional*, defaults to `True`):
Whether to apply logit softcapping or nor
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
Scaling factor when applying tanh soft-capping on the attention scorexs.
cache_implementation (`str`, *optional*, defaults to `"hybrid"`):
The cache type to be used with `generate`.
```python
>>> from transformers import Gemma3Model, Gemma3TextConfig
>>> # Initializing a Gemma3 gemma3-4b style configuration
>>> configuration = Gemma3Config()
>>> # Initializing a model from the gemma3-4b style configuration
>>> model = Gemma3Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "gemma3_text"
def __init__(
self,
vocab_size: int = 262_144,
hidden_size: int = 2304,
intermediate_size: int = 9216,
num_hidden_layers: int = 26,
num_attention_heads: int = 8,
num_key_value_heads: int = 4,
head_dim: int = 256,
sliding_window: int = 4096,
query_pre_attn_scalar: Optional[float] = 256,
rope_theta: float = 1_000_000.0,
rope_scaling=None,
rope_local_base_freq: float = 10_000.0,
sliding_window_pattern: int = 6,
rms_norm_eps: float = 1e-6,
hidden_activation: str = "gelu_pytorch_tanh",
pad_token_id: int = 0,
eos_token_id: int = 1,
bos_token_id: int = 2,
tie_word_embeddings: bool = True,
max_position_embeddings: int = 131_072,
initializer_range: float = 0.02,
attention_bias: bool = False,
attention_dropout: float = 0.0,
use_cache: bool = True,
final_logit_softcapping=None,
attn_logit_softcapping=None,
cache_implementation: str = "hybrid",
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.num_key_value_heads = num_key_value_heads
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.rope_local_base_freq = rope_local_base_freq
# For configuring HybridCache to work with 5:1 attention pattern
self.sliding_window_pattern = sliding_window_pattern
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.hidden_activation = hidden_activation
self.query_pre_attn_scalar = query_pre_attn_scalar
self.sliding_window = sliding_window
self.final_logit_softcapping = final_logit_softcapping
self.attn_logit_softcapping = attn_logit_softcapping
self.cache_implementation = cache_implementation
rope_config_validation(self)
class Gemma3Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an
Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the PaliGemma-2B.
e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`Union[Gemma3TextConfig, dict]`, *optional*):
The config object of the text backbone.
vision_config (`Union[AutoConfig, dict]`, *optional*):
Custom vision config or dict.
mm_tokens_per_image (`int`, *optional*, defaults to 256):
The number of tokens per image embedding.
boi_token_index (`int`, *optional*, defaults to 255999):
The begin-of-image token index to wrap the image prompt.
eoi_token_index (`int`, *optional*, defaults to 256000):
The end-of-image token index to wrap the image prompt.
image_token_index (`int`, *optional*, defaults to 262144):
The image token index to encode the image prompt.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
Example:
```python
>>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig
>>> # Initializing a Siglip-like vision config
>>> vision_config = SiglipVisionConfig()
>>> # Initializing a Gemma3 Text config
>>> text_config = Gemma3TextConfig()
>>> # Initializing a Gemma3 gemma-3-4b style configuration
>>> configuration = Gemma3Config(vision_config, text_config)
>>> # Initializing a model from the gemma-3-4b style configuration
>>> model = Gemma3TextConfig(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "gemma3"
sub_configs = {
"text_config": Gemma3TextConfig,
"vision_config": SiglipVisionConfig,
}
def __init__(
self,
text_config: Optional[Gemma3TextConfig] = None,
vision_config: Optional[SiglipVisionConfig] = None,
mm_tokens_per_image: int = 256,
boi_token_index: int = 255_999,
eoi_token_index: int = 256_000,
image_token_index: int = 262_144,
initializer_range: float = 0.02,
**kwargs,
):
if text_config is None:
text_config = Gemma3TextConfig()
logger.info(
"text_config is None, using default Gemma3TextConfig vision config."
)
elif isinstance(text_config, dict):
text_config = Gemma3TextConfig(**text_config)
if isinstance(vision_config, dict):
vision_config = SiglipVisionConfig(**vision_config)
else:
vision_config = SiglipVisionConfig()
logger.info(
"vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited "
"to text tasks."
)
self.text_config = text_config
self.vision_config = vision_config
self.mm_tokens_per_image = mm_tokens_per_image
self.boi_token_index = boi_token_index
self.eoi_token_index = eoi_token_index
self.image_token_index = image_token_index
self.initializer_range = initializer_range
super().__init__(**kwargs)
__all__ = ["Gemma3Config", "Gemma3TextConfig"]

View File

@ -0,0 +1,463 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for Gemma3."""
import itertools
import math
from typing import Dict, List, Optional, Union
import numpy as np
from transformers.image_processing_utils import (
BaseImageProcessor,
BatchFeature,
get_size_dict,
)
from transformers.image_transforms import (
convert_to_rgb,
resize,
to_channel_dimension_format,
)
from transformers.image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from transformers.utils import (
TensorType,
filter_out_non_signature_kwargs,
is_vision_available,
logging,
)
from .utils import make_nested_list_of_images
logger = logging.get_logger(__name__)
if is_vision_available():
import PIL
class Gemma3ImageProcessor(BaseImageProcessor):
r"""
Constructs a SigLIP image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
`do_resize` in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
`do_normalize` in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
do_pan_and_scan (`bool`, *optional*):
Whether to apply `pan_and_scan` to images.
"""
model_input_names = ["pixel_values", "num_crops"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = False,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = None,
do_pan_and_scan: bool = None,
pan_and_scan_min_crop_size: int = None,
pan_and_scan_max_num_crops: int = None,
pan_and_scan_min_ratio_to_activate: float = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 224, "width": 224}
image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_convert_rgb = do_convert_rgb
self.do_pan_and_scan = do_pan_and_scan
self.pan_and_scan_min_crop_size = pan_and_scan_min_crop_size
self.pan_and_scan_max_num_crops = pan_and_scan_max_num_crops
self.pan_and_scan_min_ratio_to_activate = pan_and_scan_min_ratio_to_activate
def pan_and_scan(
self,
image: np.ndarray,
pan_and_scan_min_crop_size: int,
pan_and_scan_max_num_crops: int,
pan_and_scan_min_ratio_to_activate: float,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""
Pan and Scan and image, whatever it means. TODO: write-up docs
Args:
image (`np.ndarray`):
Image to resize.
pan_and_scan_min_crop_size (`int`):
Size of pan_and_scan_min_crop_size.
pan_and_scan_max_num_crops (`int`):
pan_and_scan_max_num_crops for the image.
pan_and_scan_min_ratio_to_activate (`int`):
pan_and_scan_min_ratio_to_activate for the image..
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
height, width = get_image_size(image)
# Square or landscape image.
if width >= height:
# Only apply PaS if the image is sufficiently exaggerated
if width / height < pan_and_scan_min_ratio_to_activate:
return []
# Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
num_crops_w = int(
math.floor(width / height + 0.5)
) # Half round up rounding.
num_crops_w = min(
int(math.floor(width / pan_and_scan_min_crop_size)), num_crops_w
)
# Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
num_crops_w = max(2, num_crops_w)
num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
num_crops_h = 1
# Portrait image.
else:
# Only apply PaS if the image is sufficiently exaggerated
if height / width < pan_and_scan_min_ratio_to_activate:
return []
# Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
num_crops_h = int(math.floor(height / width + 0.5))
num_crops_h = min(
int(math.floor(height / pan_and_scan_min_crop_size)), num_crops_h
)
# Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
num_crops_h = max(2, num_crops_h)
num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
num_crops_w = 1
crop_size_w = int(math.ceil(width / num_crops_w))
crop_size_h = int(math.ceil(height / num_crops_h))
# Don't apply PaS if crop size is too small.
if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
return []
crop_positions_w = [crop_size_w * i for i in range(num_crops_w)]
crop_positions_h = [crop_size_h * i for i in range(num_crops_h)]
if input_data_format == ChannelDimension.LAST:
image_crops = [
image[pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]
for pos_h, pos_w in itertools.product(
crop_positions_h, crop_positions_w
)
]
else:
image_crops = [
image[:, pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]
for pos_h, pos_w in itertools.product(
crop_positions_h, crop_positions_w
)
]
return image_crops
def _process_images_for_pas(
self,
images: List[np.ndarray],
do_pan_and_scan: bool,
pan_and_scan_min_crop_size: int,
pan_and_scan_max_num_crops: int,
pan_and_scan_min_ratio_to_activate: float,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
pas_images_list = []
num_crops = []
for image in images:
pas_images = self.pan_and_scan(
image=image,
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
data_format=data_format,
input_data_format=input_data_format,
)
pas_images_list.extend([image] + pas_images)
num_crops.append(len(pas_images))
return pas_images_list, num_crops
@filter_out_non_signature_kwargs()
def preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
do_convert_rgb: bool = True,
do_pan_and_scan: bool = None,
pan_and_scan_min_crop_size: int = None,
pan_and_scan_max_num_crops: int = None,
pan_and_scan_min_ratio_to_activate: float = None,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
do_pan_and_scan (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to apply `pan_and_scan` to images.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size, param_name="size", default_to_square=False)
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = (
rescale_factor if rescale_factor is not None else self.rescale_factor
)
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = (
do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
)
do_pan_and_scan = (
do_pan_and_scan if do_pan_and_scan is not None else self.do_pan_and_scan
)
pan_and_scan_min_crop_size = (
pan_and_scan_min_crop_size
if pan_and_scan_min_crop_size is not None
else self.pan_and_scan_min_crop_size
)
pan_and_scan_max_num_crops = (
pan_and_scan_max_num_crops
if pan_and_scan_max_num_crops is not None
else self.pan_and_scan_max_num_crops
)
pan_and_scan_min_ratio_to_activate = (
pan_and_scan_min_ratio_to_activate
if pan_and_scan_min_ratio_to_activate is not None
else self.pan_and_scan_min_ratio_to_activate
)
images_list = make_nested_list_of_images(images)
if not valid_images(images_list[0]):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
if do_convert_rgb:
images_list = [
[convert_to_rgb(image) for image in images] for images in images_list
]
# All transformations expect numpy arrays.
images_list = [
[to_numpy_array(image) for image in images] for images in images_list
]
if do_rescale and is_scaled_image(images_list[0][0]):
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images_list[0][0])
if do_pan_and_scan:
images_list_and_num_crops = [
self._process_images_for_pas(
images=images,
do_pan_and_scan=do_pan_and_scan,
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
data_format=data_format,
input_data_format=input_data_format,
)
for images in images_list
]
images_list = [images for images, _ in images_list_and_num_crops]
num_crops = [num_crops for _, num_crops in images_list_and_num_crops]
else:
num_crops = [[0] for images in images_list]
if do_resize:
height, width = size["height"], size["width"]
images_list = [
[
resize(
image=image,
size=(height, width),
resample=resample,
input_data_format=input_data_format,
)
for image in images
]
for images in images_list
]
if do_rescale:
images_list = [
[
self.rescale(
image=image,
scale=rescale_factor,
input_data_format=input_data_format,
)
for image in images
]
for images in images_list
]
if do_normalize:
images_list = [
[
self.normalize(
image=image,
mean=image_mean,
std=image_std,
input_data_format=input_data_format,
)
for image in images
]
for images in images_list
]
images = [
to_channel_dimension_format(
image, data_format, input_channel_dim=input_data_format
)
for images in images_list
for image in images
]
data = {"pixel_values": images, "num_crops": num_crops}
return BatchFeature(data=data, tensor_type=return_tensors)
__all__ = ["Gemma3ImageProcessor"]

View File

@ -0,0 +1,204 @@
# coding=utf-8
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import List, Optional, Union
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import (
ImagesKwargs,
ProcessingKwargs,
ProcessorMixin,
Unpack,
)
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.utils import to_py_obj
from text_generation_server.models.custom_modeling.gemma3.image_processing_gemma3 import (
Gemma3ImageProcessor,
)
from transformers.image_utils import PILImageResampling
from .utils import make_nested_list_of_images
class Gemma3ImagesKwargs(ImagesKwargs):
do_pan_and_scan: Optional[bool]
pan_and_scan_min_crop_size: Optional[int]
pan_and_scan_max_num_crops: Optional[int]
pan_and_scan_min_ratio_to_activate: Optional[float]
do_convert_rgb: Optional[bool]
class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
},
"images_kwargs": {
"do_pan_and_scan": False,
"pan_and_scan_min_crop_size": 256,
"pan_and_scan_max_num_crops": 4,
"pan_and_scan_min_ratio_to_activate": 1.2,
},
}
class Gemma3Processor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template"]
# # image_processor_class = "Gemma3ImageProcessor"
image_processor_class = "AutoProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor,
tokenizer,
chat_template=None,
num_mm_soft_tokens_per_image: int = 256,
**kwargs,
):
num_mm_soft_tokens_per_image = 256
chat_template = None
image_processor = Gemma3ImageProcessor(
image_mean=(127.5,) * 3,
image_std=(127.5,) * 3,
size={"height": 896, "width": 896},
do_rescale=False,
resample=PILImageResampling.BILINEAR,
)
self.image_token_id = tokenizer.image_token_id
image_tokens_expanded = "".join(
[tokenizer.image_token] * num_mm_soft_tokens_per_image
)
self.full_image_sequence = (
f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"
)
self.image_processor = image_processor
self.tokenizer = tokenizer
self.chat_template = chat_template
# super().__init__(
# image_processor=image_processor,
# tokenizer=tokenizer,
# chat_template=chat_template,
# **kwargs,
# )
def __call__(
self,
images: ImageInput = None,
text: Union[
TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
] = None,
videos=None,
audio=None,
**kwargs: Unpack[Gemma3ProcessorKwargs],
) -> BatchFeature:
if text is None and images is None:
raise ValueError("Provide at least one of `text` or `images`.")
output_kwargs = self._merge_kwargs(
Gemma3ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError(
"Invalid input text. Please provide a string, or a list of strings"
)
image_inputs = {}
if images is not None:
batched_images = make_nested_list_of_images(images)
image_inputs = self.image_processor(
batched_images, **output_kwargs["images_kwargs"]
)
# Create empty text to be replaced with placeholders
if not text:
text = [
" ".join(["<image>"] * len(images)) for images in batched_images
]
if len(batched_images) != len(text):
raise ValueError(
f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})."
)
# Replace image tokens by the full expanded sequence
batch_num_crops = to_py_obj(image_inputs.pop("num_crops"))
for prompt, images, num_crops in zip(text, batched_images, batch_num_crops):
image_indexes = [m.start() for m in re.finditer("<image>", prompt)]
if len(images) != len(image_indexes):
raise ValueError(
f"Prompt contained {len(image_indexes)} image tokens but received {len(images)} images."
)
# Insert additional image tokens for Pan-and-Scan crops
for num, idx in reversed(list(zip(num_crops, image_indexes))):
if num:
formatted_image_text = (
"Here is the original image <image> and here are some crops to help you see better "
+ " ".join(["<image>"] * num)
)
prompt = (
prompt[:idx]
+ formatted_image_text
+ prompt[idx + len("<image>") :]
)
# Expand placeholder image tokens to the full image token sequence
text = [
prompt.replace("<image>", self.full_image_sequence) for prompt in text
]
text_input = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
return BatchFeature(data={**text_input, **image_inputs})
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->PaliGemma
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
__all__ = ["Gemma3Processor"]

View File

@ -0,0 +1,61 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
from transformers.image_utils import ImageInput, is_valid_image, is_pil_image
def is_valid_list_of_images(images: List):
return images and all(is_valid_image(image) for image in images)
def make_nested_list_of_images(
images: Union[List[ImageInput], ImageInput],
) -> ImageInput:
"""
Ensure that the output is a nested list of images.
Args:
images (`Union[List[ImageInput], ImageInput]`):
The input image.
Returns:
list: A list of list of images or a list of 4d array of images.
"""
# If it's a list of batches, it's already in the right format
if (
isinstance(images, (list, tuple))
and all(isinstance(images_i, (list, tuple)) for images_i in images)
and all(is_valid_list_of_images(images_i) for images_i in images)
):
return images
# If it's a list of images, it's a single batch, so convert it to a list of lists
if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
if is_pil_image(images[0]) or images[0].ndim == 3:
return [images]
if images[0].ndim == 4:
return [list(image) for image in images]
# If it's a single image, convert it to a list of lists
if is_valid_image(images):
if is_pil_image(images) or images.ndim == 3:
return [[images]]
if images.ndim == 4:
return [list(images)]
raise ValueError(
"Invalid input type. Must be a single image, a list of images, or a list of batches of images."
)

View File

@ -633,7 +633,7 @@ class Qwen2_5VisionModel(nn.Module):
config=config,
weights=weights,
)
# import ipdb; ipdb.set_trace()
self.temporal_patch_size = config.temporal_patch_size
self.spatial_patch_size = config.spatial_patch_size
self.in_channels = config.in_channels

View File

@ -542,6 +542,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
adapter_data=adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -23,6 +23,13 @@ def load_text_model(prefix, config, weights, name=None):
)
return FlashGemma2ForCausalLM(prefix, config, weights)
elif config.model_type == "gemma3" or config.model_type == "gemma3_text":
from text_generation_server.models.custom_modeling.flash_gemma3_modeling import (
FlashGemma3ForCausalLM,
)
return FlashGemma3ForCausalLM(prefix, config, weights)
elif config.model_type == "paligemma":
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
@ -42,13 +49,21 @@ def load_vision_model(prefix, config, weights):
return CLIPVisionTransformer(
prefix=f"{prefix}.vision_model", config=config, weights=weights
)
if config.model_type == "siglip_vision_model":
if (
config.model_type == "siglip_vision_model"
or config.model_type == "gemma3_vision"
):
from text_generation_server.models.custom_modeling.siglip import (
SiglipVisionTransformer,
)
# TODO: ensure that using the prefix doesn't break any existing models
# that rely on the old prefix (update the old models if necessary)
return SiglipVisionTransformer(
prefix="vision_tower.vision_model", config=config, weights=weights
# prefix="vision_model.vision_model", config=config, weights=weights
prefix=f"{prefix}.vision_model",
config=config,
weights=weights,
)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")

View File

@ -83,24 +83,11 @@ from text_generation_server.models.metadata_kernels import (
tracer = trace.get_tracer(__name__)
# Will be set in init
SLIDING_WINDOW: Optional[int] = None
def small_power_of_2(n: int):
return 1 << ((n - 1).bit_length() - 1)
def set_sliding_window(sliding_window: int):
global SLIDING_WINDOW
SLIDING_WINDOW = sliding_window
def get_sliding_windows() -> int:
global SLIDING_WINDOW
return SLIDING_WINDOW
def init_cpu_threads_env(rank_id: int, world_size: int):
import importlib.util
@ -1002,10 +989,8 @@ class FlashCausalLMBatch(Batch):
self.slot_indices,
)
sliding_window = get_sliding_windows()
position_ids = []
slot_indices = []
prefill_cache_indices = []
all_prefill_logprobs = True
no_prefill_logprobs = True
prefill_cu_outlens = [0]
@ -1064,14 +1049,6 @@ class FlashCausalLMBatch(Batch):
# Update
cumulative_slot_tokens += len(request_slots)
# Create tensor to slice into the kv tensor in prefill
if sliding_window is not None:
request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, input_length - sliding_window),
cumulative_length + input_length,
dtype=torch.int64,
)
# Prefill logprobs is ignored if the request is done prefilling
prefill_logprobs = r.prefill_logprobs and request_prefilling
@ -1085,9 +1062,6 @@ class FlashCausalLMBatch(Batch):
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
if sliding_window is not None:
prefill_cache_indices.append(request_prefill_cache_indices)
ADAPTER_TO_INDEX = get_adapter_to_index()
if ADAPTER_TO_INDEX:
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
@ -1151,24 +1125,18 @@ class FlashCausalLMBatch(Batch):
position_ids = torch.cat(position_ids)
if slot_indices:
slot_indices = torch.cat(slot_indices)
if sliding_window is not None:
prefill_cache_indices = torch.cat(prefill_cache_indices)
else:
if position_ids:
position_ids = position_ids[0]
if slot_indices:
slot_indices = slot_indices[0]
if sliding_window is not None:
prefill_cache_indices = prefill_cache_indices[0]
if not has_triton():
self.position_ids = position_ids.to(device)
self.slot_indices = slot_indices.to(device)
self.prefill_cu_outlens = prefill_cu_outlens
self.prefill_cache_indices = (
prefill_cache_indices.to(device) if sliding_window is not None else None
)
self.prefill_cache_indices = None
if all_prefill_logprobs:
prefill_head_indices = None
@ -1306,9 +1274,7 @@ class FlashCausalLM(Model):
if text_config is not None:
config = text_config
if getattr(config, "sliding_window", None) is not None:
set_sliding_window(config.sliding_window)
else:
if getattr(config, "sliding_window", None) is None:
config.sliding_window = None
self.num_layers = config.num_hidden_layers
@ -2500,7 +2466,6 @@ class FlashCausalLM(Model):
page_size=BLOCK_SIZE,
kv_dtype=self.kv_cache_dtype,
q_dtype=self.dtype,
window_left=self.sliding_window,
)
else:
assert input_lengths_tensor is not None
@ -2514,5 +2479,4 @@ class FlashCausalLM(Model):
page_size=BLOCK_SIZE,
kv_cache_dtype=self.kv_cache_dtype,
q_dtype=self.dtype,
window_left=self.sliding_window,
)

View File

@ -110,7 +110,7 @@ class Model(ABC):
requires_padding=self.requires_padding,
dtype=str(self.dtype),
device_type=self.device.type,
window_size=self.sliding_window,
window_size=None, # Setting this parameter to None disabled the block logic with sliding window.
speculate=self.speculate,
support_chunking=self.support_chunking,
use_prefix_caching=PREFIX_CACHING,

View File

@ -128,6 +128,12 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>"
elif config.model_type == "gemma3":
# TODO: get correct number of features via reviewing the Gemma3 architecture
# and calculating the number of image tokens
num_pads = 256
padding = "<image_soft_token>" * num_pads
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n"
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
@ -244,6 +250,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
if config.model_type == "llava_next":
images.append(image)
elif config.model_type == "gemma3":
images.append(image)
else:
images.append([image])
else:

View File

@ -18,14 +18,10 @@ def get_cuda_free_memory(device, memory_fraction):
def get_xpu_free_memory(device, memory_fraction):
total_memory = torch.xpu.get_device_properties(device).total_memory
device_id = device.index
memory_fraction = float(os.getenv("XPU_MEMORY_FRACTION", "1.0"))
total_free_memory, total_xpu_memory = torch.xpu.mem_get_info(device)
memory_fraction = float(os.getenv("XPU_MEMORY_FRACTION", "0.9"))
free_memory = max(
0,
int(
total_memory * 0.9 * memory_fraction - torch.xpu.memory_reserved(device_id)
),
0, int(total_free_memory - (1 - memory_fraction) * total_xpu_memory)
)
return free_memory

View File

@ -1,7 +1,7 @@
import importlib
from loguru import logger
from hf_kernels import load_kernel as hf_load_kernel
from kernels import load_kernel as hf_load_kernel
from text_generation_server.utils.log import log_once

View File

@ -79,6 +79,8 @@ def _get_quantizer_config(model_id, revision):
modules_to_not_convert = data["quantization_config"].get(
"modules_to_not_convert", []
)
if modules_to_not_convert is None:
modules_to_not_convert = []
except Exception:
filename = "quantize_config.json"
try:

Some files were not shown because too many files have changed in this diff Show More