mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 20:12:07 +00:00
Merge branch 'main' into gaudi_backend_pa
This commit is contained in:
commit
d5b78ba16f
16
.github/workflows/build.yaml
vendored
16
.github/workflows/build.yaml
vendored
@ -124,6 +124,15 @@ jobs:
|
||||
export extra_pytest="--neuron"
|
||||
export target=""
|
||||
;;
|
||||
gaudi)
|
||||
export dockerfile="Dockerfile_gaudi"
|
||||
export label_extension="-gaudi"
|
||||
export docker_volume="/mnt/cache"
|
||||
export docker_devices=""
|
||||
export runs_on="ubuntu-latest"
|
||||
export platform=""
|
||||
export extra_pytest=""
|
||||
export target=""
|
||||
esac
|
||||
echo $dockerfile
|
||||
echo "Dockerfile=${dockerfile}"
|
||||
@ -224,7 +233,12 @@ jobs:
|
||||
- name: Final
|
||||
id: final
|
||||
run: |
|
||||
echo "docker_image=docker.io/huggingface/text-generation-inference-ci:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
echo "docker_image=docker.io/huggingface/text-generation-inference-ci:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "docker_image=ghcr.io/huggingface/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT"
|
||||
echo "docker_volume=${{ env.DOCKER_VOLUME }}" >> "$GITHUB_OUTPUT"
|
||||
echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT"
|
||||
|
3
.github/workflows/ci_build.yaml
vendored
3
.github/workflows/ci_build.yaml
vendored
@ -21,6 +21,7 @@ on:
|
||||
- "Dockerfile_amd"
|
||||
- "Dockerfile_intel"
|
||||
- "Dockerfile.neuron"
|
||||
- "Dockerfile_gaudi"
|
||||
branches:
|
||||
- "main"
|
||||
workflow_dispatch:
|
||||
@ -38,7 +39,7 @@ jobs:
|
||||
# fail-fast is true by default
|
||||
fail-fast: false
|
||||
matrix:
|
||||
hardware: ["cuda", "cuda-trtllm", "rocm", "intel-xpu", "intel-cpu", "neuron"]
|
||||
hardware: ["cuda", "cuda-trtllm", "rocm", "intel-xpu", "intel-cpu", "neuron", "gaudi"]
|
||||
uses: ./.github/workflows/build.yaml # calls the one above ^
|
||||
permissions:
|
||||
contents: write
|
||||
|
53
.github/workflows/nix_build.yaml
vendored
Normal file
53
.github/workflows/nix_build.yaml
vendored
Normal file
@ -0,0 +1,53 @@
|
||||
name: "Nix Build Docker image"
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
tags:
|
||||
- 'v*'
|
||||
concurrency:
|
||||
group: nix-image-${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build_nix_image:
|
||||
runs-on:
|
||||
group: aws-highmemory-32-plus-priv
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: cachix/install-nix-action@v27
|
||||
with:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
- uses: cachix/cachix-action@v14
|
||||
with:
|
||||
name: text-generation-inference
|
||||
# If you chose signing key for write access
|
||||
authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
|
||||
env:
|
||||
USER: github_runner
|
||||
- name: Build
|
||||
run: nix build .#dockerImage
|
||||
- name: Initialize Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
install: true
|
||||
buildkitd-config: /tmp/buildkitd.toml
|
||||
- name: Inject slug/short variables
|
||||
uses: rlespinasse/github-slug-action@v4.4.1
|
||||
- name: Login to internal Container Registry
|
||||
# if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.REGISTRY_USERNAME }}
|
||||
password: ${{ secrets.REGISTRY_PASSWORD }}
|
||||
registry: registry.internal.huggingface.tech
|
||||
- name: Push to docker
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
export TAG=nix-sha-${{ env.GITHUB_SHA_SHORT }}
|
||||
else
|
||||
export TAG=nix-${{ github.ref_name }}
|
||||
fi
|
||||
export IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:$TAG
|
||||
nix-shell -p skopeo --command "skopeo --insecure-policy copy docker-archive:$(readlink -f ./result) docker://$IMAGE --dest-compress-format zstd"
|
2
.github/workflows/tests.yaml
vendored
2
.github/workflows/tests.yaml
vendored
@ -46,7 +46,7 @@ jobs:
|
||||
- name: Download locked kernels
|
||||
run: |
|
||||
source ./.venv/bin/activate
|
||||
hf-kernels download server
|
||||
kernels download server
|
||||
- name: Run server tests
|
||||
run: |
|
||||
source ./.venv/bin/activate
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -28,3 +28,4 @@ server/fbgemmm
|
||||
hl-smi_log*.txt
|
||||
.graph_dumps
|
||||
out
|
||||
hqt_output
|
||||
|
16
Cargo.lock
generated
16
Cargo.lock
generated
@ -4617,7 +4617,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-backends-trtllm"
|
||||
version = "3.1.2-dev0"
|
||||
version = "3.2.1-dev0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"clap 4.5.30",
|
||||
@ -4638,7 +4638,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-benchmark"
|
||||
version = "3.1.2-dev0"
|
||||
version = "3.2.1-dev0"
|
||||
dependencies = [
|
||||
"average",
|
||||
"clap 4.5.30",
|
||||
@ -4658,7 +4658,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-client"
|
||||
version = "3.1.2-dev0"
|
||||
version = "3.2.1-dev0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
@ -4676,7 +4676,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-launcher"
|
||||
version = "3.1.2-dev0"
|
||||
version = "3.2.1-dev0"
|
||||
dependencies = [
|
||||
"clap 4.5.30",
|
||||
"ctrlc",
|
||||
@ -4697,7 +4697,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router"
|
||||
version = "3.1.2-dev0"
|
||||
version = "3.2.1-dev0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-stream",
|
||||
@ -4749,7 +4749,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router-llamacpp"
|
||||
version = "3.1.2-dev0"
|
||||
version = "3.2.1-dev0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bindgen 0.71.1",
|
||||
@ -4767,7 +4767,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router-v2"
|
||||
version = "3.1.2-dev0"
|
||||
version = "3.2.1-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
@ -4816,7 +4816,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router-v3"
|
||||
version = "3.1.2-dev0"
|
||||
version = "3.2.1-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
|
@ -21,7 +21,7 @@ default-members = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "3.1.2-dev0"
|
||||
version = "3.2.1-dev0"
|
||||
edition = "2021"
|
||||
authors = ["Olivier Dehaene"]
|
||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||
@ -29,7 +29,7 @@ homepage = "https://github.com/huggingface/text-generation-inference"
|
||||
[workspace.dependencies]
|
||||
base64 = "0.22.0"
|
||||
tokenizers = { version = "0.20.0", features = ["http"] }
|
||||
hf-hub = { version = "0.4.1", features = ["tokio"] }
|
||||
hf-hub = { version = "0.4.2", features = ["tokio"] }
|
||||
metrics = { version = "0.23.0" }
|
||||
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
||||
minijinja = { version = "2.2.0", features = ["json"] }
|
||||
|
@ -183,12 +183,12 @@ COPY server server
|
||||
COPY server/Makefile server/Makefile
|
||||
ENV HF_KERNELS_CACHE=/kernels
|
||||
RUN cd server && \
|
||||
uv sync --frozen --extra gen --extra attention --extra bnb --extra accelerate --extra compressed-tensors --extra marlin --extra moe --extra quantize --extra peft --extra outlines --no-install-project --active && \
|
||||
uv sync --frozen --extra gen --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --no-install-project --active && \
|
||||
make gen-server-raw && \
|
||||
hf-kernels download .
|
||||
kernels download .
|
||||
|
||||
RUN cd server && \
|
||||
uv sync --frozen --extra gen --extra attention --extra bnb --extra accelerate --extra compressed-tensors --extra marlin --extra moe --extra quantize --extra peft --extra outlines --active --python=${PYTHON_VERSION} && \
|
||||
uv sync --frozen --extra gen --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --active --python=${PYTHON_VERSION} && \
|
||||
uv pip install nvidia-nccl-cu12==2.25.1 && \
|
||||
pwd && \
|
||||
text-generation-server --help
|
||||
|
@ -5,7 +5,7 @@ RUN mkdir -p /tgi
|
||||
# Fetch the optimum-neuron sources directly to avoid relying on pypi deployments
|
||||
FROM alpine AS optimum-neuron
|
||||
RUN mkdir -p /optimum-neuron
|
||||
ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.0.28.tar.gz /optimum-neuron/sources.tar.gz
|
||||
ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.1.0.tar.gz /optimum-neuron/sources.tar.gz
|
||||
RUN tar -C /optimum-neuron -xf /optimum-neuron/sources.tar.gz --strip-components=1
|
||||
|
||||
# Build cargo components (adapted from TGI original Dockerfile)
|
||||
@ -108,10 +108,10 @@ RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEU
|
||||
# Install neuronx packages
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
aws-neuronx-dkms=2.18.20.0 \
|
||||
aws-neuronx-collectives=2.22.33.0-d2128d1aa \
|
||||
aws-neuronx-runtime-lib=2.22.19.0-5856c0b42 \
|
||||
aws-neuronx-tools=2.19.0.0 \
|
||||
aws-neuronx-dkms=2.19.64.0 \
|
||||
aws-neuronx-collectives=2.23.135.0-3e70920f2 \
|
||||
aws-neuronx-runtime-lib=2.23.112.0-9b5179492 \
|
||||
aws-neuronx-tools=2.20.204.0 \
|
||||
libxml2 \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& apt-get clean
|
||||
@ -120,16 +120,16 @@ ENV PATH="/opt/bin/:/opt/aws/neuron/bin:${PATH}"
|
||||
|
||||
# Install manually torch CPU version to avoid pulling CUDA
|
||||
RUN pip3 install \
|
||||
torch==2.1.2 \
|
||||
torchvision==0.16.2 \
|
||||
torch==2.5.1 \
|
||||
torchvision==0.20.1 \
|
||||
--index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
RUN pip3 install \
|
||||
neuronx-cc==2.15.143.0 \
|
||||
torch-neuronx==2.1.2.2.3.2 \
|
||||
transformers-neuronx==0.12.313 \
|
||||
neuronx-distributed==0.9.0 \
|
||||
libneuronxla==2.0.5347.0 \
|
||||
neuronx-cc==2.16.372.0 \
|
||||
torch-neuronx==2.5.1.2.4.0 \
|
||||
transformers-neuronx==0.13.322 \
|
||||
neuronx-distributed==0.10.1 \
|
||||
libneuronxla==2.1.681.0 \
|
||||
--extra-index-url=https://pip.repos.neuron.amazonaws.com
|
||||
|
||||
# Install HuggingFace packages
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Those arguments are required to build the image
|
||||
ARG HABANA_VERSION
|
||||
ARG PYTORCH_VERSION
|
||||
ARG HABANA_VERSION=1.20.0
|
||||
ARG PYTORCH_VERSION=2.6.0
|
||||
|
||||
# Rust builder
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.85.0 AS chef
|
||||
@ -92,7 +92,6 @@ RUN cd server && \
|
||||
make gen-server && \
|
||||
pip install --no-deps -r requirements.txt && \
|
||||
bash ./dill-0.3.8-patch.sh && \
|
||||
pip install outlines~=0.0.34 && \
|
||||
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
|
||||
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
|
||||
pip install . --no-cache-dir
|
||||
|
@ -45,7 +45,7 @@ RUN cargo build --profile release-opt --frozen
|
||||
|
||||
# Text Generation Inference base image for Intel
|
||||
|
||||
FROM intel/oneapi-basekit:2024.2.1-0-devel-ubuntu22.04 AS xpu
|
||||
FROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04 AS xpu
|
||||
|
||||
USER root
|
||||
|
||||
@ -87,7 +87,7 @@ RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https:/
|
||||
|
||||
RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-pti-dev-0.9
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc
|
||||
|
||||
# Text Generation Inference base env
|
||||
ENV HF_HOME=/data \
|
||||
@ -96,13 +96,11 @@ ENV HF_HOME=/data \
|
||||
|
||||
|
||||
|
||||
WORKDIR /usr/src
|
||||
RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/torch-2.5.0a0%2Bgite84e33f-cp311-cp311-linux_x86_64.whl --no-cache-dir
|
||||
RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/torchaudio-2.5.0a0%2B56bc006-cp311-cp311-linux_x86_64.whl --no-cache-dir
|
||||
RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/torchvision-0.20.0a0%2B8e8a208-cp311-cp311-linux_x86_64.whl --no-cache-dir
|
||||
RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/oneccl_bind_pt-2.5.0%2Bxpu-cp311-cp311-linux_x86_64.whl --no-cache-dir
|
||||
|
||||
RUN pip install triton-xpu==3.0.0b2 --no-cache-dir
|
||||
WORKDIR /usr/src
|
||||
RUN pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/test/xpu
|
||||
|
||||
RUN pip install triton-xpu==3.2.0b1 --no-cache-dir
|
||||
|
||||
# Install server
|
||||
COPY proto proto
|
||||
@ -114,15 +112,14 @@ RUN cd server && \
|
||||
pip install -U pip uv && \
|
||||
uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
|
||||
|
||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/pti/0.9/lib:/opt/conda/lib
|
||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib
|
||||
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
||||
#ENV TORCH_LLM_ALLREDUCE=1
|
||||
#ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
|
||||
ENV TORCH_LLM_ALLREDUCE=1
|
||||
ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
|
||||
ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0
|
||||
|
||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout 1ccf72b2d11cd00b47aef6d6cd054c088aa6f083
|
||||
RUN cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc,ats-m150' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
|
||||
|
||||
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.6.0%2Bxpu-cp311-cp311-linux_x86_64.whl
|
||||
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.6.10%2Bxpu-cp311-cp311-linux_x86_64.whl
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
# Install router
|
||||
|
3
Makefile
3
Makefile
@ -53,3 +53,6 @@ run-falcon-7b-instruct-quantize:
|
||||
|
||||
clean:
|
||||
rm -rf target aml
|
||||
|
||||
preview_doc:
|
||||
doc-builder preview text-generation-inference docs/source --not_python_module
|
||||
|
@ -84,7 +84,7 @@ model=HuggingFaceH4/zephyr-7b-beta
|
||||
volume=$PWD/data
|
||||
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:3.1.1 --model-id $model
|
||||
ghcr.io/huggingface/text-generation-inference:3.2.1 --model-id $model
|
||||
```
|
||||
|
||||
And then you can make requests like
|
||||
@ -121,7 +121,7 @@ curl localhost:8080/v1/chat/completions \
|
||||
|
||||
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
||||
|
||||
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.1-rocm --model-id $model` instead of the command above.
|
||||
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.2.1-rocm --model-id $model` instead of the command above.
|
||||
|
||||
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
|
||||
```
|
||||
@ -152,7 +152,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
||||
token=<your cli READ token>
|
||||
|
||||
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:3.1.1 --model-id $model
|
||||
ghcr.io/huggingface/text-generation-inference:3.2.1 --model-id $model
|
||||
```
|
||||
|
||||
### A note on Shared Memory (shm)
|
||||
|
@ -2,8 +2,8 @@ mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
|
||||
mkfile_dir := $(dir $(mkfile_path))
|
||||
root_dir := "${mkfile_dir}/../.."
|
||||
|
||||
HABANA_VERSION := 1.19.0
|
||||
PYTORCH_VERSION := 2.5.1
|
||||
HABANA_VERSION := 1.20.0
|
||||
PYTORCH_VERSION := 2.6.0
|
||||
|
||||
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
|
||||
|
||||
|
283
backends/gaudi/examples/docker_commands/docker_commands.md
Normal file
283
backends/gaudi/examples/docker_commands/docker_commands.md
Normal file
@ -0,0 +1,283 @@
|
||||
# Examples of Docker Commands for Gaudi Backend
|
||||
|
||||
This page gives a list of examples of docker run commands for some of the most popular models.
|
||||
|
||||
> **Note:** The parameters are chosen for Gaudi2 hardware to maximize performance on this given hardware, please adjust the parameters based on your hardware. For example, if you are using Gaudi3, you may want to increase the batch size.
|
||||
|
||||
## Default Precision (BF16)
|
||||
|
||||
### Llama3.1-8B on 1 card (BF16)
|
||||
|
||||
```bash
|
||||
model=meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
hf_token=YOUR_ACCESS_TOKEN
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run -p 8080:80 \
|
||||
--runtime=habana \
|
||||
--cap-add=sys_nice \
|
||||
--ipc=host \
|
||||
-v $volume:/data \
|
||||
-e HF_TOKEN=$hf_token \
|
||||
-e MAX_TOTAL_TOKENS=2048 \
|
||||
-e PREFILL_BATCH_BUCKET_SIZE=2 \
|
||||
-e BATCH_BUCKET_SIZE=32 \
|
||||
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||
--model-id $model \
|
||||
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
|
||||
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
|
||||
```
|
||||
|
||||
### Llama3.1-70B 8 cards (BF16)
|
||||
|
||||
```bash
|
||||
model=meta-llama/Meta-Llama-3.1-70B-Instruct
|
||||
hf_token=YOUR_ACCESS_TOKEN
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run -p 8080:80 \
|
||||
--runtime=habana \
|
||||
--cap-add=sys_nice \
|
||||
--ipc=host \
|
||||
-v $volume:/data \
|
||||
-e HF_TOKEN=$hf_token \
|
||||
-e MAX_TOTAL_TOKENS=2048 \
|
||||
-e BATCH_BUCKET_SIZE=256 \
|
||||
-e PREFILL_BATCH_BUCKET_SIZE=4 \
|
||||
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||
--model-id $model \
|
||||
--sharded true --num-shard 8 \
|
||||
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
|
||||
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
|
||||
```
|
||||
|
||||
### Llama2-7B on 1 Card (BF16)
|
||||
|
||||
```bash
|
||||
model=meta-llama/Llama-2-7b-chat-hf
|
||||
hf_token=YOUR_ACCESS_TOKEN
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run -p 8080:80 \
|
||||
--runtime=habana \
|
||||
--cap-add=sys_nice \
|
||||
--ipc=host \
|
||||
-v $volume:/data \
|
||||
-e HF_TOKEN=$hf_token \
|
||||
-e MAX_TOTAL_TOKENS=2048 \
|
||||
-e PREFILL_BATCH_BUCKET_SIZE=2 \
|
||||
-e BATCH_BUCKET_SIZE=32 \
|
||||
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||
--model-id $model \
|
||||
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
|
||||
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
|
||||
```
|
||||
|
||||
### Llama2-70B on 8 cards (BF16)
|
||||
|
||||
```bash
|
||||
model=meta-llama/Llama-2-70b-chat-hf
|
||||
hf_token=YOUR_ACCESS_TOKEN
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run -p 8080:80 \
|
||||
--runtime=habana \
|
||||
--cap-add=sys_nice \
|
||||
--ipc=host \
|
||||
-v $volume:/data \
|
||||
-e HF_TOKEN=$hf_token \
|
||||
-e MAX_TOTAL_TOKENS=2048 \
|
||||
-e BATCH_BUCKET_SIZE=256 \
|
||||
-e PREFILL_BATCH_BUCKET_SIZE=4 \
|
||||
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||
--model-id $model \
|
||||
--sharded true --num-shard 8 \
|
||||
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
|
||||
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
|
||||
```
|
||||
|
||||
### Llava-v1.6-Mistral-7B on 1 card (BF16)
|
||||
|
||||
```bash
|
||||
model=llava-hf/llava-v1.6-mistral-7b-hf
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run -p 8080:80 \
|
||||
--runtime=habana \
|
||||
--cap-add=sys_nice \
|
||||
--ipc=host \
|
||||
-v $volume:/data \
|
||||
-e PREFILL_BATCH_BUCKET_SIZE=1 \
|
||||
-e BATCH_BUCKET_SIZE=1 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||
--model-id $model \
|
||||
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
|
||||
--max-total-tokens 8192 --max-batch-size 4
|
||||
```
|
||||
|
||||
## FP8 Precision
|
||||
|
||||
Please refer to the [FP8 Precision](https://huggingface.co/docs/text-generation-inference/backends/gaudi_new#how-to-use-different-precision-formats) section for more details. You need to measure the statistics of the model first before running the model in FP8 precision.
|
||||
|
||||
## Llama3.1-8B on 1 Card (FP8)
|
||||
|
||||
```bash
|
||||
model=meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
hf_token=YOUR_ACCESS_TOKEN
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run -p 8080:80 \
|
||||
--runtime=habana \
|
||||
--cap-add=sys_nice \
|
||||
--ipc=host \
|
||||
-v $volume:/data \
|
||||
-v $PWD/quantization_config:/usr/src/quantization_config \
|
||||
-v $PWD/hqt_output:/usr/src/hqt_output \
|
||||
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
||||
-e HF_TOKEN=$hf_token \
|
||||
-e MAX_TOTAL_TOKENS=2048 \
|
||||
-e PREFILL_BATCH_BUCKET_SIZE=2 \
|
||||
-e BATCH_BUCKET_SIZE=32 \
|
||||
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||
--model-id $model \
|
||||
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
|
||||
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
|
||||
```
|
||||
|
||||
## Llama3.1-70B on 8 cards (FP8)
|
||||
|
||||
```bash
|
||||
model=meta-llama/Meta-Llama-3.1-70B-Instruct
|
||||
hf_token=YOUR_ACCESS_TOKEN
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run -p 8080:80 \
|
||||
--runtime=habana \
|
||||
--cap-add=sys_nice \
|
||||
--ipc=host \
|
||||
-v $volume:/data \
|
||||
-v $PWD/quantization_config:/usr/src/quantization_config \
|
||||
-v $PWD/hqt_output:/usr/src/hqt_output \
|
||||
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
||||
-e HF_TOKEN=$hf_token \
|
||||
-e MAX_TOTAL_TOKENS=2048 \
|
||||
-e BATCH_BUCKET_SIZE=256 \
|
||||
-e PREFILL_BATCH_BUCKET_SIZE=4 \
|
||||
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||
--model-id $model \
|
||||
--sharded true --num-shard 8 \
|
||||
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
|
||||
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
|
||||
```
|
||||
|
||||
## Llama2-7B on 1 Card (FP8)
|
||||
|
||||
```bash
|
||||
model=meta-llama/Llama-2-7b-chat-hf
|
||||
hf_token=YOUR_ACCESS_TOKEN
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run -p 8080:80 \
|
||||
--runtime=habana \
|
||||
--cap-add=sys_nice \
|
||||
--ipc=host \
|
||||
-v $volume:/data \
|
||||
-v $PWD/quantization_config:/usr/src/quantization_config \
|
||||
-v $PWD/hqt_output:/usr/src/hqt_output \
|
||||
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
||||
-e HF_TOKEN=$hf_token \
|
||||
-e MAX_TOTAL_TOKENS=2048 \
|
||||
-e PREFILL_BATCH_BUCKET_SIZE=2 \
|
||||
-e BATCH_BUCKET_SIZE=32 \
|
||||
-e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||
--model-id $model \
|
||||
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
|
||||
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
|
||||
```
|
||||
|
||||
## Llama2-70B on 8 Cards (FP8)
|
||||
|
||||
```bash
|
||||
model=meta-llama/Llama-2-70b-chat-hf
|
||||
hf_token=YOUR_ACCESS_TOKEN
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run -p 8080:80 \
|
||||
--runtime=habana \
|
||||
--cap-add=sys_nice \
|
||||
--ipc=host \
|
||||
-v $volume:/data \
|
||||
-v $PWD/quantization_config:/usr/src/quantization_config \
|
||||
-v $PWD/hqt_output:/usr/src/hqt_output \
|
||||
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
||||
-e HF_TOKEN=$hf_token \
|
||||
-e MAX_TOTAL_TOKENS=2048 \
|
||||
-e BATCH_BUCKET_SIZE=256 \
|
||||
-e PREFILL_BATCH_BUCKET_SIZE=4 \
|
||||
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||
--model-id $model \
|
||||
--sharded true --num-shard 8 \
|
||||
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
|
||||
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
|
||||
```
|
||||
|
||||
## Llava-v1.6-Mistral-7B on 1 Card (FP8)
|
||||
|
||||
```bash
|
||||
model=llava-hf/llava-v1.6-mistral-7b-hf
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run -p 8080:80 \
|
||||
--runtime=habana \
|
||||
--cap-add=sys_nice \
|
||||
--ipc=host \
|
||||
-v $volume:/data \
|
||||
-v $PWD/quantization_config:/usr/src/quantization_config \
|
||||
-v $PWD/hqt_output:/usr/src/hqt_output \
|
||||
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
||||
-e PREFILL_BATCH_BUCKET_SIZE=1 \
|
||||
-e BATCH_BUCKET_SIZE=1 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||
--model-id $model \
|
||||
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
|
||||
--max-total-tokens 8192 --max-batch-size 4
|
||||
```
|
||||
|
||||
## Llava-v1.6-Mistral-7B on 8 Cards (FP8)
|
||||
|
||||
```bash
|
||||
model=llava-hf/llava-v1.6-mistral-7b-hf
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run -p 8080:80 \
|
||||
--runtime=habana \
|
||||
--cap-add=sys_nice \
|
||||
--ipc=host \
|
||||
-v $volume:/data \
|
||||
-v $PWD/quantization_config:/usr/src/quantization_config \
|
||||
-v $PWD/hqt_output:/usr/src/hqt_output \
|
||||
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
||||
-e PREFILL_BATCH_BUCKET_SIZE=1 \
|
||||
-e BATCH_BUCKET_SIZE=1 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
|
||||
--model-id $model \
|
||||
--sharded true --num-shard 8 \
|
||||
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
|
||||
--max-total-tokens 8192 --max-batch-size 4
|
||||
```
|
@ -22,7 +22,7 @@ opentelemetry-instrumentation-grpc = "^0.36b0"
|
||||
hf-transfer = "^0.1.2"
|
||||
sentencepiece = "^0.1.97"
|
||||
peft = "^0.10"
|
||||
optimum-habana = "1.15.0"
|
||||
optimum-habana = "1.16.0"
|
||||
transformers = "4.45.2"
|
||||
numpy = "1.26.4"
|
||||
accelerate = "0.33.0"
|
||||
|
@ -46,7 +46,7 @@ opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_versi
|
||||
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
optimum-habana==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
optimum-habana==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
optimum==1.23.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pandas==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
@ -87,3 +87,18 @@ wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
xxhash==3.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
yarl==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
outlines==0.0.34 ; python_version >= "3.9" and python_version < "3.13"
|
||||
interegular==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
lark==1.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
cloudpickle==3.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
diskcache==5.6.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numba==0.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
llvmlite==0.43.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
jsonschema==4.23.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
annotated-types==0.7.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
jsonschema-specifications==2024.10.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
nest-asyncio==1.6.0; python_version >= "3.9" and python_version < "3.13"
|
||||
pydantic==2.10.6; python_version >= "3.9" and python_version < "3.13"
|
||||
pydantic-core==2.27.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
referencing==0.36.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
rpds-py==0.22.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -20,8 +20,9 @@ from text_generation_server.models.causal_lm import CausalLM
|
||||
from text_generation_server.models.bloom import BLOOM
|
||||
from text_generation_server.models.starcoder import StarCoder
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
|
||||
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
|
||||
from text_generation_server.models.custom_modeling.mllama import (
|
||||
MllamaForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.llava_next import (
|
||||
LlavaNextForConditionalGeneration,
|
||||
)
|
||||
@ -30,9 +31,6 @@ from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import
|
||||
)
|
||||
|
||||
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
||||
# from text_generation_server.models.custom_modeling.mllama import (
|
||||
# MllamaForConditionalGeneration,
|
||||
# )
|
||||
from text_generation_server.utils.adapter import (
|
||||
AdapterParameters,
|
||||
build_layer_weight_lookup,
|
||||
@ -329,6 +327,7 @@ __GLOBALS = locals()
|
||||
for data in ModelType:
|
||||
__GLOBALS[data.name] = data.value["type"]
|
||||
|
||||
SDP_ON_BF16 = int(os.environ.get("SDP_ON_BF16", 0))
|
||||
# Disable gradients
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
@ -849,6 +848,8 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
adapt_transformers_to_gaudi()
|
||||
if SDP_ON_BF16 == 1:
|
||||
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
|
||||
if model_type == "gpt_bigcode":
|
||||
return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
|
||||
if model_type == "bloom":
|
||||
@ -871,6 +872,17 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "mllama":
|
||||
return VlmCausalLM(
|
||||
model_class=MllamaForConditionalGeneration,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=None,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
return CausalLM(
|
||||
model_id,
|
||||
|
@ -704,6 +704,9 @@ class CausalLM(Model):
|
||||
htorch.core.hpu_set_env()
|
||||
|
||||
if world_size > 1:
|
||||
os.environ.setdefault(
|
||||
"DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1"
|
||||
)
|
||||
model = self.get_deepspeed_model(model_id, dtype, revision)
|
||||
model = hq_env.prepare_model_for_quantization(model)
|
||||
else:
|
||||
|
@ -14,25 +14,18 @@
|
||||
# limitations under the License.
|
||||
""" PyTorch Llava-NeXT model."""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.models.llava_next.modeling_llava_next import (
|
||||
unpad_image,
|
||||
)
|
||||
from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
|
||||
from transformers.image_processing_utils import select_best_resolution
|
||||
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from text_generation_server.models.custom_modeling.vlm import (
|
||||
load_text_model,
|
||||
load_vision_model,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
"""
|
||||
@ -40,7 +33,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
|
||||
Args:
|
||||
image_size (`tuple`):
|
||||
The size of the input image in the format (height, width).
|
||||
The size of the input image in the format (width, height).
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
@ -48,7 +41,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
The size of each image patch.
|
||||
|
||||
Returns:
|
||||
tuple: The shape of the image patch grid in the format (height, width).
|
||||
tuple: The shape of the image patch grid in the format (width, height).
|
||||
"""
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
||||
@ -57,100 +50,53 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
return height // patch_size, width // patch_size
|
||||
|
||||
|
||||
def unpad_image(tensor, original_size):
|
||||
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79
|
||||
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
|
||||
"""
|
||||
Unpads a PyTorch tensor of a padded and resized image.
|
||||
Calculate the number of patches after the preprocessing for images of any resolution.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`):
|
||||
The image tensor, assumed to be of shape (num_channels, height, width).
|
||||
original_size (`tuple`):
|
||||
The original size of the image (height, width).
|
||||
image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`):
|
||||
The size of the input image in the format (height, width). ?
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
patch_size (`int`):
|
||||
The size of each image patch.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The unpadded image tensor.
|
||||
int: the number of patches
|
||||
"""
|
||||
original_height, original_width = original_size
|
||||
current_height, current_width = tensor.shape[1:]
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise TypeError("grid_pinpoints should be a list of tuples or lists")
|
||||
|
||||
original_aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
|
||||
if not isinstance(image_size, (list, tuple)):
|
||||
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
|
||||
raise TypeError(
|
||||
f"image_size invalid type {type(image_size)} with value {image_size}"
|
||||
)
|
||||
image_size = image_size.tolist()
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / original_width
|
||||
new_height = int(original_height * scale_factor)
|
||||
padding = (current_height - new_height) // 2
|
||||
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
new_width = int(original_width * scale_factor)
|
||||
padding = (current_width - new_width) // 2
|
||||
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
||||
|
||||
return unpadded_tensor
|
||||
best_resolution = select_best_resolution(image_size, grid_pinpoints)
|
||||
height, width = best_resolution
|
||||
num_patches = 0
|
||||
# consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
|
||||
for i in range(0, height, patch_size):
|
||||
for j in range(0, width, patch_size):
|
||||
num_patches += 1
|
||||
# add the base patch
|
||||
num_patches += 1
|
||||
return num_patches
|
||||
|
||||
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
|
||||
class LlavaNextMultiModalProjector(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = TensorParallelColumnLinear.load(
|
||||
prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
|
||||
)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.linear_2 = TensorParallelRowLinear.load(
|
||||
prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
|
||||
)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.linear_1(image_features)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LlavaNextForConditionalGeneration(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
config.vision_config.quantize = config.quantize
|
||||
vision_config = config.vision_config
|
||||
# Instead of selecting in hidden_states[-2].
|
||||
# Instead compute only the n -2 + 1 layers and don't pool
|
||||
if config.vision_feature_layer < 0:
|
||||
vision_config.num_hidden_layers += config.vision_feature_layer + 1
|
||||
else:
|
||||
vision_config.num_hidden_layers = config.vision_feature_layer + 1
|
||||
self.vision_tower = load_vision_model(
|
||||
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
||||
config=config.vision_config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.multi_modal_projector = LlavaNextMultiModalProjector(
|
||||
prefix="multi_modal_projector", config=config, weights=weights
|
||||
)
|
||||
|
||||
self.image_newline = weights.get_tensor("image_newline")
|
||||
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self.config = config
|
||||
config.text_config.quantize = config.quantize
|
||||
config.text_config.speculator = config.speculator
|
||||
self.text_model = load_text_model(
|
||||
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||
config=config.text_config,
|
||||
weights=weights,
|
||||
)
|
||||
self.pad_token_id = (
|
||||
config.pad_token_id if config.pad_token_id is not None else -1
|
||||
)
|
||||
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
image_features: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
):
|
||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||
mask = input_ids == self.config.image_token_index
|
||||
@ -165,126 +111,315 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
# Unused for this model
|
||||
pixel_attention_mask=None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[int] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
token_idx: Optional[torch.Tensor] = None,
|
||||
use_flash_attention: Optional[bool] = True,
|
||||
flash_attention_recompute: Optional[bool] = True,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
||||
# 1. Extract the input embeddings
|
||||
|
||||
# 2. Merge text and images
|
||||
num_images, num_patches, channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.view(
|
||||
num_images * num_patches, channels, height, width
|
||||
if token_idx is not None:
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
image_features = self.vision_tower(pixel_values)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
|
||||
# Already done within the clip model
|
||||
selected_image_feature = image_features.last_hidden_state
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
token_idx=token_idx,
|
||||
use_flash_attention=use_flash_attention,
|
||||
flash_attention_recompute=flash_attention_recompute,
|
||||
)
|
||||
|
||||
if self.config.vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif self.config.vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
|
||||
logits = outputs[0]
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return output
|
||||
|
||||
return outputs
|
||||
|
||||
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_sizes: torch.Tensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
vision_feature_select_strategy: str,
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
|
||||
The tensors corresponding to the input images.
|
||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||
Actual image size of each images (H, W).
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
vision_feature_select_strategy (`str`):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
|
||||
and are of shape `(num_patches, image_length, embed_dim)`).
|
||||
"""
|
||||
# ! infer image_num_patches from image_sizes
|
||||
image_num_patches = [
|
||||
image_size_to_num_patches(
|
||||
image_size=imsize,
|
||||
grid_pinpoints=self.config.image_grid_pinpoints,
|
||||
patch_size=self.config.vision_config.image_size,
|
||||
)
|
||||
for imsize in image_sizes
|
||||
]
|
||||
if pixel_values.dim() == 5:
|
||||
# stacked if input is (batch_size, num_patches, num_channels, height, width)
|
||||
_pixel_values_list = [
|
||||
pix_val[:num_patch]
|
||||
for pix_val, num_patch in zip(pixel_values, image_num_patches)
|
||||
]
|
||||
pixel_values = torch.cat(_pixel_values_list, dim=0)
|
||||
elif pixel_values.dim() != 4:
|
||||
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
|
||||
raise ValueError(
|
||||
f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions"
|
||||
)
|
||||
|
||||
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
# If we have one vision feature layer, return the corresponding hidden states,
|
||||
# otherwise, select the hidden states of each feature layer and concatenate them
|
||||
if isinstance(vision_feature_layer, int):
|
||||
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
||||
else:
|
||||
hs_pool = [
|
||||
image_features.hidden_states[layer_idx]
|
||||
for layer_idx in vision_feature_layer
|
||||
]
|
||||
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
||||
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
image_features = torch.split(image_features, image_num_patches, dim=0)
|
||||
return image_features
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
pixel_values=None,
|
||||
image_sizes=None,
|
||||
attention_mask=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635
|
||||
The only differences are:
|
||||
- add new args token_idx
|
||||
- add the process of merging images into inputs_embeds
|
||||
"""
|
||||
token_idx = kwargs.get("token_idx", None)
|
||||
if token_idx is None:
|
||||
return super().prepare_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
pixel_values=pixel_values,
|
||||
image_sizes=image_sizes,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
use_flash_attention = kwargs.get("use_flash_attention", True)
|
||||
flash_attention_recompute = kwargs.get("flash_attention_recompute", True)
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
labels = kwargs.get("labels", None)
|
||||
if (
|
||||
past_key_values is None
|
||||
and pixel_values is not None
|
||||
and input_ids.shape[1] != 1
|
||||
):
|
||||
vision_feature_select_strategy = kwargs.get(
|
||||
"vision_feature_select_strategy", None
|
||||
)
|
||||
vision_feature_layer = kwargs.get("vision_feature_layer", None)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer
|
||||
if vision_feature_layer is not None
|
||||
else self.config.vision_feature_layer
|
||||
)
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
# 1. Extract the input embeddings
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
# 2. Merge text and images
|
||||
image_features = self.get_image_features(
|
||||
pixel_values,
|
||||
image_sizes,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
# split up image_features for each of the individual images
|
||||
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
|
||||
# if we assume each image has 5 image features (base image + 4 patches)
|
||||
split_sizes = [num_patches] * num_images
|
||||
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||
height = width = (
|
||||
self.config.vision_config.image_size
|
||||
// self.config.vision_config.patch_size
|
||||
)
|
||||
|
||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||
height = width = (
|
||||
self.config.vision_config.image_size
|
||||
// self.config.vision_config.patch_size
|
||||
)
|
||||
new_image_features = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
|
||||
new_image_features = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
if height * width != base_image_feature.shape[0]:
|
||||
raise ValueError(
|
||||
"The number of patches is not consistent with the image size."
|
||||
)
|
||||
|
||||
if height * width != base_image_feature.shape[0]:
|
||||
raise ValueError(
|
||||
"The number of patches is not consistent with the image size."
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx].tolist(),
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
|
||||
# Dimensions are intentionally swapped to be bug-compatible with
|
||||
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
image_feature = image_feature.view(
|
||||
num_patch_height, num_patch_width, height, width, -1
|
||||
)
|
||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
self.image_newline[:, None, None].expand(
|
||||
*image_feature.shape[:-1], 1
|
||||
image_feature = image_feature.view(
|
||||
num_patch_height, num_patch_width, height, width, -1
|
||||
)
|
||||
image_feature = image_feature.permute(
|
||||
4, 0, 2, 1, 3
|
||||
).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(
|
||||
image_feature, image_sizes[image_idx]
|
||||
)
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
self.image_newline[:, None, None].expand(
|
||||
*image_feature.shape[:-1], 1
|
||||
),
|
||||
),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat(
|
||||
(base_image_feature, image_feature), dim=0
|
||||
)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
image_feature = torch.cat(
|
||||
(image_feature, self.image_newline[None]), dim=0
|
||||
)
|
||||
new_image_features.append(image_feature)
|
||||
image_features = torch.stack(new_image_features, dim=0)
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat(
|
||||
(base_image_feature, image_feature), dim=0
|
||||
)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
image_feature = torch.cat(
|
||||
(image_feature, self.image_newline[None]), dim=0
|
||||
)
|
||||
new_image_features.append(image_feature)
|
||||
image_features = torch.cat(new_image_features, dim=0)
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
inputs_embeds, image_features, input_ids
|
||||
)
|
||||
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
||||
# generation with cache
|
||||
elif past_key_values is not None:
|
||||
seq_len = input_ids.shape[1]
|
||||
pad_len = seq_len - token_idx.item()
|
||||
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(
|
||||
first_layer_past_key_value.float().sum(-2) == 0
|
||||
)
|
||||
# Get the target length
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, image_features
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
|
||||
attention_mask = extended_attention_mask
|
||||
attention_mask[:, -pad_len:] = 0
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
if token_idx is not None:
|
||||
position_ids = (
|
||||
torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
)
|
||||
else:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
"token_idx": token_idx,
|
||||
"labels": labels,
|
||||
"use_flash_attention": use_flash_attention,
|
||||
"flash_attention_recompute": flash_attention_recompute,
|
||||
}
|
||||
)
|
||||
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
||||
return model_inputs
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -25,6 +25,7 @@ image:
|
||||
--ulimit nofile=100000:100000 \
|
||||
--build-arg VERSION=$(VERSION) \
|
||||
-t text-generation-inference:$(VERSION)-neuron ${root_dir}
|
||||
docker tag text-generation-inference:$(VERSION)-neuron text-generation-inference:latest-neuron
|
||||
|
||||
install_server:
|
||||
make -C ${mkfile_dir}/server install VERSION:=${VERSION}
|
||||
|
@ -10,7 +10,7 @@
|
||||
"name": "Apache 2.0",
|
||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||
},
|
||||
"version": "3.1.2-dev0"
|
||||
"version": "3.2.1-dev0"
|
||||
},
|
||||
"paths": {
|
||||
"/": {
|
||||
@ -2148,9 +2148,6 @@
|
||||
},
|
||||
"StreamOptions": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"include_usage"
|
||||
],
|
||||
"properties": {
|
||||
"include_usage": {
|
||||
"type": "boolean",
|
||||
|
@ -52,6 +52,8 @@
|
||||
- sections:
|
||||
- local: backends/neuron
|
||||
title: Neuron
|
||||
- local: backends/gaudi
|
||||
title: Gaudi
|
||||
- local: backends/trtllm
|
||||
title: TensorRT-LLM
|
||||
- local: backends/llamacpp
|
||||
|
317
docs/source/backends/gaudi.mdx
Normal file
317
docs/source/backends/gaudi.mdx
Normal file
@ -0,0 +1,317 @@
|
||||
# Gaudi Backend for Text Generation Inference
|
||||
|
||||
## Overview
|
||||
Text Generation Inference (TGI) has been optimized to run on Gaudi hardware via the Gaudi backend for TGI.
|
||||
|
||||
## Supported Hardware
|
||||
- **Gaudi1**: Available on [AWS EC2 DL1 instances](https://aws.amazon.com/ec2/instance-types/dl1/)
|
||||
- **Gaudi2**: Available on [Intel Cloud](https://console.cloud.intel.com/docs/reference/ai_instances.html)
|
||||
- **Gaudi3**: Available on [Intel Cloud](https://console.cloud.intel.com/docs/reference/ai_instances.html)
|
||||
|
||||
## Tutorial: Getting Started with TGI on Gaudi
|
||||
|
||||
### Basic Usage
|
||||
The easiest way to run TGI on Gaudi is to use the official Docker image:
|
||||
|
||||
```bash
|
||||
model=meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
hf_token=YOUR_HF_ACCESS_TOKEN
|
||||
|
||||
docker run --runtime=habana --cap-add=sys_nice --ipc=host \
|
||||
-p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
|
||||
ghcr.io/huggingface/text-generation-inference:3.2.1-gaudi \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
Once you see the `connected` log, the server is ready to accept requests:
|
||||
> 2024-05-22T19:31:48.302239Z INFO text_generation_router: router/src/main.rs:378: Connected
|
||||
|
||||
You can find your `YOUR_HF_ACCESS_TOKEN` at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens). This is necessary to access gated models like llama3.1.
|
||||
|
||||
### Making Your First Request
|
||||
You can send a request from a separate terminal:
|
||||
|
||||
```bash
|
||||
curl 127.0.0.1:8080/generate \
|
||||
-X POST \
|
||||
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":32}}' \
|
||||
-H 'Content-Type: application/json'
|
||||
```
|
||||
|
||||
## How-to Guides
|
||||
|
||||
### How to Run Specific Models
|
||||
|
||||
The following models have been validated on Gaudi2:
|
||||
|
||||
| Model | Model ID | BF16 | | FP8 | |
|
||||
|-----------------------|----------------------------------------|-------------|------------|-------------|------------|
|
||||
| | | Single Card | Multi-Card | Single Card | Multi-Card |
|
||||
| Llama2-7B | meta-llama/Llama-2-7b-chat-hf | ✔ | ✔ | ✔ | ✔ |
|
||||
| Llama2-70B | meta-llama/Llama-2-70b-chat-hf | | ✔ | | ✔ |
|
||||
| Llama3-8B | meta-llama/Meta-Llama-3.1-8B-Instruct | ✔ | ✔ | ✔ | ✔ |
|
||||
| Llama3-70B | meta-llama/Meta-Llama-3-70B-Instruct | | ✔ | | ✔ |
|
||||
| Llama3.1-8B | meta-llama/Meta-Llama-3.1-8B-Instruct | ✔ | ✔ | ✔ | ✔ |
|
||||
| Llama3.1-70B | meta-llama/Meta-Llama-3.1-70B-Instruct | | ✔ | | ✔ |
|
||||
| CodeLlama-13B | codellama/CodeLlama-13b-hf | ✔ | ✔ | ✔ | ✔ |
|
||||
| Mixtral-8x7B | mistralai/Mixtral-8x7B-Instruct-v0.1 | ✔ | ✔ | ✔ | ✔ |
|
||||
| Mistral-7B | mistralai/Mistral-7B-Instruct-v0.3 | ✔ | ✔ | ✔ | ✔ |
|
||||
| Falcon-180B | tiiuae/falcon-180B-chat | | ✔ | | ✔ |
|
||||
| Qwen2-72B | Qwen/Qwen2-72B-Instruct | | ✔ | | ✔ |
|
||||
| Starcoder2-3b | bigcode/starcoder2-3b | ✔ | ✔ | ✔ | |
|
||||
| Starcoder2-15b | bigcode/starcoder2-15b | ✔ | ✔ | ✔ | |
|
||||
| Starcoder | bigcode/starcoder | ✔ | ✔ | ✔ | ✔ |
|
||||
| Gemma-7b | google/gemma-7b-it | ✔ | ✔ | ✔ | ✔ |
|
||||
| Llava-v1.6-Mistral-7B | llava-hf/llava-v1.6-mistral-7b-hf | ✔ | ✔ | ✔ | ✔ |
|
||||
|
||||
To run any of these models:
|
||||
|
||||
```bash
|
||||
model=MODEL_ID_THAT_YOU_WANT_TO_RUN
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
hf_token=YOUR_ACCESS_TOKEN
|
||||
|
||||
docker run --runtime=habana --cap-add=sys_nice --ipc=host \
|
||||
-p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
|
||||
ghcr.io/huggingface/text-generation-inference:3.2.1-gaudi \
|
||||
--model-id $model
|
||||
<text-generation-inference-launcher-arguments>
|
||||
```
|
||||
|
||||
For the full list of service parameters, refer to the [launcher-arguments page](https://huggingface.co/docs/text-generation-inference/reference/launcher).
|
||||
|
||||
The validated docker commands can be found in the [examples/docker_commands folder](https://github.com/huggingface/text-generation-inference/tree/main/backends/gaudi/examples/docker_commands).
|
||||
|
||||
> Note: `--runtime=habana --cap-add=sys_nice --ipc=host ` is required to enable docker to use the Gaudi hardware (more details [here](https://docs.habana.ai/en/latest/Installation_Guide/Additional_Installation/Docker_Installation.html)).
|
||||
|
||||
### How to Enable Multi-Card Inference (Sharding)
|
||||
|
||||
TGI-Gaudi supports sharding for multi-card inference, allowing you to distribute the load across multiple Gaudi cards.
|
||||
|
||||
For example, on a machine with 8 Gaudi cards, you can run:
|
||||
|
||||
```bash
|
||||
docker run --runtime=habana --ipc=host --cap-add=sys_nice \
|
||||
-p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
|
||||
tgi-gaudi \
|
||||
--model-id $model --sharded true --num-shard 8
|
||||
```
|
||||
|
||||
<Tip>
|
||||
We recommend always using sharding when running on a multi-card machine.
|
||||
</Tip>
|
||||
|
||||
### How to Use Different Precision Formats
|
||||
|
||||
#### BF16 Precision (Default)
|
||||
By default, all models run with BF16 precision on Gaudi hardware.
|
||||
|
||||
#### FP8 Precision
|
||||
TGI-Gaudi supports FP8 precision inference with [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html).
|
||||
|
||||
To run FP8 Inference:
|
||||
|
||||
1. Measure statistics using [Optimum Habana measurement script](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation#running-with-fp8)
|
||||
2. Run the model in TGI with QUANT_CONFIG setting - e.g. `-e QUANT_CONFIG=./quantization_config/maxabs_quant.json`.
|
||||
|
||||
The following commmand example for FP8 inference is based on the assumption that measurement is done via the first step above.
|
||||
|
||||
Example for Llama3.1-70B on 8 cards with FP8 precision:
|
||||
|
||||
```bash
|
||||
model=meta-llama/Meta-Llama-3.1-70B-Instruct
|
||||
hf_token=YOUR_ACCESS_TOKEN
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run -p 8080:80 \
|
||||
--runtime=habana \
|
||||
--cap-add=sys_nice \
|
||||
--ipc=host \
|
||||
-v $volume:/data \
|
||||
-v $PWD/quantization_config:/usr/src/quantization_config \
|
||||
-v $PWD/hqt_output:/usr/src/hqt_output \
|
||||
-e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
|
||||
-e HF_TOKEN=$hf_token \
|
||||
-e MAX_TOTAL_TOKENS=2048 \
|
||||
-e BATCH_BUCKET_SIZE=256 \
|
||||
-e PREFILL_BATCH_BUCKET_SIZE=4 \
|
||||
-e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.2.1-gaudi \
|
||||
--model-id $model \
|
||||
--sharded true --num-shard 8 \
|
||||
--max-input-tokens 1024 --max-total-tokens 2048 \
|
||||
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
|
||||
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
|
||||
```
|
||||
|
||||
### How to Run Vision-Language Models (VLMs)
|
||||
|
||||
Gaudi supports VLM inference.
|
||||
|
||||
Example for Llava-v1.6-Mistral-7B on 1 card:
|
||||
|
||||
Start the TGI server via the following command:
|
||||
```bash
|
||||
model=llava-hf/llava-v1.6-mistral-7b-hf
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run -p 8080:80 \
|
||||
--runtime=habana \
|
||||
--cap-add=sys_nice \
|
||||
--ipc=host \
|
||||
-v $volume:/data \
|
||||
-e PREFILL_BATCH_BUCKET_SIZE=1 \
|
||||
-e BATCH_BUCKET_SIZE=1 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.2.1-gaudi \
|
||||
--model-id $model \
|
||||
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
|
||||
--max-total-tokens 8192 --max-batch-size 4
|
||||
```
|
||||
|
||||
You can then send a request to the server via the following command:
|
||||
```bash
|
||||
curl -N 127.0.0.1:8080/generate \
|
||||
-X POST \
|
||||
-d '{"inputs":"What is this a picture of?\n\n","parameters":{"max_new_tokens":32}}' \
|
||||
-H 'Content-Type: application/json'
|
||||
```
|
||||
|
||||
> Note: In Llava-v1.6-Mistral-7B, an image usually accounts for 2000 input tokens. For example, an image of size 512x512 is represented by 2800 tokens. Thus, `max-input-tokens` must be larger than the number of tokens associated with the image. Otherwise the image may be truncated. We set `BASE_IMAGE_TOKENS=2048` as the default image token value. This is the minimum value of `max-input-tokens`. You can override the environment variable `BASE_IMAGE_TOKENS` to change this value. The warmup will generate graphs with input length from `BASE_IMAGE_TOKENS` to `max-input-tokens`. For Llava-v1.6-Mistral-7B, the value of `max-batch-prefill-tokens` is 16384, which is calcualted as follows: `prefill_batch_size` = `max-batch-prefill-tokens` / `max-input-tokens`.
|
||||
|
||||
### How to Benchmark Performance
|
||||
|
||||
We recommend using the [inference-benchmarker tool](https://github.com/huggingface/inference-benchmarker) to benchmark performance on Gaudi hardware.
|
||||
|
||||
This benchmark tool simulates user requests and measures the performance of the model on realistic scenarios.
|
||||
|
||||
To run it on the same machine, you can do the following:
|
||||
```bash
|
||||
MODEL=meta-llama/Llama-3.1-8B-Instruct
|
||||
HF_TOKEN=<your HF READ token>
|
||||
# run a benchmark to evaluate the performance of the model for chat use case
|
||||
# we mount results to the current directory
|
||||
docker run \
|
||||
--rm \
|
||||
-it \
|
||||
--net host \
|
||||
-v $(pwd):/opt/inference-benchmarker/results \
|
||||
-e "HF_TOKEN=$HF_TOKEN" \
|
||||
ghcr.io/huggingface/inference-benchmarker:latest \
|
||||
inference-benchmarker \
|
||||
--tokenizer-name "$MODEL" \
|
||||
--url http://localhost:8080 \
|
||||
--profile chat
|
||||
```
|
||||
|
||||
Please refer to the [inference-benchmarker README](https://github.com/huggingface/inference-benchmarker) for more details.
|
||||
|
||||
### How to Profile Performance
|
||||
|
||||
To collect performance profiling, you need to set the following environment variables:
|
||||
|
||||
| Name | Value(s) | Default | Description |
|
||||
|--------------------| :--------- | :--------------- | :------------------------------------------------------- |
|
||||
| PROF_WAITSTEP | integer | 0 | Control profile wait steps |
|
||||
| PROF_WARMUPSTEP | integer | 0 | Control profile warmup steps |
|
||||
| PROF_STEP | integer | 0 | Enable/disable profile, control profile active steps |
|
||||
| PROF_PATH | string | /tmp/hpu_profile | Define profile folder |
|
||||
| PROF_RANKS | string | 0 | Comma-separated list of ranks to profile |
|
||||
| PROF_RECORD_SHAPES | True/False | False | Control record_shapes option in the profiler |
|
||||
|
||||
To use these environment variables, add them to your docker run command with the -e flag. For example:
|
||||
|
||||
```bash
|
||||
docker run --runtime=habana --ipc=host --cap-add=sys_nice \
|
||||
-p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
|
||||
-e PROF_WAITSTEP=10 \
|
||||
-e PROF_WARMUPSTEP=10 \
|
||||
-e PROF_STEP=1 \
|
||||
-e PROF_PATH=/tmp/hpu_profile \
|
||||
-e PROF_RANKS=0 \
|
||||
-e PROF_RECORD_SHAPES=True \
|
||||
ghcr.io/huggingface/text-generation-inference:3.2.1-gaudi \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
## Explanation: Understanding TGI on Gaudi
|
||||
|
||||
### The Warmup Process
|
||||
|
||||
To ensure optimal performance, warmup is performed at the beginning of each server run. This process creates queries with various input shapes based on provided parameters and runs basic TGI operations (prefill, decode, concatenate).
|
||||
|
||||
Note: Model warmup can take several minutes, especially for FP8 inference. For faster subsequent runs, refer to [Disk Caching Eviction Policy](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#disk-caching-eviction-policy).
|
||||
|
||||
### Understanding Parameter Tuning
|
||||
|
||||
#### Sequence Length Parameters
|
||||
- `--max-input-tokens` is the maximum possible input prompt length. Default value is `4095`.
|
||||
- `--max-total-tokens` is the maximum possible total length of the sequence (input and output). Default value is `4096`.
|
||||
|
||||
#### Batch Size Parameters
|
||||
- For prefill operation, please set `--max-batch-prefill-tokens` as `bs * max-input-tokens`, where `bs` is your expected maximum prefill batch size.
|
||||
- For decode operation, please set `--max-batch-size` as `bs`, where `bs` is your expected maximum decode batch size.
|
||||
- Please note that batch size will be always padded to the nearest multiplication of `BATCH_BUCKET_SIZE` and `PREFILL_BATCH_BUCKET_SIZE`.
|
||||
|
||||
#### Performance and Memory Parameters
|
||||
- `PAD_SEQUENCE_TO_MULTIPLE_OF` determines sizes of input length buckets. Since warmup creates several graphs for each bucket, it's important to adjust that value proportionally to input sequence length. Otherwise, some out of memory issues can be observed.
|
||||
- `ENABLE_HPU_GRAPH` enables HPU graphs usage, which is crucial for performance results. Recommended value to keep is `true`.
|
||||
|
||||
#### Sequence Length Parameters
|
||||
- `--max-input-tokens`: Maximum possible input prompt length (default: 4095)
|
||||
- `--max-total-tokens`: Maximum possible total sequence length (input + output) (default: 4096)
|
||||
|
||||
#### Batch Size Parameters
|
||||
- `--max-batch-prefill-tokens`: Set as `bs * max-input-tokens` where `bs` is your expected maximum prefill batch size
|
||||
- `--max-batch-size`: Set as `bs` where `bs` is your expected maximum decode batch size
|
||||
- Note: Batch sizes are padded to the nearest multiple of `BATCH_BUCKET_SIZE` and `PREFILL_BATCH_BUCKET_SIZE`
|
||||
|
||||
## Reference
|
||||
|
||||
This section contains reference information about the Gaudi backend.
|
||||
|
||||
### Environment Variables
|
||||
|
||||
The following table contains the environment variables that can be used to configure the Gaudi backend:
|
||||
|
||||
| Name | Value(s) | Default | Description | Usage |
|
||||
|-----------------------------| :--------- | :--------------- | :------------------------------------------------------------------------------------------------------------------------------- | :--------------------------- |
|
||||
| ENABLE_HPU_GRAPH | True/False | True | Enable hpu graph or not | add -e in docker run command |
|
||||
| LIMIT_HPU_GRAPH | True/False | True | Skip HPU graph usage for prefill to save memory, set to `True` for large sequence/decoding lengths(e.g. 300/212) | add -e in docker run command |
|
||||
| BATCH_BUCKET_SIZE | integer | 8 | Batch size for decode operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
|
||||
| PREFILL_BATCH_BUCKET_SIZE | integer | 4 | Batch size for prefill operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
|
||||
| PAD_SEQUENCE_TO_MULTIPLE_OF | integer | 128 | For prefill operation, sequences will be padded to a multiple of provided value. | add -e in docker run command |
|
||||
| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command |
|
||||
| WARMUP_ENABLED | True/False | True | Enable warmup during server initialization to recompile all graphs. This can increase TGI setup time. | add -e in docker run command |
|
||||
| QUEUE_THRESHOLD_MS | integer | 120 | Controls the threshold beyond which the request are considered overdue and handled with priority. Shorter requests are prioritized otherwise. | add -e in docker run command |
|
||||
| USE_FLASH_ATTENTION | True/False | True | Whether to enable Habana Flash Attention, provided that the model supports it. Please refer to https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html?highlight=fusedsdpa#using-fused-scaled-dot-product-attention-fusedsdpa | add -e in docker run command |
|
||||
| FLASH_ATTENTION_RECOMPUTE | True/False | True | Whether to enable Habana Flash Attention in recompute mode on first token generation. | add -e in docker run command |
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions to the TGI-Gaudi project are welcome. Please refer to the [contributing guide](https://github.com/huggingface/text-generation-inference/blob/main/CONTRIBUTING.md).
|
||||
|
||||
**Guidelines for contributing to Gaudi on TGI:** All changes should be made within the `backends/gaudi` folder. In general, you should avoid modifying the router, launcher, or benchmark to accommodate Gaudi hardware, as all Gaudi-specific logic should be contained within the `backends/gaudi` folder.
|
||||
|
||||
### Building the Docker Image from Source
|
||||
|
||||
To build the Docker image from source:
|
||||
|
||||
```bash
|
||||
make -C backends/gaudi image
|
||||
```
|
||||
|
||||
This builds the image and saves it as `tgi-gaudi`. You can then run TGI-Gaudi with this image:
|
||||
|
||||
```bash
|
||||
model=meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
volume=$PWD/data
|
||||
hf_token=YOUR_ACCESS_TOKEN
|
||||
|
||||
docker run --runtime=habana --ipc=host --cap-add=sys_nice \
|
||||
-p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
|
||||
tgi-gaudi \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
For more details, see the [README of the Gaudi backend](https://github.com/huggingface/text-generation-inference/blob/main/backends/gaudi/README.md) and the [Makefile of the Gaudi backend](https://github.com/huggingface/text-generation-inference/blob/main/backends/gaudi/Makefile).
|
@ -31,7 +31,7 @@ deployment instructions in the model card:
|
||||
The service is launched simply by running the text-generation-inference container with two sets of parameters:
|
||||
|
||||
```
|
||||
docker run <system_parameters> ghcr.io/huggingface/text-generation-inference:3.1.1-neuron <service_parameters>
|
||||
docker run <system_parameters> ghcr.io/huggingface/text-generation-inference:3.2.1-neuron <service_parameters>
|
||||
```
|
||||
|
||||
- system parameters are used to map ports, volumes and devices between the host and the service,
|
||||
|
@ -19,6 +19,6 @@ docker run --gpus all \
|
||||
--shm-size 1g \
|
||||
-e HF_TOKEN=$token \
|
||||
-p 8080:80 \
|
||||
-v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.1 \
|
||||
-v $volume:/data ghcr.io/huggingface/text-generation-inference:3.2.1 \
|
||||
--model-id $model
|
||||
```
|
||||
|
@ -19,7 +19,7 @@ bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models.
|
||||
In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇
|
||||
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.1 --model-id $model --quantize bitsandbytes
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.2.1 --model-id $model --quantize bitsandbytes
|
||||
```
|
||||
|
||||
4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.
|
||||
@ -27,7 +27,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf
|
||||
In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇
|
||||
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.1 --model-id $model --quantize bitsandbytes-nf4
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.2.1 --model-id $model --quantize bitsandbytes-nf4
|
||||
```
|
||||
|
||||
You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
|
||||
@ -48,7 +48,7 @@ $$({\hat{W}_{l}}^{*} = argmin_{\hat{W_{l}}} ||W_{l}X-\hat{W}_{l}X||^{2}_{2})$$
|
||||
TGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇
|
||||
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.1.1 --model-id $model --quantize gptq
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.2.1 --model-id $model --quantize gptq
|
||||
```
|
||||
|
||||
Note that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI.
|
||||
|
@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
||||
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
|
||||
--device=/dev/kfd --device=/dev/dri --group-add video \
|
||||
--ipc=host --shm-size 256g --net host -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:3.1.1-rocm \
|
||||
ghcr.io/huggingface/text-generation-inference:3.2.1-rocm \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
|
@ -1,3 +1,3 @@
|
||||
# Using TGI with Intel Gaudi
|
||||
|
||||
Check out this [repository](https://github.com/huggingface/tgi-gaudi) to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index).
|
||||
You can use TGI on Intel Gaudi using the [TGI gaudi backend](https://huggingface.co/docs/text-generation-inference/backends/gaudi).
|
||||
|
@ -12,7 +12,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
||||
docker run --rm --privileged --cap-add=sys_nice \
|
||||
--device=/dev/dri \
|
||||
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:3.1.1-intel-xpu \
|
||||
ghcr.io/huggingface/text-generation-inference:3.2.1-intel-xpu \
|
||||
--model-id $model --cuda-graphs 0
|
||||
```
|
||||
|
||||
@ -29,7 +29,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
||||
docker run --rm --privileged --cap-add=sys_nice \
|
||||
--device=/dev/dri \
|
||||
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:3.1.1-intel-cpu \
|
||||
ghcr.io/huggingface/text-generation-inference:3.2.1-intel-cpu \
|
||||
--model-id $model --cuda-graphs 0
|
||||
```
|
||||
|
||||
|
@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:3.1.1 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.2.1 \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
|
@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:3.1.1 \
|
||||
ghcr.io/huggingface/text-generation-inference:3.2.1 \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
@ -96,7 +96,7 @@ curl 127.0.0.1:8080/generate \
|
||||
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
|
||||
|
||||
```bash
|
||||
docker run ghcr.io/huggingface/text-generation-inference:3.1.1 --help
|
||||
docker run ghcr.io/huggingface/text-generation-inference:3.2.1 --help
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
@ -163,7 +163,7 @@ hub = {
|
||||
|
||||
# create Hugging Face Model Class
|
||||
huggingface_model = HuggingFaceModel(
|
||||
image_uri=get_huggingface_llm_image_uri("huggingface",version="3.1.1"),
|
||||
image_uri=get_huggingface_llm_image_uri("huggingface",version="3.2.1"),
|
||||
env=hub,
|
||||
role=role,
|
||||
)
|
||||
|
@ -14,6 +14,8 @@ Text Generation Inference enables serving optimized models. The following sectio
|
||||
- [Gemma](https://huggingface.co/google/gemma-7b)
|
||||
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
|
||||
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
|
||||
- [Gemma3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d)
|
||||
- [Gemma3 Text](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d)
|
||||
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
|
||||
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct)
|
||||
- [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj)
|
||||
|
@ -978,16 +978,16 @@
|
||||
"nixpkgs": "nixpkgs_6"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1740049068,
|
||||
"narHash": "sha256-heYzYOt+TSnRKHIV24s74yEjLkTbBfjNCWHdQEX++eI=",
|
||||
"lastModified": 1741617161,
|
||||
"narHash": "sha256-cwKYAsIVSLtoLbG48+oi3NkSrvuZRLYs8lkJmpDsTw0=",
|
||||
"owner": "huggingface",
|
||||
"repo": "text-generation-inference-nix",
|
||||
"rev": "143e8451efa22b120f97e6698508e9a0aed82769",
|
||||
"rev": "5946021ec6cb6aae18158a9dc27f893cfbab2925",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "huggingface",
|
||||
"ref": "hub-rotary",
|
||||
"ref": "kernels-0.2.0",
|
||||
"repo": "text-generation-inference-nix",
|
||||
"type": "github"
|
||||
}
|
||||
|
10
flake.nix
10
flake.nix
@ -5,7 +5,7 @@
|
||||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||
};
|
||||
nix-filter.url = "github:numtide/nix-filter";
|
||||
tgi-nix.url = "github:huggingface/text-generation-inference-nix/hub-rotary";
|
||||
tgi-nix.url = "github:huggingface/text-generation-inference-nix/kernels-0.2.0";
|
||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
rust-overlay = {
|
||||
@ -176,11 +176,15 @@
|
||||
'';
|
||||
};
|
||||
|
||||
dockerImage = pkgs.callPackage nix/docker.nix {
|
||||
# Use plain nixpkgs without overlays for dockerTools. dockerTools
|
||||
# uses a Python package for computing the layers from the transitive
|
||||
# closure. However, this needs a lot of rebuilds due to our overlay.
|
||||
|
||||
dockerImage = nixpkgs.legacyPackages.${system}.callPackage nix/docker.nix {
|
||||
text-generation-inference = default;
|
||||
};
|
||||
|
||||
dockerImageStreamed = pkgs.callPackage nix/docker.nix {
|
||||
dockerImageStreamed = nixpkgs.legacyPackages.${system}.callPackage nix/docker.nix {
|
||||
text-generation-inference = default;
|
||||
stream = true;
|
||||
};
|
||||
|
@ -0,0 +1,109 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 16,
|
||||
"prefill": [],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 506,
|
||||
"logprob": -1.3984375,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 1331,
|
||||
"logprob": -1.6953125,
|
||||
"special": false,
|
||||
"text": " people"
|
||||
},
|
||||
{
|
||||
"id": 236764,
|
||||
"logprob": -0.23535156,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 532,
|
||||
"logprob": -0.24316406,
|
||||
"special": false,
|
||||
"text": " and"
|
||||
},
|
||||
{
|
||||
"id": 506,
|
||||
"logprob": -0.12109375,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 2780,
|
||||
"logprob": -1.1640625,
|
||||
"special": false,
|
||||
"text": " food"
|
||||
},
|
||||
{
|
||||
"id": 236761,
|
||||
"logprob": -0.21386719,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 108,
|
||||
"logprob": -0.64453125,
|
||||
"special": false,
|
||||
"text": "\n\n"
|
||||
},
|
||||
{
|
||||
"id": 2094,
|
||||
"logprob": -0.77734375,
|
||||
"special": false,
|
||||
"text": "This"
|
||||
},
|
||||
{
|
||||
"id": 563,
|
||||
"logprob": -0.040283203,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 496,
|
||||
"logprob": -0.03125,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 6290,
|
||||
"logprob": -0.03515625,
|
||||
"special": false,
|
||||
"text": " nice"
|
||||
},
|
||||
{
|
||||
"id": 1977,
|
||||
"logprob": -0.0020751953,
|
||||
"special": false,
|
||||
"text": " place"
|
||||
},
|
||||
{
|
||||
"id": 236761,
|
||||
"logprob": -0.0079956055,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 107,
|
||||
"logprob": -0.9921875,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 106,
|
||||
"logprob": -0.45507812,
|
||||
"special": true,
|
||||
"text": "<end_of_turn>"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " the people, and the food.\n\nThis is a nice place.\n"
|
||||
}
|
@ -0,0 +1,613 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 100,
|
||||
"prefill": [],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 1331,
|
||||
"logprob": -0.34960938,
|
||||
"special": false,
|
||||
"text": " people"
|
||||
},
|
||||
{
|
||||
"id": 8390,
|
||||
"logprob": -0.14746094,
|
||||
"special": false,
|
||||
"text": " died"
|
||||
},
|
||||
{
|
||||
"id": 528,
|
||||
"logprob": -1.2265625,
|
||||
"special": false,
|
||||
"text": " in"
|
||||
},
|
||||
{
|
||||
"id": 506,
|
||||
"logprob": -0.47070312,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 3640,
|
||||
"logprob": -0.5859375,
|
||||
"special": false,
|
||||
"text": " United"
|
||||
},
|
||||
{
|
||||
"id": 4184,
|
||||
"logprob": -0.0027770996,
|
||||
"special": false,
|
||||
"text": " States"
|
||||
},
|
||||
{
|
||||
"id": 236761,
|
||||
"logprob": -0.34765625,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 108,
|
||||
"logprob": -0.0859375,
|
||||
"special": false,
|
||||
"text": "\n\n"
|
||||
},
|
||||
{
|
||||
"id": 818,
|
||||
"logprob": -1.1640625,
|
||||
"special": false,
|
||||
"text": "The"
|
||||
},
|
||||
{
|
||||
"id": 6816,
|
||||
"logprob": -1.890625,
|
||||
"special": false,
|
||||
"text": " generally"
|
||||
},
|
||||
{
|
||||
"id": 10951,
|
||||
"logprob": -0.14648438,
|
||||
"special": false,
|
||||
"text": " accepted"
|
||||
},
|
||||
{
|
||||
"id": 10967,
|
||||
"logprob": -0.90625,
|
||||
"special": false,
|
||||
"text": " estimate"
|
||||
},
|
||||
{
|
||||
"id": 563,
|
||||
"logprob": -0.49414062,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 600,
|
||||
"logprob": -0.65234375,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 236743,
|
||||
"logprob": -1.2109375,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 236825,
|
||||
"logprob": -0.00088119507,
|
||||
"special": false,
|
||||
"text": "6"
|
||||
},
|
||||
{
|
||||
"id": 236832,
|
||||
"logprob": -6.580353e-05,
|
||||
"special": false,
|
||||
"text": "7"
|
||||
},
|
||||
{
|
||||
"id": 236810,
|
||||
"logprob": -5.2690506e-05,
|
||||
"special": false,
|
||||
"text": "5"
|
||||
},
|
||||
{
|
||||
"id": 236764,
|
||||
"logprob": -0.0001745224,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 236771,
|
||||
"logprob": -1.180172e-05,
|
||||
"special": false,
|
||||
"text": "0"
|
||||
},
|
||||
{
|
||||
"id": 236771,
|
||||
"logprob": -1.7881393e-06,
|
||||
"special": false,
|
||||
"text": "0"
|
||||
},
|
||||
{
|
||||
"id": 236771,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "0"
|
||||
},
|
||||
{
|
||||
"id": 1331,
|
||||
"logprob": -0.44921875,
|
||||
"special": false,
|
||||
"text": " people"
|
||||
},
|
||||
{
|
||||
"id": 8390,
|
||||
"logprob": -0.011474609,
|
||||
"special": false,
|
||||
"text": " died"
|
||||
},
|
||||
{
|
||||
"id": 528,
|
||||
"logprob": -0.084472656,
|
||||
"special": false,
|
||||
"text": " in"
|
||||
},
|
||||
{
|
||||
"id": 506,
|
||||
"logprob": -0.00034713745,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 3640,
|
||||
"logprob": -0.028564453,
|
||||
"special": false,
|
||||
"text": " United"
|
||||
},
|
||||
{
|
||||
"id": 4184,
|
||||
"logprob": -0.00012207031,
|
||||
"special": false,
|
||||
"text": " States"
|
||||
},
|
||||
{
|
||||
"id": 236761,
|
||||
"logprob": -1.15625,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 3153,
|
||||
"logprob": -0.103027344,
|
||||
"special": false,
|
||||
"text": " However"
|
||||
},
|
||||
{
|
||||
"id": 236764,
|
||||
"logprob": -0.009155273,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 1070,
|
||||
"logprob": -0.92578125,
|
||||
"special": false,
|
||||
"text": " some"
|
||||
},
|
||||
{
|
||||
"id": 61806,
|
||||
"logprob": -0.91796875,
|
||||
"special": false,
|
||||
"text": " historians"
|
||||
},
|
||||
{
|
||||
"id": 4646,
|
||||
"logprob": -1.3828125,
|
||||
"special": false,
|
||||
"text": " believe"
|
||||
},
|
||||
{
|
||||
"id": 506,
|
||||
"logprob": -0.65234375,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 5396,
|
||||
"logprob": -0.8046875,
|
||||
"special": false,
|
||||
"text": " actual"
|
||||
},
|
||||
{
|
||||
"id": 1548,
|
||||
"logprob": -0.04321289,
|
||||
"special": false,
|
||||
"text": " number"
|
||||
},
|
||||
{
|
||||
"id": 1451,
|
||||
"logprob": -0.66015625,
|
||||
"special": false,
|
||||
"text": " could"
|
||||
},
|
||||
{
|
||||
"id": 577,
|
||||
"logprob": -0.091308594,
|
||||
"special": false,
|
||||
"text": " be"
|
||||
},
|
||||
{
|
||||
"id": 618,
|
||||
"logprob": -0.57421875,
|
||||
"special": false,
|
||||
"text": " as"
|
||||
},
|
||||
{
|
||||
"id": 1494,
|
||||
"logprob": -0.00036239624,
|
||||
"special": false,
|
||||
"text": " high"
|
||||
},
|
||||
{
|
||||
"id": 618,
|
||||
"logprob": -0.0001335144,
|
||||
"special": false,
|
||||
"text": " as"
|
||||
},
|
||||
{
|
||||
"id": 236743,
|
||||
"logprob": -0.0009689331,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 236770,
|
||||
"logprob": -0.26367188,
|
||||
"special": false,
|
||||
"text": "1"
|
||||
},
|
||||
{
|
||||
"id": 236771,
|
||||
"logprob": -0.17773438,
|
||||
"special": false,
|
||||
"text": "0"
|
||||
},
|
||||
{
|
||||
"id": 3625,
|
||||
"logprob": -0.012084961,
|
||||
"special": false,
|
||||
"text": " million"
|
||||
},
|
||||
{
|
||||
"id": 236761,
|
||||
"logprob": -0.21289062,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 108,
|
||||
"logprob": -0.37304688,
|
||||
"special": false,
|
||||
"text": "\n\n"
|
||||
},
|
||||
{
|
||||
"id": 236777,
|
||||
"logprob": -1.078125,
|
||||
"special": false,
|
||||
"text": "I"
|
||||
},
|
||||
{
|
||||
"id": 1006,
|
||||
"logprob": -1.3203125,
|
||||
"special": false,
|
||||
"text": " am"
|
||||
},
|
||||
{
|
||||
"id": 3182,
|
||||
"logprob": -1.078125,
|
||||
"special": false,
|
||||
"text": " looking"
|
||||
},
|
||||
{
|
||||
"id": 573,
|
||||
"logprob": -0.035888672,
|
||||
"special": false,
|
||||
"text": " for"
|
||||
},
|
||||
{
|
||||
"id": 919,
|
||||
"logprob": -1.25,
|
||||
"special": false,
|
||||
"text": " more"
|
||||
},
|
||||
{
|
||||
"id": 1938,
|
||||
"logprob": -1.2421875,
|
||||
"special": false,
|
||||
"text": " information"
|
||||
},
|
||||
{
|
||||
"id": 580,
|
||||
"logprob": -0.7734375,
|
||||
"special": false,
|
||||
"text": " on"
|
||||
},
|
||||
{
|
||||
"id": 672,
|
||||
"logprob": -0.73046875,
|
||||
"special": false,
|
||||
"text": " this"
|
||||
},
|
||||
{
|
||||
"id": 59725,
|
||||
"logprob": -0.75,
|
||||
"special": false,
|
||||
"text": " discrepancy"
|
||||
},
|
||||
{
|
||||
"id": 532,
|
||||
"logprob": -0.83984375,
|
||||
"special": false,
|
||||
"text": " and"
|
||||
},
|
||||
{
|
||||
"id": 506,
|
||||
"logprob": -0.7109375,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 5872,
|
||||
"logprob": -1.2734375,
|
||||
"special": false,
|
||||
"text": " factors"
|
||||
},
|
||||
{
|
||||
"id": 600,
|
||||
"logprob": -0.22851562,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 19263,
|
||||
"logprob": -1.1640625,
|
||||
"special": false,
|
||||
"text": " contributed"
|
||||
},
|
||||
{
|
||||
"id": 531,
|
||||
"logprob": -0.0010757446,
|
||||
"special": false,
|
||||
"text": " to"
|
||||
},
|
||||
{
|
||||
"id": 506,
|
||||
"logprob": -0.18945312,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 5777,
|
||||
"logprob": -1.2734375,
|
||||
"special": false,
|
||||
"text": " wide"
|
||||
},
|
||||
{
|
||||
"id": 2644,
|
||||
"logprob": -0.01940918,
|
||||
"special": false,
|
||||
"text": " range"
|
||||
},
|
||||
{
|
||||
"id": 529,
|
||||
"logprob": -0.14550781,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 14287,
|
||||
"logprob": -0.032470703,
|
||||
"special": false,
|
||||
"text": " estimates"
|
||||
},
|
||||
{
|
||||
"id": 236761,
|
||||
"logprob": -0.010375977,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 108,
|
||||
"logprob": -0.06591797,
|
||||
"special": false,
|
||||
"text": "\n\n"
|
||||
},
|
||||
{
|
||||
"id": 8291,
|
||||
"logprob": -0.8046875,
|
||||
"special": false,
|
||||
"text": "Here"
|
||||
},
|
||||
{
|
||||
"id": 236789,
|
||||
"logprob": -0.23828125,
|
||||
"special": false,
|
||||
"text": "'"
|
||||
},
|
||||
{
|
||||
"id": 236751,
|
||||
"logprob": -1.0728836e-06,
|
||||
"special": false,
|
||||
"text": "s"
|
||||
},
|
||||
{
|
||||
"id": 496,
|
||||
"logprob": -0.17480469,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 25890,
|
||||
"logprob": -0.087402344,
|
||||
"special": false,
|
||||
"text": " breakdown"
|
||||
},
|
||||
{
|
||||
"id": 529,
|
||||
"logprob": -0.0021209717,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 506,
|
||||
"logprob": -0.19140625,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 5872,
|
||||
"logprob": -1.0078125,
|
||||
"special": false,
|
||||
"text": " factors"
|
||||
},
|
||||
{
|
||||
"id": 20894,
|
||||
"logprob": -0.26367188,
|
||||
"special": false,
|
||||
"text": " contributing"
|
||||
},
|
||||
{
|
||||
"id": 531,
|
||||
"logprob": -9.250641e-05,
|
||||
"special": false,
|
||||
"text": " to"
|
||||
},
|
||||
{
|
||||
"id": 506,
|
||||
"logprob": -0.008666992,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 5777,
|
||||
"logprob": -0.6171875,
|
||||
"special": false,
|
||||
"text": " wide"
|
||||
},
|
||||
{
|
||||
"id": 2644,
|
||||
"logprob": -0.0023956299,
|
||||
"special": false,
|
||||
"text": " range"
|
||||
},
|
||||
{
|
||||
"id": 529,
|
||||
"logprob": -0.016723633,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 14287,
|
||||
"logprob": -0.011352539,
|
||||
"special": false,
|
||||
"text": " estimates"
|
||||
},
|
||||
{
|
||||
"id": 573,
|
||||
"logprob": -0.30664062,
|
||||
"special": false,
|
||||
"text": " for"
|
||||
},
|
||||
{
|
||||
"id": 506,
|
||||
"logprob": -0.21386719,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 236743,
|
||||
"logprob": -0.35351562,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 236770,
|
||||
"logprob": -3.5762787e-07,
|
||||
"special": false,
|
||||
"text": "1"
|
||||
},
|
||||
{
|
||||
"id": 236819,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "9"
|
||||
},
|
||||
{
|
||||
"id": 236770,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "1"
|
||||
},
|
||||
{
|
||||
"id": 236828,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "8"
|
||||
},
|
||||
{
|
||||
"id": 7745,
|
||||
"logprob": -0.70703125,
|
||||
"special": false,
|
||||
"text": " flu"
|
||||
},
|
||||
{
|
||||
"id": 10248,
|
||||
"logprob": -0.015258789,
|
||||
"special": false,
|
||||
"text": " pandemic"
|
||||
},
|
||||
{
|
||||
"id": 4355,
|
||||
"logprob": -0.83203125,
|
||||
"special": false,
|
||||
"text": " death"
|
||||
},
|
||||
{
|
||||
"id": 25363,
|
||||
"logprob": -7.43866e-05,
|
||||
"special": false,
|
||||
"text": " toll"
|
||||
},
|
||||
{
|
||||
"id": 528,
|
||||
"logprob": -0.08496094,
|
||||
"special": false,
|
||||
"text": " in"
|
||||
},
|
||||
{
|
||||
"id": 506,
|
||||
"logprob": -6.67572e-06,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 3640,
|
||||
"logprob": -0.0059509277,
|
||||
"special": false,
|
||||
"text": " United"
|
||||
},
|
||||
{
|
||||
"id": 4184,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " States"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " people died in the United States.\n\nThe generally accepted estimate is that 675,000 people died in the United States. However, some historians believe the actual number could be as high as 10 million.\n\nI am looking for more information on this discrepancy and the factors that contributed to the wide range of estimates.\n\nHere's a breakdown of the factors contributing to the wide range of estimates for the 1918 flu pandemic death toll in the United States"
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "Okay, let's analyze the image.\n\nThe image is a solid, bright white color. There is nothing else visible within it. \n\nIt's essentially a blank white canvas or a completely white square. \n\nIs there anything specific you'd like me to do with this image, such as describe it further or imagine what it might represent?",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1741965894,
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.2.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 74,
|
||||
"prompt_tokens": 277,
|
||||
"total_tokens": 351
|
||||
}
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "Okay, let's analyze the image. \n\nThe image is entirely white, with a very subtle, faint outline of a stylized, cartoonish figure. It appears to be a simplified depiction of a person, likely a child, with a wide-eyed expression and a small, rounded body. \n\nIt's almost like a minimalist, iconic representation. \n\nDo you want me to try and describe it in more detail or perhaps speculate about the context of the image?",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1741965892,
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.2.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 98,
|
||||
"prompt_tokens": 277,
|
||||
"total_tokens": 375
|
||||
}
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "Okay, let's analyze the image. \n\nThe transparent image reveals a stylized depiction of **a human head**. It's a minimalist, geometric representation, showing the basic shapes of the skull, eye sockets, and head outline. \n\nDo you want me to describe any specific element of the image in more detail?",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1741966313,
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.2.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 67,
|
||||
"prompt_tokens": 277,
|
||||
"total_tokens": 344
|
||||
}
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "Here's a description of what's shown in the image:\n\nThe image depicts a brown cow standing on a sandy beach. The beach has turquoise water and a distant island visible in the background. The sky is bright blue with some white clouds. \n\nIt's a quite a humorous and unusual scene – a cow enjoying a day at the beach!",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1741964480,
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.2.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 74,
|
||||
"prompt_tokens": 275,
|
||||
"total_tokens": 349
|
||||
}
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "That's a fantastic question! However, the image doesn't show a dog. It shows a **Brown Swiss cow** standing on a beach. \n\nBrown Swiss cows are known for their reddish-brown color and distinctive white markings. \n\nIf you'd like, you can send me another image and I’ll do my best to identify it!",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1741964477,
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.2.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 75,
|
||||
"prompt_tokens": 279,
|
||||
"total_tokens": 354
|
||||
}
|
||||
}
|
@ -10,7 +10,7 @@
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "{\"format\":\"fahrenheit\",\"location\":\"Brooklyn, NY\"}",
|
||||
"arguments": "{\"location\":\"Brooklyn, NY\",\"format\":\"fahrenheit\"}",
|
||||
"description": null,
|
||||
"name": "get_current_weather"
|
||||
},
|
||||
@ -21,7 +21,7 @@
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1741263682,
|
||||
"created": 1741372434,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion",
|
||||
|
@ -10,7 +10,7 @@
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "{\"format\":\"fahrenheit\",\"location\":\"Brooklyn, NY\"}",
|
||||
"arguments": "{\"location\":\"Brooklyn, NY\",\"format\":\"fahrenheit\"}",
|
||||
"description": null,
|
||||
"name": "get_current_weather"
|
||||
},
|
||||
@ -21,7 +21,7 @@
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1741263684,
|
||||
"created": 1741372657,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion",
|
||||
|
@ -8,10 +8,10 @@
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "{\"",
|
||||
"name": null
|
||||
"arguments": "{",
|
||||
"name": "get_current_weather"
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -22,187 +22,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "function",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "\":",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": " {\"",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "_",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "name",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "\":",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"created": 1741688515,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -221,7 +41,7 @@
|
||||
"arguments": " \"",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -232,157 +52,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "get",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "_current",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "_weather",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "\",",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": " \"",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"created": 1741688515,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -401,7 +71,7 @@
|
||||
"arguments": "location",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -412,7 +82,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"created": 1741688515,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -431,7 +101,7 @@
|
||||
"arguments": "\":",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -442,7 +112,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"created": 1741688515,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -461,7 +131,7 @@
|
||||
"arguments": " \"",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -472,7 +142,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"created": 1741688515,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -488,10 +158,10 @@
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "Paris",
|
||||
"arguments": "Bro",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -502,7 +172,37 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"created": 1741688515,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "oklyn",
|
||||
"name": null
|
||||
},
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741688515,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -521,7 +221,7 @@
|
||||
"arguments": ",",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -532,7 +232,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"created": 1741688515,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -548,10 +248,10 @@
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": " France",
|
||||
"arguments": " NY",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -562,7 +262,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"created": 1741688515,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -581,7 +281,7 @@
|
||||
"arguments": "\",",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -592,7 +292,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"created": 1741688515,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -611,7 +311,7 @@
|
||||
"arguments": " \"",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -622,7 +322,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"created": 1741688515,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -641,7 +341,7 @@
|
||||
"arguments": "format",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -652,7 +352,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"created": 1741688515,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -671,7 +371,7 @@
|
||||
"arguments": "\":",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -682,7 +382,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"created": 1741688515,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -701,7 +401,7 @@
|
||||
"arguments": " \"",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -712,7 +412,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"created": 1741688515,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -728,10 +428,10 @@
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "c",
|
||||
"arguments": "f",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -742,7 +442,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"created": 1741688515,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -758,10 +458,10 @@
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "elsius",
|
||||
"arguments": "ahrenheit",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -772,7 +472,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"created": 1741688515,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -788,10 +488,10 @@
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "\"}}",
|
||||
"arguments": "\"}",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -802,37 +502,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "<|eot_id|>",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263685,
|
||||
"created": 1741688515,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
|
@ -5,20 +5,20 @@
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "I am a helpful assistant!",
|
||||
"content": "I'm an artificial intelligence model known as a large language model (LLM) or conversational AI",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1741263686,
|
||||
"created": 1741693957,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 23,
|
||||
"prompt_tokens": 494,
|
||||
"total_tokens": 517
|
||||
"completion_tokens": 12,
|
||||
"prompt_tokens": 53,
|
||||
"total_tokens": 65
|
||||
}
|
||||
}
|
||||
|
@ -12,7 +12,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263687,
|
||||
"created": 1741694017,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -23,7 +23,7 @@
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": " am",
|
||||
"content": "'m",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
@ -32,7 +32,127 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263687,
|
||||
"created": 1741694017,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": " an",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741694017,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": " artificial",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741694017,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": " intelligence",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741694017,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": " model",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741694017,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": " known",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741694017,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": " as",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741694017,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -52,7 +172,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263687,
|
||||
"created": 1741694017,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -63,7 +183,7 @@
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": " helpful",
|
||||
"content": " large",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
@ -72,7 +192,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263687,
|
||||
"created": 1741694017,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -83,7 +203,7 @@
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": " assistant",
|
||||
"content": " language",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
@ -92,7 +212,187 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263687,
|
||||
"created": 1741694017,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": " model",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741694017,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": " (",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741694017,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": "LL",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741694017,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": "M",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741694017,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": ")",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741694017,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": " or",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741694017,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": " convers",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741694017,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": "ational",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741694017,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": " AI",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741694017,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
|
@ -10,7 +10,7 @@
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "{\"format\":\"fahrenheit\",\"location\":\"Brooklyn, NY\"}",
|
||||
"arguments": "{\"location\":\"Brooklyn, NY\",\"format\":\"fahrenheit\"}",
|
||||
"description": null,
|
||||
"name": "get_current_weather"
|
||||
},
|
||||
@ -21,7 +21,7 @@
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1741263680,
|
||||
"created": 1741372335,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion",
|
||||
|
@ -10,10 +10,10 @@
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "{\"",
|
||||
"name": null
|
||||
"arguments": "{",
|
||||
"name": "get_current_weather"
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -24,205 +24,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"function_call": null,
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "function",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"function_call": null,
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "\":",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"function_call": null,
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": " {\"",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"function_call": null,
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "_",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"function_call": null,
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "name",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"function_call": null,
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "\":",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"created": 1741689423,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -244,7 +46,7 @@
|
||||
"arguments": " \"",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -255,172 +57,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"function_call": null,
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "get",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"function_call": null,
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "_current",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"function_call": null,
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "_weather",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"function_call": null,
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "\",",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"function_call": null,
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": " \"",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null,
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"created": 1741689423,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -442,7 +79,7 @@
|
||||
"arguments": "location",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -453,7 +90,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"created": 1741689423,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -475,7 +112,7 @@
|
||||
"arguments": "\":",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -486,7 +123,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"created": 1741689423,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -508,7 +145,7 @@
|
||||
"arguments": " \"",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -519,7 +156,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"created": 1741689423,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -541,7 +178,7 @@
|
||||
"arguments": "Bro",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -552,7 +189,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"created": 1741689423,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -574,7 +211,7 @@
|
||||
"arguments": "oklyn",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -585,7 +222,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"created": 1741689423,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -607,7 +244,7 @@
|
||||
"arguments": ",",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -618,7 +255,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"created": 1741689423,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -640,7 +277,7 @@
|
||||
"arguments": " NY",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -651,7 +288,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"created": 1741689423,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -673,7 +310,7 @@
|
||||
"arguments": "\",",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -684,7 +321,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"created": 1741689423,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -706,7 +343,7 @@
|
||||
"arguments": " \"",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -717,7 +354,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"created": 1741689423,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -739,7 +376,7 @@
|
||||
"arguments": "format",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -750,7 +387,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"created": 1741689423,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -772,7 +409,7 @@
|
||||
"arguments": "\":",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -783,7 +420,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"created": 1741689423,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -805,7 +442,7 @@
|
||||
"arguments": " \"",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -816,7 +453,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"created": 1741689423,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -838,7 +475,7 @@
|
||||
"arguments": "f",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -849,7 +486,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"created": 1741689423,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -871,7 +508,7 @@
|
||||
"arguments": "ahrenheit",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -882,7 +519,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"created": 1741689423,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -901,10 +538,10 @@
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "\"}}",
|
||||
"arguments": "\"}",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"id": "0",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
@ -915,40 +552,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"service_tier": null,
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"function_call": null,
|
||||
"refusal": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "<|eot_id|>",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741263681,
|
||||
"created": 1741689423,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
170
integration-tests/models/test_flash_gemma3.py
Normal file
170
integration-tests/models/test_flash_gemma3.py
Normal file
@ -0,0 +1,170 @@
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_gemma3_handle(launcher):
|
||||
with launcher("google/gemma-3-4b-it", num_shard=2) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_gemma3(flash_gemma3_handle):
|
||||
await flash_gemma3_handle.health(300)
|
||||
return flash_gemma3_handle.client
|
||||
|
||||
|
||||
async def test_flash_gemma3(flash_gemma3, response_snapshot):
|
||||
response = await flash_gemma3.generate(
|
||||
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
|
||||
seed=42,
|
||||
max_new_tokens=100,
|
||||
)
|
||||
|
||||
assert (
|
||||
response.generated_text
|
||||
== " people died in the United States.\n\nThe generally accepted estimate is that 675,000 people died in the United States. However, some historians believe the actual number could be as high as 10 million.\n\nI am looking for more information on this discrepancy and the factors that contributed to the wide range of estimates.\n\nHere's a breakdown of the factors contributing to the wide range of estimates for the 1918 flu pandemic death toll in the United States"
|
||||
)
|
||||
assert response.details.generated_tokens == 100
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
async def test_flash_gemma3_image_cow_dog(flash_gemma3, response_snapshot):
|
||||
image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
|
||||
response = await flash_gemma3.chat(
|
||||
seed=42,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What is the breed of the dog in the image?",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
assert (
|
||||
response.choices[0].message.content
|
||||
== "That's a fantastic question! However, the image doesn't show a dog. It shows a **Brown Swiss cow** standing on a beach. \n\nBrown Swiss cows are known for their reddish-brown color and distinctive white markings. \n\nIf you'd like, you can send me another image and I’ll do my best to identify it!"
|
||||
)
|
||||
assert response.usage["completion_tokens"] == 75
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
async def test_flash_gemma3_image_cow(flash_gemma3, response_snapshot):
|
||||
image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
|
||||
response = await flash_gemma3.chat(
|
||||
seed=42,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
],
|
||||
max_tokens=100,
|
||||
)
|
||||
assert (
|
||||
response.choices[0].message.content
|
||||
== "Here's a description of what's shown in the image:\n\nThe image depicts a brown cow standing on a sandy beach. The beach has turquoise water and a distant island visible in the background. The sky is bright blue with some white clouds. \n\nIt's a quite a humorous and unusual scene – a cow enjoying a day at the beach!"
|
||||
)
|
||||
assert response.usage["completion_tokens"] == 74
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
async def test_exceed_window(flash_gemma3, response_snapshot):
|
||||
response = await flash_gemma3.generate(
|
||||
"This is a nice place. " * 800 + "I really enjoy the scenery,",
|
||||
seed=42,
|
||||
max_new_tokens=20,
|
||||
)
|
||||
|
||||
assert (
|
||||
response.generated_text
|
||||
== " the people, and the food.\n\nThis is a nice place.\n"
|
||||
)
|
||||
assert response.details.generated_tokens == 16
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
# Helper function to convert a Pillow image to a base64 data URL
|
||||
def image_to_data_url(img: Image.Image, fmt: str) -> str:
|
||||
buffer = BytesIO()
|
||||
img.save(buffer, format=fmt)
|
||||
img_data = buffer.getvalue()
|
||||
b64_str = base64.b64encode(img_data).decode("utf-8")
|
||||
mime_type = "image/png" if fmt.upper() == "PNG" else "image/jpeg"
|
||||
return f"data:{mime_type};base64,{b64_str}"
|
||||
|
||||
|
||||
async def test_flash_gemma3_image_base64_rgba(flash_gemma3, response_snapshot):
|
||||
# Create an empty 100x100 PNG image with alpha (transparent background)
|
||||
img = Image.new("RGBA", (100, 100), (0, 0, 0, 0))
|
||||
data_url = image_to_data_url(img, "PNG")
|
||||
response = await flash_gemma3.chat(
|
||||
seed=42,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": data_url}},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What do you see in this transparent image?",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
max_tokens=100,
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
async def test_flash_gemma3_image_base64_rgb_png(flash_gemma3, response_snapshot):
|
||||
# Create an empty 100x100 PNG image without alpha (white background)
|
||||
img = Image.new("RGB", (100, 100), (255, 255, 255))
|
||||
data_url = image_to_data_url(img, "PNG")
|
||||
response = await flash_gemma3.chat(
|
||||
seed=42,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": data_url}},
|
||||
{"type": "text", "text": "What do you see in this plain image?"},
|
||||
],
|
||||
},
|
||||
],
|
||||
max_tokens=100,
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
async def test_flash_gemma3_image_base64_rgb_jpg(flash_gemma3, response_snapshot):
|
||||
# Create an empty 100x100 JPEG image (white background)
|
||||
img = Image.new("RGB", (100, 100), (255, 255, 255))
|
||||
data_url = image_to_data_url(img, "JPEG")
|
||||
response = await flash_gemma3.chat(
|
||||
seed=42,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": data_url}},
|
||||
{"type": "text", "text": "What do you see in this JPEG image?"},
|
||||
],
|
||||
},
|
||||
],
|
||||
max_tokens=100,
|
||||
)
|
||||
assert response == response_snapshot
|
@ -108,7 +108,7 @@ async def test_flash_llama_grammar_tools_nostream(
|
||||
function=ChatCompletionOutputFunctionDefinition(
|
||||
description=None,
|
||||
name="get_current_weather",
|
||||
arguments='{"format":"fahrenheit","location":"Brooklyn, NY"}',
|
||||
arguments='{"location":"Brooklyn, NY","format":"fahrenheit"}',
|
||||
),
|
||||
)
|
||||
]
|
||||
@ -142,14 +142,15 @@ async def test_flash_llama_grammar_tools_openai(
|
||||
|
||||
chunks = []
|
||||
tool = ""
|
||||
name = ""
|
||||
for chunk in stream:
|
||||
if chunk.choices[0].delta.tool_calls[0].function.name:
|
||||
name += chunk.choices[0].delta.tool_calls[0].function.name
|
||||
tool += chunk.choices[0].delta.tool_calls[0].function.arguments
|
||||
chunks.append(chunk)
|
||||
|
||||
assert (
|
||||
tool
|
||||
== '{"function": {"_name": "get_current_weather", "location": "Brooklyn, NY", "format": "fahrenheit"}}<|eot_id|>'
|
||||
)
|
||||
assert name == "get_current_weather"
|
||||
assert tool == '{ "location": "Brooklyn, NY", "format": "fahrenheit"}'
|
||||
assert chunks == response_snapshot
|
||||
|
||||
|
||||
@ -184,7 +185,7 @@ async def test_flash_llama_grammar_tools_auto_nostream(
|
||||
function=ChatCompletionOutputFunctionDefinition(
|
||||
description=None,
|
||||
name="get_current_weather",
|
||||
arguments='{"format":"fahrenheit","location":"Brooklyn, NY"}',
|
||||
arguments='{"location":"Brooklyn, NY","format":"fahrenheit"}',
|
||||
),
|
||||
)
|
||||
]
|
||||
@ -223,7 +224,7 @@ async def test_flash_llama_grammar_tools_choice_nostream(
|
||||
function=ChatCompletionOutputFunctionDefinition(
|
||||
description=None,
|
||||
name="get_current_weather",
|
||||
arguments='{"format":"fahrenheit","location":"Brooklyn, NY"}',
|
||||
arguments='{"location":"Brooklyn, NY","format":"fahrenheit"}',
|
||||
),
|
||||
)
|
||||
]
|
||||
@ -250,23 +251,24 @@ async def test_flash_llama_grammar_tools_choice_stream(
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Paris, France?",
|
||||
"content": "What is the weather like in Brooklyn, New York?",
|
||||
},
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
tool_calls_generated = ""
|
||||
arguments = ""
|
||||
chunks = []
|
||||
name = ""
|
||||
for chunk in stream:
|
||||
tool_calls_generated += chunk.choices[0].delta.tool_calls[0].function.arguments
|
||||
if chunk.choices[0].delta.tool_calls[0].function.name:
|
||||
name += chunk.choices[0].delta.tool_calls[0].function.name
|
||||
arguments += chunk.choices[0].delta.tool_calls[0].function.arguments
|
||||
assert chunk.choices[0].delta.content is None
|
||||
chunks.append(chunk)
|
||||
|
||||
assert (
|
||||
tool_calls_generated
|
||||
== '{"function": {"_name": "get_current_weather", "location": "Paris, France", "format": "celsius"}}<|eot_id|>'
|
||||
)
|
||||
assert name == "get_current_weather"
|
||||
assert arguments == '{ "location": "Brooklyn, NY", "format": "fahrenheit"}'
|
||||
assert chunks == response_snapshot
|
||||
|
||||
|
||||
@ -277,7 +279,7 @@ async def test_flash_llama_grammar_tools_insufficient_information_nostream(
|
||||
):
|
||||
client = InferenceClient(base_url=f"{flash_llama_grammar_tools.base_url}/v1")
|
||||
response = client.chat_completion(
|
||||
max_tokens=100,
|
||||
max_tokens=20,
|
||||
seed=24,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
@ -297,9 +299,10 @@ async def test_flash_llama_grammar_tools_insufficient_information_nostream(
|
||||
content_generated = response.choices[0].message.content
|
||||
assert response.choices[0].message.tool_calls is None
|
||||
|
||||
######## FIXME before MERGE ############################
|
||||
# TODO This is different from the streaming case, this is NOT normal.
|
||||
assert content_generated == "I am a helpful assistant!"
|
||||
assert (
|
||||
content_generated
|
||||
== "I'm an artificial intelligence model known as a large language model (LLM) or conversational AI"
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@ -310,7 +313,7 @@ async def test_flash_llama_grammar_tools_insufficient_information_stream(
|
||||
):
|
||||
client = InferenceClient(base_url=f"{flash_llama_grammar_tools.base_url}/v1")
|
||||
stream = client.chat_completion(
|
||||
max_tokens=100,
|
||||
max_tokens=20,
|
||||
seed=24,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
@ -334,7 +337,11 @@ async def test_flash_llama_grammar_tools_insufficient_information_stream(
|
||||
chunks.append(chunk)
|
||||
assert chunk.choices[0].delta.tool_calls is None
|
||||
|
||||
assert content_generated == "I am a helpful assistant"
|
||||
######## This is exactly the same as the non streaming case
|
||||
assert (
|
||||
content_generated
|
||||
== "I'm an artificial intelligence model known as a large language model (LLM) or conversational AI"
|
||||
)
|
||||
assert chunks == response_snapshot
|
||||
|
||||
|
||||
@ -345,7 +352,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_auto(
|
||||
):
|
||||
client = InferenceClient(base_url=f"{flash_llama_grammar_tools.base_url}/v1")
|
||||
stream = client.chat_completion(
|
||||
max_tokens=100,
|
||||
max_tokens=20,
|
||||
seed=24,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
@ -371,7 +378,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_auto(
|
||||
|
||||
assert (
|
||||
content_generated
|
||||
== "There was a wise old octopus named Oracle. He lived in a cozy little cave beneath the waves with his best friend, a curious seahorse named Finley. One day, Finley met a playful dolphin named Daisy, and the three became inseparable. They spent their days exploring the ocean, playing hide-and-seek, and learning about the wonders of the sea from Oracle"
|
||||
== "Once upon a time, in a vibrant ocean filled with coral reefs and schools of shimmering fish,"
|
||||
)
|
||||
assert chunks == response_snapshot
|
||||
|
||||
@ -401,14 +408,18 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_required(
|
||||
)
|
||||
|
||||
tool_calls_generated = ""
|
||||
name = ""
|
||||
chunks = []
|
||||
for chunk in stream:
|
||||
assert chunk.choices[0].delta.content is None
|
||||
if chunk.choices[0].delta.tool_calls[0].function.name:
|
||||
name += chunk.choices[0].delta.tool_calls[0].function.name
|
||||
tool_calls_generated += chunk.choices[0].delta.tool_calls[0].function.arguments
|
||||
|
||||
assert name == "get_n_day_weather_forecast"
|
||||
assert (
|
||||
tool_calls_generated
|
||||
== '{"function": {"_name": "get_n_day_weather_forecast", "location": "San Francisco, CA", "format": "fahrenheit", "num_days":3}}<|eot_id|>'
|
||||
== '{ "location": "San Francisco, CA", "format": "fahrenheit", "num_days":3}'
|
||||
)
|
||||
assert chunks == response_snapshot
|
||||
|
||||
@ -479,12 +490,17 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
|
||||
)
|
||||
chunks = []
|
||||
tool_calls_generated = ""
|
||||
name = ""
|
||||
for chunk in stream:
|
||||
assert chunk.choices[0].delta.content is None
|
||||
if chunk.choices[0].delta.tool_calls[0].function.name:
|
||||
name += chunk.choices[0].delta.tool_calls[0].function.name
|
||||
tool_calls_generated += chunk.choices[0].delta.tool_calls[0].function.arguments
|
||||
chunks.append(chunk)
|
||||
|
||||
assert name == "get_n_day_weather_forecast"
|
||||
assert (
|
||||
tool_calls_generated
|
||||
== '{"function": {"_name": "get_n_day_weather_forecast", "location": "San Francisco, CA", "format": "celsius", "num_days": 3}}<|eot_id|>'
|
||||
== '{ "location": "San Francisco, CA", "format": "celsius", "num_days": 3}'
|
||||
)
|
||||
assert chunks == response_snapshot
|
||||
|
||||
|
@ -49,17 +49,11 @@ async def test_model_single_request(tgi_service):
|
||||
max_new_tokens=128,
|
||||
seed=42,
|
||||
)
|
||||
sample_expectations = {
|
||||
"gpt2": "Deep Learning",
|
||||
"llama": "Deep Learning",
|
||||
"mistral": "Deep learning",
|
||||
"qwen2": "Deep Learning",
|
||||
"granite": "Deep learning",
|
||||
}
|
||||
assert sample_expectations[service_name] in response
|
||||
# The response must be different
|
||||
assert not response.startswith(greedy_expectations[service_name])
|
||||
|
||||
# Sampling with stop sequence
|
||||
stop_sequence = sample_expectations[service_name][-5:]
|
||||
# Sampling with stop sequence (using one of the words returned from the previous test)
|
||||
stop_sequence = response.split(" ")[-5]
|
||||
response = await tgi_service.client.text_generation(
|
||||
"What is Deep Learning?",
|
||||
do_sample=True,
|
||||
|
@ -15,6 +15,7 @@ dependencies = [
|
||||
"numpy>=2.0",
|
||||
"openai>=1.65",
|
||||
"huggingface_hub>=0.29",
|
||||
"pillow>=11.1.0",
|
||||
]
|
||||
|
||||
[tool.isort]
|
||||
|
@ -1,8 +1,8 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile pyproject.toml -o requirements.txt
|
||||
aiohappyeyeballs==2.4.6
|
||||
# uv pip compile pyproject.toml
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.11.12
|
||||
aiohttp==3.11.13
|
||||
# via text-generation
|
||||
aiosignal==1.3.2
|
||||
# via aiohttp
|
||||
@ -12,7 +12,7 @@ anyio==4.8.0
|
||||
# via
|
||||
# httpx
|
||||
# openai
|
||||
attrs==25.1.0
|
||||
attrs==25.3.0
|
||||
# via aiohttp
|
||||
certifi==2025.1.31
|
||||
# via
|
||||
@ -25,13 +25,13 @@ distro==1.9.0
|
||||
# via openai
|
||||
docker==7.1.0
|
||||
# via text-generation-integration-tests (pyproject.toml)
|
||||
filelock==3.17.0
|
||||
filelock==3.18.0
|
||||
# via huggingface-hub
|
||||
frozenlist==1.5.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec==2025.2.0
|
||||
fsspec==2025.3.0
|
||||
# via huggingface-hub
|
||||
h11==0.14.0
|
||||
# via httpcore
|
||||
@ -39,7 +39,7 @@ httpcore==1.0.7
|
||||
# via httpx
|
||||
httpx==0.28.1
|
||||
# via openai
|
||||
huggingface-hub==0.29.0
|
||||
huggingface-hub==0.29.3
|
||||
# via
|
||||
# text-generation-integration-tests (pyproject.toml)
|
||||
# text-generation
|
||||
@ -51,7 +51,7 @@ idna==3.10
|
||||
# yarl
|
||||
iniconfig==2.0.0
|
||||
# via pytest
|
||||
jiter==0.8.2
|
||||
jiter==0.9.0
|
||||
# via openai
|
||||
multidict==6.1.0
|
||||
# via
|
||||
@ -59,15 +59,17 @@ multidict==6.1.0
|
||||
# yarl
|
||||
numpy==2.2.3
|
||||
# via text-generation-integration-tests (pyproject.toml)
|
||||
openai==1.65.3
|
||||
openai==1.66.3
|
||||
# via text-generation-integration-tests (pyproject.toml)
|
||||
packaging==24.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# pytest
|
||||
pillow==11.1.0
|
||||
# via text-generation-integration-tests (pyproject.toml)
|
||||
pluggy==1.5.0
|
||||
# via pytest
|
||||
propcache==0.2.1
|
||||
propcache==0.3.0
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
@ -78,7 +80,7 @@ pydantic==2.10.6
|
||||
# text-generation
|
||||
pydantic-core==2.27.2
|
||||
# via pydantic
|
||||
pytest==8.3.4
|
||||
pytest==8.3.5
|
||||
# via
|
||||
# text-generation-integration-tests (pyproject.toml)
|
||||
# pytest-asyncio
|
||||
@ -95,7 +97,7 @@ sniffio==1.3.1
|
||||
# via
|
||||
# anyio
|
||||
# openai
|
||||
syrupy==4.8.1
|
||||
syrupy==4.9.0
|
||||
# via text-generation-integration-tests (pyproject.toml)
|
||||
text-generation==0.7.0
|
||||
# via text-generation-integration-tests (pyproject.toml)
|
||||
|
@ -97,6 +97,21 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anyio"
|
||||
version = "4.8.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
|
||||
{ name = "idna" },
|
||||
{ name = "sniffio" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a3/73/199a98fc2dae33535d6b8e8e6ec01f8c1d76c9adb096c6b7d64823038cde/anyio-4.8.0.tar.gz", hash = "sha256:1d9fe889df5212298c0c0723fa20479d1b94883a2df44bd3897aa91083316f7a", size = 181126 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/46/eb/e7f063ad1fec6b3178a3cd82d1a3c4de82cccf283fc42746168188e1cdd5/anyio-4.8.0-py3-none-any.whl", hash = "sha256:b5011f270ab5eb0abf13385f851315585cc37ef330dd88e27ec3d34d651fd47a", size = 96041 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-timeout"
|
||||
version = "5.0.1"
|
||||
@ -181,6 +196,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "distro"
|
||||
version = "1.9.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/fc/f8/98eea607f65de6527f8a2e8885fc8015d3e6f5775df186e443e0964a11c3/distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed", size = 60722 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "docker"
|
||||
version = "7.1.0"
|
||||
@ -276,6 +300,43 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e2/94/758680531a00d06e471ef649e4ec2ed6bf185356a7f9fbfbb7368a40bd49/fsspec-2025.2.0-py3-none-any.whl", hash = "sha256:9de2ad9ce1f85e1931858535bc882543171d197001a0a5eb2ddc04f1781ab95b", size = 184484 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h11"
|
||||
version = "0.14.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f5/38/3af3d3633a34a3316095b39c8e8fb4853a28a536e55d347bd8d8e9a14b03/h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d", size = 100418 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761", size = 58259 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "httpcore"
|
||||
version = "1.0.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "certifi" },
|
||||
{ name = "h11" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/6a/41/d7d0a89eb493922c37d343b607bc1b5da7f5be7e383740b4753ad8943e90/httpcore-1.0.7.tar.gz", hash = "sha256:8551cb62a169ec7162ac7be8d4817d561f60e08eaa485234898414bb5a8a0b4c", size = 85196 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/87/f5/72347bc88306acb359581ac4d52f23c0ef445b57157adedb9aee0cd689d2/httpcore-1.0.7-py3-none-any.whl", hash = "sha256:a3fff8f43dc260d5bd363d9f9cf1830fa3a458b332856f34282de498ed420edd", size = 78551 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "httpx"
|
||||
version = "0.28.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
{ name = "certifi" },
|
||||
{ name = "httpcore" },
|
||||
{ name = "idna" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "huggingface-hub"
|
||||
version = "0.29.0"
|
||||
@ -312,6 +373,50 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jiter"
|
||||
version = "0.9.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/1e/c2/e4562507f52f0af7036da125bb699602ead37a2332af0788f8e0a3417f36/jiter-0.9.0.tar.gz", hash = "sha256:aadba0964deb424daa24492abc3d229c60c4a31bfee205aedbf1acc7639d7893", size = 162604 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b0/82/39f7c9e67b3b0121f02a0b90d433626caa95a565c3d2449fea6bcfa3f5f5/jiter-0.9.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:816ec9b60fdfd1fec87da1d7ed46c66c44ffec37ab2ef7de5b147b2fce3fd5ad", size = 314540 },
|
||||
{ url = "https://files.pythonhosted.org/packages/01/07/7bf6022c5a152fca767cf5c086bb41f7c28f70cf33ad259d023b53c0b858/jiter-0.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9b1d3086f8a3ee0194ecf2008cf81286a5c3e540d977fa038ff23576c023c0ea", size = 321065 },
|
||||
{ url = "https://files.pythonhosted.org/packages/6c/b2/de3f3446ecba7c48f317568e111cc112613da36c7b29a6de45a1df365556/jiter-0.9.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1339f839b91ae30b37c409bf16ccd3dc453e8b8c3ed4bd1d6a567193651a4a51", size = 341664 },
|
||||
{ url = "https://files.pythonhosted.org/packages/13/cf/6485a4012af5d407689c91296105fcdb080a3538e0658d2abf679619c72f/jiter-0.9.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ffba79584b3b670fefae66ceb3a28822365d25b7bf811e030609a3d5b876f538", size = 364635 },
|
||||
{ url = "https://files.pythonhosted.org/packages/0d/f7/4a491c568f005553240b486f8e05c82547340572d5018ef79414b4449327/jiter-0.9.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5cfc7d0a8e899089d11f065e289cb5b2daf3d82fbe028f49b20d7b809193958d", size = 406288 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d3/ca/f4263ecbce7f5e6bded8f52a9f1a66540b270c300b5c9f5353d163f9ac61/jiter-0.9.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e00a1a2bbfaaf237e13c3d1592356eab3e9015d7efd59359ac8b51eb56390a12", size = 397499 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ac/a2/522039e522a10bac2f2194f50e183a49a360d5f63ebf46f6d890ef8aa3f9/jiter-0.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1d9870561eb26b11448854dce0ff27a9a27cb616b632468cafc938de25e9e51", size = 352926 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b1/67/306a5c5abc82f2e32bd47333a1c9799499c1c3a415f8dde19dbf876f00cb/jiter-0.9.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9872aeff3f21e437651df378cb75aeb7043e5297261222b6441a620218b58708", size = 384506 },
|
||||
{ url = "https://files.pythonhosted.org/packages/0f/89/c12fe7b65a4fb74f6c0d7b5119576f1f16c79fc2953641f31b288fad8a04/jiter-0.9.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:1fd19112d1049bdd47f17bfbb44a2c0001061312dcf0e72765bfa8abd4aa30e5", size = 520621 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c4/2b/d57900c5c06e6273fbaa76a19efa74dbc6e70c7427ab421bf0095dfe5d4a/jiter-0.9.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6ef5da104664e526836070e4a23b5f68dec1cc673b60bf1edb1bfbe8a55d0678", size = 512613 },
|
||||
{ url = "https://files.pythonhosted.org/packages/89/05/d8b90bfb21e58097d5a4e0224f2940568366f68488a079ae77d4b2653500/jiter-0.9.0-cp310-cp310-win32.whl", hash = "sha256:cb12e6d65ebbefe5518de819f3eda53b73187b7089040b2d17f5b39001ff31c4", size = 206613 },
|
||||
{ url = "https://files.pythonhosted.org/packages/2c/1d/5767f23f88e4f885090d74bbd2755518050a63040c0f59aa059947035711/jiter-0.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:c43ca669493626d8672be3b645dbb406ef25af3f4b6384cfd306da7eb2e70322", size = 208371 },
|
||||
{ url = "https://files.pythonhosted.org/packages/23/44/e241a043f114299254e44d7e777ead311da400517f179665e59611ab0ee4/jiter-0.9.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:6c4d99c71508912a7e556d631768dcdef43648a93660670986916b297f1c54af", size = 314654 },
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/1b/a7e5e42db9fa262baaa9489d8d14ca93f8663e7f164ed5e9acc9f467fc00/jiter-0.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8f60fb8ce7df529812bf6c625635a19d27f30806885139e367af93f6e734ef58", size = 320909 },
|
||||
{ url = "https://files.pythonhosted.org/packages/60/bf/8ebdfce77bc04b81abf2ea316e9c03b4a866a7d739cf355eae4d6fd9f6fe/jiter-0.9.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51c4e1a4f8ea84d98b7b98912aa4290ac3d1eabfde8e3c34541fae30e9d1f08b", size = 341733 },
|
||||
{ url = "https://files.pythonhosted.org/packages/a8/4e/754ebce77cff9ab34d1d0fa0fe98f5d42590fd33622509a3ba6ec37ff466/jiter-0.9.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5f4c677c424dc76684fea3e7285a7a2a7493424bea89ac441045e6a1fb1d7b3b", size = 365097 },
|
||||
{ url = "https://files.pythonhosted.org/packages/32/2c/6019587e6f5844c612ae18ca892f4cd7b3d8bbf49461ed29e384a0f13d98/jiter-0.9.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2221176dfec87f3470b21e6abca056e6b04ce9bff72315cb0b243ca9e835a4b5", size = 406603 },
|
||||
{ url = "https://files.pythonhosted.org/packages/da/e9/c9e6546c817ab75a1a7dab6dcc698e62e375e1017113e8e983fccbd56115/jiter-0.9.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3c7adb66f899ffa25e3c92bfcb593391ee1947dbdd6a9a970e0d7e713237d572", size = 396625 },
|
||||
{ url = "https://files.pythonhosted.org/packages/be/bd/976b458add04271ebb5a255e992bd008546ea04bb4dcadc042a16279b4b4/jiter-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c98d27330fdfb77913c1097a7aab07f38ff2259048949f499c9901700789ac15", size = 351832 },
|
||||
{ url = "https://files.pythonhosted.org/packages/07/51/fe59e307aaebec9265dbad44d9d4381d030947e47b0f23531579b9a7c2df/jiter-0.9.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:eda3f8cc74df66892b1d06b5d41a71670c22d95a1ca2cbab73654745ce9d0419", size = 384590 },
|
||||
{ url = "https://files.pythonhosted.org/packages/db/55/5dcd2693794d8e6f4889389ff66ef3be557a77f8aeeca8973a97a7c00557/jiter-0.9.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:dd5ab5ddc11418dce28343123644a100f487eaccf1de27a459ab36d6cca31043", size = 520690 },
|
||||
{ url = "https://files.pythonhosted.org/packages/54/d5/9f51dc90985e9eb251fbbb747ab2b13b26601f16c595a7b8baba964043bd/jiter-0.9.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:42f8a68a69f047b310319ef8e2f52fdb2e7976fb3313ef27df495cf77bcad965", size = 512649 },
|
||||
{ url = "https://files.pythonhosted.org/packages/a6/e5/4e385945179bcf128fa10ad8dca9053d717cbe09e258110e39045c881fe5/jiter-0.9.0-cp311-cp311-win32.whl", hash = "sha256:a25519efb78a42254d59326ee417d6f5161b06f5da827d94cf521fed961b1ff2", size = 206920 },
|
||||
{ url = "https://files.pythonhosted.org/packages/4c/47/5e0b94c603d8e54dd1faab439b40b832c277d3b90743e7835879ab663757/jiter-0.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:923b54afdd697dfd00d368b7ccad008cccfeb1efb4e621f32860c75e9f25edbd", size = 210119 },
|
||||
{ url = "https://files.pythonhosted.org/packages/af/d7/c55086103d6f29b694ec79156242304adf521577530d9031317ce5338c59/jiter-0.9.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:7b46249cfd6c48da28f89eb0be3f52d6fdb40ab88e2c66804f546674e539ec11", size = 309203 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b0/01/f775dfee50beb420adfd6baf58d1c4d437de41c9b666ddf127c065e5a488/jiter-0.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:609cf3c78852f1189894383cf0b0b977665f54cb38788e3e6b941fa6d982c00e", size = 319678 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/b8/09b73a793714726893e5d46d5c534a63709261af3d24444ad07885ce87cb/jiter-0.9.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d726a3890a54561e55a9c5faea1f7655eda7f105bd165067575ace6e65f80bb2", size = 341816 },
|
||||
{ url = "https://files.pythonhosted.org/packages/35/6f/b8f89ec5398b2b0d344257138182cc090302854ed63ed9c9051e9c673441/jiter-0.9.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2e89dc075c1fef8fa9be219e249f14040270dbc507df4215c324a1839522ea75", size = 364152 },
|
||||
{ url = "https://files.pythonhosted.org/packages/9b/ca/978cc3183113b8e4484cc7e210a9ad3c6614396e7abd5407ea8aa1458eef/jiter-0.9.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:04e8ffa3c353b1bc4134f96f167a2082494351e42888dfcf06e944f2729cbe1d", size = 406991 },
|
||||
{ url = "https://files.pythonhosted.org/packages/13/3a/72861883e11a36d6aa314b4922125f6ae90bdccc225cd96d24cc78a66385/jiter-0.9.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:203f28a72a05ae0e129b3ed1f75f56bc419d5f91dfacd057519a8bd137b00c42", size = 395824 },
|
||||
{ url = "https://files.pythonhosted.org/packages/87/67/22728a86ef53589c3720225778f7c5fdb617080e3deaed58b04789418212/jiter-0.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fca1a02ad60ec30bb230f65bc01f611c8608b02d269f998bc29cca8619a919dc", size = 351318 },
|
||||
{ url = "https://files.pythonhosted.org/packages/69/b9/f39728e2e2007276806d7a6609cda7fac44ffa28ca0d02c49a4f397cc0d9/jiter-0.9.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:237e5cee4d5d2659aaf91bbf8ec45052cc217d9446070699441a91b386ae27dc", size = 384591 },
|
||||
{ url = "https://files.pythonhosted.org/packages/eb/8f/8a708bc7fd87b8a5d861f1c118a995eccbe6d672fe10c9753e67362d0dd0/jiter-0.9.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:528b6b71745e7326eed73c53d4aa57e2a522242320b6f7d65b9c5af83cf49b6e", size = 520746 },
|
||||
{ url = "https://files.pythonhosted.org/packages/95/1e/65680c7488bd2365dbd2980adaf63c562d3d41d3faac192ebc7ef5b4ae25/jiter-0.9.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:9f48e86b57bc711eb5acdfd12b6cb580a59cc9a993f6e7dcb6d8b50522dcd50d", size = 512754 },
|
||||
{ url = "https://files.pythonhosted.org/packages/78/f3/fdc43547a9ee6e93c837685da704fb6da7dba311fc022e2766d5277dfde5/jiter-0.9.0-cp312-cp312-win32.whl", hash = "sha256:699edfde481e191d81f9cf6d2211debbfe4bd92f06410e7637dffb8dd5dfde06", size = 207075 },
|
||||
{ url = "https://files.pythonhosted.org/packages/cd/9d/742b289016d155f49028fe1bfbeb935c9bf0ffeefdf77daf4a63a42bb72b/jiter-0.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:099500d07b43f61d8bd780466d429c45a7b25411b334c60ca875fa775f68ccb0", size = 207999 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "multidict"
|
||||
version = "6.1.0"
|
||||
@ -411,6 +516,25 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/17/7f/d322a4125405920401450118dbdc52e0384026bd669939484670ce8b2ab9/numpy-2.2.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:783145835458e60fa97afac25d511d00a1eca94d4a8f3ace9fe2043003c678e4", size = 12839607 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openai"
|
||||
version = "1.66.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
{ name = "distro" },
|
||||
{ name = "httpx" },
|
||||
{ name = "jiter" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "sniffio" },
|
||||
{ name = "tqdm" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a3/77/5172104ca1df35ed2ed8fb26dbc787f721c39498fc51d666c4db07756a0c/openai-1.66.3.tar.gz", hash = "sha256:8dde3aebe2d081258d4159c4cb27bdc13b5bb3f7ea2201d9bd940b9a89faf0c9", size = 397244 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/78/5a/e20182f7b6171642d759c548daa0ba20a1d3ac10d2bd0a13fd75704a9ac3/openai-1.66.3-py3-none-any.whl", hash = "sha256:a427c920f727711877ab17c11b95f1230b27767ba7a01e5b66102945141ceca9", size = 567400 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "24.2"
|
||||
@ -420,6 +544,54 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", size = 65451 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pillow"
|
||||
version = "11.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f3/af/c097e544e7bd278333db77933e535098c259609c4eb3b85381109602fb5b/pillow-11.1.0.tar.gz", hash = "sha256:368da70808b36d73b4b390a8ffac11069f8a5c85f29eff1f1b01bcf3ef5b2a20", size = 46742715 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/50/1c/2dcea34ac3d7bc96a1fd1bd0a6e06a57c67167fec2cff8d95d88229a8817/pillow-11.1.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:e1abe69aca89514737465752b4bcaf8016de61b3be1397a8fc260ba33321b3a8", size = 3229983 },
|
||||
{ url = "https://files.pythonhosted.org/packages/14/ca/6bec3df25e4c88432681de94a3531cc738bd85dea6c7aa6ab6f81ad8bd11/pillow-11.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c640e5a06869c75994624551f45e5506e4256562ead981cce820d5ab39ae2192", size = 3101831 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d4/2c/668e18e5521e46eb9667b09e501d8e07049eb5bfe39d56be0724a43117e6/pillow-11.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a07dba04c5e22824816b2615ad7a7484432d7f540e6fa86af60d2de57b0fcee2", size = 4314074 },
|
||||
{ url = "https://files.pythonhosted.org/packages/02/80/79f99b714f0fc25f6a8499ecfd1f810df12aec170ea1e32a4f75746051ce/pillow-11.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e267b0ed063341f3e60acd25c05200df4193e15a4a5807075cd71225a2386e26", size = 4394933 },
|
||||
{ url = "https://files.pythonhosted.org/packages/81/aa/8d4ad25dc11fd10a2001d5b8a80fdc0e564ac33b293bdfe04ed387e0fd95/pillow-11.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:bd165131fd51697e22421d0e467997ad31621b74bfc0b75956608cb2906dda07", size = 4353349 },
|
||||
{ url = "https://files.pythonhosted.org/packages/84/7a/cd0c3eaf4a28cb2a74bdd19129f7726277a7f30c4f8424cd27a62987d864/pillow-11.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:abc56501c3fd148d60659aae0af6ddc149660469082859fa7b066a298bde9482", size = 4476532 },
|
||||
{ url = "https://files.pythonhosted.org/packages/8f/8b/a907fdd3ae8f01c7670dfb1499c53c28e217c338b47a813af8d815e7ce97/pillow-11.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:54ce1c9a16a9561b6d6d8cb30089ab1e5eb66918cb47d457bd996ef34182922e", size = 4279789 },
|
||||
{ url = "https://files.pythonhosted.org/packages/6f/9a/9f139d9e8cccd661c3efbf6898967a9a337eb2e9be2b454ba0a09533100d/pillow-11.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:73ddde795ee9b06257dac5ad42fcb07f3b9b813f8c1f7f870f402f4dc54b5269", size = 4413131 },
|
||||
{ url = "https://files.pythonhosted.org/packages/a8/68/0d8d461f42a3f37432203c8e6df94da10ac8081b6d35af1c203bf3111088/pillow-11.1.0-cp310-cp310-win32.whl", hash = "sha256:3a5fe20a7b66e8135d7fd617b13272626a28278d0e578c98720d9ba4b2439d49", size = 2291213 },
|
||||
{ url = "https://files.pythonhosted.org/packages/14/81/d0dff759a74ba87715509af9f6cb21fa21d93b02b3316ed43bda83664db9/pillow-11.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:b6123aa4a59d75f06e9dd3dac5bf8bc9aa383121bb3dd9a7a612e05eabc9961a", size = 2625725 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ce/1f/8d50c096a1d58ef0584ddc37e6f602828515219e9d2428e14ce50f5ecad1/pillow-11.1.0-cp310-cp310-win_arm64.whl", hash = "sha256:a76da0a31da6fcae4210aa94fd779c65c75786bc9af06289cd1c184451ef7a65", size = 2375213 },
|
||||
{ url = "https://files.pythonhosted.org/packages/dd/d6/2000bfd8d5414fb70cbbe52c8332f2283ff30ed66a9cde42716c8ecbe22c/pillow-11.1.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:e06695e0326d05b06833b40b7ef477e475d0b1ba3a6d27da1bb48c23209bf457", size = 3229968 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d9/45/3fe487010dd9ce0a06adf9b8ff4f273cc0a44536e234b0fad3532a42c15b/pillow-11.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:96f82000e12f23e4f29346e42702b6ed9a2f2fea34a740dd5ffffcc8c539eb35", size = 3101806 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e3/72/776b3629c47d9d5f1c160113158a7a7ad177688d3a1159cd3b62ded5a33a/pillow-11.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a3cd561ded2cf2bbae44d4605837221b987c216cff94f49dfeed63488bb228d2", size = 4322283 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e4/c2/e25199e7e4e71d64eeb869f5b72c7ddec70e0a87926398785ab944d92375/pillow-11.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f189805c8be5ca5add39e6f899e6ce2ed824e65fb45f3c28cb2841911da19070", size = 4402945 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c1/ed/51d6136c9d5911f78632b1b86c45241c712c5a80ed7fa7f9120a5dff1eba/pillow-11.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:dd0052e9db3474df30433f83a71b9b23bd9e4ef1de13d92df21a52c0303b8ab6", size = 4361228 },
|
||||
{ url = "https://files.pythonhosted.org/packages/48/a4/fbfe9d5581d7b111b28f1d8c2762dee92e9821bb209af9fa83c940e507a0/pillow-11.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:837060a8599b8f5d402e97197d4924f05a2e0d68756998345c829c33186217b1", size = 4484021 },
|
||||
{ url = "https://files.pythonhosted.org/packages/39/db/0b3c1a5018117f3c1d4df671fb8e47d08937f27519e8614bbe86153b65a5/pillow-11.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:aa8dd43daa836b9a8128dbe7d923423e5ad86f50a7a14dc688194b7be5c0dea2", size = 4287449 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d9/58/bc128da7fea8c89fc85e09f773c4901e95b5936000e6f303222490c052f3/pillow-11.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0a2f91f8a8b367e7a57c6e91cd25af510168091fb89ec5146003e424e1558a96", size = 4419972 },
|
||||
{ url = "https://files.pythonhosted.org/packages/5f/bb/58f34379bde9fe197f51841c5bbe8830c28bbb6d3801f16a83b8f2ad37df/pillow-11.1.0-cp311-cp311-win32.whl", hash = "sha256:c12fc111ef090845de2bb15009372175d76ac99969bdf31e2ce9b42e4b8cd88f", size = 2291201 },
|
||||
{ url = "https://files.pythonhosted.org/packages/3a/c6/fce9255272bcf0c39e15abd2f8fd8429a954cf344469eaceb9d0d1366913/pillow-11.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:fbd43429d0d7ed6533b25fc993861b8fd512c42d04514a0dd6337fb3ccf22761", size = 2625686 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/52/8ba066d569d932365509054859f74f2a9abee273edcef5cd75e4bc3e831e/pillow-11.1.0-cp311-cp311-win_arm64.whl", hash = "sha256:f7955ecf5609dee9442cbface754f2c6e541d9e6eda87fad7f7a989b0bdb9d71", size = 2375194 },
|
||||
{ url = "https://files.pythonhosted.org/packages/95/20/9ce6ed62c91c073fcaa23d216e68289e19d95fb8188b9fb7a63d36771db8/pillow-11.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2062ffb1d36544d42fcaa277b069c88b01bb7298f4efa06731a7fd6cc290b81a", size = 3226818 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b9/d8/f6004d98579a2596c098d1e30d10b248798cceff82d2b77aa914875bfea1/pillow-11.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a85b653980faad27e88b141348707ceeef8a1186f75ecc600c395dcac19f385b", size = 3101662 },
|
||||
{ url = "https://files.pythonhosted.org/packages/08/d9/892e705f90051c7a2574d9f24579c9e100c828700d78a63239676f960b74/pillow-11.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9409c080586d1f683df3f184f20e36fb647f2e0bc3988094d4fd8c9f4eb1b3b3", size = 4329317 },
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/aa/7f29711f26680eab0bcd3ecdd6d23ed6bce180d82e3f6380fb7ae35fcf3b/pillow-11.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7fdadc077553621911f27ce206ffcbec7d3f8d7b50e0da39f10997e8e2bb7f6a", size = 4412999 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/c4/8f0fe3b9e0f7196f6d0bbb151f9fba323d72a41da068610c4c960b16632a/pillow-11.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:93a18841d09bcdd774dcdc308e4537e1f867b3dec059c131fde0327899734aa1", size = 4368819 },
|
||||
{ url = "https://files.pythonhosted.org/packages/38/0d/84200ed6a871ce386ddc82904bfadc0c6b28b0c0ec78176871a4679e40b3/pillow-11.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:9aa9aeddeed452b2f616ff5507459e7bab436916ccb10961c4a382cd3e03f47f", size = 4496081 },
|
||||
{ url = "https://files.pythonhosted.org/packages/84/9c/9bcd66f714d7e25b64118e3952d52841a4babc6d97b6d28e2261c52045d4/pillow-11.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3cdcdb0b896e981678eee140d882b70092dac83ac1cdf6b3a60e2216a73f2b91", size = 4296513 },
|
||||
{ url = "https://files.pythonhosted.org/packages/db/61/ada2a226e22da011b45f7104c95ebda1b63dcbb0c378ad0f7c2a710f8fd2/pillow-11.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:36ba10b9cb413e7c7dfa3e189aba252deee0602c86c309799da5a74009ac7a1c", size = 4431298 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e7/c4/fc6e86750523f367923522014b821c11ebc5ad402e659d8c9d09b3c9d70c/pillow-11.1.0-cp312-cp312-win32.whl", hash = "sha256:cfd5cd998c2e36a862d0e27b2df63237e67273f2fc78f47445b14e73a810e7e6", size = 2291630 },
|
||||
{ url = "https://files.pythonhosted.org/packages/08/5c/2104299949b9d504baf3f4d35f73dbd14ef31bbd1ddc2c1b66a5b7dfda44/pillow-11.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:a697cd8ba0383bba3d2d3ada02b34ed268cb548b369943cd349007730c92bddf", size = 2626369 },
|
||||
{ url = "https://files.pythonhosted.org/packages/37/f3/9b18362206b244167c958984b57c7f70a0289bfb59a530dd8af5f699b910/pillow-11.1.0-cp312-cp312-win_arm64.whl", hash = "sha256:4dd43a78897793f60766563969442020e90eb7847463eca901e41ba186a7d4a5", size = 2375240 },
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/c5/389961578fb677b8b3244fcd934f720ed25a148b9a5cc81c91bdf59d8588/pillow-11.1.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:8c730dc3a83e5ac137fbc92dfcfe1511ce3b2b5d7578315b63dbbb76f7f51d90", size = 3198345 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c4/fa/803c0e50ffee74d4b965229e816af55276eac1d5806712de86f9371858fd/pillow-11.1.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:7d33d2fae0e8b170b6a6c57400e077412240f6f5bb2a342cf1ee512a787942bb", size = 3072938 },
|
||||
{ url = "https://files.pythonhosted.org/packages/dc/67/2a3a5f8012b5d8c63fe53958ba906c1b1d0482ebed5618057ef4d22f8076/pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a8d65b38173085f24bc07f8b6c505cbb7418009fa1a1fcb111b1f4961814a442", size = 3400049 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e5/a0/514f0d317446c98c478d1872497eb92e7cde67003fed74f696441e647446/pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:015c6e863faa4779251436db398ae75051469f7c903b043a48f078e437656f83", size = 3422431 },
|
||||
{ url = "https://files.pythonhosted.org/packages/cd/00/20f40a935514037b7d3f87adfc87d2c538430ea625b63b3af8c3f5578e72/pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:d44ff19eea13ae4acdaaab0179fa68c0c6f2f45d66a4d8ec1eda7d6cecbcc15f", size = 3446208 },
|
||||
{ url = "https://files.pythonhosted.org/packages/28/3c/7de681727963043e093c72e6c3348411b0185eab3263100d4490234ba2f6/pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d3d8da4a631471dfaf94c10c85f5277b1f8e42ac42bade1ac67da4b4a7359b73", size = 3509746 },
|
||||
{ url = "https://files.pythonhosted.org/packages/41/67/936f9814bdd74b2dfd4822f1f7725ab5d8ff4103919a1664eb4874c58b2f/pillow-11.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:4637b88343166249fe8aa94e7c4a62a180c4b3898283bb5d3d2fd5fe10d8e4e0", size = 2626353 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pluggy"
|
||||
version = "1.5.0"
|
||||
@ -656,6 +828,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sniffio"
|
||||
version = "1.3.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syrupy"
|
||||
version = "4.8.1"
|
||||
@ -688,7 +869,10 @@ version = "2.0.1"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "docker" },
|
||||
{ name = "huggingface-hub" },
|
||||
{ name = "numpy" },
|
||||
{ name = "openai" },
|
||||
{ name = "pillow" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-asyncio" },
|
||||
@ -699,7 +883,10 @@ dependencies = [
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "docker", specifier = ">=7" },
|
||||
{ name = "huggingface-hub", specifier = ">=0.29" },
|
||||
{ name = "numpy", specifier = ">=2.0" },
|
||||
{ name = "openai", specifier = ">=1.65" },
|
||||
{ name = "pillow", specifier = ">=11.1.0" },
|
||||
{ name = "pydantic", specifier = ">2,<3" },
|
||||
{ name = "pytest", specifier = ">=8.3.0" },
|
||||
{ name = "pytest-asyncio", specifier = ">=0.23.1" },
|
||||
@ -741,7 +928,7 @@ name = "tqdm"
|
||||
version = "4.67.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
{ name = "colorama", marker = "platform_system == 'Windows'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 }
|
||||
wheels = [
|
||||
|
@ -97,11 +97,10 @@ fn get_config(
|
||||
let filename = if !path.exists() {
|
||||
// Assume it's a hub id
|
||||
|
||||
let mut builder = if let Ok(token) = std::env::var("HF_TOKEN") {
|
||||
let mut builder = ApiBuilder::from_env();
|
||||
if let Ok(token) = std::env::var("HF_TOKEN") {
|
||||
// env variable has precedence over on file token.
|
||||
ApiBuilder::new().with_token(Some(token))
|
||||
} else {
|
||||
ApiBuilder::new()
|
||||
builder = builder.with_token(Some(token))
|
||||
};
|
||||
if let Ok(origin) = env::var("HF_HUB_USER_AGENT_ORIGIN") {
|
||||
builder = builder.with_user_agent("origin", origin.as_str());
|
||||
@ -152,7 +151,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||
"flashdecoding"
|
||||
};
|
||||
|
||||
match config.head_dim {
|
||||
match config.get_head_dim() {
|
||||
Some(h) if h == 64 || h == 128 || h == 256 => {
|
||||
if lora_adapters.is_some() && prefix_caching.is_none() {
|
||||
tracing::info!("Disabling prefix caching because of lora adapters");
|
||||
@ -214,6 +213,7 @@ struct RawConfig {
|
||||
num_key_value_heads: Option<usize>,
|
||||
num_hidden_layers: Option<usize>,
|
||||
head_dim: Option<usize>,
|
||||
text_config: Option<TextConfig>,
|
||||
vision_config: Option<VisionConfig>,
|
||||
is_encoder_decoder: Option<bool>,
|
||||
#[serde(rename = "num_experts_per_tok")]
|
||||
@ -233,6 +233,11 @@ struct QuantizationConfig {
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct VisionConfig {}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TextConfig {
|
||||
head_dim: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Config {
|
||||
max_position_embeddings: Option<usize>,
|
||||
@ -244,6 +249,7 @@ struct Config {
|
||||
intermediate_size: Option<usize>,
|
||||
hidden_size: Option<usize>,
|
||||
model_type: Option<String>,
|
||||
text_config: Option<TextConfig>,
|
||||
vision_config: Option<VisionConfig>,
|
||||
is_encoder_decoder: bool,
|
||||
num_experts_per_token: usize,
|
||||
@ -253,6 +259,14 @@ struct Config {
|
||||
}
|
||||
|
||||
impl Config {
|
||||
fn get_head_dim(&self) -> Option<usize> {
|
||||
self.head_dim.or_else(|| {
|
||||
self.text_config
|
||||
.as_ref()
|
||||
.and_then(|text_config| text_config.head_dim)
|
||||
})
|
||||
}
|
||||
|
||||
fn flop(&self) -> Option<u64> {
|
||||
if self.vision_config.is_some() {
|
||||
// VLM are much harder to predict and VRAM requirements
|
||||
@ -261,7 +275,7 @@ impl Config {
|
||||
}
|
||||
let num_heads = self.num_heads? as u64;
|
||||
let num_kv_heads = self.num_kv_heads? as u64;
|
||||
let head_dim = self.head_dim? as u64;
|
||||
let head_dim = self.get_head_dim()? as u64;
|
||||
let hidden_size = self.hidden_size? as u64;
|
||||
let intermediate_size = (self.intermediate_size?
|
||||
* (self.num_experts_per_token + self.num_shared_experts))
|
||||
@ -289,7 +303,7 @@ impl Config {
|
||||
}
|
||||
// 2 for key and values
|
||||
// 2 for f16 dtype?
|
||||
Some(self.num_kv_heads? * 2 * self.head_dim? * 2 * self.num_layers?)
|
||||
Some(self.num_kv_heads? * 2 * self.get_head_dim()? * 2 * self.num_layers?)
|
||||
}
|
||||
|
||||
fn mlp_vram_per_tok(&self) -> Option<usize> {
|
||||
@ -310,8 +324,8 @@ impl Config {
|
||||
}
|
||||
|
||||
fn model_vram(&self) -> Option<usize> {
|
||||
let attn_vram = (self.num_heads? + 2 * self.num_kv_heads?) * self.head_dim?;
|
||||
let o_vram = self.num_heads? * self.head_dim? * self.hidden_size?;
|
||||
let attn_vram = (self.num_heads? + 2 * self.num_kv_heads?) * self.get_head_dim()?;
|
||||
let o_vram = self.num_heads? * self.get_head_dim()? * self.hidden_size?;
|
||||
// gate + up + down = 3
|
||||
let mlp_vram = 3 * self.intermediate_size? * self.num_experts * self.hidden_size?;
|
||||
let layer_vram = mlp_vram + attn_vram + o_vram;
|
||||
@ -349,6 +363,7 @@ impl From<RawConfig> for Config {
|
||||
let num_kv_heads = other.num_key_value_heads.or(other.num_attention_heads);
|
||||
let intermediate_size = other.intermediate_size;
|
||||
let model_type = other.model_type;
|
||||
let text_config = other.text_config;
|
||||
let vision_config = other.vision_config;
|
||||
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
|
||||
let num_experts_per_token = other.num_experts_per_token.unwrap_or(1);
|
||||
@ -360,6 +375,7 @@ impl From<RawConfig> for Config {
|
||||
quantize,
|
||||
head_dim,
|
||||
model_type,
|
||||
text_config,
|
||||
vision_config,
|
||||
is_encoder_decoder,
|
||||
hidden_size,
|
||||
@ -2067,6 +2083,7 @@ fn main() -> Result<(), LauncherError> {
|
||||
let default_optimal = match config {
|
||||
Some(ref config) => match config.model_type.as_deref() {
|
||||
Some("qwen2_vl") | Some("qwen2_5_vl") => 10_000,
|
||||
Some("gemma3") => 8000,
|
||||
_ => 4096,
|
||||
},
|
||||
None => 4096,
|
||||
|
@ -1,4 +1,5 @@
|
||||
{
|
||||
stdenv,
|
||||
dockerTools,
|
||||
cacert,
|
||||
text-generation-inference,
|
||||
@ -11,13 +12,25 @@ in
|
||||
build {
|
||||
name = "tgi-docker";
|
||||
tag = "latest";
|
||||
compressor = "zstd";
|
||||
config = {
|
||||
EntryPoint = [ "${text-generation-inference}/bin/text-generation-inference" ];
|
||||
Env = [
|
||||
"HF_HOME=/data"
|
||||
"PORT=80"
|
||||
# The CUDA container toolkit will mount the driver shim into the
|
||||
# container. We just have to ensure that the dynamic loader finds
|
||||
# the libraries.
|
||||
"LD_LIBRARY_PATH=/usr/lib64"
|
||||
];
|
||||
|
||||
};
|
||||
contents = [ cacert ];
|
||||
extraCommands = ''
|
||||
mkdir -p tmp
|
||||
chmod -R 1777 tmp
|
||||
'';
|
||||
contents = [
|
||||
cacert
|
||||
stdenv.cc
|
||||
];
|
||||
}
|
||||
|
@ -16,8 +16,8 @@
|
||||
grpcio-reflection,
|
||||
grpcio-status,
|
||||
grpcio-tools,
|
||||
hf-kernels,
|
||||
hf-transfer,
|
||||
kernels,
|
||||
loguru,
|
||||
mamba-ssm,
|
||||
moe,
|
||||
@ -91,8 +91,8 @@ buildPythonPackage {
|
||||
grpcio-reflection
|
||||
grpcio-status
|
||||
grpcio-tools
|
||||
hf-kernels
|
||||
hf-transfer
|
||||
kernels
|
||||
loguru
|
||||
mamba-ssm
|
||||
moe
|
||||
|
700
router/src/chat.rs
Normal file
700
router/src/chat.rs
Normal file
@ -0,0 +1,700 @@
|
||||
use crate::{
|
||||
infer::InferError, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
|
||||
ChatCompletionLogprobs, CompletionType, DeltaToolCall, Function, FunctionDefinition,
|
||||
StreamOptions, StreamResponse, TextMessage, ToolCallDelta, Usage,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ToolCall {
|
||||
_name: String,
|
||||
#[serde(flatten, default)]
|
||||
/// Using Map to preserve order
|
||||
arguments: serde_json::Map<String, Value>,
|
||||
}
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Call {
|
||||
function: ToolCall,
|
||||
}
|
||||
|
||||
#[cfg_attr(test, derive(Debug))]
|
||||
pub(crate) enum ChatEvent {
|
||||
NoTool,
|
||||
Events(Vec<CompletionType>),
|
||||
}
|
||||
|
||||
#[cfg_attr(test, derive(Debug))]
|
||||
pub(crate) enum ChatChoice {
|
||||
NoTool,
|
||||
ToolCalls(Vec<crate::ToolCall>),
|
||||
}
|
||||
|
||||
pub(crate) fn parse_output(generated_text: &str) -> Result<ChatChoice, InferError> {
|
||||
let call: Call = serde_json::from_str(generated_text).map_err(|e| {
|
||||
InferError::ToolError(format!(
|
||||
"Failed to parse generated text: {} {:?}",
|
||||
e, generated_text
|
||||
))
|
||||
})?;
|
||||
let name = call.function._name;
|
||||
|
||||
match &name[..] {
|
||||
"no_tool" => {
|
||||
// parse the content message
|
||||
Ok(ChatChoice::NoTool)
|
||||
}
|
||||
name => {
|
||||
let tool_calls = vec![crate::ToolCall {
|
||||
id: "0".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
description: None,
|
||||
name: name.to_string(),
|
||||
arguments: serde_json::to_value(call.function.arguments).map_err(|err| {
|
||||
InferError::ToolError(format!(
|
||||
"Could not convert arguments to JSON map {err}"
|
||||
))
|
||||
})?,
|
||||
},
|
||||
}];
|
||||
Ok(ChatChoice::ToolCalls(tool_calls))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a StreamResponse into an Event to be sent over SSE
|
||||
fn create_event_from_stream_token(
|
||||
stream_token: &StreamResponse,
|
||||
logprobs: bool,
|
||||
inner_using_tools: bool,
|
||||
system_fingerprint: String,
|
||||
model_id: String,
|
||||
function_name: Option<String>,
|
||||
id: String,
|
||||
) -> CompletionType {
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let logprobs = logprobs.then(|| {
|
||||
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens.clone()))
|
||||
});
|
||||
|
||||
// replace the content with the tool calls if grammar is present
|
||||
let content = if !stream_token.token.special {
|
||||
Some(stream_token.token.text.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let (content, tool_calls) = if inner_using_tools {
|
||||
// Cast into a vec
|
||||
(None, content)
|
||||
} else {
|
||||
(content, None)
|
||||
};
|
||||
let finish_reason = stream_token
|
||||
.details
|
||||
.as_ref()
|
||||
.map(|details| details.finish_reason.format(true));
|
||||
let delta = match (content, tool_calls) {
|
||||
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: delta,
|
||||
..Default::default()
|
||||
}),
|
||||
(None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {
|
||||
role: "assistant".to_string(),
|
||||
tool_calls: vec![DeltaToolCall {
|
||||
index: 0,
|
||||
id,
|
||||
r#type: "function".to_string(),
|
||||
function: Function {
|
||||
name: function_name,
|
||||
arguments: tool_calls,
|
||||
},
|
||||
}],
|
||||
}),
|
||||
(None, None) => ChatCompletionDelta::Chat(TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: "".to_string(),
|
||||
..Default::default()
|
||||
}),
|
||||
};
|
||||
let choices = vec![ChatCompletionChoice {
|
||||
index: 0,
|
||||
delta,
|
||||
logprobs,
|
||||
finish_reason,
|
||||
}];
|
||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
||||
model_id,
|
||||
system_fingerprint,
|
||||
current_time,
|
||||
choices,
|
||||
None,
|
||||
))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum StreamState {
|
||||
/// Before the tools was parsed
|
||||
Buffering,
|
||||
/// We detected a tool call here
|
||||
Tool,
|
||||
/// This is without tool calling
|
||||
Content,
|
||||
}
|
||||
|
||||
pub struct ChatState {
|
||||
state: StreamState,
|
||||
text: String,
|
||||
options: StreamOptions,
|
||||
model_id: String,
|
||||
fingerprint: String,
|
||||
logprobs: bool,
|
||||
id: String,
|
||||
}
|
||||
|
||||
impl ChatState {
|
||||
pub fn new(
|
||||
using_tools: bool,
|
||||
options: StreamOptions,
|
||||
fingerprint: String,
|
||||
model_id: String,
|
||||
logprobs: bool,
|
||||
id: String,
|
||||
) -> Self {
|
||||
let state = if using_tools {
|
||||
StreamState::Buffering
|
||||
} else {
|
||||
StreamState::Content
|
||||
};
|
||||
let text = String::new();
|
||||
Self {
|
||||
state,
|
||||
text,
|
||||
options,
|
||||
fingerprint,
|
||||
model_id,
|
||||
logprobs,
|
||||
id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push(&mut self, mut stream_token: StreamResponse) -> ChatEvent {
|
||||
let mut events = vec![];
|
||||
let token_text = &stream_token.token.text;
|
||||
match self.state {
|
||||
StreamState::Buffering => {
|
||||
self.text.push_str(token_text);
|
||||
tracing::info!("Current text {:?}", self.text);
|
||||
let partial = &self.text;
|
||||
let partial =
|
||||
partial.trim_end_matches(|c: char| c.is_whitespace() || c == ',' || c == '}');
|
||||
if let Ok(call) = serde_json::from_str::<Call>(&format!("{}}}}}", partial)) {
|
||||
// This can be no_tool before the content has been emitted
|
||||
if call.function._name != "no_tool" {
|
||||
stream_token.token.text = "{".to_string();
|
||||
let chat_complete = create_event_from_stream_token(
|
||||
&stream_token,
|
||||
self.logprobs,
|
||||
true,
|
||||
self.fingerprint.clone(),
|
||||
self.model_id.clone(),
|
||||
Some(call.function._name),
|
||||
self.id.clone(),
|
||||
);
|
||||
|
||||
events.push(chat_complete);
|
||||
self.state = StreamState::Tool;
|
||||
} else {
|
||||
return ChatEvent::NoTool;
|
||||
}
|
||||
}
|
||||
}
|
||||
StreamState::Tool => {
|
||||
self.text.push_str(token_text);
|
||||
if serde_json::from_str::<Call>(&self.text).is_ok() {
|
||||
self.state = StreamState::Buffering;
|
||||
let mut text = stream_token.token.text.trim_end();
|
||||
// Effectively trimming only the last closing brace
|
||||
if text.ends_with('}') {
|
||||
text = &text[..text.len() - 1];
|
||||
}
|
||||
stream_token.token.text = text.to_string();
|
||||
let chat_complete = create_event_from_stream_token(
|
||||
&stream_token,
|
||||
self.logprobs,
|
||||
true,
|
||||
self.fingerprint.clone(),
|
||||
self.model_id.clone(),
|
||||
None,
|
||||
self.id.clone(),
|
||||
);
|
||||
events.push(chat_complete);
|
||||
} else {
|
||||
let chat_complete = create_event_from_stream_token(
|
||||
&stream_token,
|
||||
self.logprobs,
|
||||
true,
|
||||
self.fingerprint.clone(),
|
||||
self.model_id.clone(),
|
||||
None,
|
||||
self.id.clone(),
|
||||
);
|
||||
events.push(chat_complete);
|
||||
}
|
||||
}
|
||||
StreamState::Content => {
|
||||
let chat_complete = create_event_from_stream_token(
|
||||
&stream_token,
|
||||
self.logprobs,
|
||||
false,
|
||||
self.fingerprint.clone(),
|
||||
self.model_id.clone(),
|
||||
None,
|
||||
self.id.clone(),
|
||||
);
|
||||
|
||||
events.push(chat_complete);
|
||||
}
|
||||
}
|
||||
|
||||
if self.options.include_usage {
|
||||
if let Some(details) = stream_token.details {
|
||||
let completion_tokens = details.generated_tokens;
|
||||
let prompt_tokens = details.input_length;
|
||||
let total_tokens = prompt_tokens + completion_tokens;
|
||||
|
||||
let usage = Usage {
|
||||
completion_tokens,
|
||||
prompt_tokens,
|
||||
total_tokens,
|
||||
};
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk {
|
||||
id: String::new(),
|
||||
created: current_time,
|
||||
model: self.model_id.clone(),
|
||||
system_fingerprint: self.fingerprint.clone(),
|
||||
choices: vec![],
|
||||
usage: Some(Usage {
|
||||
prompt_tokens: usage.prompt_tokens,
|
||||
completion_tokens: usage.completion_tokens,
|
||||
total_tokens: usage.total_tokens,
|
||||
}),
|
||||
});
|
||||
|
||||
events.push(chat_complete);
|
||||
}
|
||||
}
|
||||
ChatEvent::Events(events)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
ChatCompletionChoice, ChatCompletionDelta, FinishReason, StreamDetails, TextMessage, Token,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
fn get_tool_call_content(event: &CompletionType) -> (Option<&String>, &String) {
|
||||
match event {
|
||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
|
||||
assert_eq!(choices.len(), 1);
|
||||
if let ChatCompletionChoice {
|
||||
delta: ChatCompletionDelta::Tool(ToolCallDelta { tool_calls, .. }),
|
||||
..
|
||||
} = &choices[0]
|
||||
{
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
let DeltaToolCall {
|
||||
index,
|
||||
id,
|
||||
r#type,
|
||||
function,
|
||||
} = &tool_calls[0];
|
||||
assert_eq!(*index, 0);
|
||||
assert_eq!(id, "0");
|
||||
assert_eq!(r#type, "function");
|
||||
(function.name.as_ref(), &function.arguments)
|
||||
} else {
|
||||
panic!("Expected plain message");
|
||||
}
|
||||
}
|
||||
_ => panic!("Unexpected chunk"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_stream() {
|
||||
let mut chat_state = ChatState::new(
|
||||
false,
|
||||
StreamOptions {
|
||||
include_usage: false,
|
||||
},
|
||||
"fingerprint".to_string(),
|
||||
"model_id".to_string(),
|
||||
false,
|
||||
"0".to_string(),
|
||||
);
|
||||
|
||||
let events = chat_state.push(StreamResponse {
|
||||
generated_text: None,
|
||||
token: Token {
|
||||
id: 42,
|
||||
text: "Hi".to_string(),
|
||||
logprob: 0.0,
|
||||
special: false,
|
||||
},
|
||||
top_tokens: vec![],
|
||||
index: 0,
|
||||
details: None,
|
||||
});
|
||||
if let ChatEvent::Events(events) = events {
|
||||
assert_eq!(events.len(), 1);
|
||||
match &events[0] {
|
||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
|
||||
assert_eq!(
|
||||
choices,
|
||||
&[ChatCompletionChoice {
|
||||
index: 0,
|
||||
delta: ChatCompletionDelta::Chat(TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: "Hi".to_string(),
|
||||
tool_call_id: None,
|
||||
}),
|
||||
logprobs: None,
|
||||
finish_reason: None,
|
||||
}]
|
||||
);
|
||||
}
|
||||
_ => panic!("Unexpected chunk"),
|
||||
}
|
||||
} else {
|
||||
panic!("Expected chat events");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_stream_usage() {
|
||||
let mut chat_state = ChatState::new(
|
||||
false,
|
||||
StreamOptions {
|
||||
include_usage: true,
|
||||
},
|
||||
"fingerprint".to_string(),
|
||||
"model_id".to_string(),
|
||||
false,
|
||||
"0".to_string(),
|
||||
);
|
||||
|
||||
let events = chat_state.push(StreamResponse {
|
||||
generated_text: None,
|
||||
token: Token {
|
||||
id: 42,
|
||||
text: "Hi".to_string(),
|
||||
logprob: 0.0,
|
||||
special: false,
|
||||
},
|
||||
top_tokens: vec![],
|
||||
index: 0,
|
||||
details: Some(StreamDetails {
|
||||
input_length: 2,
|
||||
generated_tokens: 10,
|
||||
seed: None,
|
||||
finish_reason: FinishReason::Length,
|
||||
}),
|
||||
});
|
||||
if let ChatEvent::Events(events) = events {
|
||||
assert_eq!(events.len(), 2);
|
||||
match &events[0] {
|
||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {
|
||||
assert_eq!(
|
||||
choices,
|
||||
&[ChatCompletionChoice {
|
||||
index: 0,
|
||||
delta: ChatCompletionDelta::Chat(TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: "Hi".to_string(),
|
||||
tool_call_id: None,
|
||||
}),
|
||||
logprobs: None,
|
||||
// HAS A FINISH REASON
|
||||
finish_reason: Some("length".to_string()),
|
||||
}]
|
||||
);
|
||||
}
|
||||
_ => panic!("Unexpected chunk"),
|
||||
}
|
||||
match &events[1] {
|
||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk { usage, .. }) => {
|
||||
assert_eq!(
|
||||
*usage,
|
||||
Some(Usage {
|
||||
prompt_tokens: 2,
|
||||
completion_tokens: 10,
|
||||
total_tokens: 12,
|
||||
})
|
||||
);
|
||||
}
|
||||
_ => panic!("Unexpected chunk"),
|
||||
}
|
||||
} else {
|
||||
panic!("Expected chat events");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_stream_tool_no_tool_simple() {
|
||||
let mut chat_state = ChatState::new(
|
||||
true,
|
||||
StreamOptions {
|
||||
include_usage: true,
|
||||
},
|
||||
"fingerprint".to_string(),
|
||||
"model_id".to_string(),
|
||||
false,
|
||||
"0".to_string(),
|
||||
);
|
||||
|
||||
let tokens = vec![
|
||||
"{\"".to_string(),
|
||||
"function".to_string(),
|
||||
"\":".to_string(),
|
||||
" {\"".to_string(),
|
||||
"_".to_string(),
|
||||
"name".to_string(),
|
||||
"\":".to_string(),
|
||||
" \"".to_string(),
|
||||
"no".to_string(),
|
||||
"_tool".to_string(),
|
||||
"\",".to_string(),
|
||||
" \"".to_string(),
|
||||
"content".to_string(),
|
||||
"\":".to_string(),
|
||||
" \"".to_string(), // Token 14
|
||||
"I".to_string(), // Event 1
|
||||
" am".to_string(), // Event 2
|
||||
" a".to_string(), // Event 3
|
||||
" helpful".to_string(), // Event 4
|
||||
" assistant".to_string(), // Event 5
|
||||
"!\"".to_string(), // Event 6 (with trailing quore removed)
|
||||
"}".to_string(),
|
||||
"}".to_string(),
|
||||
];
|
||||
let tokens: Vec<_> = tokens
|
||||
.into_iter()
|
||||
.map(|text| StreamResponse {
|
||||
generated_text: None,
|
||||
token: Token {
|
||||
id: 42,
|
||||
text: text.to_string(),
|
||||
logprob: 0.0,
|
||||
special: false,
|
||||
},
|
||||
top_tokens: vec![],
|
||||
index: 0,
|
||||
details: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Initial ignored output
|
||||
for token in &tokens[..10] {
|
||||
let events = chat_state.push(token.clone());
|
||||
if let ChatEvent::Events(events) = events {
|
||||
assert_eq!(events.len(), 0, "{events:?}");
|
||||
} else {
|
||||
panic!("Expected chat events");
|
||||
}
|
||||
}
|
||||
|
||||
// No tool output
|
||||
let events = chat_state.push(tokens[10].clone());
|
||||
if let ChatEvent::NoTool = events {
|
||||
assert!(true);
|
||||
} else {
|
||||
panic!("Expected chat events");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_stream_tool_no_tool_empty() {
|
||||
let mut chat_state = ChatState::new(
|
||||
true,
|
||||
StreamOptions {
|
||||
include_usage: true,
|
||||
},
|
||||
"fingerprint".to_string(),
|
||||
"model_id".to_string(),
|
||||
false,
|
||||
"0".to_string(),
|
||||
);
|
||||
|
||||
let tokens = vec![
|
||||
"{\"".to_string(),
|
||||
"function".to_string(),
|
||||
"\":".to_string(),
|
||||
" {\"".to_string(),
|
||||
"_".to_string(),
|
||||
"name".to_string(),
|
||||
"\":".to_string(),
|
||||
" \"".to_string(),
|
||||
"no".to_string(),
|
||||
"_tool".to_string(),
|
||||
"\",".to_string(),
|
||||
" \"".to_string(),
|
||||
"content".to_string(),
|
||||
"\":\"".to_string(),
|
||||
"\"}".to_string(), // Token 13
|
||||
"}".to_string(), // Event 1
|
||||
];
|
||||
let tokens: Vec<_> = tokens
|
||||
.into_iter()
|
||||
.map(|text| StreamResponse {
|
||||
generated_text: None,
|
||||
token: Token {
|
||||
id: 42,
|
||||
text: text.to_string(),
|
||||
logprob: 0.0,
|
||||
special: false,
|
||||
},
|
||||
top_tokens: vec![],
|
||||
index: 0,
|
||||
details: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Initial ignored output
|
||||
for token in &tokens[..10] {
|
||||
let events = chat_state.push(token.clone());
|
||||
if let ChatEvent::Events(events) = events {
|
||||
assert_eq!(events.len(), 0, "{events:?}");
|
||||
} else {
|
||||
panic!("Expected chat events");
|
||||
}
|
||||
}
|
||||
|
||||
// No tool output
|
||||
let events = chat_state.push(tokens[10].clone());
|
||||
if let ChatEvent::NoTool = events {
|
||||
assert!(true);
|
||||
} else {
|
||||
panic!("Expected chat events");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_stream_tool_get_weather() {
|
||||
let mut chat_state = ChatState::new(
|
||||
true,
|
||||
StreamOptions {
|
||||
include_usage: true,
|
||||
},
|
||||
"fingerprint".to_string(),
|
||||
"model_id".to_string(),
|
||||
false,
|
||||
"0".to_string(),
|
||||
);
|
||||
|
||||
let tokens = vec![
|
||||
"{\"".to_string(),
|
||||
"function".to_string(),
|
||||
"\":".to_string(),
|
||||
" {\"".to_string(),
|
||||
"_".to_string(),
|
||||
"name".to_string(),
|
||||
"\":".to_string(),
|
||||
" \"".to_string(),
|
||||
"get".to_string(),
|
||||
"_current".to_string(),
|
||||
"_weather".to_string(),
|
||||
"\",".to_string(),
|
||||
// Event 1 is the function name
|
||||
// Event 2 is the start of the arguments "{"
|
||||
" \"".to_string(), // Event 3
|
||||
"location".to_string(), // Event 4
|
||||
"\":".to_string(), // Event 5
|
||||
" \"".to_string(), // Event 6
|
||||
"San".to_string(), // Event 7
|
||||
" Francisco".to_string(), // Event 8
|
||||
",".to_string(), // Event 9
|
||||
" CA".to_string(), // Event 10
|
||||
"\",".to_string(), // Event 11
|
||||
" \"".to_string(), // Event 12
|
||||
"format".to_string(), // Event 13
|
||||
"\":".to_string(), // Event 14
|
||||
" \"".to_string(), // Event 15
|
||||
"c".to_string(), // Event 16
|
||||
"elsius".to_string(), // Event 17
|
||||
"\"}}".to_string(), // Event 18 retained (trailing brace removed)
|
||||
];
|
||||
let tokens: Vec<_> = tokens
|
||||
.into_iter()
|
||||
.map(|text| StreamResponse {
|
||||
generated_text: None,
|
||||
token: Token {
|
||||
id: 42,
|
||||
text: text.to_string(),
|
||||
logprob: 0.0,
|
||||
special: false,
|
||||
},
|
||||
top_tokens: vec![],
|
||||
index: 0,
|
||||
details: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Initial ignored output
|
||||
for token in &tokens[..11] {
|
||||
let events = chat_state.push(token.clone());
|
||||
if let ChatEvent::Events(events) = events {
|
||||
assert_eq!(events.len(), 0, "{events:?}");
|
||||
} else {
|
||||
panic!("Expected chat events");
|
||||
}
|
||||
}
|
||||
|
||||
// No tool output
|
||||
let mut output = String::new();
|
||||
let mut output_name = String::new();
|
||||
for token in &tokens[11..11 + 17] {
|
||||
let events = chat_state.push(token.clone());
|
||||
if let ChatEvent::Events(events) = events {
|
||||
assert_eq!(events.len(), 1);
|
||||
let (name, arguments) = get_tool_call_content(&events[0]);
|
||||
if let Some(name) = name {
|
||||
assert_eq!(name, "get_current_weather");
|
||||
output_name.push_str(&name);
|
||||
}
|
||||
output.push_str(arguments);
|
||||
} else {
|
||||
panic!("Expected chat events");
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(output_name, "get_current_weather");
|
||||
assert_eq!(
|
||||
output,
|
||||
"{ \"location\": \"San Francisco, CA\", \"format\": \"celsius\"}"
|
||||
);
|
||||
|
||||
// No tool finish
|
||||
for token in &tokens[11 + 17..] {
|
||||
let events = chat_state.push(token.clone());
|
||||
if let ChatEvent::Events(events) = events {
|
||||
assert_eq!(events.len(), 0, "{events:?}");
|
||||
} else {
|
||||
panic!("Expected chat events");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -216,6 +216,19 @@ impl Qwen2_5Vl {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct Gemma3VisionConfig {
|
||||
pub(crate) image_size: usize,
|
||||
pub(crate) patch_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct Gemma3 {
|
||||
vision_config: Gemma3VisionConfig,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "model_type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
@ -249,6 +262,8 @@ pub enum Config {
|
||||
Paligemma(Paligemma),
|
||||
Gemma,
|
||||
Gemma2,
|
||||
Gemma3(Gemma3),
|
||||
Gemma3Text,
|
||||
Cohere,
|
||||
Drbx,
|
||||
Falcon,
|
||||
|
@ -16,7 +16,7 @@ pub(crate) fn strftime_now(format_str: String) -> Result<String, minijinja::Erro
|
||||
Ok(Local::now().format(&format_str).to_string())
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ChatTemplate {
|
||||
template: Template<'static, 'static>,
|
||||
bos_token: Option<String>,
|
||||
@ -33,7 +33,16 @@ impl ChatTemplate {
|
||||
let mut env = Box::new(Environment::new());
|
||||
// enable things like .strip() or .capitalize()
|
||||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||
let template_str = template.into_boxed_str();
|
||||
|
||||
// TODO: replace with better solution
|
||||
// hack to adjust gemma3 template for debug
|
||||
// replace 'messages[0]['content'][0]['text']' with 'messages[0]['content']'
|
||||
let mutated_template = template.replace(
|
||||
"messages[0]['content'][0]['text']",
|
||||
"messages[0]['content']",
|
||||
);
|
||||
|
||||
let template_str = mutated_template.into_boxed_str();
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
env.add_function("strftime_now", strftime_now);
|
||||
tracing::debug!("Loading template: {}", template_str);
|
||||
@ -123,8 +132,8 @@ mod tests {
|
||||
use crate::infer::chat_template::{raise_exception, strftime_now};
|
||||
use crate::infer::ChatTemplate;
|
||||
use crate::{
|
||||
ChatTemplateInputs, Message, MessageBody, MessageContent, TextMessage,
|
||||
TokenizerConfigToken, Tool,
|
||||
ChatTemplateInputs, Message, MessageBody, MessageChunk, MessageContent, TextMessage,
|
||||
TokenizerConfigToken, Tool, Url,
|
||||
};
|
||||
use chrono::Local;
|
||||
use minijinja::Environment;
|
||||
@ -1230,4 +1239,98 @@ TOOL CALL ID: 0
|
||||
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": \"{\\\"type\\\":\\\"object\\\",\\\"properties\\\":{\\\"location\\\":{\\\"type\\\":\\\"string\\\",\\\"description\\\":\\\"The city and state, e.g. San Francisco, CA\\\"},\\\"format\\\":{\\\"type\\\":\\\"string\\\",\\\"enum\\\":[\\\"celsius\\\",\\\"fahrenheit\\\"],\\\"description\\\":\\\"The temperature unit to use. Infer this from the users location.\\\"}},\\\"required\\\":[\\\"location\\\",\\\"format\\\"]}\",\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
|
||||
assert_eq!(result.unwrap(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_template_with_special_system_prompt() {
|
||||
// chat template from gemma3
|
||||
let ct = ChatTemplate::new(
|
||||
r#"{{ bos_token }}
|
||||
{%- if messages[0]['role'] == 'system' -%}
|
||||
{%- set first_user_prefix = messages[0]['content'][0]['text'] + '
|
||||
|
||||
' -%}
|
||||
{%- set loop_messages = messages[1:] -%}
|
||||
{%- else -%}
|
||||
{%- set first_user_prefix = "" -%}
|
||||
{%- set loop_messages = messages -%}
|
||||
{%- endif -%}
|
||||
{%- for message in loop_messages -%}
|
||||
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
|
||||
{{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
|
||||
{%- endif -%}
|
||||
{%- if (message['role'] == 'assistant') -%}
|
||||
{%- set role = "model" -%}
|
||||
{%- else -%}
|
||||
{%- set role = message['role'] -%}
|
||||
{%- endif -%}
|
||||
{{ '<start_of_turn>' + role + '
|
||||
' + (first_user_prefix if loop.first else "") }}
|
||||
{%- if message['content'] is string -%}
|
||||
{{ message['content'] | trim }}
|
||||
{%- elif message['content'] is iterable -%}
|
||||
{%- for item in message['content'] -%}
|
||||
{%- if item['type'] == 'image' -%}
|
||||
{{ '<start_of_image>' }}
|
||||
{%- elif item['type'] == 'text' -%}
|
||||
{{ item['text'] | trim }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- else -%}
|
||||
{{ raise_exception("Invalid content type") }}
|
||||
{%- endif -%}
|
||||
{{ '<end_of_turn>
|
||||
' }}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{'<start_of_turn>model
|
||||
'}}
|
||||
{%- endif -%}
|
||||
"#
|
||||
.to_string(),
|
||||
Some(TokenizerConfigToken::String("<bos>".to_string())),
|
||||
Some(TokenizerConfigToken::String("</eos>".to_string())),
|
||||
);
|
||||
let msgs: Vec<Message> = vec![
|
||||
Message {
|
||||
name: None,
|
||||
role: "system".to_string(),
|
||||
body: MessageBody::Content {
|
||||
content: MessageContent::MultipleChunks(vec![MessageChunk::Text {
|
||||
text: "You are a helpful assistant.".to_string(),
|
||||
}]),
|
||||
},
|
||||
},
|
||||
Message {
|
||||
name: None,
|
||||
role: "user".to_string(),
|
||||
body: MessageBody::Content {
|
||||
content: MessageContent::MultipleChunks(vec![
|
||||
MessageChunk::Text {
|
||||
text: "I'm already using this supplement ".to_string(),
|
||||
},
|
||||
MessageChunk::ImageUrl {
|
||||
image_url: Url {
|
||||
url: "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_3018.JPG".to_string()
|
||||
},
|
||||
},
|
||||
MessageChunk::Text {
|
||||
text: "and I want to use this one too ".to_string()
|
||||
},
|
||||
MessageChunk::ImageUrl {
|
||||
image_url: Url {
|
||||
url: "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_3015.jpg".to_string()
|
||||
},
|
||||
},
|
||||
MessageChunk::Text {
|
||||
text: " what are cautions?".to_string()
|
||||
},
|
||||
]),
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
let result = ct.apply(msgs, None);
|
||||
let expected = "<bos><start_of_turn>user\nYou are a helpful assistant.\n\nI'm already using this supplement and I want to use this one too  what are cautions?<end_of_turn>\n<start_of_turn>model\n".to_string();
|
||||
assert_eq!(result.unwrap(), expected);
|
||||
}
|
||||
}
|
||||
|
@ -52,7 +52,7 @@ pub struct Infer {
|
||||
/// Request backend
|
||||
backend: Arc<dyn Backend + Send + Sync>,
|
||||
/// Chat template
|
||||
chat_template: Option<ChatTemplate>,
|
||||
pub(crate) chat_template: Option<ChatTemplate>,
|
||||
/// Inference limit
|
||||
limit_concurrent_requests: Arc<Semaphore>,
|
||||
/// Backend health
|
||||
|
@ -40,13 +40,13 @@ impl ToolGrammar {
|
||||
),
|
||||
arguments: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The response content",
|
||||
}
|
||||
},
|
||||
"required": ["content"]
|
||||
// "properties": {
|
||||
// "content": {
|
||||
// "type": "string",
|
||||
// "description": "The response content",
|
||||
// }
|
||||
// },
|
||||
// "required": ["content"]
|
||||
}),
|
||||
},
|
||||
}))
|
||||
|
@ -8,6 +8,7 @@ pub mod validation;
|
||||
mod kserve;
|
||||
pub mod logging;
|
||||
|
||||
mod chat;
|
||||
mod sagemaker;
|
||||
pub mod usage_stats;
|
||||
mod vertex;
|
||||
@ -20,6 +21,7 @@ use serde::{Deserialize, Serialize};
|
||||
use tokenizers::Encoding;
|
||||
use tracing::warn;
|
||||
use utoipa::ToSchema;
|
||||
use uuid::Uuid;
|
||||
use validation::Validation;
|
||||
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
@ -150,6 +152,11 @@ impl HubTokenizerConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct ChatTemplateStandalone {
|
||||
pub chat_template: ChatTemplateVersions,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum TokenizerConfigToken {
|
||||
@ -171,6 +178,7 @@ impl TokenizerConfigToken {
|
||||
pub enum HubPreprocessorConfig {
|
||||
Idefics2Processor(Idefics2Preprocessor),
|
||||
Idefics3Processor(Idefics2Preprocessor),
|
||||
Gemma3Processor(Gemma3Processor),
|
||||
}
|
||||
|
||||
impl HubPreprocessorConfig {
|
||||
@ -186,6 +194,12 @@ pub struct Idefics2Preprocessor {
|
||||
do_image_splitting: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Gemma3Processor {
|
||||
#[serde(default)]
|
||||
do_image_splitting: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct HubProcessorConfig {
|
||||
pub chat_template: Option<ChatTemplateVersions>,
|
||||
@ -541,6 +555,7 @@ pub(crate) struct Chunk {
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
#[cfg_attr(test, derive(Debug))]
|
||||
pub(crate) struct ChatCompletion {
|
||||
pub id: String,
|
||||
#[schema(example = "1706270835")]
|
||||
@ -553,6 +568,7 @@ pub(crate) struct ChatCompletion {
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
#[cfg_attr(test, derive(Debug))]
|
||||
pub(crate) struct ChatCompletionComplete {
|
||||
pub index: u32,
|
||||
pub message: OutputMessage,
|
||||
@ -561,6 +577,7 @@ pub(crate) struct ChatCompletionComplete {
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||
pub(crate) struct ChatCompletionLogprobs {
|
||||
content: Vec<ChatCompletionLogprob>,
|
||||
}
|
||||
@ -619,6 +636,7 @@ impl From<(Vec<Token>, Vec<Vec<Token>>)> for ChatCompletionLogprobs {
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||
pub(crate) struct ChatCompletionLogprob {
|
||||
token: String,
|
||||
logprob: f32,
|
||||
@ -626,12 +644,14 @@ pub(crate) struct ChatCompletionLogprob {
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||
pub(crate) struct ChatCompletionTopLogprob {
|
||||
token: String,
|
||||
logprob: f32,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
|
||||
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||
pub(crate) struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
@ -640,6 +660,7 @@ pub(crate) struct Usage {
|
||||
|
||||
#[derive(Clone, Serialize, ToSchema)]
|
||||
#[serde(tag = "object")]
|
||||
#[cfg_attr(test, derive(Debug))]
|
||||
enum CompletionType {
|
||||
#[serde(rename = "chat.completion.chunk")]
|
||||
ChatCompletionChunk(ChatCompletionChunk),
|
||||
@ -707,6 +728,7 @@ impl ChatCompletion {
|
||||
}
|
||||
}
|
||||
#[derive(Clone, Serialize, ToSchema)]
|
||||
#[cfg_attr(test, derive(Debug))]
|
||||
pub(crate) struct ChatCompletionChunk {
|
||||
pub id: String,
|
||||
#[schema(example = "1706270978")]
|
||||
@ -719,6 +741,7 @@ pub(crate) struct ChatCompletionChunk {
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, ToSchema)]
|
||||
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||
pub(crate) struct ChatCompletionChoice {
|
||||
pub index: u32,
|
||||
pub delta: ChatCompletionDelta,
|
||||
@ -735,6 +758,7 @@ pub struct ToolCallDelta {
|
||||
|
||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||
#[serde(untagged)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
enum ChatCompletionDelta {
|
||||
Chat(TextMessage),
|
||||
Tool(ToolCallDelta),
|
||||
@ -759,48 +783,17 @@ impl ChatCompletionChunk {
|
||||
pub(crate) fn new(
|
||||
model: String,
|
||||
system_fingerprint: String,
|
||||
delta: Option<String>,
|
||||
tool_calls: Option<Vec<String>>,
|
||||
created: u64,
|
||||
logprobs: Option<ChatCompletionLogprobs>,
|
||||
finish_reason: Option<String>,
|
||||
choices: Vec<ChatCompletionChoice>,
|
||||
usage: Option<Usage>,
|
||||
) -> Self {
|
||||
let delta = match (delta, tool_calls) {
|
||||
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: delta,
|
||||
..Default::default()
|
||||
}),
|
||||
(None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {
|
||||
role: "assistant".to_string(),
|
||||
tool_calls: vec![DeltaToolCall {
|
||||
index: 0,
|
||||
id: String::new(),
|
||||
r#type: "function".to_string(),
|
||||
function: Function {
|
||||
name: None,
|
||||
arguments: tool_calls[0].to_string(),
|
||||
},
|
||||
}],
|
||||
}),
|
||||
(None, None) => ChatCompletionDelta::Chat(TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: "".to_string(),
|
||||
..Default::default()
|
||||
}),
|
||||
};
|
||||
Self {
|
||||
id: String::new(),
|
||||
created,
|
||||
model,
|
||||
system_fingerprint,
|
||||
choices: vec![ChatCompletionChoice {
|
||||
index: 0,
|
||||
delta,
|
||||
logprobs,
|
||||
finish_reason,
|
||||
}],
|
||||
usage: None,
|
||||
choices,
|
||||
usage,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -915,7 +908,7 @@ pub(crate) struct ChatRequest {
|
||||
/// Options for streaming response. Only set this when you set stream: true.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub stream_options: Option<StreamOptions>,
|
||||
pub stream_options: StreamOptions,
|
||||
}
|
||||
|
||||
impl ChatRequest {
|
||||
@ -1015,13 +1008,37 @@ impl ChatRequest {
|
||||
using_tools,
|
||||
))
|
||||
}
|
||||
|
||||
fn next_int_id(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let mut id: usize = 0;
|
||||
for message in &self.messages {
|
||||
if let MessageBody::Tool { tool_calls } = &message.body {
|
||||
for tool_call in tool_calls {
|
||||
let new_id: usize = tool_call.id.parse()?;
|
||||
id = std::cmp::max(id, new_id + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(id.to_string())
|
||||
}
|
||||
|
||||
/// Try to have linearly increasing id
|
||||
/// or resort to using Uuid if the initial
|
||||
/// scheme is not understood
|
||||
fn next_tool_call_id(&self) -> String {
|
||||
self.next_int_id().unwrap_or_else(|_| {
|
||||
let uid = Uuid::new_v4().to_string();
|
||||
uid.to_string()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Default)]
|
||||
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||
struct StreamOptions {
|
||||
/// If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.
|
||||
#[schema(example = "true")]
|
||||
#[serde(default)]
|
||||
include_usage: bool,
|
||||
}
|
||||
|
||||
@ -1445,7 +1462,7 @@ pub(crate) struct ChatTokenizeResponse {
|
||||
#[serde(transparent)]
|
||||
pub(crate) struct TokenizeResponse(Vec<SimpleToken>);
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
#[derive(Serialize, ToSchema, Clone)]
|
||||
pub(crate) struct StreamDetails {
|
||||
#[schema(example = "length")]
|
||||
pub finish_reason: FinishReason,
|
||||
@ -1457,7 +1474,7 @@ pub(crate) struct StreamDetails {
|
||||
pub input_length: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
#[derive(Serialize, ToSchema, Clone)]
|
||||
pub(crate) struct StreamResponse {
|
||||
pub index: u32,
|
||||
pub token: Token,
|
||||
@ -1700,9 +1717,25 @@ mod tests {
|
||||
|
||||
assert!(matches!(
|
||||
request.stream_options,
|
||||
Some(StreamOptions {
|
||||
StreamOptions {
|
||||
include_usage: true
|
||||
})
|
||||
}
|
||||
));
|
||||
|
||||
let json = json!({
|
||||
"model": "",
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
}]
|
||||
});
|
||||
let request: ChatRequest = serde_json::from_str(json.to_string().as_str()).unwrap();
|
||||
|
||||
assert!(matches!(
|
||||
request.stream_options,
|
||||
StreamOptions {
|
||||
include_usage: false
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
use crate::chat::{ChatChoice, ChatEvent, ChatState};
|
||||
/// HTTP Server logic
|
||||
use crate::config::Config;
|
||||
use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse};
|
||||
@ -47,8 +48,6 @@ use http::header::AUTHORIZATION;
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::IntoPyDict;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
use std::convert::Infallible;
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
@ -1114,62 +1113,6 @@ pub(crate) async fn completions(
|
||||
}
|
||||
}
|
||||
|
||||
enum StreamState {
|
||||
Buffering,
|
||||
BufferTrailing,
|
||||
Content { skip_close_quote: bool },
|
||||
}
|
||||
|
||||
/// Convert a StreamResponse into an Event to be sent over SSE
|
||||
fn create_event_from_stream_token(
|
||||
stream_token: &StreamResponse,
|
||||
logprobs: bool,
|
||||
inner_using_tools: bool,
|
||||
system_fingerprint: String,
|
||||
model_id: String,
|
||||
) -> Event {
|
||||
let event = Event::default();
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let logprobs = logprobs.then(|| {
|
||||
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens.clone()))
|
||||
});
|
||||
|
||||
// replace the content with the tool calls if grammar is present
|
||||
let (content, tool_calls) = if inner_using_tools {
|
||||
(None, Some(vec![stream_token.token.text.clone()]))
|
||||
} else {
|
||||
let content = if !stream_token.token.special {
|
||||
Some(stream_token.token.text.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
(content, None)
|
||||
};
|
||||
let finish_reason = stream_token
|
||||
.details
|
||||
.as_ref()
|
||||
.map(|details| details.finish_reason.format(true));
|
||||
let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
||||
model_id.clone(),
|
||||
system_fingerprint.clone(),
|
||||
content,
|
||||
tool_calls,
|
||||
current_time,
|
||||
logprobs,
|
||||
finish_reason,
|
||||
));
|
||||
|
||||
event.json_data(chat_complete).unwrap_or_else(|e| {
|
||||
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
||||
Event::default()
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate tokens
|
||||
#[utoipa::path(
|
||||
post,
|
||||
@ -1208,7 +1151,7 @@ pub(crate) async fn chat_completions(
|
||||
Extension(infer): Extension<Infer>,
|
||||
Extension(compute_type): Extension<ComputeType>,
|
||||
Extension(info): Extension<Info>,
|
||||
Json(chat): Json<ChatRequest>,
|
||||
Json(mut chat): Json<ChatRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
@ -1219,8 +1162,11 @@ pub(crate) async fn chat_completions(
|
||||
logprobs,
|
||||
..
|
||||
} = chat.clone();
|
||||
|
||||
tracing::debug!("Got chat_template {:?}", infer.chat_template);
|
||||
let id = chat.next_tool_call_id();
|
||||
let (generate_request, using_tools): (GenerateRequest, bool) =
|
||||
chat.try_into_generate(&infer)?;
|
||||
chat.clone().try_into_generate(&infer)?;
|
||||
span.record("parameters", format!("{:?}", generate_request.parameters));
|
||||
let logprobs = logprobs.unwrap_or_default();
|
||||
|
||||
@ -1232,167 +1178,41 @@ pub(crate) async fn chat_completions(
|
||||
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
||||
// switch on stream
|
||||
if stream {
|
||||
let (headers, response_stream) =
|
||||
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
|
||||
|
||||
// regex to match any function name
|
||||
let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
|
||||
Ok(regex) => regex,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: format!("Failed to compile regex: {}", e),
|
||||
error_type: "regex".to_string(),
|
||||
}),
|
||||
))
|
||||
}
|
||||
};
|
||||
let (headers, response_stream) = generate_stream_internal(
|
||||
infer.clone(),
|
||||
compute_type.clone(),
|
||||
Json(generate_request),
|
||||
span.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let response_stream = async_stream::stream! {
|
||||
let mut response_stream = Box::pin(response_stream);
|
||||
let mut buffer = Vec::new();
|
||||
let mut json_buffer = String::new();
|
||||
let mut state = if using_tools {
|
||||
StreamState::Buffering
|
||||
} else {
|
||||
StreamState::Content {
|
||||
skip_close_quote: false,
|
||||
}
|
||||
};
|
||||
let mut response_as_tool = using_tools;
|
||||
let mut state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
|
||||
while let Some(result) = response_stream.next().await {
|
||||
match result{
|
||||
Ok(stream_token) => {
|
||||
let token_text = &stream_token.token.text.clone();
|
||||
let usage = stream_token.details.as_ref().map(|details| {
|
||||
let completion_tokens = details.generated_tokens;
|
||||
let prompt_tokens = details.input_length;
|
||||
let total_tokens = prompt_tokens + completion_tokens;
|
||||
|
||||
Usage {
|
||||
completion_tokens,
|
||||
prompt_tokens,
|
||||
total_tokens,
|
||||
let events = state.push(stream_token);
|
||||
match events{
|
||||
ChatEvent::NoTool => {
|
||||
chat.tools = None;
|
||||
chat.response_format = None;
|
||||
let (generate_request, using_tools): (GenerateRequest, bool) =
|
||||
chat.clone().try_into_generate(&infer).unwrap();
|
||||
assert!(!using_tools);
|
||||
let (_headers, response_stream2) =
|
||||
generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await;
|
||||
state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
|
||||
response_stream = Box::pin(response_stream2);
|
||||
}
|
||||
});
|
||||
match state {
|
||||
StreamState::Buffering => {
|
||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||
buffer.push(stream_token);
|
||||
if let Some(captures) = function_regex.captures(&json_buffer) {
|
||||
let function_name = captures[1].to_string();
|
||||
if function_name == "no_tool" {
|
||||
state = StreamState::BufferTrailing;
|
||||
response_as_tool = false;
|
||||
buffer.clear();
|
||||
json_buffer.clear();
|
||||
} else {
|
||||
state = StreamState::Content {
|
||||
skip_close_quote: false,
|
||||
};
|
||||
// send all the buffered messages
|
||||
for stream_token in &buffer {
|
||||
let event = create_event_from_stream_token(
|
||||
stream_token,
|
||||
logprobs,
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
);
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// if we skipped sending the buffer we need to avoid sending the following json key and quotes
|
||||
StreamState::BufferTrailing => {
|
||||
let infix_text = "\"content\":\"";
|
||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||
// keep capturing until we find the infix text
|
||||
match json_buffer.find(infix_text) {
|
||||
Some(content_key_index) => {
|
||||
json_buffer =
|
||||
json_buffer[content_key_index + infix_text.len()..].to_string();
|
||||
}
|
||||
None => {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// if there is leftover text after removing the infix text, we need to send it
|
||||
if !json_buffer.is_empty() {
|
||||
let event = Event::default();
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
let chat_complete =
|
||||
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
||||
model_id.clone(),
|
||||
system_fingerprint.clone(),
|
||||
Some(json_buffer.clone()),
|
||||
None,
|
||||
current_time,
|
||||
None,
|
||||
None,
|
||||
));
|
||||
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
|
||||
InferError::StreamSerializationError(e.to_string()).into()
|
||||
ChatEvent::Events(events) => {
|
||||
for chat_complete in events{
|
||||
yield Ok(Event::default().json_data(chat_complete).unwrap_or_else(|e| {
|
||||
tracing::error!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
||||
Event::default()
|
||||
}));
|
||||
}
|
||||
// cleanup the buffers
|
||||
buffer.clear();
|
||||
json_buffer.clear();
|
||||
state = StreamState::Content {
|
||||
skip_close_quote: true,
|
||||
};
|
||||
}
|
||||
StreamState::Content { skip_close_quote } => {
|
||||
if skip_close_quote && token_text.contains('"') {
|
||||
break;
|
||||
}
|
||||
|
||||
// send the content
|
||||
let event = create_event_from_stream_token(
|
||||
&stream_token,
|
||||
logprobs,
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
);
|
||||
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
}
|
||||
}
|
||||
|
||||
let should_send_usage = usage.is_some()
|
||||
&& stream_options
|
||||
.as_ref()
|
||||
.is_some_and(|opts| opts.include_usage);
|
||||
|
||||
if should_send_usage {
|
||||
let usage_data = usage.unwrap();
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk {
|
||||
id: String::new(),
|
||||
created: current_time,
|
||||
model: model_id.clone(),
|
||||
system_fingerprint: system_fingerprint.clone(),
|
||||
choices: vec![],
|
||||
usage: Some(Usage {
|
||||
prompt_tokens: usage_data.prompt_tokens,
|
||||
completion_tokens: usage_data.completion_tokens,
|
||||
total_tokens: usage_data.total_tokens,
|
||||
}),
|
||||
});
|
||||
|
||||
yield Ok(Event::default()
|
||||
.json_data(chat_complete)
|
||||
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into()));
|
||||
}
|
||||
}
|
||||
Err(err) => yield Ok(err.into_openai_event())
|
||||
@ -1404,8 +1224,13 @@ pub(crate) async fn chat_completions(
|
||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||
Ok((headers, sse).into_response())
|
||||
} else {
|
||||
let (headers, input_length, Json(generation)) =
|
||||
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
|
||||
let (mut headers, mut input_length, Json(generation)) = generate_internal(
|
||||
Extension(infer.clone()),
|
||||
compute_type.clone(),
|
||||
Json(generate_request),
|
||||
span.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
@ -1413,55 +1238,25 @@ pub(crate) async fn chat_completions(
|
||||
.as_secs();
|
||||
|
||||
let (tool_calls, output) = if using_tools {
|
||||
let gen_text_value: Value =
|
||||
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
||||
InferError::ToolError(format!(
|
||||
"Failed to parse generated text: {} {:?}",
|
||||
e, generation.generated_text
|
||||
))
|
||||
})?;
|
||||
let function = gen_text_value.get("function").ok_or(InferError::ToolError(
|
||||
"No function found in generated text".to_string(),
|
||||
))?;
|
||||
|
||||
let name = function
|
||||
.get("_name")
|
||||
.and_then(Value::as_str)
|
||||
.ok_or(InferError::ToolError(
|
||||
"No _name found in generated text".to_string(),
|
||||
))?
|
||||
.to_string();
|
||||
|
||||
let mut arguments = function.clone();
|
||||
if let Value::Object(ref mut props) = arguments {
|
||||
props.remove("_name");
|
||||
}
|
||||
match name.as_str() {
|
||||
"no_tool" => {
|
||||
// parse the content message
|
||||
let content_message = arguments
|
||||
.get("content")
|
||||
.and_then(Value::as_str)
|
||||
.ok_or_else(|| {
|
||||
InferError::ToolError(
|
||||
"No `content` found in generated text".to_string(),
|
||||
)
|
||||
})?
|
||||
.to_string();
|
||||
(None, Some(content_message))
|
||||
}
|
||||
_ => {
|
||||
let tool_calls = vec![ToolCall {
|
||||
id: "0".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
description: None,
|
||||
name,
|
||||
arguments,
|
||||
},
|
||||
}];
|
||||
(Some(tool_calls), None)
|
||||
match crate::chat::parse_output(&generation.generated_text)? {
|
||||
ChatChoice::NoTool => {
|
||||
chat.tools = None;
|
||||
chat.response_format = None;
|
||||
let (generate_request, using_tools): (GenerateRequest, bool) =
|
||||
chat.clone().try_into_generate(&infer)?;
|
||||
assert!(!using_tools);
|
||||
let (headers_final, input_length_final, Json(generation)) = generate_internal(
|
||||
Extension(infer),
|
||||
compute_type,
|
||||
Json(generate_request),
|
||||
span,
|
||||
)
|
||||
.await?;
|
||||
headers = headers_final;
|
||||
input_length = input_length_final;
|
||||
(None, Some(generation.generated_text))
|
||||
}
|
||||
ChatChoice::ToolCalls(tool_calls) => (Some(tool_calls), None),
|
||||
}
|
||||
} else {
|
||||
(None, Some(generation.generated_text))
|
||||
@ -1727,7 +1522,7 @@ pub async fn run(
|
||||
|
||||
// Shared API builder initialization
|
||||
let api_builder = || {
|
||||
let mut builder = ApiBuilder::new().with_progress(false);
|
||||
let mut builder = ApiBuilder::from_env().with_progress(false);
|
||||
if let Some(token) = authorization_token {
|
||||
builder = builder.with_token(Some(token));
|
||||
}
|
||||
@ -1781,6 +1576,7 @@ pub async fn run(
|
||||
tokenizer_config_filename,
|
||||
preprocessor_config_filename,
|
||||
processor_config_filename,
|
||||
chat_template_filename,
|
||||
model_info,
|
||||
) = match api {
|
||||
Type::None => (
|
||||
@ -1788,6 +1584,7 @@ pub async fn run(
|
||||
Some(local_path.join("tokenizer_config.json")),
|
||||
Some(local_path.join("preprocessor_config.json")),
|
||||
Some(local_path.join("processor_config.json")),
|
||||
Some(local_path.join("chat_template.json")),
|
||||
None,
|
||||
),
|
||||
Type::Api(api) => {
|
||||
@ -1801,6 +1598,7 @@ pub async fn run(
|
||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
||||
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
|
||||
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
||||
let chat_template_filename = api_repo.get("chat_template.json").await.ok();
|
||||
|
||||
let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {
|
||||
Some(model_info)
|
||||
@ -1813,10 +1611,12 @@ pub async fn run(
|
||||
tokenizer_config_filename,
|
||||
preprocessor_config_filename,
|
||||
processor_config_filename,
|
||||
chat_template_filename,
|
||||
model_info,
|
||||
)
|
||||
}
|
||||
Type::Cache(cache) => {
|
||||
tracing::info!("Cache {cache:?}");
|
||||
let repo = cache.repo(Repo::with_revision(
|
||||
tokenizer_name.to_string(),
|
||||
RepoType::Model,
|
||||
@ -1827,23 +1627,41 @@ pub async fn run(
|
||||
repo.get("tokenizer_config.json"),
|
||||
repo.get("preprocessor_config.json"),
|
||||
repo.get("processor_config.json"),
|
||||
repo.get("chat_template.json"),
|
||||
None,
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
// if chat_template_filename is present, load the chat template
|
||||
let chat_template: Option<crate::ChatTemplateVersions> = chat_template_filename
|
||||
.and_then(|f| std::fs::read_to_string(f).ok())
|
||||
.and_then(|c| {
|
||||
let res = serde_json::from_str::<crate::ChatTemplateStandalone>(&c);
|
||||
if let Err(e) = &res {
|
||||
tracing::warn!("Could not parse chat template {e:?}");
|
||||
}
|
||||
res.ok().map(|t| t.chat_template)
|
||||
});
|
||||
|
||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||
tracing::warn!("Tokenizer_config {tokenizer_config_path:?} - {tokenizer_config_filename:?}");
|
||||
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
||||
{
|
||||
HubTokenizerConfig::from_file(filename)
|
||||
} else {
|
||||
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
||||
};
|
||||
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
||||
let mut tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
||||
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||
HubTokenizerConfig::default()
|
||||
});
|
||||
|
||||
if chat_template.is_some() {
|
||||
tracing::info!("Using chat template from chat_template.json");
|
||||
tokenizer_config.chat_template = chat_template;
|
||||
}
|
||||
|
||||
let tokenizer: Result<Tokenizer, WebServerError> = {
|
||||
use pyo3::prelude::*;
|
||||
Python::with_gil(|py| -> PyResult<()> {
|
||||
|
@ -18,6 +18,7 @@ use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::warn;
|
||||
use tracing::{instrument, Span};
|
||||
use {once_cell::sync::Lazy, regex::Regex};
|
||||
|
||||
@ -694,6 +695,14 @@ fn image_tokens(
|
||||
"<|vision_start|>{:?}<|vision_end|>",
|
||||
"<|image_pad|>".repeat(config.get_number_of_features(height, width))
|
||||
),
|
||||
Gemma3(_config) => {
|
||||
// TODO: prefer using the config to determine the number of features
|
||||
let num_mm_soft_tokens_per_image = 256;
|
||||
format!(
|
||||
"\n\n<start_of_image>{}<end_of_image>\n\n",
|
||||
"<image_soft_token>".repeat(num_mm_soft_tokens_per_image)
|
||||
)
|
||||
}
|
||||
_ => unimplemented!("Images tokens are not supported for this model configuration"),
|
||||
}
|
||||
}
|
||||
@ -721,8 +730,8 @@ fn prepare_input<T: TokenizerTrait>(
|
||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||
let (tokenizer_query, input_chunks) = match config {
|
||||
Some(
|
||||
config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Paligemma(_) | LlavaNext(_)
|
||||
| Qwen2Vl(_) | Qwen2_5Vl(_)),
|
||||
config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Gemma3(_) | Paligemma(_)
|
||||
| LlavaNext(_) | Qwen2Vl(_) | Qwen2_5Vl(_)),
|
||||
) => {
|
||||
let mut input_chunks = Vec::new();
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
|
@ -39,6 +39,7 @@ install: install-cuda
|
||||
install-cuda: install-server install-flash-attention-v2-cuda install-flash-attention
|
||||
uv pip install -e ".[attention,bnb,marlin,moe]"
|
||||
uv pip install nvidia-nccl-cu12==2.22.3
|
||||
kernels download .
|
||||
|
||||
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
272
server/kernels.lock
Normal file
272
server/kernels.lock
Normal file
@ -0,0 +1,272 @@
|
||||
[
|
||||
{
|
||||
"repo_id": "kernels-community/paged-attention",
|
||||
"sha": "331b7e63a6b592799c8bc992f681bb1ee2c865a2",
|
||||
"variants": {
|
||||
"torch25-cxx11-cu118-x86_64-linux": {
|
||||
"hash": "sha256-8e0aa39abab82f1d21b661d35e0470a24c3ebbdda38532ded805c18037a1ad1e",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx11-cu121-x86_64-linux": {
|
||||
"hash": "sha256-b0c3aef6c4c9aac627975cb1a2bfc46a70390763c8165575b89d1651d007c38a",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx11-cu124-x86_64-linux": {
|
||||
"hash": "sha256-960fbc8998439d779adb47fb2a37cce68c7dc075d8a49893bd487be9ca2d1389",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx98-cu118-x86_64-linux": {
|
||||
"hash": "sha256-9d6d60c411c55aa2f9d7c681c2be96f4262d56c96f592f3d4fb35ce4f4f1e18e",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx98-cu121-x86_64-linux": {
|
||||
"hash": "sha256-98c0a305b2cc9b7be757fab923d9aa406c686dcd0460e462926f87d051ef3d19",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx98-cu124-x86_64-linux": {
|
||||
"hash": "sha256-71e586416213c96ffbdeae0d077ba97bfde5b00005f2746d4cba2320cb53bf87",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu118-x86_64-linux": {
|
||||
"hash": "sha256-2f559312c54d558b33a4082ffc3fcf923f51da40ced19bfc8920e998ba2b71bf",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu124-x86_64-linux": {
|
||||
"hash": "sha256-6033b41a0f8a9509887c6171f0b42d9aa738490903b3fd5ea2c52703c5fb8fc3",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu126-x86_64-linux": {
|
||||
"hash": "sha256-3139f66a53f2bf0c314b4d309893095746bdc9c3914c904fc31adfdf553ed219",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu118-x86_64-linux": {
|
||||
"hash": "sha256-2173d77e384d8e2881fc38603992c09e8be7bcd9da4cafdd4f2a5ce0ce22caf4",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu124-x86_64-linux": {
|
||||
"hash": "sha256-7b1aaef81e01ecce83e03c50872910680ff2953f7c6ffd3ff15e8d9497ca9239",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu126-x86_64-linux": {
|
||||
"hash": "sha256-818b160a88b12b8e871099e40f76aa436ee828e2e060ecc35502dbe34a6ebd3b",
|
||||
"hash_type": "git_lfs_concat"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"repo_id": "kernels-community/moe",
|
||||
"sha": "605a216f507b9a97b543140dee8937a4622069a8",
|
||||
"variants": {
|
||||
"torch25-cxx11-cu118-x86_64-linux": {
|
||||
"hash": "sha256-855d92f02be3bfba0758161fa1266159d76c172e7c5d43d30816d22cfba76074",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx11-cu121-x86_64-linux": {
|
||||
"hash": "sha256-e6e780230477bbbc26fc40cc7fcff50298155998af4fc77a026c9f815ec984b1",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx11-cu124-x86_64-linux": {
|
||||
"hash": "sha256-52c1fb337033c4d1d7a279c5cb28aebbc7389976f21dc5803aeb16b2f7aeb94c",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx98-cu118-x86_64-linux": {
|
||||
"hash": "sha256-1fb654e8d02dda2a2382d1fb3a3ca9738d292eea674b30b80030cdcdfb6a0035",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx98-cu121-x86_64-linux": {
|
||||
"hash": "sha256-0cf235f1de85d4ce7490c79aa64220f608f886f313b676d91c331a6a2fd67bbb",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx98-cu124-x86_64-linux": {
|
||||
"hash": "sha256-3def11fee9bf1ea9b1579206fd5f5ecbcaad47ac478e2c3aa7b2c9c7fd5db934",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu118-x86_64-linux": {
|
||||
"hash": "sha256-3a49ee03f675190a79c7c74a45cc403d491eceb63a943f47d52064a11ca6ef6f",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu124-x86_64-linux": {
|
||||
"hash": "sha256-dbf20cb11db7d53e11147ab13641eefaa235f9ac2fde1beaf8f56f850c11bd54",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu126-x86_64-linux": {
|
||||
"hash": "sha256-8a07232ab316e8eab74747662cb7b86aac03f44ff158f275768fd59390df2525",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu118-x86_64-linux": {
|
||||
"hash": "sha256-cdd46301af997eeace5e016d8590969981b3a3f8647828d04baa5fa10c696746",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu124-x86_64-linux": {
|
||||
"hash": "sha256-c865188e9d2c17f3358f3d343fb40340232457572744bf85efd6b20af545d5f3",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu126-x86_64-linux": {
|
||||
"hash": "sha256-2a8b09f3272ea80491e78a39ff886680471d99f7ba571581809adfe918013898",
|
||||
"hash_type": "git_lfs_concat"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"repo_id": "kernels-community/quantization",
|
||||
"sha": "95272c71ca71b1ddbacb0105dab54e5d5240bd5c",
|
||||
"variants": {
|
||||
"torch25-cxx11-cu118-x86_64-linux": {
|
||||
"hash": "sha256-2d0a274cf0117bf7880d6040adafa1b70fe8bff3a00ef2834ed5435a6b525a49",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx11-cu121-x86_64-linux": {
|
||||
"hash": "sha256-116458beac63ea5eeb1e7fba7edc68d160cd8ac28f55b926d79035551aac7d5f",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx11-cu124-x86_64-linux": {
|
||||
"hash": "sha256-cace644c6fb04470384796c18987135cb051dfb90a14e902c51a3786fc07c599",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx98-cu118-x86_64-linux": {
|
||||
"hash": "sha256-104c6961cd3e1a74efdf14ea2172acc6647846852fccafe3698a27a6cf37941d",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx98-cu121-x86_64-linux": {
|
||||
"hash": "sha256-cdc95b41aa91a803f11f8cd53001895c2b69550b5af2fb278d6f124381229d0b",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx98-cu124-x86_64-linux": {
|
||||
"hash": "sha256-d5388469cb6074f196f20b1e1e4805bb3c967a8147b31ca2c0461aa87b50604e",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu118-x86_64-linux": {
|
||||
"hash": "sha256-70c4bb3792c4c3207d4963173d8d0ef3b2bda677151aef140662dd87bfa1b69f",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu124-x86_64-linux": {
|
||||
"hash": "sha256-bcacbb2232f49345f27e07fa821b48a7e3df643c01af37281fcafc74c471f682",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu126-x86_64-linux": {
|
||||
"hash": "sha256-344d20964f7eb133e5ec6fda976fa5ee62807b739a4361f236aca5ae53beb9ac",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu118-x86_64-linux": {
|
||||
"hash": "sha256-dfaec226550254fbce1a5c7e2f547e85700958a1a4087e1c873d22e6f71a5ceb",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu124-x86_64-linux": {
|
||||
"hash": "sha256-0abe6460d0a2202b0086e3663092595e5b93b9a9cbb85c10034180cc9bfebc6e",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu126-x86_64-linux": {
|
||||
"hash": "sha256-68e156f94c3c0c9523773b62eaeced93766e0d9ee67d8191fb9570fb5af30d5b",
|
||||
"hash_type": "git_lfs_concat"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"repo_id": "kernels-community/quantization-eetq",
|
||||
"sha": "a80ce846d6270ddddeee109523ed947f594f246b",
|
||||
"variants": {
|
||||
"torch25-cxx11-cu118-x86_64-linux": {
|
||||
"hash": "sha256-e06beb00799b1e656583eb0496f09fc0bf1b26f75e9864a2fe19ebd5b62c3671",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx11-cu121-x86_64-linux": {
|
||||
"hash": "sha256-c128d3ef6558cfedf045c4a713891792708851b7f6f027de835d9083cb3b297d",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx11-cu124-x86_64-linux": {
|
||||
"hash": "sha256-c7e2e14fc114788634b34a4f670f7bf4d27321e5ed40ff446f5a25eef70222c7",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx98-cu118-x86_64-linux": {
|
||||
"hash": "sha256-58dad53cfbf1315af464f9d8ba7be9012089c839d4f06a8d2cf8ce0deaf5949a",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx98-cu121-x86_64-linux": {
|
||||
"hash": "sha256-6519af49c0f689744a7b49497ad2bea1524b69e4095446087d7ab622b898aa30",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx98-cu124-x86_64-linux": {
|
||||
"hash": "sha256-94e0731b58a9ba0e5e2f37b100c8d987c80b5d349008ef625917d020b6c52d25",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu118-x86_64-linux": {
|
||||
"hash": "sha256-e5b04475538f49d7b4ffded080e4c9c86a658abc12667e3838ebcc410ab1eef4",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu124-x86_64-linux": {
|
||||
"hash": "sha256-783c02db737a6ec9958b3090f164b87888d3b26e30a4fb6e1cd0c1a635753fab",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu126-x86_64-linux": {
|
||||
"hash": "sha256-a3d81f82f9cfe9d8a6d46758758b3a1b3055d902f41917b4ef2976373db843d6",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu118-x86_64-linux": {
|
||||
"hash": "sha256-f1de67e17944a9816f778c72ae73bbbc90d795cb4885c2f9ee5e0b9a3c57583b",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu124-x86_64-linux": {
|
||||
"hash": "sha256-789b50d767a5121a7e5a52eaf0c8e897bf1787f049ca08faffb220e5053a5f10",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu126-x86_64-linux": {
|
||||
"hash": "sha256-7c7fe57fea7b9be253085d506f01b2487b2306f22bdffe1de44397fc9f8a3613",
|
||||
"hash_type": "git_lfs_concat"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"repo_id": "kernels-community/rotary",
|
||||
"sha": "4db658e027ec752840bb3f557ee076413b8db03f",
|
||||
"variants": {
|
||||
"torch25-cxx11-cu118-x86_64-linux": {
|
||||
"hash": "sha256-907df2035267a65793985bb7f69fb2a975955fb08c2bbc78c58def43d02801da",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx11-cu121-x86_64-linux": {
|
||||
"hash": "sha256-b614735ae61ee2c1825a3c823fa0cdd3aa07d0bb3f4106001b9e1a557c0ca9b9",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx11-cu124-x86_64-linux": {
|
||||
"hash": "sha256-f2e98ec72faaebc1cae25f83ccdbb151868b6902fb5a0623e09d700a514c2a7e",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx98-cu118-x86_64-linux": {
|
||||
"hash": "sha256-421214c5a576fac2e0b7998395dccd7f66010f65a6fc647ce06b106ea91105d2",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx98-cu121-x86_64-linux": {
|
||||
"hash": "sha256-9d1c464cf7f391975afa48f2254a639f41582155ad1b50c25bb122418ce8db58",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx98-cu124-x86_64-linux": {
|
||||
"hash": "sha256-82f8012d78304efaa7318f106907630294d10c8b5c9f56923c71df0b03e09f14",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu118-x86_64-linux": {
|
||||
"hash": "sha256-a3247919dcc392efc7e54725dfbce9ee8a796fe4ee53d113048b313de074d3da",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu124-x86_64-linux": {
|
||||
"hash": "sha256-a21c9734d15946f4cc967d0555d45d7effc6624990c6889fc49162af744fbbe9",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu126-x86_64-linux": {
|
||||
"hash": "sha256-01cdda160425b29db0d9bb084874ade4ac081735f9717f272aaefe5bcb379ae1",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu118-x86_64-linux": {
|
||||
"hash": "sha256-17be5b770418ad47101c49d8945b5aa32af9eb5a840bdffb0514d0e264edd860",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu124-x86_64-linux": {
|
||||
"hash": "sha256-3cd4b9f63cc903e01325b7e5b204e40fc6600c0685f2e19e6f1fa604a599d82d",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu126-x86_64-linux": {
|
||||
"hash": "sha256-c569f4a4f9b64792507c58d7cfa31dde1285b52125ef07cc98d9f23636af09ca",
|
||||
"hash_type": "git_lfs_concat"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
@ -14,7 +14,7 @@ dependencies = [
|
||||
"grpcio>=1.67.0",
|
||||
"grpcio-reflection>=1.67.0",
|
||||
"grpcio-status>=1.67.0",
|
||||
"hf-kernels>=0.1.5",
|
||||
"kernels>=0.2.1",
|
||||
"hf-transfer>=0.1.8",
|
||||
"loguru>=0.7.3",
|
||||
"numpy>=1.26,<3",
|
||||
@ -36,7 +36,7 @@ dependencies = [
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hf-kernels>=0.1.2", "setuptools"]
|
||||
requires = ["kernels>=0.1.7", "setuptools"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.kernels.dependencies]
|
||||
|
@ -205,7 +205,6 @@ class LoraWeights(AdapterWeights):
|
||||
lora_a_list = [None] * nlayers
|
||||
lora_b_list = [None] * nlayers
|
||||
|
||||
# import ipdb; ipdb.set_trace()
|
||||
for layer_id in range(nlayers):
|
||||
key = (layer_id, layer_type)
|
||||
if key not in target_to_layer:
|
||||
|
@ -38,6 +38,7 @@ def paged_attention(
|
||||
*,
|
||||
kv_scales: KVScales,
|
||||
softcap: Optional[float] = None,
|
||||
window_size_left: Optional[int] = -1,
|
||||
):
|
||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||
# Copyright 2023 The vLLM team. All rights
|
||||
@ -79,12 +80,15 @@ def paged_attention(
|
||||
sm_scale=softmax_scale,
|
||||
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
|
||||
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
|
||||
window_left=window_size_left,
|
||||
)
|
||||
elif ATTENTION == "flashdecoding":
|
||||
max_q = 1
|
||||
max_k = max_s
|
||||
import flash_attn_2_cuda
|
||||
|
||||
window_size_right = -1 if window_size_left == -1 else 0
|
||||
|
||||
# TODO fixme when flash contains the fix.
|
||||
# Number of splits is not correctly handled
|
||||
# by the current path
|
||||
@ -109,8 +113,8 @@ def paged_attention(
|
||||
softmax_scale,
|
||||
False, # zero_tensors
|
||||
True, # causal
|
||||
-1, # Window_left
|
||||
-1, # Window right
|
||||
window_size_left, # Window_left
|
||||
window_size_right, # Window right
|
||||
softcap,
|
||||
False, # return softmax
|
||||
None, # generator
|
||||
@ -253,6 +257,7 @@ def attention(
|
||||
sm_scale=softmax_scale,
|
||||
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
|
||||
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
|
||||
window_left=window_size_left,
|
||||
)
|
||||
|
||||
# If we are using flashdecoding or paged, we always use flash-attn for
|
||||
|
@ -52,7 +52,6 @@ def use_prefill_with_paged_kv_state(
|
||||
page_size: int,
|
||||
kv_dtype: torch.dtype,
|
||||
q_dtype: torch.dtype,
|
||||
window_left: int,
|
||||
):
|
||||
"""
|
||||
Context manager to set the active flashinfer prefill state to the given
|
||||
@ -95,7 +94,6 @@ def use_prefill_with_paged_kv_state(
|
||||
kv_data_type=kv_dtype,
|
||||
q_data_type=q_dtype,
|
||||
page_size=page_size,
|
||||
window_left=-1 if window_left is None else window_left,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
@ -172,7 +170,6 @@ def use_decode_state(
|
||||
page_size: int,
|
||||
kv_cache_dtype: torch.dtype,
|
||||
q_dtype: torch.dtype,
|
||||
window_left: int,
|
||||
):
|
||||
"""
|
||||
Context manager to set the active flashinfer decoding state to the given
|
||||
@ -209,7 +206,6 @@ def use_decode_state(
|
||||
page_size=page_size,
|
||||
data_type=kv_cache_dtype,
|
||||
q_data_type=q_dtype,
|
||||
window_left=-1 if window_left is None else window_left,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
|
@ -78,6 +78,7 @@ def paged_attention(
|
||||
*,
|
||||
kv_scales: KVScales,
|
||||
softcap: Optional[float] = None,
|
||||
window_size_left: Optional[int] = -1,
|
||||
):
|
||||
if softcap is not None:
|
||||
raise NotImplementedError("softcap is not available in IPEX")
|
||||
|
@ -59,6 +59,7 @@ def paged_attention(
|
||||
*,
|
||||
kv_scales: KVScales,
|
||||
softcap: Optional[float] = None,
|
||||
window_size_left: Optional[int] = -1,
|
||||
):
|
||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||
# Copyright 2023 The vLLM team. All rights
|
||||
@ -82,6 +83,8 @@ def paged_attention(
|
||||
max_k = max_s
|
||||
import flash_attn_2_cuda
|
||||
|
||||
window_size_right = -1 if window_size_left == -1 else 0
|
||||
|
||||
if softcap is None:
|
||||
softcap = 0.0
|
||||
out = flash_attn_2_cuda.varlen_fwd(
|
||||
@ -101,8 +104,8 @@ def paged_attention(
|
||||
softmax_scale,
|
||||
False, # zero_tensors
|
||||
True, # causal
|
||||
-1, # Window_left
|
||||
-1, # Window right
|
||||
window_size_left, # Window_left
|
||||
window_size_right, # Window right
|
||||
softcap,
|
||||
False, # return softmax
|
||||
None, # generator
|
||||
|
@ -106,6 +106,17 @@ try:
|
||||
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
||||
FlashGemma2ForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_gemma3_modeling import (
|
||||
FlashGemma3ForCausalLM,
|
||||
Gemma3ForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.gemma3.processing_gemma3 import (
|
||||
Gemma3Processor,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.gemma3.configuration_gemma3 import (
|
||||
Gemma3Config,
|
||||
Gemma3TextConfig,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
|
||||
FlashDbrxForCausalLM,
|
||||
DbrxConfig,
|
||||
@ -258,6 +269,16 @@ class ModelType(enum.Enum):
|
||||
"name": "Gemma2",
|
||||
"url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
|
||||
}
|
||||
GEMMA3 = {
|
||||
"type": "gemma3",
|
||||
"name": "Gemma3",
|
||||
"url": "https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d",
|
||||
}
|
||||
GEMMA3_TEXT = {
|
||||
"type": "gemma3_text",
|
||||
"name": "Gemma3 Text",
|
||||
"url": "https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d",
|
||||
}
|
||||
COHERE = {
|
||||
"type": "cohere",
|
||||
"name": "Cohere",
|
||||
@ -1094,6 +1115,83 @@ def get_model(
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif model_type == GEMMA3_TEXT:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashGemma3ForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
# TODO: once implemented in transformers, use the config class
|
||||
# and processor class from there.
|
||||
config_class=Gemma3TextConfig,
|
||||
# Works better for these models
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif FLASH_TRANSFORMERS_BACKEND:
|
||||
return TransformersFlashCausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma3"))
|
||||
else:
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif model_type == GEMMA3:
|
||||
if FLASH_ATTENTION:
|
||||
# TODO: Use VlmCausalLM when image support is added.
|
||||
return VlmCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=Gemma3ForConditionalGeneration,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
# TODO: once implemented in transformers, use the config class
|
||||
# and processor class from there.
|
||||
config_class=Gemma3Config,
|
||||
processor_class=Gemma3Processor,
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif FLASH_TRANSFORMERS_BACKEND:
|
||||
return TransformersFlashCausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma3"))
|
||||
else:
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == COHERE:
|
||||
if FLASH_ATTENTION:
|
||||
|
@ -287,6 +287,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||
max_s,
|
||||
softcap=self.softcap,
|
||||
kv_scales=self.kv_scales,
|
||||
window_size_left=self.window_size,
|
||||
)
|
||||
|
||||
return self.o_proj(
|
||||
|
@ -0,0 +1,902 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from torch import nn
|
||||
from typing import Optional, List, Tuple
|
||||
import copy
|
||||
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
get_linear,
|
||||
#
|
||||
SpeculativeHead,
|
||||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
from text_generation_server.models.custom_modeling.vlm import (
|
||||
load_text_model,
|
||||
load_vision_model,
|
||||
)
|
||||
|
||||
|
||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
from transformers.activations import ACT2FN
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
Seqlen,
|
||||
)
|
||||
|
||||
|
||||
ATTENTION_TYPE_GLOBAL = "global"
|
||||
ATTENTION_TYPE_LOCAL = "local_sliding"
|
||||
|
||||
|
||||
class Gemma3FastRMSNorm(FastRMSNorm):
|
||||
@classmethod
|
||||
def load(cls, prefix: str, weights, eps=1e-6):
|
||||
dtype = weights.dtype
|
||||
weights.dtype = torch.float32
|
||||
weight = weights.get_tensor(f"{prefix}.weight") + 1
|
||||
weights.dtype = dtype
|
||||
new = cls(weight, eps)
|
||||
new.dtype = dtype
|
||||
return new
|
||||
|
||||
# perform the multiplication in full precision and downcast after
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
hidden_states = hidden_states * self.weight
|
||||
return hidden_states.to(self.dtype), residual
|
||||
|
||||
|
||||
def load_attention(config, prefix: str, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
return _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
|
||||
def _load_gqa(config, prefix: str, weights):
|
||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if isinstance(weight, UnquantizedWeight):
|
||||
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
|
||||
head_size = config.head_dim
|
||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||
assert list(weight.weight.shape) == [
|
||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||
config.hidden_size,
|
||||
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias=None))
|
||||
|
||||
|
||||
class FlashGemma3Attention(torch.nn.Module):
|
||||
def __init__(
|
||||
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_size = config.head_dim
|
||||
self.causal = causal
|
||||
if is_sliding:
|
||||
self.window_size = config.sliding_window
|
||||
# TODO: remove this hack to support local sliding window
|
||||
config = copy.deepcopy(config)
|
||||
config.rope_scaling = dict(rope_type="default")
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=config.head_dim,
|
||||
base=config.rope_local_base_freq,
|
||||
device=weights.device,
|
||||
)
|
||||
else:
|
||||
self.window_size = -1
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=config.head_dim,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
self.softmax_scale = (
|
||||
config.query_pre_attn_scalar**-0.5
|
||||
if config.query_pre_attn_scalar is not None
|
||||
else None
|
||||
)
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.num_key_value_heads = (
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
self.softcap = None # config.attn_logit_softcapping
|
||||
|
||||
query_key_value = load_attention(config, prefix, weights)
|
||||
self.query_key_value = TensorParallelMultiAdapterLinear.load(
|
||||
query_key_value,
|
||||
layer_id,
|
||||
["q_proj", "k_proj", "v_proj"],
|
||||
sizes=[
|
||||
self.head_size * config.num_attention_heads,
|
||||
self.head_size * config.num_key_value_heads,
|
||||
self.head_size * config.num_key_value_heads,
|
||||
],
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||
|
||||
o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||
o_proj,
|
||||
layer_id,
|
||||
"o_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
self.q_norm = Gemma3FastRMSNorm.load(
|
||||
prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
self.k_norm = Gemma3FastRMSNorm.load(
|
||||
prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
self.enable_gqa = self.num_heads != self.num_key_value_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
max_s,
|
||||
adapter_data,
|
||||
attention_mask,
|
||||
):
|
||||
|
||||
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||
query, kv = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
2 * self.head_size * self.num_key_value_heads,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads * self.head_size)
|
||||
key = kv[:, 0]
|
||||
value = kv[:, 1]
|
||||
|
||||
query = query.reshape(-1, self.head_size)
|
||||
key = key.reshape(-1, self.head_size)
|
||||
|
||||
query, _ = self.q_norm(query.contiguous())
|
||||
key, _ = self.k_norm(key.contiguous())
|
||||
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_key_value_heads, self.head_size)
|
||||
value = value.view(-1, self.num_key_value_heads, self.head_size)
|
||||
|
||||
self.rotary_emb(query, key, cos, sin)
|
||||
|
||||
kv_cache.store(
|
||||
key=key,
|
||||
value=value,
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
if attention_mask is None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
window_size_left=self.window_size,
|
||||
softcap=self.softcap,
|
||||
)
|
||||
else:
|
||||
lengths = cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]
|
||||
|
||||
# Split tensors using vectorized split
|
||||
query_list = torch.split(query, lengths.tolist(), dim=0)
|
||||
key_list = torch.split(key, lengths.tolist(), dim=0)
|
||||
value_list = torch.split(value, lengths.tolist(), dim=0)
|
||||
|
||||
padded_query = torch.nn.utils.rnn.pad_sequence(
|
||||
query_list, batch_first=True
|
||||
)
|
||||
padded_key = torch.nn.utils.rnn.pad_sequence(key_list, batch_first=True)
|
||||
padded_value = torch.nn.utils.rnn.pad_sequence(
|
||||
value_list, batch_first=True
|
||||
)
|
||||
|
||||
padded_query = padded_query.transpose(1, 2).contiguous()
|
||||
padded_key = padded_key.transpose(1, 2).contiguous()
|
||||
padded_value = padded_value.transpose(1, 2).contiguous()
|
||||
|
||||
# Compute attention
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
padded_query,
|
||||
padded_key,
|
||||
padded_value,
|
||||
attn_mask=attention_mask,
|
||||
scale=self.softmax_scale,
|
||||
enable_gqa=self.enable_gqa,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(
|
||||
1, 2
|
||||
) # [batch_size, seq_len, num_heads, head_dim]
|
||||
max_seq_len = padded_query.size(2)
|
||||
seq_range = torch.arange(
|
||||
max_seq_len, device=padded_query.device
|
||||
).unsqueeze(0)
|
||||
lengths_tensor = torch.tensor(
|
||||
lengths, device=padded_query.device
|
||||
).unsqueeze(1)
|
||||
mask = seq_range < lengths_tensor # [batch, max_seq_len]
|
||||
attn_output = attn_output[mask] # [total_seq_len, num_heads, head_dim]
|
||||
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
softcap=self.softcap,
|
||||
kv_scales=self.kv_scales,
|
||||
window_size_left=self.window_size,
|
||||
)
|
||||
|
||||
return self.o_proj(
|
||||
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||
)
|
||||
|
||||
|
||||
class Gemma3MLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights, layer_id):
|
||||
super().__init__()
|
||||
act = config.hidden_activation
|
||||
self.act = (
|
||||
ACT2FN[act]
|
||||
if "gelu" not in act
|
||||
else lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate=(
|
||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||
),
|
||||
)
|
||||
)
|
||||
# Fuse gate and up proj
|
||||
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=False,
|
||||
)
|
||||
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||
gate_up_proj,
|
||||
layer_id,
|
||||
["gate_proj", "up_proj"],
|
||||
sizes=[
|
||||
config.intermediate_size,
|
||||
config.intermediate_size,
|
||||
],
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||
down_proj,
|
||||
layer_id,
|
||||
"down_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, adapter_data):
|
||||
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(
|
||||
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||
)
|
||||
|
||||
|
||||
class FlashGemma3Layer(nn.Module):
|
||||
def __init__(
|
||||
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = FlashGemma3Attention(
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
layer_id=layer_id,
|
||||
causal=causal,
|
||||
is_sliding=is_sliding,
|
||||
)
|
||||
self.mlp = Gemma3MLP(
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
|
||||
)
|
||||
|
||||
self.input_layernorm = Gemma3FastRMSNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
self.post_attention_layernorm = Gemma3FastRMSNorm.load(
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
self.pre_feedforward_layernorm = Gemma3FastRMSNorm.load(
|
||||
prefix=f"{prefix}.pre_feedforward_layernorm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
self.post_feedforward_layernorm = Gemma3FastRMSNorm.load(
|
||||
prefix=f"{prefix}.post_feedforward_layernorm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
max_s,
|
||||
adapter_data,
|
||||
attention_mask,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
# Self Attention
|
||||
attn_output = self.self_attn(
|
||||
normed_hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
max_s,
|
||||
adapter_data,
|
||||
attention_mask,
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
normed_attn_res_output, _ = self.post_attention_layernorm(attn_output)
|
||||
normed_attn_res_output = normed_attn_res_output + res
|
||||
res = normed_attn_res_output
|
||||
|
||||
pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output)
|
||||
mlp_output = self.mlp(pre_normed, adapter_data)
|
||||
post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output)
|
||||
|
||||
return post_hidden_states, normed_attn_res_output
|
||||
|
||||
|
||||
class FlashGemma3Model(torch.nn.Module):
|
||||
def __init__(self, prefix: str, config, weights, causal: bool):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashGemma3Layer(
|
||||
prefix=f"{prefix}.layers.{layer_id}",
|
||||
config=config,
|
||||
weights=weights,
|
||||
layer_id=layer_id,
|
||||
causal=causal,
|
||||
is_sliding=bool((layer_id + 1) % config.sliding_window_pattern),
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = Gemma3FastRMSNorm.load(
|
||||
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.head_size = self.layers[0].self_attn.head_size
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask_local: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
cos, sin = self.layers[i].self_attn.rotary_emb.get_cos_sin(
|
||||
position_ids, max_s, hidden_states.dtype
|
||||
)
|
||||
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
max_s,
|
||||
adapter_data,
|
||||
(
|
||||
attention_mask
|
||||
if self.layers[i].self_attn.window_size == -1
|
||||
else attention_mask_local
|
||||
),
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlashGemma3ForCausalLM(torch.nn.Module):
|
||||
def __init__(self, prefix: str, config, weights, *, causal: bool = True):
|
||||
super().__init__()
|
||||
|
||||
embed_norm = config.hidden_size**0.5
|
||||
if not prefix:
|
||||
prefix = "model"
|
||||
else:
|
||||
prefix = f"{prefix}.model"
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||
)
|
||||
self.embed_tokens.weight *= embed_norm
|
||||
|
||||
self.model = FlashGemma3Model(
|
||||
prefix=prefix, config=config, weights=weights, causal=causal
|
||||
)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
prefix=(
|
||||
f"{prefix}.embed_tokens"
|
||||
if config.tie_word_embeddings
|
||||
else f"{prefix}.lm_head"
|
||||
),
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
# self.softcap = config.attn_logit_softcapping
|
||||
# assert isinstance(self.softcap, float)
|
||||
self.softcap = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_embeds,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
max_s,
|
||||
adapter_data,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
|
||||
return logits, speculative_logits
|
||||
|
||||
|
||||
class Gemma3MultimodalInputProjection(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.mm_input_projection_weight = weights.get_tensor(
|
||||
"multi_modal_projector.mm_input_projection_weight"
|
||||
)
|
||||
|
||||
self.mm_soft_emb_norm = Gemma3FastRMSNorm.load(
|
||||
prefix=f"{prefix}.mm_soft_emb_norm",
|
||||
weights=weights,
|
||||
eps=config.vision_config.layer_norm_eps,
|
||||
)
|
||||
|
||||
self.patches_per_image = int(
|
||||
config.vision_config.image_size // config.vision_config.patch_size
|
||||
)
|
||||
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
|
||||
self.kernel_size = self.patches_per_image // self.tokens_per_side
|
||||
self.avg_pool = nn.AvgPool2d(
|
||||
kernel_size=self.kernel_size, stride=self.kernel_size
|
||||
)
|
||||
|
||||
def forward(self, vision_outputs: torch.Tensor):
|
||||
batch_size, _, seq_length = vision_outputs.shape
|
||||
|
||||
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
|
||||
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
|
||||
batch_size, seq_length, self.patches_per_image, self.patches_per_image
|
||||
)
|
||||
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
|
||||
|
||||
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
|
||||
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
|
||||
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
|
||||
|
||||
normed_vision_outputs, _ = self.mm_soft_emb_norm(pooled_vision_outputs)
|
||||
|
||||
projected_vision_outputs = torch.matmul(
|
||||
normed_vision_outputs, self.mm_input_projection_weight
|
||||
)
|
||||
return projected_vision_outputs.type_as(vision_outputs)
|
||||
|
||||
|
||||
class Gemma3ForConditionalGeneration(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
if config.vision_config is not None:
|
||||
|
||||
config.vision_config.quantize = config.quantize
|
||||
|
||||
self.post_vision_model_layernorm = nn.LayerNorm.load(
|
||||
prefix="vision_tower.vision_model.post_layernorm",
|
||||
weights=weights,
|
||||
eps=config.vision_config.layer_norm_eps,
|
||||
)
|
||||
|
||||
self.multimodal_projector = Gemma3MultimodalInputProjection(
|
||||
prefix="multi_modal_projector",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
text_config = config.text_config
|
||||
text_config.speculator = config.speculator
|
||||
text_config.quantize = config.quantize
|
||||
|
||||
self.vision_model = load_vision_model(
|
||||
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
||||
config=config.vision_config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.text_model = load_text_model(
|
||||
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||
config=config.text_config,
|
||||
weights=weights,
|
||||
)
|
||||
else:
|
||||
config.text_config.quantize = config.quantize
|
||||
config.text_config.speculator = config.speculator
|
||||
self.text_model = load_text_model(
|
||||
prefix=prefix,
|
||||
config=config.text_config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.pad_token_id = (
|
||||
config.pad_token_id if config.pad_token_id is not None else -1
|
||||
)
|
||||
|
||||
def get_attention_mask(
|
||||
self, input_ids, max_s, cu_seqlen_prefill, dtype, image_token_mask
|
||||
):
|
||||
device = input_ids.device
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
||||
lengths = (cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]).tolist()
|
||||
batch_size = len(lengths)
|
||||
|
||||
sequence_length = max(lengths)
|
||||
target_length = sequence_length
|
||||
# Create the padding mask from the computed lengths.
|
||||
# pad_mask: [batch, sequence_length] where True indicates valid tokens.
|
||||
seq_range = torch.arange(sequence_length, device=device).unsqueeze(0)
|
||||
lengths_tensor = torch.tensor(lengths, device=device).unsqueeze(1)
|
||||
pad_mask = seq_range < lengths_tensor # shape: [batch, sequence_length]
|
||||
|
||||
# Build the base causal mask (for non-image tokens):
|
||||
causal_mask = torch.tril(
|
||||
torch.ones(
|
||||
(sequence_length, sequence_length), dtype=torch.bool, device=device
|
||||
)
|
||||
)
|
||||
base_mask = pad_mask.unsqueeze(2) & pad_mask.unsqueeze(
|
||||
1
|
||||
) # [batch, sequence_length, sequence_length]
|
||||
base_mask = base_mask & causal_mask.unsqueeze(0) # apply causal constraint
|
||||
|
||||
image_token_mask = torch.nn.utils.rnn.pad_sequence(
|
||||
torch.split(image_token_mask, lengths), batch_first=True, padding_value=0
|
||||
)
|
||||
bidirectional_mask = image_token_mask.unsqueeze(2) & image_token_mask.unsqueeze(
|
||||
1
|
||||
)
|
||||
|
||||
# Combine the causal base mask and the bidirectional mask.
|
||||
combined_mask = torch.logical_or(
|
||||
base_mask.unsqueeze(1), bidirectional_mask.unsqueeze(1)
|
||||
).to(device)
|
||||
# combined_mask now has shape [batch, 1, sequence_length, sequence_length]
|
||||
|
||||
full_attention_mask = torch.zeros(
|
||||
(batch_size, 1, sequence_length, target_length),
|
||||
device=device,
|
||||
dtype=torch.bool,
|
||||
)
|
||||
full_attention_mask[:, :, :, :sequence_length] = combined_mask
|
||||
|
||||
final_attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(device)
|
||||
|
||||
return final_attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
# Unused here
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if cu_seqlen_prefill is not None:
|
||||
max_s += 1
|
||||
position_ids += 1
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
|
||||
image_outputs = self.vision_model(pixel_values)
|
||||
vision_outputs = self.post_vision_model_layernorm(
|
||||
image_outputs.last_hidden_state
|
||||
)
|
||||
image_features = self.multimodal_projector(vision_outputs)
|
||||
|
||||
image_token_mask = (input_ids == self.config.image_token_index).to(
|
||||
input_ids.device
|
||||
)
|
||||
inputs_embeds[image_token_mask] = image_features.view(
|
||||
-1, image_features.shape[-1]
|
||||
)
|
||||
attention_mask = self.get_attention_mask(
|
||||
input_ids,
|
||||
max_s,
|
||||
cu_seqlen_prefill,
|
||||
inputs_embeds.dtype,
|
||||
image_token_mask,
|
||||
)
|
||||
# Use flash attention for text-only input
|
||||
# else:
|
||||
# if cu_seqlen_prefill is not None:
|
||||
# min_dtype = torch.finfo(inputs_embeds.dtype).min
|
||||
# lengths = (cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]).tolist()
|
||||
|
||||
# # Determine the maximum sequence length (after padding) from query.
|
||||
# sequence_length = max(lengths)
|
||||
# target_length = sequence_length
|
||||
|
||||
# # Create the padding mask from the computed lengths.
|
||||
# # pad_mask: [batch, sequence_length] where True indicates valid tokens.
|
||||
# seq_range = torch.arange(
|
||||
# sequence_length, device=input_ids.device
|
||||
# ).unsqueeze(0)
|
||||
# lengths_tensor = torch.tensor(
|
||||
# lengths, device=input_ids.device
|
||||
# ).unsqueeze(1)
|
||||
# pad_mask = seq_range < lengths_tensor # shape: [batch, sequence_length]
|
||||
|
||||
# # Build the base causal mask (for non-image tokens):
|
||||
# causal_mask = torch.tril(
|
||||
# torch.ones(
|
||||
# (sequence_length, sequence_length),
|
||||
# dtype=torch.bool,
|
||||
# device=input_ids.device,
|
||||
# )
|
||||
# )
|
||||
# base_mask = pad_mask.unsqueeze(2) & pad_mask.unsqueeze(
|
||||
# 1
|
||||
# ) # [batch, sequence_length, sequence_length]
|
||||
# base_mask = base_mask & causal_mask.unsqueeze(0)
|
||||
# attention_mask = base_mask.unsqueeze(
|
||||
# 1
|
||||
# ) # [batch, 1, sequence_length, sequence_length]
|
||||
# full_attention_mask = torch.zeros(
|
||||
# (len(lengths), 1, sequence_length, target_length),
|
||||
# device=input_ids.device,
|
||||
# dtype=torch.bool,
|
||||
# )
|
||||
# full_attention_mask[:, :, :, :sequence_length] = attention_mask
|
||||
|
||||
# attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(
|
||||
# input_ids.device
|
||||
# )
|
||||
|
||||
if attention_mask is not None:
|
||||
min_dtype = torch.finfo(inputs_embeds.dtype).min
|
||||
# prefill may be larger than sliding window
|
||||
effective_seq_len = max(
|
||||
position_ids.shape[0], self.config.text_config.sliding_window
|
||||
)
|
||||
sliding_window_mask = torch.tril(
|
||||
torch.ones_like(attention_mask, dtype=torch.bool),
|
||||
diagonal=-self.config.text_config.sliding_window,
|
||||
)
|
||||
attention_mask_local = torch.where(
|
||||
sliding_window_mask, min_dtype, attention_mask
|
||||
)
|
||||
offset = max(0, position_ids.shape[0] - effective_seq_len)
|
||||
attention_mask_local = attention_mask_local[
|
||||
:, :, :, offset : offset + effective_seq_len
|
||||
]
|
||||
else:
|
||||
attention_mask_local = None
|
||||
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
attention_mask=attention_mask,
|
||||
attention_mask_local=attention_mask_local,
|
||||
)
|
||||
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
||||
|
||||
# pad logit with 1 zero logit for the image token
|
||||
if pixel_values is not None:
|
||||
logits = torch.cat(
|
||||
[logits, torch.zeros(logits.size(0), 1, device=logits.device)], dim=1
|
||||
)
|
||||
if speculative_logits is not None:
|
||||
speculative_logits = torch.cat(
|
||||
[
|
||||
speculative_logits,
|
||||
torch.zeros(
|
||||
speculative_logits.size(0),
|
||||
1,
|
||||
device=speculative_logits.device,
|
||||
),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
return logits, speculative_logits
|
@ -242,6 +242,7 @@ class MistralAttention(torch.nn.Module):
|
||||
seqlen,
|
||||
max_s,
|
||||
kv_scales=self.kv_scales,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
|
||||
return self.o_proj(
|
||||
|
@ -290,6 +290,7 @@ class MixtralAttention(torch.nn.Module):
|
||||
seqlen,
|
||||
max_s,
|
||||
kv_scales=self.kv_scales,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
@ -74,7 +74,7 @@ class Qwen2Attention(torch.nn.Module):
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_past = (
|
||||
self.window_size = (
|
||||
config.sliding_window if config.sliding_window is not None else -1
|
||||
)
|
||||
self.num_heads = config.num_attention_heads
|
||||
@ -172,7 +172,7 @@ class Qwen2Attention(torch.nn.Module):
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
window_size_left=self.max_past,
|
||||
window_size_left=self.window_size,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
@ -185,6 +185,7 @@ class Qwen2Attention(torch.nn.Module):
|
||||
seqlen,
|
||||
max_s,
|
||||
kv_scales=self.kv_scales,
|
||||
window_size_left=self.window_size,
|
||||
)
|
||||
|
||||
return self.o_proj(
|
||||
@ -405,10 +406,10 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.max_past = config.sliding_window
|
||||
self.max_past_tensor = (
|
||||
self.window_size = config.sliding_window
|
||||
self.window_size_tensor = (
|
||||
torch.tensor(config.sliding_window, device=weights.device)
|
||||
if self.max_past is not None
|
||||
if self.window_size is not None
|
||||
else None
|
||||
)
|
||||
|
||||
@ -430,10 +431,10 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
||||
if prefill_cache_indices is not None:
|
||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||
slots = slots[prefill_cache_indices]
|
||||
elif self.max_past is not None:
|
||||
elif self.window_size is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
||||
seqlen = seqlen.clamp(max=self.window_size_tensor)
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
|
@ -291,6 +291,7 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
seqlen,
|
||||
max_s,
|
||||
kv_scales=self.kv_scales,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
|
||||
return self.o_proj(
|
||||
|
@ -0,0 +1,313 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_gemma3.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
from transformers import SiglipVisionConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Gemma3TextConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Gemma3Model`]. It is used to instantiate a Gemma3
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the Gemma3-4B.
|
||||
e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 262144):
|
||||
Vocabulary size of the Gemma3 model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Gemma3Model`]
|
||||
hidden_size (`int`, *optional*, defaults to 2304):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 9216):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 26):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 8):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 4):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
head_dim (`int`, *optional*, defaults to 256):
|
||||
The attention head dimension.
|
||||
sliding_window (`int`, *optional*, defaults to 4096): in Gemma3, every other layer uses sliding window
|
||||
attention. This is the size of the sliding window.
|
||||
query_pre_attn_scalar (`float`, *optional*):
|
||||
The scaling factor used on the attention scores, not that
|
||||
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
||||
The base period of the RoPE embeddings used for global attention.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
rope_local_base_freq (float, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings for local attention.
|
||||
sliding_window_pattern (`int`, *optional*, defaults to 6):
|
||||
Pattern for the sliding window attention.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
||||
The non-linear activation function (function or string) in the decoder. Will default to
|
||||
`"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"`
|
||||
activation function.
|
||||
pad_token_id (`int`, *optional*, defaults to 0):
|
||||
Padding token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 1):
|
||||
End of stream token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 2):
|
||||
Beginning of stream token id.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
||||
Whether to tie weight embeddings
|
||||
max_position_embeddings (`int`, *optional*, defaults to 131072):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
final_logit_softcapping (`bool`, *optional*, defaults to `True`):
|
||||
Whether to apply logit softcapping or nor
|
||||
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
|
||||
Scaling factor when applying tanh soft-capping on the attention scorexs.
|
||||
cache_implementation (`str`, *optional*, defaults to `"hybrid"`):
|
||||
The cache type to be used with `generate`.
|
||||
|
||||
```python
|
||||
>>> from transformers import Gemma3Model, Gemma3TextConfig
|
||||
>>> # Initializing a Gemma3 gemma3-4b style configuration
|
||||
>>> configuration = Gemma3Config()
|
||||
>>> # Initializing a model from the gemma3-4b style configuration
|
||||
>>> model = Gemma3Model(configuration)
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "gemma3_text"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int = 262_144,
|
||||
hidden_size: int = 2304,
|
||||
intermediate_size: int = 9216,
|
||||
num_hidden_layers: int = 26,
|
||||
num_attention_heads: int = 8,
|
||||
num_key_value_heads: int = 4,
|
||||
head_dim: int = 256,
|
||||
sliding_window: int = 4096,
|
||||
query_pre_attn_scalar: Optional[float] = 256,
|
||||
rope_theta: float = 1_000_000.0,
|
||||
rope_scaling=None,
|
||||
rope_local_base_freq: float = 10_000.0,
|
||||
sliding_window_pattern: int = 6,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
hidden_activation: str = "gelu_pytorch_tanh",
|
||||
pad_token_id: int = 0,
|
||||
eos_token_id: int = 1,
|
||||
bos_token_id: int = 2,
|
||||
tie_word_embeddings: bool = True,
|
||||
max_position_embeddings: int = 131_072,
|
||||
initializer_range: float = 0.02,
|
||||
attention_bias: bool = False,
|
||||
attention_dropout: float = 0.0,
|
||||
use_cache: bool = True,
|
||||
final_logit_softcapping=None,
|
||||
attn_logit_softcapping=None,
|
||||
cache_implementation: str = "hybrid",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.head_dim = head_dim
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.rope_local_base_freq = rope_local_base_freq
|
||||
# For configuring HybridCache to work with 5:1 attention pattern
|
||||
self.sliding_window_pattern = sliding_window_pattern
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.hidden_activation = hidden_activation
|
||||
self.query_pre_attn_scalar = query_pre_attn_scalar
|
||||
self.sliding_window = sliding_window
|
||||
self.final_logit_softcapping = final_logit_softcapping
|
||||
self.attn_logit_softcapping = attn_logit_softcapping
|
||||
self.cache_implementation = cache_implementation
|
||||
rope_config_validation(self)
|
||||
|
||||
|
||||
class Gemma3Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an
|
||||
Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the PaliGemma-2B.
|
||||
|
||||
e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
text_config (`Union[Gemma3TextConfig, dict]`, *optional*):
|
||||
The config object of the text backbone.
|
||||
vision_config (`Union[AutoConfig, dict]`, *optional*):
|
||||
Custom vision config or dict.
|
||||
mm_tokens_per_image (`int`, *optional*, defaults to 256):
|
||||
The number of tokens per image embedding.
|
||||
boi_token_index (`int`, *optional*, defaults to 255999):
|
||||
The begin-of-image token index to wrap the image prompt.
|
||||
eoi_token_index (`int`, *optional*, defaults to 256000):
|
||||
The end-of-image token index to wrap the image prompt.
|
||||
image_token_index (`int`, *optional*, defaults to 262144):
|
||||
The image token index to encode the image prompt.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig
|
||||
|
||||
>>> # Initializing a Siglip-like vision config
|
||||
>>> vision_config = SiglipVisionConfig()
|
||||
|
||||
>>> # Initializing a Gemma3 Text config
|
||||
>>> text_config = Gemma3TextConfig()
|
||||
|
||||
>>> # Initializing a Gemma3 gemma-3-4b style configuration
|
||||
>>> configuration = Gemma3Config(vision_config, text_config)
|
||||
|
||||
>>> # Initializing a model from the gemma-3-4b style configuration
|
||||
>>> model = Gemma3TextConfig(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "gemma3"
|
||||
sub_configs = {
|
||||
"text_config": Gemma3TextConfig,
|
||||
"vision_config": SiglipVisionConfig,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_config: Optional[Gemma3TextConfig] = None,
|
||||
vision_config: Optional[SiglipVisionConfig] = None,
|
||||
mm_tokens_per_image: int = 256,
|
||||
boi_token_index: int = 255_999,
|
||||
eoi_token_index: int = 256_000,
|
||||
image_token_index: int = 262_144,
|
||||
initializer_range: float = 0.02,
|
||||
**kwargs,
|
||||
):
|
||||
if text_config is None:
|
||||
text_config = Gemma3TextConfig()
|
||||
logger.info(
|
||||
"text_config is None, using default Gemma3TextConfig vision config."
|
||||
)
|
||||
elif isinstance(text_config, dict):
|
||||
text_config = Gemma3TextConfig(**text_config)
|
||||
|
||||
if isinstance(vision_config, dict):
|
||||
vision_config = SiglipVisionConfig(**vision_config)
|
||||
else:
|
||||
vision_config = SiglipVisionConfig()
|
||||
logger.info(
|
||||
"vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited "
|
||||
"to text tasks."
|
||||
)
|
||||
|
||||
self.text_config = text_config
|
||||
self.vision_config = vision_config
|
||||
self.mm_tokens_per_image = mm_tokens_per_image
|
||||
self.boi_token_index = boi_token_index
|
||||
self.eoi_token_index = eoi_token_index
|
||||
self.image_token_index = image_token_index
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
__all__ = ["Gemma3Config", "Gemma3TextConfig"]
|
@ -0,0 +1,463 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Image processor class for Gemma3."""
|
||||
|
||||
import itertools
|
||||
import math
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.image_processing_utils import (
|
||||
BaseImageProcessor,
|
||||
BatchFeature,
|
||||
get_size_dict,
|
||||
)
|
||||
from transformers.image_transforms import (
|
||||
convert_to_rgb,
|
||||
resize,
|
||||
to_channel_dimension_format,
|
||||
)
|
||||
from transformers.image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from transformers.utils import (
|
||||
TensorType,
|
||||
filter_out_non_signature_kwargs,
|
||||
is_vision_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
from .utils import make_nested_list_of_images
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
|
||||
|
||||
class Gemma3ImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a SigLIP image processor.
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
|
||||
`do_resize` in the `preprocess` method.
|
||||
size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
|
||||
Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
||||
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
||||
the `preprocess` method.
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
||||
method.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
|
||||
`do_normalize` in the `preprocess` method.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB.
|
||||
do_pan_and_scan (`bool`, *optional*):
|
||||
Whether to apply `pan_and_scan` to images.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values", "num_crops"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
size: Dict[str, int] = None,
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
do_rescale: bool = False,
|
||||
rescale_factor: Union[int, float] = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: bool = None,
|
||||
do_pan_and_scan: bool = None,
|
||||
pan_and_scan_min_crop_size: int = None,
|
||||
pan_and_scan_max_num_crops: int = None,
|
||||
pan_and_scan_min_ratio_to_activate: float = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
size = size if size is not None else {"height": 224, "width": 224}
|
||||
image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||
image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.resample = resample
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
self.do_pan_and_scan = do_pan_and_scan
|
||||
self.pan_and_scan_min_crop_size = pan_and_scan_min_crop_size
|
||||
self.pan_and_scan_max_num_crops = pan_and_scan_max_num_crops
|
||||
self.pan_and_scan_min_ratio_to_activate = pan_and_scan_min_ratio_to_activate
|
||||
|
||||
def pan_and_scan(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
pan_and_scan_min_crop_size: int,
|
||||
pan_and_scan_max_num_crops: int,
|
||||
pan_and_scan_min_ratio_to_activate: float,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
"""
|
||||
Pan and Scan and image, whatever it means. TODO: write-up docs
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
pan_and_scan_min_crop_size (`int`):
|
||||
Size of pan_and_scan_min_crop_size.
|
||||
pan_and_scan_max_num_crops (`int`):
|
||||
pan_and_scan_max_num_crops for the image.
|
||||
pan_and_scan_min_ratio_to_activate (`int`):
|
||||
pan_and_scan_min_ratio_to_activate for the image..
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
height, width = get_image_size(image)
|
||||
|
||||
# Square or landscape image.
|
||||
if width >= height:
|
||||
# Only apply PaS if the image is sufficiently exaggerated
|
||||
if width / height < pan_and_scan_min_ratio_to_activate:
|
||||
return []
|
||||
|
||||
# Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
|
||||
num_crops_w = int(
|
||||
math.floor(width / height + 0.5)
|
||||
) # Half round up rounding.
|
||||
num_crops_w = min(
|
||||
int(math.floor(width / pan_and_scan_min_crop_size)), num_crops_w
|
||||
)
|
||||
|
||||
# Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
|
||||
num_crops_w = max(2, num_crops_w)
|
||||
num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
|
||||
num_crops_h = 1
|
||||
|
||||
# Portrait image.
|
||||
else:
|
||||
# Only apply PaS if the image is sufficiently exaggerated
|
||||
if height / width < pan_and_scan_min_ratio_to_activate:
|
||||
return []
|
||||
|
||||
# Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
|
||||
num_crops_h = int(math.floor(height / width + 0.5))
|
||||
num_crops_h = min(
|
||||
int(math.floor(height / pan_and_scan_min_crop_size)), num_crops_h
|
||||
)
|
||||
|
||||
# Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
|
||||
num_crops_h = max(2, num_crops_h)
|
||||
num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
|
||||
num_crops_w = 1
|
||||
|
||||
crop_size_w = int(math.ceil(width / num_crops_w))
|
||||
crop_size_h = int(math.ceil(height / num_crops_h))
|
||||
|
||||
# Don't apply PaS if crop size is too small.
|
||||
if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
|
||||
return []
|
||||
|
||||
crop_positions_w = [crop_size_w * i for i in range(num_crops_w)]
|
||||
crop_positions_h = [crop_size_h * i for i in range(num_crops_h)]
|
||||
|
||||
if input_data_format == ChannelDimension.LAST:
|
||||
image_crops = [
|
||||
image[pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]
|
||||
for pos_h, pos_w in itertools.product(
|
||||
crop_positions_h, crop_positions_w
|
||||
)
|
||||
]
|
||||
else:
|
||||
image_crops = [
|
||||
image[:, pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]
|
||||
for pos_h, pos_w in itertools.product(
|
||||
crop_positions_h, crop_positions_w
|
||||
)
|
||||
]
|
||||
|
||||
return image_crops
|
||||
|
||||
def _process_images_for_pas(
|
||||
self,
|
||||
images: List[np.ndarray],
|
||||
do_pan_and_scan: bool,
|
||||
pan_and_scan_min_crop_size: int,
|
||||
pan_and_scan_max_num_crops: int,
|
||||
pan_and_scan_min_ratio_to_activate: float,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
pas_images_list = []
|
||||
num_crops = []
|
||||
for image in images:
|
||||
pas_images = self.pan_and_scan(
|
||||
image=image,
|
||||
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
|
||||
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
|
||||
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
pas_images_list.extend([image] + pas_images)
|
||||
num_crops.append(len(pas_images))
|
||||
return pas_images_list, num_crops
|
||||
|
||||
@filter_out_non_signature_kwargs()
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
do_resize: bool = None,
|
||||
size: Dict[str, int] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: bool = None,
|
||||
rescale_factor: float = None,
|
||||
do_normalize: bool = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
do_convert_rgb: bool = True,
|
||||
do_pan_and_scan: bool = None,
|
||||
pan_and_scan_min_crop_size: int = None,
|
||||
pan_and_scan_max_num_crops: int = None,
|
||||
pan_and_scan_min_ratio_to_activate: float = None,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
Preprocess an image or batch of images.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing.
|
||||
resample (`int`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image.
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
do_pan_and_scan (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to apply `pan_and_scan` to images.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size, param_name="size", default_to_square=False)
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = (
|
||||
rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
)
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
do_convert_rgb = (
|
||||
do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
)
|
||||
do_pan_and_scan = (
|
||||
do_pan_and_scan if do_pan_and_scan is not None else self.do_pan_and_scan
|
||||
)
|
||||
pan_and_scan_min_crop_size = (
|
||||
pan_and_scan_min_crop_size
|
||||
if pan_and_scan_min_crop_size is not None
|
||||
else self.pan_and_scan_min_crop_size
|
||||
)
|
||||
pan_and_scan_max_num_crops = (
|
||||
pan_and_scan_max_num_crops
|
||||
if pan_and_scan_max_num_crops is not None
|
||||
else self.pan_and_scan_max_num_crops
|
||||
)
|
||||
pan_and_scan_min_ratio_to_activate = (
|
||||
pan_and_scan_min_ratio_to_activate
|
||||
if pan_and_scan_min_ratio_to_activate is not None
|
||||
else self.pan_and_scan_min_ratio_to_activate
|
||||
)
|
||||
|
||||
images_list = make_nested_list_of_images(images)
|
||||
|
||||
if not valid_images(images_list[0]):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
validate_preprocess_arguments(
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
)
|
||||
if do_convert_rgb:
|
||||
images_list = [
|
||||
[convert_to_rgb(image) for image in images] for images in images_list
|
||||
]
|
||||
|
||||
# All transformations expect numpy arrays.
|
||||
images_list = [
|
||||
[to_numpy_array(image) for image in images] for images in images_list
|
||||
]
|
||||
|
||||
if do_rescale and is_scaled_image(images_list[0][0]):
|
||||
logger.warning_once(
|
||||
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
)
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images_list[0][0])
|
||||
|
||||
if do_pan_and_scan:
|
||||
images_list_and_num_crops = [
|
||||
self._process_images_for_pas(
|
||||
images=images,
|
||||
do_pan_and_scan=do_pan_and_scan,
|
||||
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
|
||||
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
|
||||
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for images in images_list
|
||||
]
|
||||
images_list = [images for images, _ in images_list_and_num_crops]
|
||||
num_crops = [num_crops for _, num_crops in images_list_and_num_crops]
|
||||
else:
|
||||
num_crops = [[0] for images in images_list]
|
||||
|
||||
if do_resize:
|
||||
height, width = size["height"], size["width"]
|
||||
images_list = [
|
||||
[
|
||||
resize(
|
||||
image=image,
|
||||
size=(height, width),
|
||||
resample=resample,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
for images in images_list
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images_list = [
|
||||
[
|
||||
self.rescale(
|
||||
image=image,
|
||||
scale=rescale_factor,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
for images in images_list
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images_list = [
|
||||
[
|
||||
self.normalize(
|
||||
image=image,
|
||||
mean=image_mean,
|
||||
std=image_std,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
for images in images_list
|
||||
]
|
||||
|
||||
images = [
|
||||
to_channel_dimension_format(
|
||||
image, data_format, input_channel_dim=input_data_format
|
||||
)
|
||||
for images in images_list
|
||||
for image in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images, "num_crops": num_crops}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
|
||||
__all__ = ["Gemma3ImageProcessor"]
|
@ -0,0 +1,204 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import re
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.processing_utils import (
|
||||
ImagesKwargs,
|
||||
ProcessingKwargs,
|
||||
ProcessorMixin,
|
||||
Unpack,
|
||||
)
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from transformers.utils import to_py_obj
|
||||
from text_generation_server.models.custom_modeling.gemma3.image_processing_gemma3 import (
|
||||
Gemma3ImageProcessor,
|
||||
)
|
||||
|
||||
from transformers.image_utils import PILImageResampling
|
||||
|
||||
from .utils import make_nested_list_of_images
|
||||
|
||||
|
||||
class Gemma3ImagesKwargs(ImagesKwargs):
|
||||
do_pan_and_scan: Optional[bool]
|
||||
pan_and_scan_min_crop_size: Optional[int]
|
||||
pan_and_scan_max_num_crops: Optional[int]
|
||||
pan_and_scan_min_ratio_to_activate: Optional[float]
|
||||
do_convert_rgb: Optional[bool]
|
||||
|
||||
|
||||
class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
},
|
||||
"images_kwargs": {
|
||||
"do_pan_and_scan": False,
|
||||
"pan_and_scan_min_crop_size": 256,
|
||||
"pan_and_scan_max_num_crops": 4,
|
||||
"pan_and_scan_min_ratio_to_activate": 1.2,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Gemma3Processor(ProcessorMixin):
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = ["chat_template"]
|
||||
# # image_processor_class = "Gemma3ImageProcessor"
|
||||
image_processor_class = "AutoProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor,
|
||||
tokenizer,
|
||||
chat_template=None,
|
||||
num_mm_soft_tokens_per_image: int = 256,
|
||||
**kwargs,
|
||||
):
|
||||
num_mm_soft_tokens_per_image = 256
|
||||
chat_template = None
|
||||
|
||||
image_processor = Gemma3ImageProcessor(
|
||||
image_mean=(127.5,) * 3,
|
||||
image_std=(127.5,) * 3,
|
||||
size={"height": 896, "width": 896},
|
||||
do_rescale=False,
|
||||
resample=PILImageResampling.BILINEAR,
|
||||
)
|
||||
|
||||
self.image_token_id = tokenizer.image_token_id
|
||||
image_tokens_expanded = "".join(
|
||||
[tokenizer.image_token] * num_mm_soft_tokens_per_image
|
||||
)
|
||||
self.full_image_sequence = (
|
||||
f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"
|
||||
)
|
||||
|
||||
self.image_processor = image_processor
|
||||
self.tokenizer = tokenizer
|
||||
self.chat_template = chat_template
|
||||
|
||||
# super().__init__(
|
||||
# image_processor=image_processor,
|
||||
# tokenizer=tokenizer,
|
||||
# chat_template=chat_template,
|
||||
# **kwargs,
|
||||
# )
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: Union[
|
||||
TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
|
||||
] = None,
|
||||
videos=None,
|
||||
audio=None,
|
||||
**kwargs: Unpack[Gemma3ProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
if text is None and images is None:
|
||||
raise ValueError("Provide at least one of `text` or `images`.")
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Gemma3ProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError(
|
||||
"Invalid input text. Please provide a string, or a list of strings"
|
||||
)
|
||||
|
||||
image_inputs = {}
|
||||
if images is not None:
|
||||
batched_images = make_nested_list_of_images(images)
|
||||
image_inputs = self.image_processor(
|
||||
batched_images, **output_kwargs["images_kwargs"]
|
||||
)
|
||||
|
||||
# Create empty text to be replaced with placeholders
|
||||
if not text:
|
||||
text = [
|
||||
" ".join(["<image>"] * len(images)) for images in batched_images
|
||||
]
|
||||
|
||||
if len(batched_images) != len(text):
|
||||
raise ValueError(
|
||||
f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})."
|
||||
)
|
||||
|
||||
# Replace image tokens by the full expanded sequence
|
||||
batch_num_crops = to_py_obj(image_inputs.pop("num_crops"))
|
||||
for prompt, images, num_crops in zip(text, batched_images, batch_num_crops):
|
||||
image_indexes = [m.start() for m in re.finditer("<image>", prompt)]
|
||||
|
||||
if len(images) != len(image_indexes):
|
||||
raise ValueError(
|
||||
f"Prompt contained {len(image_indexes)} image tokens but received {len(images)} images."
|
||||
)
|
||||
|
||||
# Insert additional image tokens for Pan-and-Scan crops
|
||||
for num, idx in reversed(list(zip(num_crops, image_indexes))):
|
||||
if num:
|
||||
formatted_image_text = (
|
||||
"Here is the original image <image> and here are some crops to help you see better "
|
||||
+ " ".join(["<image>"] * num)
|
||||
)
|
||||
prompt = (
|
||||
prompt[:idx]
|
||||
+ formatted_image_text
|
||||
+ prompt[idx + len("<image>") :]
|
||||
)
|
||||
|
||||
# Expand placeholder image tokens to the full image token sequence
|
||||
text = [
|
||||
prompt.replace("<image>", self.full_image_sequence) for prompt in text
|
||||
]
|
||||
|
||||
text_input = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
|
||||
return BatchFeature(data={**text_input, **image_inputs})
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->PaliGemma
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
|
||||
|
||||
__all__ = ["Gemma3Processor"]
|
@ -0,0 +1,61 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
from transformers.image_utils import ImageInput, is_valid_image, is_pil_image
|
||||
|
||||
|
||||
def is_valid_list_of_images(images: List):
|
||||
return images and all(is_valid_image(image) for image in images)
|
||||
|
||||
|
||||
def make_nested_list_of_images(
|
||||
images: Union[List[ImageInput], ImageInput],
|
||||
) -> ImageInput:
|
||||
"""
|
||||
Ensure that the output is a nested list of images.
|
||||
Args:
|
||||
images (`Union[List[ImageInput], ImageInput]`):
|
||||
The input image.
|
||||
Returns:
|
||||
list: A list of list of images or a list of 4d array of images.
|
||||
"""
|
||||
# If it's a list of batches, it's already in the right format
|
||||
if (
|
||||
isinstance(images, (list, tuple))
|
||||
and all(isinstance(images_i, (list, tuple)) for images_i in images)
|
||||
and all(is_valid_list_of_images(images_i) for images_i in images)
|
||||
):
|
||||
return images
|
||||
|
||||
# If it's a list of images, it's a single batch, so convert it to a list of lists
|
||||
if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
|
||||
if is_pil_image(images[0]) or images[0].ndim == 3:
|
||||
return [images]
|
||||
if images[0].ndim == 4:
|
||||
return [list(image) for image in images]
|
||||
|
||||
# If it's a single image, convert it to a list of lists
|
||||
if is_valid_image(images):
|
||||
if is_pil_image(images) or images.ndim == 3:
|
||||
return [[images]]
|
||||
if images.ndim == 4:
|
||||
return [list(images)]
|
||||
|
||||
raise ValueError(
|
||||
"Invalid input type. Must be a single image, a list of images, or a list of batches of images."
|
||||
)
|
@ -633,7 +633,7 @@ class Qwen2_5VisionModel(nn.Module):
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
self.temporal_patch_size = config.temporal_patch_size
|
||||
self.spatial_patch_size = config.spatial_patch_size
|
||||
self.in_channels = config.in_channels
|
||||
|
@ -542,6 +542,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
max_s=max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
@ -23,6 +23,13 @@ def load_text_model(prefix, config, weights, name=None):
|
||||
)
|
||||
|
||||
return FlashGemma2ForCausalLM(prefix, config, weights)
|
||||
|
||||
elif config.model_type == "gemma3" or config.model_type == "gemma3_text":
|
||||
from text_generation_server.models.custom_modeling.flash_gemma3_modeling import (
|
||||
FlashGemma3ForCausalLM,
|
||||
)
|
||||
|
||||
return FlashGemma3ForCausalLM(prefix, config, weights)
|
||||
elif config.model_type == "paligemma":
|
||||
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||
FlashGemmaForCausalLM,
|
||||
@ -42,13 +49,21 @@ def load_vision_model(prefix, config, weights):
|
||||
return CLIPVisionTransformer(
|
||||
prefix=f"{prefix}.vision_model", config=config, weights=weights
|
||||
)
|
||||
if config.model_type == "siglip_vision_model":
|
||||
if (
|
||||
config.model_type == "siglip_vision_model"
|
||||
or config.model_type == "gemma3_vision"
|
||||
):
|
||||
from text_generation_server.models.custom_modeling.siglip import (
|
||||
SiglipVisionTransformer,
|
||||
)
|
||||
|
||||
# TODO: ensure that using the prefix doesn't break any existing models
|
||||
# that rely on the old prefix (update the old models if necessary)
|
||||
return SiglipVisionTransformer(
|
||||
prefix="vision_tower.vision_model", config=config, weights=weights
|
||||
# prefix="vision_model.vision_model", config=config, weights=weights
|
||||
prefix=f"{prefix}.vision_model",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||
|
@ -83,24 +83,11 @@ from text_generation_server.models.metadata_kernels import (
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
# Will be set in init
|
||||
SLIDING_WINDOW: Optional[int] = None
|
||||
|
||||
|
||||
def small_power_of_2(n: int):
|
||||
return 1 << ((n - 1).bit_length() - 1)
|
||||
|
||||
|
||||
def set_sliding_window(sliding_window: int):
|
||||
global SLIDING_WINDOW
|
||||
SLIDING_WINDOW = sliding_window
|
||||
|
||||
|
||||
def get_sliding_windows() -> int:
|
||||
global SLIDING_WINDOW
|
||||
return SLIDING_WINDOW
|
||||
|
||||
|
||||
def init_cpu_threads_env(rank_id: int, world_size: int):
|
||||
import importlib.util
|
||||
|
||||
@ -1002,10 +989,8 @@ class FlashCausalLMBatch(Batch):
|
||||
self.slot_indices,
|
||||
)
|
||||
|
||||
sliding_window = get_sliding_windows()
|
||||
position_ids = []
|
||||
slot_indices = []
|
||||
prefill_cache_indices = []
|
||||
all_prefill_logprobs = True
|
||||
no_prefill_logprobs = True
|
||||
prefill_cu_outlens = [0]
|
||||
@ -1064,14 +1049,6 @@ class FlashCausalLMBatch(Batch):
|
||||
# Update
|
||||
cumulative_slot_tokens += len(request_slots)
|
||||
|
||||
# Create tensor to slice into the kv tensor in prefill
|
||||
if sliding_window is not None:
|
||||
request_prefill_cache_indices = torch.arange(
|
||||
cumulative_length + max(0, input_length - sliding_window),
|
||||
cumulative_length + input_length,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
# Prefill logprobs is ignored if the request is done prefilling
|
||||
prefill_logprobs = r.prefill_logprobs and request_prefilling
|
||||
|
||||
@ -1085,9 +1062,6 @@ class FlashCausalLMBatch(Batch):
|
||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||
prefill_out_cumulative_length += 1
|
||||
|
||||
if sliding_window is not None:
|
||||
prefill_cache_indices.append(request_prefill_cache_indices)
|
||||
|
||||
ADAPTER_TO_INDEX = get_adapter_to_index()
|
||||
if ADAPTER_TO_INDEX:
|
||||
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
|
||||
@ -1151,24 +1125,18 @@ class FlashCausalLMBatch(Batch):
|
||||
position_ids = torch.cat(position_ids)
|
||||
if slot_indices:
|
||||
slot_indices = torch.cat(slot_indices)
|
||||
if sliding_window is not None:
|
||||
prefill_cache_indices = torch.cat(prefill_cache_indices)
|
||||
else:
|
||||
if position_ids:
|
||||
position_ids = position_ids[0]
|
||||
if slot_indices:
|
||||
slot_indices = slot_indices[0]
|
||||
if sliding_window is not None:
|
||||
prefill_cache_indices = prefill_cache_indices[0]
|
||||
|
||||
if not has_triton():
|
||||
self.position_ids = position_ids.to(device)
|
||||
self.slot_indices = slot_indices.to(device)
|
||||
|
||||
self.prefill_cu_outlens = prefill_cu_outlens
|
||||
self.prefill_cache_indices = (
|
||||
prefill_cache_indices.to(device) if sliding_window is not None else None
|
||||
)
|
||||
self.prefill_cache_indices = None
|
||||
|
||||
if all_prefill_logprobs:
|
||||
prefill_head_indices = None
|
||||
@ -1306,9 +1274,7 @@ class FlashCausalLM(Model):
|
||||
if text_config is not None:
|
||||
config = text_config
|
||||
|
||||
if getattr(config, "sliding_window", None) is not None:
|
||||
set_sliding_window(config.sliding_window)
|
||||
else:
|
||||
if getattr(config, "sliding_window", None) is None:
|
||||
config.sliding_window = None
|
||||
|
||||
self.num_layers = config.num_hidden_layers
|
||||
@ -2500,7 +2466,6 @@ class FlashCausalLM(Model):
|
||||
page_size=BLOCK_SIZE,
|
||||
kv_dtype=self.kv_cache_dtype,
|
||||
q_dtype=self.dtype,
|
||||
window_left=self.sliding_window,
|
||||
)
|
||||
else:
|
||||
assert input_lengths_tensor is not None
|
||||
@ -2514,5 +2479,4 @@ class FlashCausalLM(Model):
|
||||
page_size=BLOCK_SIZE,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
q_dtype=self.dtype,
|
||||
window_left=self.sliding_window,
|
||||
)
|
||||
|
@ -110,7 +110,7 @@ class Model(ABC):
|
||||
requires_padding=self.requires_padding,
|
||||
dtype=str(self.dtype),
|
||||
device_type=self.device.type,
|
||||
window_size=self.sliding_window,
|
||||
window_size=None, # Setting this parameter to None disabled the block logic with sliding window.
|
||||
speculate=self.speculate,
|
||||
support_chunking=self.support_chunking,
|
||||
use_prefix_caching=PREFIX_CACHING,
|
||||
|
@ -128,6 +128,12 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
||||
num_pads = grid_t * grid_h * grid_w // 4
|
||||
padding = "<|image_pad|>" * num_pads
|
||||
return f"<|vision_start|>{padding}<|vision_end|>"
|
||||
elif config.model_type == "gemma3":
|
||||
# TODO: get correct number of features via reviewing the Gemma3 architecture
|
||||
# and calculating the number of image tokens
|
||||
num_pads = 256
|
||||
padding = "<image_soft_token>" * num_pads
|
||||
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n"
|
||||
else:
|
||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||
|
||||
@ -244,6 +250,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
|
||||
if config.model_type == "llava_next":
|
||||
images.append(image)
|
||||
elif config.model_type == "gemma3":
|
||||
images.append(image)
|
||||
else:
|
||||
images.append([image])
|
||||
else:
|
||||
|
@ -18,14 +18,10 @@ def get_cuda_free_memory(device, memory_fraction):
|
||||
|
||||
|
||||
def get_xpu_free_memory(device, memory_fraction):
|
||||
total_memory = torch.xpu.get_device_properties(device).total_memory
|
||||
device_id = device.index
|
||||
memory_fraction = float(os.getenv("XPU_MEMORY_FRACTION", "1.0"))
|
||||
total_free_memory, total_xpu_memory = torch.xpu.mem_get_info(device)
|
||||
memory_fraction = float(os.getenv("XPU_MEMORY_FRACTION", "0.9"))
|
||||
free_memory = max(
|
||||
0,
|
||||
int(
|
||||
total_memory * 0.9 * memory_fraction - torch.xpu.memory_reserved(device_id)
|
||||
),
|
||||
0, int(total_free_memory - (1 - memory_fraction) * total_xpu_memory)
|
||||
)
|
||||
return free_memory
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import importlib
|
||||
|
||||
from loguru import logger
|
||||
from hf_kernels import load_kernel as hf_load_kernel
|
||||
from kernels import load_kernel as hf_load_kernel
|
||||
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
|
@ -79,6 +79,8 @@ def _get_quantizer_config(model_id, revision):
|
||||
modules_to_not_convert = data["quantization_config"].get(
|
||||
"modules_to_not_convert", []
|
||||
)
|
||||
if modules_to_not_convert is None:
|
||||
modules_to_not_convert = []
|
||||
except Exception:
|
||||
filename = "quantize_config.json"
|
||||
try:
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user