diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml
index 6d867ebb7..99f29d7eb 100644
--- a/.github/workflows/build.yaml
+++ b/.github/workflows/build.yaml
@@ -6,10 +6,11 @@ on:
hardware:
type: string
description: Hardware
- # options:
- # - cuda
- # - rocm
- # - intel
+ # options:
+ # - cuda
+ # - cuda-trtllm
+ # - rocm
+ # - intel
required: true
release-tests:
description: "Run release integration tests"
@@ -24,22 +25,34 @@ jobs:
docker_volume: ${{ steps.final.outputs.docker_volume }}
docker_devices: ${{ steps.final.outputs.docker_devices }}
runs_on: ${{ steps.final.outputs.runs_on }}
- label: ${{ steps.final.outputs.label }}
+ label_extension: ${{ steps.final.outputs.label_extension }}
extra_pytest: ${{ steps.final.outputs.extra_pytest }}
concurrency:
group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
runs-on:
- group: aws-highmemory-32-plus-priv
+ group: aws-highmemory-64-plus-priv
permissions:
contents: write
packages: write
+ id-token: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1
- - name: Construct harware variables
+ - name: Inject required variables for sccache to interact with Github Actions Cache
+ uses: actions/github-script@v7
+ with:
+ script: |
+ core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || '');
+ core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
+
+ - name: Extract TensorRT-LLM version
+ run: |
+ echo "TENSORRT_LLM_VERSION=$(grep -oP '([a-z,0-9]{40})' $GITHUB_WORKSPACE/backends/trtllm/cmake/trtllm.cmake)" >> $GITHUB_ENV
+ echo "TensorRT-LLM version: ${{ env.TENSORRT_LLM_VERSION }}"
+ - name: Construct hardware variables
shell: bash
run: |
case ${{ inputs.hardware }} in
@@ -51,15 +64,34 @@ jobs:
export runs_on="aws-g6-12xl-plus-priv-cache"
export platform=""
export extra_pytest=""
+ export target=""
+ ;;
+ cuda-trtllm)
+ export dockerfile="Dockerfile_trtllm"
+ export label_extension="-trtllm"
+ export docker_volume="/mnt/cache"
+ export docker_devices=""
+ export runs_on="ubuntu-latest"
+ export platform=""
+ export extra_pytest=""
+ if [[ "${GITHUB_REF}" == refs/tags/* ]]; then
+ export build_type="release";
+ export target="";
+ else
+ export build_type="dev";
+ export target="ci-runtime";
+ fi
;;
rocm)
export dockerfile="Dockerfile_amd"
export label_extension="-rocm"
export docker_devices="/dev/kfd,/dev/dri"
export docker_volume="/mnt"
- export runs_on="amd-gpu-runners"
+ # This runner was deactivated.
+ export runs_on="ubuntu-latest"
export platform=""
export extra_pytest="-k test_flash_gemma_gptq_load"
+ export target=""
;;
intel-xpu)
export dockerfile="Dockerfile_intel"
@@ -69,6 +101,7 @@ jobs:
export runs_on="ubuntu-latest"
export platform="xpu"
export extra_pytest=""
+ export target=""
;;
intel-cpu)
export dockerfile="Dockerfile_intel"
@@ -79,7 +112,27 @@ jobs:
export runs_on="aws-highmemory-32-plus-priv"
export platform="cpu"
export extra_pytest="-k test_flash_gemma_simple"
+ export target=""
;;
+ neuron)
+ export dockerfile="Dockerfile.neuron"
+ export label_extension="-neuron"
+ export docker_devices="/dev/neuron0"
+ export docker_volume="/mnt/cache"
+ export runs_on="aws-inf2-8xlarge"
+ export platform="cpu"
+ export extra_pytest="--neuron"
+ export target=""
+ ;;
+ gaudi)
+ export dockerfile="Dockerfile_gaudi"
+ export label_extension="-gaudi"
+ export docker_volume="/mnt/cache"
+ export docker_devices=""
+ export runs_on="ubuntu-latest"
+ export platform=""
+ export extra_pytest=""
+ export target=""
esac
echo $dockerfile
echo "Dockerfile=${dockerfile}"
@@ -88,19 +141,22 @@ jobs:
echo $runs_on
echo $platform
echo "DOCKERFILE=${dockerfile}" >> $GITHUB_ENV
- echo "LABEL=${label_extension}" >> $GITHUB_ENV
+ echo "LABEL_EXTENSION=${label_extension}" >> $GITHUB_ENV
echo "PLATFORM=${platform}" >> $GITHUB_ENV
echo "DOCKER_VOLUME=${docker_volume}" >> $GITHUB_ENV
echo "DOCKER_DEVICES=${docker_devices}" >> $GITHUB_ENV
echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV
echo "EXTRA_PYTEST=${extra_pytest}" >> $GITHUB_ENV
echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV
+ echo "TARGET=${target}" >> $GITHUB_ENV
+ echo "BUILD_TYPE=${build_type}" >> $GITHUB_ENV
- name: Initialize Docker Buildx
uses: docker/setup-buildx-action@v3
with:
install: true
buildkitd-config: /tmp/buildkitd.toml
- name: Login to internal Container Registry
+ if: github.event_name != 'pull_request'
uses: docker/login-action@v3
with:
username: ${{ secrets.REGISTRY_USERNAME }}
@@ -113,6 +169,12 @@ jobs:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
+ - name: Login to Docker Hub Container Registry
+ uses: docker/login-action@v3
+ with:
+ registry: docker.io
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Login to Azure Container Registry
if: github.event_name != 'pull_request'
uses: docker/login-action@v3
@@ -127,9 +189,9 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
- registry.internal.huggingface.tech/api-inference/community/text-generation-inference
+ docker.io/huggingface/text-generation-inference-ci
tags: |
- type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
+ type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL_EXTENSION }}
# If main, release or tag
- name: Extract metadata (tags, labels) for Docker
if: ${{ github.event_name != 'pull_request' }}
@@ -143,10 +205,10 @@ jobs:
ghcr.io/huggingface/text-generation-inference
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
tags: |
- type=semver,pattern={{version}}${{ env.LABEL }}
- type=semver,pattern={{major}}.{{minor}}${{ env.LABEL }}
- type=raw,value=latest${{ env.LABEL }},enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
- type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
+ type=semver,pattern={{version}}${{ env.LABEL_EXTENSION }}
+ type=semver,pattern={{major}}.{{minor}}${{ env.LABEL_EXTENSION }}
+ type=raw,value=latest${{ env.LABEL_EXTENSION }},enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
+ type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL_EXTENSION }}
- name: Build and push Docker image
id: build-and-push
uses: docker/build-push-action@v4
@@ -157,27 +219,66 @@ jobs:
platforms: 'linux/amd64'
build-args: |
GIT_SHA=${{ env.GITHUB_SHA }}
- DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
+ DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL_EXTENSION }}
PLATFORM=${{ env.PLATFORM }}
+ build_type=${{ env.BUILD_TYPE }}
+ sccache_gha_enabled=on
+ actions_cache_url=${{ env.ACTIONS_CACHE_URL }}
+ actions_runtime_token=${{ env.ACTIONS_RUNTIME_TOKEN }}
+ target: ${{ env.TARGET }}
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
- cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
- cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
+ cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL_EXTENSION }},mode=max,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
+ cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL_EXTENSION }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
- name: Final
id: final
run: |
- echo "docker_image=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
+
+ if [ "${{ github.event_name }}" = "pull_request" ]; then
+ echo "docker_image=docker.io/huggingface/text-generation-inference-ci:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
+ else
+ echo "docker_image=ghcr.io/huggingface/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
+ fi
echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT"
echo "docker_volume=${{ env.DOCKER_VOLUME }}" >> "$GITHUB_OUTPUT"
echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT"
- echo "label=${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
+ echo "label_extension=${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
echo "extra_pytest=${{ env.EXTRA_PYTEST }}" >> "$GITHUB_OUTPUT"
- integration_tests:
+ precompile_neuron_models:
concurrency:
- group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label }}-${{ github.head_ref || github.run_id }}
+ group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label_extension }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
needs: build-and-push
- if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
+ if: needs.build-and-push.outputs.label_extension == '-neuron'
+ runs-on:
+ group: ${{ needs.build-and-push.outputs.runs_on }}
+ env:
+ PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '--release' }}
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+ - name: Inject slug/short variables
+ uses: rlespinasse/github-slug-action@v4.4.1
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.11"
+ - name: Install
+ run: |
+ make install-integration-tests
+ - name: Export neuron models
+ run: |
+ export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }}
+ echo $DOCKER_IMAGE
+ docker pull $DOCKER_IMAGE
+ export HF_TOKEN=${{ secrets.HF_TOKEN_NEURON }}
+ python integration-tests/fixtures/neuron/export_models.py
+ integration_tests:
+ concurrency:
+ group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label_extension }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+ needs: [precompile_neuron_models, build-and-push]
+ if: ${{ always() && !contains(needs.*.result, 'failure') && !contains(needs.*.result, 'cancelled') && needs.build-and-push.outputs.runs_on != 'ubuntu-latest' }}
runs-on:
group: ${{ needs.build-and-push.outputs.runs_on }}
env:
@@ -204,3 +305,23 @@ jobs:
echo $DOCKER_IMAGE
docker pull $DOCKER_IMAGE
pytest -s -vv integration-tests ${PYTEST_FLAGS} ${EXTRA_PYTEST}
+
+ backend_trtllm_cxx_tests:
+ needs: build-and-push
+ if: needs.build-and-push.outputs.label_extension == '-trtllm'
+ concurrency:
+ group: ${{ github.workflow }}-${{ github.job }}-trtllm-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+ runs-on:
+ group: aws-g6-12xl-plus-priv-cache
+ container:
+ image: ${{ needs.build-and-push.outputs.docker_image }}
+ credentials:
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_PASSWORD }}
+ options: --gpus all --shm-size=8g
+
+ steps:
+ - name: Run C++/CUDA tests
+ if: ${{ env.LABEL_EXTENSION == 'ci-runtime' }}
+ run: /usr/local/tgi/bin/tgi_trtllm_backend_tests
diff --git a/.github/workflows/ci_build.yaml b/.github/workflows/ci_build.yaml
index 5190f3217..f0d39399b 100644
--- a/.github/workflows/ci_build.yaml
+++ b/.github/workflows/ci_build.yaml
@@ -20,6 +20,8 @@ on:
- "Dockerfile"
- "Dockerfile_amd"
- "Dockerfile_intel"
+ - "Dockerfile.neuron"
+ - "Dockerfile_gaudi"
branches:
- "main"
workflow_dispatch:
@@ -37,11 +39,12 @@ jobs:
# fail-fast is true by default
fail-fast: false
matrix:
- hardware: ["cuda", "rocm", "intel-xpu", "intel-cpu"]
+ hardware: ["cuda", "cuda-trtllm", "rocm", "intel-xpu", "intel-cpu", "neuron", "gaudi"]
uses: ./.github/workflows/build.yaml # calls the one above ^
permissions:
contents: write
packages: write
+ id-token: write
with:
hardware: ${{ matrix.hardware }}
# https://github.com/actions/runner/issues/2206
diff --git a/.github/workflows/nix_build.yaml b/.github/workflows/nix_build.yaml
new file mode 100644
index 000000000..71ad59d0b
--- /dev/null
+++ b/.github/workflows/nix_build.yaml
@@ -0,0 +1,53 @@
+name: "Nix Build Docker image"
+on:
+ pull_request:
+ push:
+ branches:
+ - 'main'
+ tags:
+ - 'v*'
+concurrency:
+ group: nix-image-${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+jobs:
+ build_nix_image:
+ runs-on:
+ group: aws-highmemory-32-plus-priv
+ steps:
+ - uses: actions/checkout@v4
+ - uses: cachix/install-nix-action@v27
+ with:
+ nix_path: nixpkgs=channel:nixos-unstable
+ - uses: cachix/cachix-action@v14
+ with:
+ name: text-generation-inference
+ # If you chose signing key for write access
+ authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
+ env:
+ USER: github_runner
+ - name: Build
+ run: nix build .#dockerImage
+ - name: Initialize Docker Buildx
+ uses: docker/setup-buildx-action@v3
+ with:
+ install: true
+ buildkitd-config: /tmp/buildkitd.toml
+ - name: Inject slug/short variables
+ uses: rlespinasse/github-slug-action@v4.4.1
+ - name: Login to internal Container Registry
+ # if: github.event_name != 'pull_request'
+ uses: docker/login-action@v3
+ with:
+ username: ${{ secrets.REGISTRY_USERNAME }}
+ password: ${{ secrets.REGISTRY_PASSWORD }}
+ registry: registry.internal.huggingface.tech
+ - name: Push to docker
+ run: |
+ if [ "${{ github.event_name }}" = "pull_request" ]; then
+ export TAG=nix-sha-${{ env.GITHUB_SHA_SHORT }}
+ else
+ export TAG=${{ github.ref_name }}-nix
+ fi
+ export IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:$TAG
+ nix-shell -p skopeo --command "skopeo --insecure-policy copy docker-archive:$(readlink -f ./result) docker://$IMAGE --dest-compress-format zstd"
diff --git a/.github/workflows/nix_tests.yaml b/.github/workflows/nix_tests.yaml
index f2209f8a4..d9b910483 100644
--- a/.github/workflows/nix_tests.yaml
+++ b/.github/workflows/nix_tests.yaml
@@ -7,6 +7,7 @@ on:
- "proto/**"
- "router/**"
- "launcher/**"
+ - "backends/**"
- "Cargo.lock"
- "rust-toolchain.toml"
concurrency:
diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml
index 4eeca3348..3e431c861 100644
--- a/.github/workflows/tests.yaml
+++ b/.github/workflows/tests.yaml
@@ -8,6 +8,7 @@ on:
- "proto/**"
- "router/**"
- "launcher/**"
+ - "backends/**"
- "Cargo.lock"
- "rust-toolchain.toml"
@@ -20,19 +21,14 @@ jobs:
runs-on:
group: aws-highmemory-32-plus-priv
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
id: python
with:
python-version: 3.11
- - name: Install Rust
- uses: actions-rs/toolchain@v1
+ - uses: dtolnay/rust-toolchain@1.85.0
with:
- # Released on: 02 May, 2024
- # https://releases.rs/docs/1.78.0/
- toolchain: 1.80.0
- override: true
components: rustfmt, clippy
- name: Install Protoc
uses: arduino/setup-protoc@v1
@@ -44,10 +40,18 @@ jobs:
run: |
sudo apt update
sudo apt install python3.11-dev -y
+ pip install -U pip uv
+ uv venv
+ source ./.venv/bin/activate
make install-cpu
+ - name: Download locked kernels
+ run: |
+ source ./.venv/bin/activate
+ kernels download server
- name: Run server tests
run: |
- pip install pytest
+ source ./.venv/bin/activate
+ uv pip install pytest
export HF_TOKEN=${{ secrets.HF_TOKEN }}
pytest -s -vv server/tests
- name: Pre-commit checks
diff --git a/.github/workflows/trufflehog.yaml b/.github/workflows/trufflehog.yaml
index b406d43b8..9f1c5f36d 100644
--- a/.github/workflows/trufflehog.yaml
+++ b/.github/workflows/trufflehog.yaml
@@ -10,9 +10,12 @@ jobs:
trufflehog:
runs-on: ubuntu-latest
steps:
- - name: Checkout code
- uses: actions/checkout@v4
- with:
- fetch-depth: 0
- - name: Secret Scanning
- uses: trufflesecurity/trufflehog@main
+ - name: Checkout code
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+ - name: Secret Scanning
+ uses: trufflesecurity/trufflehog@853e1e8d249fd1e29d0fcc7280d29b03df3d643d
+ with:
+ # exclude buggy postgres detector that is causing false positives and not relevant to our codebase
+ extra_args: --results=verified,unknown --exclude-detectors=postgres
diff --git a/.gitignore b/.gitignore
index 9434d75ca..8a6bda722 100644
--- a/.gitignore
+++ b/.gitignore
@@ -23,3 +23,9 @@ server/fbgemmm
.direnv/
.venv/
+
+# Gaudi auto-generated files
+hl-smi_log*.txt
+.graph_dumps
+out
+hqt_output
diff --git a/Cargo.lock b/Cargo.lock
index 9551ae2d9..3cac30fb6 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1,6 +1,6 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
-version = 3
+version = 4
[[package]]
name = "addr2line"
@@ -24,11 +24,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011"
dependencies = [
"cfg-if",
- "getrandom",
+ "getrandom 0.2.15",
"once_cell",
"serde",
"version_check",
- "zerocopy",
+ "zerocopy 0.7.35",
]
[[package]]
@@ -48,9 +48,24 @@ checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1"
[[package]]
name = "allocator-api2"
-version = "0.2.20"
+version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9"
+checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
+
+[[package]]
+name = "android-tzdata"
+version = "0.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
+
+[[package]]
+name = "android_system_properties"
+version = "0.1.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
+dependencies = [
+ "libc",
+]
[[package]]
name = "anstream"
@@ -93,19 +108,20 @@ dependencies = [
[[package]]
name = "anstyle-wincon"
-version = "3.0.6"
+version = "3.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125"
+checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e"
dependencies = [
"anstyle",
+ "once_cell",
"windows-sys 0.59.0",
]
[[package]]
name = "anyhow"
-version = "1.0.93"
+version = "1.0.97"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775"
+checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f"
[[package]]
name = "arbitrary"
@@ -127,7 +143,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -166,18 +182,18 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
name = "async-trait"
-version = "0.1.83"
+version = "0.1.88"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd"
+checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -230,9 +246,9 @@ dependencies = [
[[package]]
name = "avif-serialize"
-version = "0.8.2"
+version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e335041290c43101ca215eed6f43ec437eb5a42125573f600fc3fa42b9bddd62"
+checksum = "98922d6a4cfbcb08820c69d8eeccc05bb1f29bfa06b4f5b1dbfe9a868bd7608e"
dependencies = [
"arrayvec",
]
@@ -251,29 +267,25 @@ dependencies = [
[[package]]
name = "aws-lc-rs"
-version = "1.11.0"
+version = "1.12.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "fe7c2840b66236045acd2607d5866e274380afd87ef99d6226e961e2cb47df45"
+checksum = "dabb68eb3a7aa08b46fddfd59a3d55c978243557a90ab804769f7e20e67d2b01"
dependencies = [
"aws-lc-sys",
- "mirai-annotations",
- "paste",
"zeroize",
]
[[package]]
name = "aws-lc-sys"
-version = "0.23.0"
+version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ad3a619a9de81e1d7de1f1186dcba4506ed661a0e483d84410fdef0ee87b2f96"
+checksum = "77926887776171ced7d662120a75998e444d3750c951abfe07f90da130514b1f"
dependencies = [
- "bindgen",
+ "bindgen 0.69.5",
"cc",
"cmake",
"dunce",
"fs_extra",
- "libc",
- "paste",
]
[[package]]
@@ -289,7 +301,7 @@ dependencies = [
"futures-util",
"http 0.2.12",
"http-body 0.4.6",
- "hyper 0.14.31",
+ "hyper 0.14.32",
"itoa",
"matchit",
"memchr",
@@ -318,10 +330,10 @@ dependencies = [
"axum-core 0.4.5",
"bytes",
"futures-util",
- "http 1.1.0",
+ "http 1.3.1",
"http-body 1.0.1",
"http-body-util",
- "hyper 1.5.1",
+ "hyper 1.6.0",
"hyper-util",
"itoa",
"matchit",
@@ -336,7 +348,7 @@ dependencies = [
"serde_urlencoded",
"sync_wrapper 1.0.2",
"tokio",
- "tower 0.5.1",
+ "tower 0.5.2",
"tower-layer",
"tower-service",
"tracing",
@@ -368,7 +380,7 @@ dependencies = [
"async-trait",
"bytes",
"futures-util",
- "http 1.1.0",
+ "http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"mime",
@@ -389,7 +401,7 @@ dependencies = [
"axum 0.7.9",
"futures-core",
"futures-util",
- "http 1.1.0",
+ "http 1.3.1",
"opentelemetry 0.21.0",
"pin-project-lite",
"tower 0.4.13",
@@ -437,7 +449,7 @@ version = "0.69.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
"cexpr",
"clang-sys",
"itertools 0.12.1",
@@ -448,26 +460,46 @@ dependencies = [
"proc-macro2",
"quote",
"regex",
- "rustc-hash",
+ "rustc-hash 1.1.0",
"shlex",
- "syn 2.0.89",
+ "syn 2.0.100",
"which",
]
[[package]]
-name = "bit-set"
-version = "0.5.3"
+name = "bindgen"
+version = "0.71.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1"
+checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3"
+dependencies = [
+ "bitflags 2.9.0",
+ "cexpr",
+ "clang-sys",
+ "itertools 0.13.0",
+ "log",
+ "prettyplease",
+ "proc-macro2",
+ "quote",
+ "regex",
+ "rustc-hash 2.1.1",
+ "shlex",
+ "syn 2.0.100",
+]
+
+[[package]]
+name = "bit-set"
+version = "0.8.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3"
dependencies = [
"bit-vec",
]
[[package]]
name = "bit-vec"
-version = "0.6.3"
+version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
+checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7"
[[package]]
name = "bit_field"
@@ -483,9 +515,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bitflags"
-version = "2.6.0"
+version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
+checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd"
[[package]]
name = "bitstream-io"
@@ -503,16 +535,22 @@ dependencies = [
]
[[package]]
-name = "built"
-version = "0.7.5"
+name = "borrow-or-share"
+version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c360505aed52b7ec96a3636c3f039d99103c37d1d9b4f7a8c743d3ea9ffcd03b"
+checksum = "3eeab4423108c5d7c744f4d234de88d18d636100093ae04caf4825134b9c3a32"
+
+[[package]]
+name = "built"
+version = "0.7.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "56ed6191a7e78c36abdb16ab65341eefd73d64d303fffccdbb00d51e4205967b"
[[package]]
name = "bumpalo"
-version = "3.16.0"
+version = "3.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
+checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf"
[[package]]
name = "bytecount"
@@ -522,9 +560,9 @@ checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce"
[[package]]
name = "bytemuck"
-version = "1.20.0"
+version = "1.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8b37c88a63ffd85d15b406896cc343916d7cf57838a847b3a6f2ca5d39a5695a"
+checksum = "b6b1fc10dbac614ebc03540c9dbd60e83887fda27794998c6528f1782047d540"
[[package]]
name = "byteorder"
@@ -540,9 +578,9 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495"
[[package]]
name = "bytes"
-version = "1.8.0"
+version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da"
+checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a"
[[package]]
name = "camino"
@@ -555,9 +593,9 @@ dependencies = [
[[package]]
name = "cargo-platform"
-version = "0.1.8"
+version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "24b1f0365a6c6bb4020cd05806fd0d33c44d38046b8bd7f0e40814b9763cabfc"
+checksum = "e35af189006b9c0f00a064685c727031e3ed2d8020f7ba284d78cc2671bd36ea"
dependencies = [
"serde",
]
@@ -573,7 +611,7 @@ dependencies = [
"semver",
"serde",
"serde_json",
- "thiserror",
+ "thiserror 1.0.69",
]
[[package]]
@@ -599,9 +637,9 @@ dependencies = [
[[package]]
name = "cc"
-version = "1.2.1"
+version = "1.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47"
+checksum = "1fcb57c740ae1daf453ae85f16e37396f672b039e00d9d866e07ddb24e328e3a"
dependencies = [
"jobserver",
"libc",
@@ -645,6 +683,20 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
+[[package]]
+name = "chrono"
+version = "0.4.40"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1a7964611d71df112cb1730f2ee67324fcf4d0fc6606acbbe9bfe06df124637c"
+dependencies = [
+ "android-tzdata",
+ "iana-time-zone",
+ "js-sys",
+ "num-traits",
+ "wasm-bindgen",
+ "windows-link",
+]
+
[[package]]
name = "clang-sys"
version = "1.8.1"
@@ -669,9 +721,9 @@ dependencies = [
[[package]]
name = "clap"
-version = "4.5.21"
+version = "4.5.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f"
+checksum = "6088f3ae8c3608d19260cd7445411865a485688711b78b5be70d78cd96136f83"
dependencies = [
"clap_builder",
"clap_derive",
@@ -679,9 +731,9 @@ dependencies = [
[[package]]
name = "clap_builder"
-version = "4.5.21"
+version = "4.5.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec"
+checksum = "22a7ef7f676155edfb82daa97f99441f3ebf4a58d5e32f295a56259f1b6facc8"
dependencies = [
"anstream",
"anstyle",
@@ -691,39 +743,40 @@ dependencies = [
[[package]]
name = "clap_derive"
-version = "4.5.18"
+version = "4.5.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab"
+checksum = "09176aae279615badda0765c0c0b3f6ed53f4709118af73cf4655d85d1530cd7"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
name = "clap_lex"
-version = "0.7.3"
+version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7"
+checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6"
[[package]]
name = "cmake"
-version = "0.1.51"
+version = "0.1.54"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "fb1e43aa7fd152b1f968787f7dbcdeb306d1867ff373c69955211876c053f91a"
+checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0"
dependencies = [
"cc",
]
[[package]]
name = "codespan-reporting"
-version = "0.11.1"
+version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e"
+checksum = "fe6d2e5af09e8c8ad56c969f2157a3d4238cebc7c55f0a517728c38f7b200f81"
dependencies = [
+ "serde",
"termcolor",
- "unicode-width 0.1.14",
+ "unicode-width 0.2.0",
]
[[package]]
@@ -740,9 +793,9 @@ checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990"
[[package]]
name = "compact_str"
-version = "0.8.0"
+version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6050c3a16ddab2e412160b31f2c871015704239bca62f72f6e5f0be631d3f644"
+checksum = "3b79c4069c6cad78e2e0cdfcbd26275770669fb39fd308a752dc110e83b9af32"
dependencies = [
"castaway",
"cfg-if",
@@ -754,15 +807,15 @@ dependencies = [
[[package]]
name = "console"
-version = "0.15.8"
+version = "0.15.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb"
+checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8"
dependencies = [
"encode_unicode",
- "lazy_static",
"libc",
- "unicode-width 0.1.14",
- "windows-sys 0.52.0",
+ "once_cell",
+ "unicode-width 0.2.0",
+ "windows-sys 0.59.0",
]
[[package]]
@@ -793,9 +846,9 @@ checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
[[package]]
name = "cpufeatures"
-version = "0.2.16"
+version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3"
+checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280"
dependencies = [
"libc",
]
@@ -847,18 +900,18 @@ dependencies = [
[[package]]
name = "crossbeam-channel"
-version = "0.5.13"
+version = "0.5.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "33480d6946193aa8033910124896ca395333cae7e2d1113d1fef6c3272217df2"
+checksum = "06ba6d68e24814cb8de6bb986db8222d3a027d15872cabc0d18817bc3c0e4471"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-deque"
-version = "0.8.5"
+version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d"
+checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
@@ -875,9 +928,9 @@ dependencies = [
[[package]]
name = "crossbeam-utils"
-version = "0.8.20"
+version = "0.8.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80"
+checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
[[package]]
name = "crossterm"
@@ -885,11 +938,11 @@ version = "0.28.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "829d955a0bb380ef178a640b91779e3987da38c9aea133b20614cfed8cdea9c6"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
"crossterm_winapi",
"mio",
"parking_lot",
- "rustix",
+ "rustix 0.38.44",
"signal-hook",
"signal-hook-mio",
"winapi",
@@ -906,9 +959,9 @@ dependencies = [
[[package]]
name = "crunchy"
-version = "0.2.2"
+version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
+checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929"
[[package]]
name = "crypto-common"
@@ -934,9 +987,9 @@ dependencies = [
[[package]]
name = "csv-core"
-version = "0.1.11"
+version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70"
+checksum = "7d02f3b0da4c6504f86e9cd789d8dbafab48c2321be74e9987593de5a894d93d"
dependencies = [
"memchr",
]
@@ -953,46 +1006,61 @@ dependencies = [
[[package]]
name = "cxx"
-version = "1.0.130"
+version = "1.0.150"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "23c042a0ba58aaff55299632834d1ea53ceff73d62373f62c9ae60890ad1b942"
+checksum = "6d1cf22155cf6a8e0b0536efc30c775eadd7a481c376d2d7e30daf0825a42ef9"
dependencies = [
"cc",
+ "cxxbridge-cmd",
"cxxbridge-flags",
"cxxbridge-macro",
+ "foldhash",
"link-cplusplus",
]
[[package]]
name = "cxx-build"
-version = "1.0.130"
+version = "1.0.150"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "45dc1c88d0fdac57518a9b1f6c4f4fb2aca8f3c30c0d03d7d8518b47ca0bcea6"
+checksum = "db4e07e3a69db032f03450594e53785a5d6b1d787c2ad5b901d9347f0064af94"
dependencies = [
"cc",
"codespan-reporting",
"proc-macro2",
"quote",
"scratch",
- "syn 2.0.89",
+ "syn 2.0.100",
+]
+
+[[package]]
+name = "cxxbridge-cmd"
+version = "1.0.150"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "48e9ff9c627d3abe06190462f7db81fb6cc12f3424ea081c2a8c9ed7a8cc167a"
+dependencies = [
+ "clap 4.5.32",
+ "codespan-reporting",
+ "proc-macro2",
+ "quote",
+ "syn 2.0.100",
]
[[package]]
name = "cxxbridge-flags"
-version = "1.0.130"
+version = "1.0.150"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "aa7ed7d30b289e2592cc55bc2ccd89803a63c913e008e6eb59f06cddf45bb52f"
+checksum = "2e6417f4e1518ded330e088d5a66f50fbae9bbc96840e147058ae44970a2b51a"
[[package]]
name = "cxxbridge-macro"
-version = "1.0.130"
+version = "1.0.150"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0b8c465d22de46b851c04630a5fc749a26005b263632ed2e0d9cc81518ead78d"
+checksum = "856ff0dba6e023dd78189c8f4667126842dfe88392b5d4e94118bd18b8f2afbf"
dependencies = [
"proc-macro2",
"quote",
"rustversion",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -1016,7 +1084,7 @@ dependencies = [
"proc-macro2",
"quote",
"strsim",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -1027,14 +1095,14 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806"
dependencies = [
"darling_core",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
name = "deranged"
-version = "0.3.11"
+version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4"
+checksum = "28cfac68e08048ae1883171632c2aef3ebc555621ae56fbccce1cbf22dd7f058"
dependencies = [
"powerfmt",
]
@@ -1057,7 +1125,7 @@ dependencies = [
"darling",
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -1067,15 +1135,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c"
dependencies = [
"derive_builder_core",
- "syn 2.0.89",
+ "syn 2.0.100",
]
-[[package]]
-name = "diff"
-version = "0.1.13"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8"
-
[[package]]
name = "digest"
version = "0.10.7"
@@ -1115,7 +1177,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -1126,24 +1188,33 @@ checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813"
[[package]]
name = "easy-cast"
-version = "0.5.2"
+version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "10936778145f3bea71fd9bf61332cce28c28e96a380714f7ab34838b80733fd6"
+checksum = "72852736692ec862655eca398c9bb1b476161b563c9f80f45f4808b9629750d6"
dependencies = [
"libm",
]
[[package]]
name = "either"
-version = "1.13.0"
+version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
+checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
+
+[[package]]
+name = "email_address"
+version = "0.2.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e079f19b08ca6239f47f8ba8509c11cf3ea30095831f7fed61441475edd8c449"
+dependencies = [
+ "serde",
+]
[[package]]
name = "encode_unicode"
-version = "0.3.6"
+version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f"
+checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0"
[[package]]
name = "encoding_rs"
@@ -1156,18 +1227,18 @@ dependencies = [
[[package]]
name = "equivalent"
-version = "1.0.1"
+version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
+checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
[[package]]
name = "errno"
-version = "0.3.9"
+version = "0.3.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba"
+checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d"
dependencies = [
"libc",
- "windows-sys 0.52.0",
+ "windows-sys 0.59.0",
]
[[package]]
@@ -1186,7 +1257,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83197f59927b46c04a183a619b7c29df34e63e63c7869320862268c0ef687e0"
dependencies = [
"bit_field",
- "half 2.4.1",
+ "half 2.5.0",
"lebe",
"miniz_oxide",
"rayon-core",
@@ -1196,25 +1267,26 @@ dependencies = [
[[package]]
name = "fancy-regex"
-version = "0.11.0"
+version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2"
+checksum = "6e24cb5a94bcae1e5408b0effca5cd7172ea3c5755049c5f3af4cd283a165298"
dependencies = [
"bit-set",
- "regex",
+ "regex-automata 0.4.9",
+ "regex-syntax 0.8.5",
]
[[package]]
name = "fastrand"
-version = "2.2.0"
+version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4"
+checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "fdeflate"
-version = "0.3.6"
+version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "07c6f4c64c1d33a3111c4466f7365ebdcc37c5bd1ea0d62aae2e3d722aacbedb"
+checksum = "1e6853b52649d4ac5c0bd02320cddc5ba956bdb407c4b75a2c6b75bf51500f8c"
dependencies = [
"simd-adler32",
]
@@ -1227,9 +1299,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
[[package]]
name = "flate2"
-version = "1.0.35"
+version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c"
+checksum = "11faaf5a5236997af9848be0bef4db95824b1d534ebc64d0f0c6cf3e67bd38dc"
dependencies = [
"crc32fast",
"miniz_oxide",
@@ -1247,6 +1319,17 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853"
+[[package]]
+name = "fluent-uri"
+version = "0.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1918b65d96df47d3591bed19c5cca17e3fa5d0707318e4b5ef2eae01764df7e5"
+dependencies = [
+ "borrow-or-share",
+ "ref-cast",
+ "serde",
+]
+
[[package]]
name = "fnv"
version = "1.0.7"
@@ -1255,9 +1338,9 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "foldhash"
-version = "0.1.3"
+version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2"
+checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
name = "foreign-types"
@@ -1285,9 +1368,9 @@ dependencies = [
[[package]]
name = "fraction"
-version = "0.13.1"
+version = "0.15.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3027ae1df8d41b4bed2241c8fdad4acc1e7af60c8e17743534b545e77182d678"
+checksum = "0f158e3ff0a1b334408dc9fb811cd99b446986f4d8b741bb08f9df1604085ae7"
dependencies = [
"lazy_static",
"num",
@@ -1355,7 +1438,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -1414,10 +1497,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [
"cfg-if",
- "js-sys",
"libc",
- "wasi",
- "wasm-bindgen",
+ "wasi 0.11.0+wasi-snapshot-preview1",
+]
+
+[[package]]
+name = "getrandom"
+version = "0.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "73fea8450eea4bac3940448fb7ae50d91f034f941199fcd9d909a5a07aa455f0"
+dependencies = [
+ "cfg-if",
+ "libc",
+ "r-efi",
+ "wasi 0.14.2+wasi-0.2.4",
]
[[package]]
@@ -1438,9 +1531,9 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
[[package]]
name = "glob"
-version = "0.3.1"
+version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
+checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
[[package]]
name = "grpc-metadata"
@@ -1464,7 +1557,7 @@ dependencies = [
"futures-sink",
"futures-util",
"http 0.2.12",
- "indexmap 2.6.0",
+ "indexmap 2.8.0",
"slab",
"tokio",
"tokio-util",
@@ -1473,17 +1566,17 @@ dependencies = [
[[package]]
name = "h2"
-version = "0.4.7"
+version = "0.4.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ccae279728d634d083c00f6099cb58f01cc99c145b84b8be2f6c74618d79922e"
+checksum = "5017294ff4bb30944501348f6f8e42e6ad28f42c8bbef7a74029aff064a4e3c2"
dependencies = [
"atomic-waker",
"bytes",
"fnv",
"futures-core",
"futures-sink",
- "http 1.1.0",
- "indexmap 2.6.0",
+ "http 1.3.1",
+ "indexmap 2.8.0",
"slab",
"tokio",
"tokio-util",
@@ -1498,9 +1591,9 @@ checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403"
[[package]]
name = "half"
-version = "2.4.1"
+version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888"
+checksum = "7db2ff139bba50379da6aa0766b52fdcb62cb5b263009b09ed58ba604e14bbd1"
dependencies = [
"cfg-if",
"crunchy",
@@ -1519,14 +1612,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
dependencies = [
"ahash",
- "allocator-api2",
]
[[package]]
name = "hashbrown"
-version = "0.15.1"
+version = "0.15.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3"
+checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289"
dependencies = [
"allocator-api2",
"equivalent",
@@ -1567,27 +1659,47 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732"
dependencies = [
"dirs",
- "futures",
"indicatif",
"log",
"native-tls",
- "num_cpus",
- "rand",
- "reqwest",
+ "rand 0.8.5",
"serde",
"serde_json",
- "thiserror",
- "tokio",
+ "thiserror 1.0.69",
"ureq",
]
[[package]]
-name = "home"
-version = "0.5.9"
+name = "hf-hub"
+version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5"
+checksum = "cc03dcb0b0a83ae3f3363ec811014ae669f083e4e499c66602f447c4828737a1"
dependencies = [
- "windows-sys 0.52.0",
+ "dirs",
+ "futures",
+ "http 1.3.1",
+ "indicatif",
+ "libc",
+ "log",
+ "native-tls",
+ "num_cpus",
+ "rand 0.8.5",
+ "reqwest 0.12.15",
+ "serde",
+ "serde_json",
+ "thiserror 2.0.12",
+ "tokio",
+ "ureq",
+ "windows-sys 0.59.0",
+]
+
+[[package]]
+name = "home"
+version = "0.5.11"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf"
+dependencies = [
+ "windows-sys 0.59.0",
]
[[package]]
@@ -1614,9 +1726,9 @@ dependencies = [
[[package]]
name = "http"
-version = "1.1.0"
+version = "1.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258"
+checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565"
dependencies = [
"bytes",
"fnv",
@@ -1641,27 +1753,27 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184"
dependencies = [
"bytes",
- "http 1.1.0",
+ "http 1.3.1",
]
[[package]]
name = "http-body-util"
-version = "0.1.2"
+version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f"
+checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a"
dependencies = [
"bytes",
- "futures-util",
- "http 1.1.0",
+ "futures-core",
+ "http 1.3.1",
"http-body 1.0.1",
"pin-project-lite",
]
[[package]]
name = "httparse"
-version = "1.9.5"
+version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946"
+checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
[[package]]
name = "httpdate"
@@ -1671,9 +1783,9 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "hyper"
-version = "0.14.31"
+version = "0.14.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8c08302e8fa335b151b788c775ff56e7a03ae64ff85c548ee820fecb70356e85"
+checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7"
dependencies = [
"bytes",
"futures-channel",
@@ -1695,15 +1807,15 @@ dependencies = [
[[package]]
name = "hyper"
-version = "1.5.1"
+version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "97818827ef4f364230e16705d4706e2897df2bb60617d6ca15d598025a3c481f"
+checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80"
dependencies = [
"bytes",
"futures-channel",
"futures-util",
- "h2 0.4.7",
- "http 1.1.0",
+ "h2 0.4.8",
+ "http 1.3.1",
"http-body 1.0.1",
"httparse",
"httpdate",
@@ -1716,16 +1828,16 @@ dependencies = [
[[package]]
name = "hyper-rustls"
-version = "0.27.3"
+version = "0.27.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333"
+checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2"
dependencies = [
"futures-util",
- "http 1.1.0",
- "hyper 1.5.1",
+ "http 1.3.1",
+ "hyper 1.6.0",
"hyper-util",
"log",
- "rustls 0.23.17",
+ "rustls 0.23.25",
"rustls-native-certs",
"rustls-pki-types",
"tokio",
@@ -1739,7 +1851,7 @@ version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1"
dependencies = [
- "hyper 0.14.31",
+ "hyper 0.14.32",
"pin-project-lite",
"tokio",
"tokio-io-timeout",
@@ -1752,12 +1864,28 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905"
dependencies = [
"bytes",
- "hyper 0.14.31",
+ "hyper 0.14.32",
"native-tls",
"tokio",
"tokio-native-tls",
]
+[[package]]
+name = "hyper-tls"
+version = "0.6.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
+dependencies = [
+ "bytes",
+ "http-body-util",
+ "hyper 1.6.0",
+ "hyper-util",
+ "native-tls",
+ "tokio",
+ "tokio-native-tls",
+ "tower-service",
+]
+
[[package]]
name = "hyper-util"
version = "0.1.10"
@@ -1767,9 +1895,9 @@ dependencies = [
"bytes",
"futures-channel",
"futures-util",
- "http 1.1.0",
+ "http 1.3.1",
"http-body 1.0.1",
- "hyper 1.5.1",
+ "hyper 1.6.0",
"pin-project-lite",
"socket2",
"tokio",
@@ -1777,6 +1905,29 @@ dependencies = [
"tracing",
]
+[[package]]
+name = "iana-time-zone"
+version = "0.1.61"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220"
+dependencies = [
+ "android_system_properties",
+ "core-foundation-sys",
+ "iana-time-zone-haiku",
+ "js-sys",
+ "wasm-bindgen",
+ "windows-core",
+]
+
+[[package]]
+name = "iana-time-zone-haiku"
+version = "0.1.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
+dependencies = [
+ "cc",
+]
+
[[package]]
name = "icu_collections"
version = "1.5.0"
@@ -1892,7 +2043,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -1947,9 +2098,9 @@ dependencies = [
[[package]]
name = "image-webp"
-version = "0.2.0"
+version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e031e8e3d94711a9ccb5d6ea357439ef3dcbed361798bd4071dc4d9793fbe22f"
+checksum = "b77d01e822461baa8409e156015a1d91735549f0f2c17691bd2d996bef238f7f"
dependencies = [
"byteorder-lite",
"quick-error",
@@ -1973,20 +2124,20 @@ dependencies = [
[[package]]
name = "indexmap"
-version = "2.6.0"
+version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da"
+checksum = "3954d50fe15b02142bf25d3b8bdadb634ec3948f103d04ffe3031bc8fe9d7058"
dependencies = [
"equivalent",
- "hashbrown 0.15.1",
+ "hashbrown 0.15.2",
"serde",
]
[[package]]
name = "indicatif"
-version = "0.17.9"
+version = "0.17.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cbf675b85ed934d3c67b5c5469701eec7db22689d0a2139d856e0925fa28b281"
+checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235"
dependencies = [
"console",
"number_prefix",
@@ -1997,9 +2148,9 @@ dependencies = [
[[package]]
name = "indoc"
-version = "2.0.5"
+version = "2.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5"
+checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd"
[[package]]
name = "init-tracing-opentelemetry"
@@ -2009,23 +2160,22 @@ checksum = "94bd26b1b737bc11f183620072e188d1c6ede67e0e78682228d66b49ec510e17"
dependencies = [
"opentelemetry 0.20.0",
"opentelemetry-otlp",
- "thiserror",
+ "thiserror 1.0.69",
"tracing",
"tracing-opentelemetry 0.21.0",
]
[[package]]
name = "instability"
-version = "0.3.3"
+version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b829f37dead9dc39df40c2d3376c179fdfd2ac771f53f55d3c30dc096a3c0c6e"
+checksum = "0bf9fed6d91cfb734e7476a06bde8300a1b94e217e1b523b6f0cd1a01998c71d"
dependencies = [
"darling",
"indoc",
- "pretty_assertions",
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -2036,14 +2186,14 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
name = "ipnet"
-version = "2.10.1"
+version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708"
+checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
[[package]]
name = "is_terminal_polyfill"
@@ -2051,15 +2201,6 @@ version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
-[[package]]
-name = "iso8601"
-version = "0.6.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "924e5d73ea28f59011fec52a0d12185d496a9b075d360657aed2a5707f701153"
-dependencies = [
- "nom",
-]
-
[[package]]
name = "itertools"
version = "0.10.5"
@@ -2098,9 +2239,9 @@ dependencies = [
[[package]]
name = "itoa"
-version = "1.0.13"
+version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "540654e97a3f4470a492cd30ff187bc95d89557a903a2bbf112e2fae98104ef2"
+checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
[[package]]
name = "jobserver"
@@ -2119,41 +2260,37 @@ checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0"
[[package]]
name = "js-sys"
-version = "0.3.72"
+version = "0.3.77"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9"
+checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f"
dependencies = [
+ "once_cell",
"wasm-bindgen",
]
[[package]]
name = "jsonschema"
-version = "0.17.1"
+version = "0.28.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2a071f4f7efc9a9118dfb627a0a94ef247986e1ab8606a4c806ae2b3aa3b6978"
+checksum = "4b8f66fe41fa46a5c83ed1c717b7e0b4635988f427083108c8cf0a882cc13441"
dependencies = [
"ahash",
- "anyhow",
- "base64 0.21.7",
+ "base64 0.22.1",
"bytecount",
- "clap 4.5.21",
+ "email_address",
"fancy-regex",
"fraction",
- "getrandom",
- "iso8601",
+ "idna",
"itoa",
- "memchr",
"num-cmp",
"once_cell",
- "parking_lot",
"percent-encoding",
- "regex",
- "reqwest",
+ "referencing",
+ "regex-syntax 0.8.5",
+ "reqwest 0.12.15",
"serde",
"serde_json",
- "time",
- "url",
- "uuid",
+ "uuid-simd",
]
[[package]]
@@ -2176,15 +2313,15 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8"
[[package]]
name = "libc"
-version = "0.2.164"
+version = "0.2.171"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f"
+checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6"
[[package]]
name = "libfuzzer-sys"
-version = "0.4.8"
+version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9b9569d2f74e257076d8c6bfa73fb505b46b851e51ddaecc825944aa3bed17fa"
+checksum = "cf78f52d400cf2d84a3a973a78a592b4adc535739e0a5597a0da6f0c357adc75"
dependencies = [
"arbitrary",
"cc",
@@ -2192,9 +2329,9 @@ dependencies = [
[[package]]
name = "libloading"
-version = "0.8.5"
+version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4"
+checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
dependencies = [
"cfg-if",
"windows-targets 0.52.6",
@@ -2212,30 +2349,36 @@ version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
"libc",
]
[[package]]
name = "link-cplusplus"
-version = "1.0.9"
+version = "1.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9d240c6f7e1ba3a28b0249f774e6a9dd0175054b52dfbb61b16eb8505c3785c9"
+checksum = "4a6f6da007f968f9def0d65a05b187e2960183de70c160204ecfccf0ee330212"
dependencies = [
"cc",
]
[[package]]
name = "linux-raw-sys"
-version = "0.4.14"
+version = "0.4.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89"
+checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
+
+[[package]]
+name = "linux-raw-sys"
+version = "0.9.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fe7db12097d22ec582439daf8618b8fdd1a7bef6270e9af3b1ebcd30893cf413"
[[package]]
name = "litemap"
-version = "0.7.3"
+version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "643cb0b8d4fcc284004d5fd0d67ccf61dfffadb7f75e1e71bc420f4688a3a704"
+checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856"
[[package]]
name = "lock_api"
@@ -2249,9 +2392,9 @@ dependencies = [
[[package]]
name = "log"
-version = "0.4.22"
+version = "0.4.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
+checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e"
[[package]]
name = "loop9"
@@ -2268,7 +2411,7 @@ version = "0.12.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38"
dependencies = [
- "hashbrown 0.15.1",
+ "hashbrown 0.15.2",
]
[[package]]
@@ -2351,15 +2494,15 @@ checksum = "b4f0c8427b39666bf970460908b213ec09b3b350f20c0c2eabcbba51704a08e6"
dependencies = [
"base64 0.22.1",
"http-body-util",
- "hyper 1.5.1",
+ "hyper 1.6.0",
"hyper-rustls",
"hyper-util",
- "indexmap 2.6.0",
+ "indexmap 2.8.0",
"ipnet",
"metrics",
"metrics-util",
"quanta",
- "thiserror",
+ "thiserror 1.0.69",
"tokio",
"tracing",
]
@@ -2397,9 +2540,9 @@ dependencies = [
[[package]]
name = "minijinja"
-version = "2.5.0"
+version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2c37e1b517d1dcd0e51dc36c4567b9d5a29262b3ec8da6cb5d35e27a8fb529b5"
+checksum = "6e36f1329330bb1614c94b78632b9ce45dd7d761f3304a1bed07b2990a7c5097"
dependencies = [
"serde",
"serde_json",
@@ -2407,9 +2550,9 @@ dependencies = [
[[package]]
name = "minijinja-contrib"
-version = "2.5.0"
+version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7fe51f1a6a8285f03fcd1544d834234fe8db285f29e1c2253600c93b3ae19242"
+checksum = "8e807b6b15e36a4c808e92f78c2ac1f6776519a50d9cf6649819c759a8e7133c"
dependencies = [
"minijinja",
"serde",
@@ -2423,9 +2566,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]]
name = "miniz_oxide"
-version = "0.8.0"
+version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1"
+checksum = "8e3e04debbb59698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5"
dependencies = [
"adler2",
"simd-adler32",
@@ -2433,28 +2576,21 @@ dependencies = [
[[package]]
name = "mio"
-version = "1.0.2"
+version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec"
+checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd"
dependencies = [
- "hermit-abi 0.3.9",
"libc",
"log",
- "wasi",
+ "wasi 0.11.0+wasi-snapshot-preview1",
"windows-sys 0.52.0",
]
-[[package]]
-name = "mirai-annotations"
-version = "1.12.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c9be0862c1b3f26a88803c4a49de6889c10e608b3ee9344e6ef5b45fb37ad3d1"
-
[[package]]
name = "monostate"
-version = "0.1.13"
+version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0d208407d7552cd041d8cdb69a1bc3303e029c598738177a3d87082004dc0e1e"
+checksum = "aafe1be9d0c75642e3e50fedc7ecadf1ef1cbce6eb66462153fc44245343fbee"
dependencies = [
"monostate-impl",
"serde",
@@ -2462,13 +2598,13 @@ dependencies = [
[[package]]
name = "monostate-impl"
-version = "0.1.13"
+version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0"
+checksum = "c402a4092d5e204f32c9e155431046831fa712637043c58cb73bc6bc6c9663b5"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -2489,8 +2625,8 @@ dependencies = [
"bytes",
"futures",
"pin-project",
- "rand",
- "thiserror",
+ "rand 0.8.5",
+ "thiserror 1.0.69",
"tokio",
"tokio-util",
"tracing",
@@ -2498,9 +2634,9 @@ dependencies = [
[[package]]
name = "native-tls"
-version = "0.2.12"
+version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466"
+checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e"
dependencies = [
"libc",
"log",
@@ -2534,15 +2670,15 @@ dependencies = [
"bytes",
"futures",
"hostname",
- "hyper 0.14.31",
+ "hyper 0.14.32",
"muxado",
"once_cell",
"parking_lot",
"regex",
- "rustls-pemfile",
+ "rustls-pemfile 1.0.4",
"serde",
"serde_json",
- "thiserror",
+ "thiserror 1.0.69",
"tokio",
"tokio-retry",
"tokio-util",
@@ -2556,7 +2692,7 @@ version = "0.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab2156c4fce2f8df6c499cc1c763e4394b7482525bf2a9701c9d79d215f519e4"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
"cfg-if",
"cfg_aliases 0.1.1",
"libc",
@@ -2568,7 +2704,7 @@ version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
"cfg-if",
"cfg_aliases 0.2.1",
"libc",
@@ -2668,7 +2804,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -2739,18 +2875,18 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
[[package]]
name = "object"
-version = "0.36.5"
+version = "0.36.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e"
+checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87"
dependencies = [
"memchr",
]
[[package]]
name = "once_cell"
-version = "1.20.2"
+version = "1.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775"
+checksum = "d75b0bedcc4fe52caa0e03d9f1151a323e4aa5e2d78ba3580400cd3c9e2bc4bc"
[[package]]
name = "onig"
@@ -2776,17 +2912,17 @@ dependencies = [
[[package]]
name = "oorandom"
-version = "11.1.4"
+version = "11.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9"
+checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
[[package]]
name = "openssl"
-version = "0.10.68"
+version = "0.10.71"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6174bc48f102d208783c2c84bf931bb75927a617866870de8a4ea85597f871f5"
+checksum = "5e14130c6a98cd258fdcb0fb6d744152343ff729cbfcb28c656a9d12b999fbcd"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
"cfg-if",
"foreign-types",
"libc",
@@ -2803,20 +2939,20 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
name = "openssl-probe"
-version = "0.1.5"
+version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
+checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
[[package]]
name = "openssl-sys"
-version = "0.9.104"
+version = "0.9.106"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "45abf306cbf99debc8195b66b7346498d7b10c210de50418b5ccd7ceba08c741"
+checksum = "8bb61ea9811cc39e3c2069f40b8b8e2e70d8569b361f879786cc7ed48b777cdd"
dependencies = [
"cc",
"libc",
@@ -2842,28 +2978,14 @@ checksum = "1e32339a5dc40459130b3bd269e9892439f55b33e772d2a9d402a789baaf4e8a"
dependencies = [
"futures-core",
"futures-sink",
- "indexmap 2.6.0",
+ "indexmap 2.8.0",
"js-sys",
"once_cell",
"pin-project-lite",
- "thiserror",
+ "thiserror 1.0.69",
"urlencoding",
]
-[[package]]
-name = "opentelemetry"
-version = "0.24.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4c365a63eec4f55b7efeceb724f1336f26a9cf3427b70e59e2cd2a5b947fba96"
-dependencies = [
- "futures-core",
- "futures-sink",
- "js-sys",
- "once_cell",
- "pin-project-lite",
- "thiserror",
-]
-
[[package]]
name = "opentelemetry-otlp"
version = "0.13.0"
@@ -2878,7 +3000,7 @@ dependencies = [
"opentelemetry_api",
"opentelemetry_sdk 0.20.0",
"prost 0.11.9",
- "thiserror",
+ "thiserror 1.0.69",
"tokio",
"tonic 0.9.2",
]
@@ -2916,7 +3038,7 @@ dependencies = [
"js-sys",
"once_cell",
"pin-project-lite",
- "thiserror",
+ "thiserror 1.0.69",
"urlencoding",
]
@@ -2935,10 +3057,10 @@ dependencies = [
"opentelemetry_api",
"ordered-float 3.9.2",
"percent-encoding",
- "rand",
+ "rand 0.8.5",
"regex",
"serde_json",
- "thiserror",
+ "thiserror 1.0.69",
"tokio",
"tokio-stream",
]
@@ -2957,28 +3079,10 @@ dependencies = [
"glob",
"once_cell",
"opentelemetry 0.21.0",
- "ordered-float 4.5.0",
+ "ordered-float 4.6.0",
"percent-encoding",
- "rand",
- "thiserror",
-]
-
-[[package]]
-name = "opentelemetry_sdk"
-version = "0.24.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "692eac490ec80f24a17828d49b40b60f5aeaccdfe6a503f939713afd22bc28df"
-dependencies = [
- "async-trait",
- "futures-channel",
- "futures-executor",
- "futures-util",
- "glob",
- "once_cell",
- "opentelemetry 0.24.0",
- "percent-encoding",
- "rand",
- "thiserror",
+ "rand 0.8.5",
+ "thiserror 1.0.69",
]
[[package]]
@@ -2998,9 +3102,9 @@ dependencies = [
[[package]]
name = "ordered-float"
-version = "4.5.0"
+version = "4.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c65ee1f9701bf938026630b455d5315f490640234259037edb259798b3bcf85e"
+checksum = "7bb71e1b3fa6ca1c61f383464aaf2bb0e2f8e772a1f01d486832464de363b951"
dependencies = [
"num-traits",
]
@@ -3016,6 +3120,12 @@ dependencies = [
"serde_json",
]
+[[package]]
+name = "outref"
+version = "0.5.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e"
+
[[package]]
name = "overload"
version = "0.1.1"
@@ -3075,34 +3185,34 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db"
dependencies = [
"fixedbitset",
- "indexmap 2.6.0",
+ "indexmap 2.8.0",
]
[[package]]
name = "pin-project"
-version = "1.1.7"
+version = "1.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "be57f64e946e500c8ee36ef6331845d40a93055567ec57e8fae13efd33759b95"
+checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
-version = "1.1.7"
+version = "1.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3c0f5fad0874fc7abcd4d750e76917eaebbecaa2c20bde22e1dbeeba8beb758c"
+checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
name = "pin-project-lite"
-version = "0.2.15"
+version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff"
+checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b"
[[package]]
name = "pin-utils"
@@ -3112,9 +3222,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "pkg-config"
-version = "0.3.31"
+version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2"
+checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
[[package]]
name = "plotters"
@@ -3146,9 +3256,9 @@ dependencies = [
[[package]]
name = "png"
-version = "0.17.14"
+version = "0.17.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "52f9d46a34a05a6a57566bc2bfae066ef07585a6e3fa30fbbdff5936380623f0"
+checksum = "82151a2fc869e011c153adc57cf2789ccb8d9906ce52c0b39a6b5697749d7526"
dependencies = [
"bitflags 1.3.2",
"crc32fast",
@@ -3159,9 +3269,9 @@ dependencies = [
[[package]]
name = "portable-atomic"
-version = "1.9.0"
+version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2"
+checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e"
[[package]]
name = "powerfmt"
@@ -3171,31 +3281,21 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391"
[[package]]
name = "ppv-lite86"
-version = "0.2.20"
+version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04"
+checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9"
dependencies = [
- "zerocopy",
-]
-
-[[package]]
-name = "pretty_assertions"
-version = "1.4.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d"
-dependencies = [
- "diff",
- "yansi",
+ "zerocopy 0.8.24",
]
[[package]]
name = "prettyplease"
-version = "0.2.25"
+version = "0.2.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "64d1ec885c64d0457d564db4ec299b2dae3f9c02808b8ad9c3a089c591b18033"
+checksum = "5316f57387668042f561aae71480de936257848f9c43ce528e311d89a07cadeb"
dependencies = [
"proc-macro2",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -3224,9 +3324,9 @@ dependencies = [
[[package]]
name = "proc-macro2"
-version = "1.0.92"
+version = "1.0.94"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0"
+checksum = "a31971752e70b8b2686d7e46ec17fb38dad4051d94024c88df49b667caea9c84"
dependencies = [
"unicode-ident",
]
@@ -3247,7 +3347,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a65f2e60fbf1063868558d69c6beacf412dc755f9fc020f514b7955fc914fe30"
dependencies = [
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -3287,7 +3387,7 @@ dependencies = [
"prost 0.12.6",
"prost-types",
"regex",
- "syn 2.0.89",
+ "syn 2.0.100",
"tempfile",
]
@@ -3314,7 +3414,7 @@ dependencies = [
"itertools 0.12.1",
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -3373,7 +3473,7 @@ dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -3386,7 +3486,7 @@ dependencies = [
"proc-macro2",
"pyo3-build-config",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -3400,15 +3500,15 @@ dependencies = [
[[package]]
name = "quanta"
-version = "0.12.3"
+version = "0.12.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5"
+checksum = "3bd1fe6824cea6538803de3ff1bc0cf3949024db3d43c9643024bfb33a807c0e"
dependencies = [
"crossbeam-utils",
"libc",
"once_cell",
"raw-cpuid",
- "wasi",
+ "wasi 0.11.0+wasi-snapshot-preview1",
"web-sys",
"winapi",
]
@@ -3421,13 +3521,19 @@ checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3"
[[package]]
name = "quote"
-version = "1.0.37"
+version = "1.0.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af"
+checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d"
dependencies = [
"proc-macro2",
]
+[[package]]
+name = "r-efi"
+version = "5.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5"
+
[[package]]
name = "rand"
version = "0.8.5"
@@ -3435,8 +3541,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
- "rand_chacha",
- "rand_core",
+ "rand_chacha 0.3.1",
+ "rand_core 0.6.4",
+]
+
+[[package]]
+name = "rand"
+version = "0.9.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94"
+dependencies = [
+ "rand_chacha 0.9.0",
+ "rand_core 0.9.3",
+ "zerocopy 0.8.24",
]
[[package]]
@@ -3446,7 +3563,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
- "rand_core",
+ "rand_core 0.6.4",
+]
+
+[[package]]
+name = "rand_chacha"
+version = "0.9.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
+dependencies = [
+ "ppv-lite86",
+ "rand_core 0.9.3",
]
[[package]]
@@ -3455,7 +3582,16 @@ version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
- "getrandom",
+ "getrandom 0.2.15",
+]
+
+[[package]]
+name = "rand_core"
+version = "0.9.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38"
+dependencies = [
+ "getrandom 0.3.2",
]
[[package]]
@@ -3464,7 +3600,7 @@ version = "0.28.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fdef7f9be5c0122f890d58bdf4d964349ba6a6161f705907526d891efabba57d"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
"cassowary",
"compact_str",
"crossterm",
@@ -3505,11 +3641,11 @@ dependencies = [
"once_cell",
"paste",
"profiling",
- "rand",
- "rand_chacha",
+ "rand 0.8.5",
+ "rand_chacha 0.3.1",
"simd_helpers",
"system-deps",
- "thiserror",
+ "thiserror 1.0.69",
"v_frame",
"wasm-bindgen",
]
@@ -3531,11 +3667,11 @@ dependencies = [
[[package]]
name = "raw-cpuid"
-version = "11.2.0"
+version = "11.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1ab240315c661615f2ee9f0f2cd32d5a7343a84d5ebcccb99d46e6637565e7b0"
+checksum = "c6df7ab838ed27997ba19a4664507e6f82b41fe6e20be42929332156e5e85146"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
]
[[package]]
@@ -3571,11 +3707,11 @@ dependencies = [
[[package]]
name = "redox_syscall"
-version = "0.5.7"
+version = "0.5.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f"
+checksum = "0b8c0c260b63a8219631167be35e6a988e9554dbd323f8bd08439c8ed1302bd1"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
]
[[package]]
@@ -3584,9 +3720,42 @@ version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43"
dependencies = [
- "getrandom",
+ "getrandom 0.2.15",
"libredox",
- "thiserror",
+ "thiserror 1.0.69",
+]
+
+[[package]]
+name = "ref-cast"
+version = "1.0.24"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4a0ae411dbe946a674d89546582cea4ba2bb8defac896622d6496f14c23ba5cf"
+dependencies = [
+ "ref-cast-impl",
+]
+
+[[package]]
+name = "ref-cast-impl"
+version = "1.0.24"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.100",
+]
+
+[[package]]
+name = "referencing"
+version = "0.28.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d0dcb5ab28989ad7c91eb1b9531a37a1a137cc69a0499aee4117cae4a107c464"
+dependencies = [
+ "ahash",
+ "fluent-uri",
+ "once_cell",
+ "percent-encoding",
+ "serde_json",
]
[[package]]
@@ -3647,8 +3816,8 @@ dependencies = [
"h2 0.3.26",
"http 0.2.12",
"http-body 0.4.6",
- "hyper 0.14.31",
- "hyper-tls",
+ "hyper 0.14.32",
+ "hyper-tls 0.5.0",
"ipnet",
"js-sys",
"log",
@@ -3657,12 +3826,12 @@ dependencies = [
"once_cell",
"percent-encoding",
"pin-project-lite",
- "rustls-pemfile",
+ "rustls-pemfile 1.0.4",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper 0.1.2",
- "system-configuration",
+ "system-configuration 0.5.1",
"tokio",
"tokio-native-tls",
"tower-service",
@@ -3673,6 +3842,53 @@ dependencies = [
"winreg",
]
+[[package]]
+name = "reqwest"
+version = "0.12.15"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d19c46a6fdd48bc4dab94b6103fccc55d34c67cc0ad04653aad4ea2a07cd7bbb"
+dependencies = [
+ "base64 0.22.1",
+ "bytes",
+ "encoding_rs",
+ "futures-channel",
+ "futures-core",
+ "futures-util",
+ "h2 0.4.8",
+ "http 1.3.1",
+ "http-body 1.0.1",
+ "http-body-util",
+ "hyper 1.6.0",
+ "hyper-rustls",
+ "hyper-tls 0.6.0",
+ "hyper-util",
+ "ipnet",
+ "js-sys",
+ "log",
+ "mime",
+ "native-tls",
+ "once_cell",
+ "percent-encoding",
+ "pin-project-lite",
+ "rustls-pemfile 2.2.0",
+ "serde",
+ "serde_json",
+ "serde_urlencoded",
+ "sync_wrapper 1.0.2",
+ "system-configuration 0.6.1",
+ "tokio",
+ "tokio-native-tls",
+ "tokio-util",
+ "tower 0.5.2",
+ "tower-service",
+ "url",
+ "wasm-bindgen",
+ "wasm-bindgen-futures",
+ "wasm-streams",
+ "web-sys",
+ "windows-registry",
+]
+
[[package]]
name = "rgb"
version = "0.8.50"
@@ -3688,7 +3904,7 @@ dependencies = [
"cc",
"libc",
"once_cell",
- "spin 0.5.2",
+ "spin",
"untrusted 0.7.1",
"web-sys",
"winapi",
@@ -3696,24 +3912,23 @@ dependencies = [
[[package]]
name = "ring"
-version = "0.17.8"
+version = "0.17.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d"
+checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7"
dependencies = [
"cc",
"cfg-if",
- "getrandom",
+ "getrandom 0.2.15",
"libc",
- "spin 0.9.8",
"untrusted 0.9.0",
"windows-sys 0.52.0",
]
[[package]]
name = "rust-embed"
-version = "8.5.0"
+version = "8.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "fa66af4a4fdd5e7ebc276f115e895611a34739a9c1c01028383d612d550953c0"
+checksum = "0b3aba5104622db5c9fc61098de54708feb732e7763d7faa2fa625899f00bf6f"
dependencies = [
"rust-embed-impl",
"rust-embed-utils",
@@ -3722,22 +3937,22 @@ dependencies = [
[[package]]
name = "rust-embed-impl"
-version = "8.5.0"
+version = "8.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6125dbc8867951125eec87294137f4e9c2c96566e61bf72c45095a7c77761478"
+checksum = "1f198c73be048d2c5aa8e12f7960ad08443e56fd39cc26336719fdb4ea0ebaae"
dependencies = [
"proc-macro2",
"quote",
"rust-embed-utils",
- "syn 2.0.89",
+ "syn 2.0.100",
"walkdir",
]
[[package]]
name = "rust-embed-utils"
-version = "8.5.0"
+version = "8.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2e5347777e9aacb56039b0e1f28785929a8a3b709e87482e7442c72e7c12529d"
+checksum = "5a2fcdc9f40c8dc2922842ca9add611ad19f332227fc651d015881ad1552bd9a"
dependencies = [
"sha2",
"walkdir",
@@ -3755,6 +3970,12 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
+[[package]]
+name = "rustc-hash"
+version = "2.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
+
[[package]]
name = "rustc_version"
version = "0.4.1"
@@ -3766,15 +3987,28 @@ dependencies = [
[[package]]
name = "rustix"
-version = "0.38.41"
+version = "0.38.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d7f649912bc1495e167a6edee79151c84b1bad49748cb4f1f1167f459f6224f6"
+checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
"errno",
"libc",
- "linux-raw-sys",
- "windows-sys 0.52.0",
+ "linux-raw-sys 0.4.15",
+ "windows-sys 0.59.0",
+]
+
+[[package]]
+name = "rustix"
+version = "1.0.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e56a18552996ac8d29ecc3b190b4fdbb2d91ca4ec396de7bbffaf43f3d637e96"
+dependencies = [
+ "bitflags 2.9.0",
+ "errno",
+ "libc",
+ "linux-raw-sys 0.9.3",
+ "windows-sys 0.59.0",
]
[[package]]
@@ -3796,24 +4030,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432"
dependencies = [
"log",
- "ring 0.17.8",
+ "ring 0.17.14",
"rustls-pki-types",
- "rustls-webpki",
+ "rustls-webpki 0.102.8",
"subtle",
"zeroize",
]
[[package]]
name = "rustls"
-version = "0.23.17"
+version = "0.23.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7f1a745511c54ba6d4465e8d5dfbd81b45791756de28d4981af70d6dca128f1e"
+checksum = "822ee9188ac4ec04a2f0531e55d035fb2de73f18b41a63c70c2712503b6fb13c"
dependencies = [
"aws-lc-rs",
"log",
"once_cell",
"rustls-pki-types",
- "rustls-webpki",
+ "rustls-webpki 0.103.0",
"subtle",
"zeroize",
]
@@ -3827,7 +4061,7 @@ dependencies = [
"openssl-probe",
"rustls-pki-types",
"schannel",
- "security-framework 3.0.1",
+ "security-framework 3.2.0",
]
[[package]]
@@ -3840,34 +4074,54 @@ dependencies = [
]
[[package]]
-name = "rustls-pki-types"
-version = "1.10.0"
+name = "rustls-pemfile"
+version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b"
+checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50"
+dependencies = [
+ "rustls-pki-types",
+]
+
+[[package]]
+name = "rustls-pki-types"
+version = "1.11.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c"
[[package]]
name = "rustls-webpki"
version = "0.102.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9"
+dependencies = [
+ "ring 0.17.14",
+ "rustls-pki-types",
+ "untrusted 0.9.0",
+]
+
+[[package]]
+name = "rustls-webpki"
+version = "0.103.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0aa4eeac2588ffff23e9d7a7e9b3f971c5fb5b7ebc9452745e0c232c64f83b2f"
dependencies = [
"aws-lc-rs",
- "ring 0.17.8",
+ "ring 0.17.14",
"rustls-pki-types",
"untrusted 0.9.0",
]
[[package]]
name = "rustversion"
-version = "1.0.18"
+version = "1.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248"
+checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2"
[[package]]
name = "ryu"
-version = "1.0.18"
+version = "1.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
+checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
[[package]]
name = "same-file"
@@ -3895,9 +4149,9 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "scratch"
-version = "1.0.7"
+version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a3cf7c11c38cb994f3d40e8a8cde3bbd1f72a435e4c49e85d6553d8312306152"
+checksum = "9f6280af86e5f559536da57a45ebc84948833b3bee313a7dd25232e09c878a52"
[[package]]
name = "sct"
@@ -3905,7 +4159,7 @@ version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414"
dependencies = [
- "ring 0.17.8",
+ "ring 0.17.14",
"untrusted 0.9.0",
]
@@ -3915,7 +4169,7 @@ version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
"core-foundation 0.9.4",
"core-foundation-sys",
"libc",
@@ -3924,11 +4178,11 @@ dependencies = [
[[package]]
name = "security-framework"
-version = "3.0.1"
+version = "3.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e1415a607e92bec364ea2cf9264646dcce0f91e6d65281bd6f2819cca3bf39c8"
+checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
"core-foundation 0.10.0",
"core-foundation-sys",
"libc",
@@ -3937,9 +4191,9 @@ dependencies = [
[[package]]
name = "security-framework-sys"
-version = "2.12.1"
+version = "2.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "fa39c7303dc58b5543c94d22c1766b0d31f2ee58306363ea622b10bbc075eaa2"
+checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32"
dependencies = [
"core-foundation-sys",
"libc",
@@ -3947,18 +4201,18 @@ dependencies = [
[[package]]
name = "semver"
-version = "1.0.23"
+version = "1.0.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b"
+checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0"
dependencies = [
"serde",
]
[[package]]
name = "serde"
-version = "1.0.215"
+version = "1.0.219"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f"
+checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6"
dependencies = [
"serde_derive",
]
@@ -3985,22 +4239,22 @@ dependencies = [
[[package]]
name = "serde_derive"
-version = "1.0.215"
+version = "1.0.219"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0"
+checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
name = "serde_json"
-version = "1.0.133"
+version = "1.0.140"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377"
+checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373"
dependencies = [
- "indexmap 2.6.0",
+ "indexmap 2.8.0",
"itoa",
"memchr",
"ryu",
@@ -4009,9 +4263,9 @@ dependencies = [
[[package]]
name = "serde_path_to_error"
-version = "0.1.16"
+version = "0.1.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "af99884400da37c88f5e9146b7f1fd0fbcae8f6eec4e9da38b67d05486f814a6"
+checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a"
dependencies = [
"itoa",
"serde",
@@ -4135,32 +4389,37 @@ dependencies = [
[[package]]
name = "smallvec"
-version = "1.13.2"
+version = "1.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
+checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd"
[[package]]
name = "socket2"
-version = "0.5.7"
+version = "0.5.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c"
+checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8"
dependencies = [
"libc",
"windows-sys 0.52.0",
]
+[[package]]
+name = "socks"
+version = "0.3.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b"
+dependencies = [
+ "byteorder",
+ "libc",
+ "winapi",
+]
+
[[package]]
name = "spin"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
-[[package]]
-name = "spin"
-version = "0.9.8"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
-
[[package]]
name = "spm_precompiled"
version = "0.1.4"
@@ -4210,7 +4469,7 @@ dependencies = [
"proc-macro2",
"quote",
"rustversion",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -4232,9 +4491,9 @@ dependencies = [
[[package]]
name = "syn"
-version = "2.0.89"
+version = "2.0.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "44d46482f1c1c87acd84dea20c1bf5ebff4c757009ed6bf19cfd36fb10e92c4e"
+checksum = "b09a44accad81e1ba1cd74a32461ba89dee89095ba17b32f5d03683b1b1fc2a0"
dependencies = [
"proc-macro2",
"quote",
@@ -4252,6 +4511,9 @@ name = "sync_wrapper"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263"
+dependencies = [
+ "futures-core",
+]
[[package]]
name = "synstructure"
@@ -4261,7 +4523,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -4287,7 +4549,18 @@ checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7"
dependencies = [
"bitflags 1.3.2",
"core-foundation 0.9.4",
- "system-configuration-sys",
+ "system-configuration-sys 0.5.0",
+]
+
+[[package]]
+name = "system-configuration"
+version = "0.6.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b"
+dependencies = [
+ "bitflags 2.9.0",
+ "core-foundation 0.9.4",
+ "system-configuration-sys 0.6.0",
]
[[package]]
@@ -4300,6 +4573,16 @@ dependencies = [
"libc",
]
+[[package]]
+name = "system-configuration-sys"
+version = "0.6.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4"
+dependencies = [
+ "core-foundation-sys",
+ "libc",
+]
+
[[package]]
name = "system-deps"
version = "6.2.2"
@@ -4345,14 +4628,14 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
[[package]]
name = "tempfile"
-version = "3.14.0"
+version = "3.19.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c"
+checksum = "7437ac7763b9b123ccf33c338a5cc1bac6f69b45a136c19bdd8a65e3916435bf"
dependencies = [
- "cfg-if",
"fastrand",
+ "getrandom 0.3.2",
"once_cell",
- "rustix",
+ "rustix 1.0.3",
"windows-sys 0.59.0",
]
@@ -4367,42 +4650,39 @@ dependencies = [
[[package]]
name = "text-generation-backends-trtllm"
-version = "3.0.1-dev0"
+version = "3.2.3-dev0"
dependencies = [
- "async-stream",
"async-trait",
- "clap 4.5.21",
+ "clap 4.5.32",
"cmake",
"cxx",
"cxx-build",
- "hashbrown 0.14.5",
- "hf-hub",
- "log",
+ "hashbrown 0.15.2",
+ "hf-hub 0.4.2",
"pkg-config",
+ "pyo3",
"text-generation-router",
- "thiserror",
+ "thiserror 1.0.69",
"tokenizers",
"tokio",
"tokio-stream",
"tracing",
- "tracing-opentelemetry 0.25.0",
- "tracing-subscriber",
]
[[package]]
name = "text-generation-benchmark"
-version = "3.0.1-dev0"
+version = "3.2.3-dev0"
dependencies = [
"average",
- "clap 4.5.21",
+ "clap 4.5.32",
"float-ord",
- "hf-hub",
+ "hf-hub 0.4.2",
"ratatui",
"serde",
"serde_json",
"tabled",
"text-generation-client",
- "thiserror",
+ "thiserror 1.0.69",
"tokenizers",
"tokio",
"tracing",
@@ -4411,7 +4691,7 @@ dependencies = [
[[package]]
name = "text-generation-client"
-version = "3.0.1-dev0"
+version = "3.2.3-dev0"
dependencies = [
"async-trait",
"base64 0.22.1",
@@ -4419,7 +4699,7 @@ dependencies = [
"grpc-metadata",
"prost 0.12.6",
"prost-build",
- "thiserror",
+ "thiserror 1.0.69",
"tokio",
"tonic 0.10.2",
"tonic-build",
@@ -4429,20 +4709,20 @@ dependencies = [
[[package]]
name = "text-generation-launcher"
-version = "3.0.1-dev0"
+version = "3.2.3-dev0"
dependencies = [
- "clap 4.5.21",
+ "clap 4.5.32",
"ctrlc",
"float_eq",
- "hf-hub",
+ "hf-hub 0.4.2",
"nix 0.28.0",
"once_cell",
"pyo3",
"regex",
- "reqwest",
+ "reqwest 0.11.27",
"serde",
"serde_json",
- "thiserror",
+ "thiserror 1.0.69",
"tracing",
"tracing-subscriber",
"vergen",
@@ -4450,7 +4730,7 @@ dependencies = [
[[package]]
name = "text-generation-router"
-version = "3.0.1-dev0"
+version = "3.2.3-dev0"
dependencies = [
"anyhow",
"async-stream",
@@ -4458,11 +4738,12 @@ dependencies = [
"axum 0.7.9",
"axum-tracing-opentelemetry",
"base64 0.22.1",
- "clap 4.5.21",
+ "chrono",
+ "clap 4.5.32",
"csv",
"futures",
"futures-util",
- "hf-hub",
+ "hf-hub 0.4.2",
"image",
"init-tracing-opentelemetry",
"itertools 0.10.5",
@@ -4478,13 +4759,13 @@ dependencies = [
"opentelemetry-otlp",
"outlines-core",
"pyo3",
- "rand",
+ "rand 0.8.5",
"regex",
- "reqwest",
+ "reqwest 0.11.27",
"serde",
"serde_json",
"sysinfo",
- "thiserror",
+ "thiserror 1.0.69",
"tokenizers",
"tokio",
"tokio-stream",
@@ -4499,20 +4780,38 @@ dependencies = [
"vergen",
]
+[[package]]
+name = "text-generation-router-llamacpp"
+version = "3.2.3-dev0"
+dependencies = [
+ "async-trait",
+ "bindgen 0.71.1",
+ "clap 4.5.32",
+ "hf-hub 0.4.2",
+ "num_cpus",
+ "pkg-config",
+ "text-generation-router",
+ "thiserror 2.0.12",
+ "tokenizers",
+ "tokio",
+ "tokio-stream",
+ "tracing",
+]
+
[[package]]
name = "text-generation-router-v2"
-version = "3.0.1-dev0"
+version = "3.2.3-dev0"
dependencies = [
"async-stream",
"async-trait",
"axum 0.7.9",
"axum-tracing-opentelemetry",
"base64 0.22.1",
- "clap 4.5.21",
+ "clap 4.5.32",
"futures",
"futures-util",
"grpc-metadata",
- "hf-hub",
+ "hf-hub 0.4.2",
"image",
"init-tracing-opentelemetry",
"jsonschema",
@@ -4526,14 +4825,14 @@ dependencies = [
"opentelemetry-otlp",
"prost 0.12.6",
"prost-build",
- "rand",
+ "rand 0.8.5",
"regex",
- "reqwest",
+ "reqwest 0.11.27",
"serde",
"serde_json",
"slotmap",
"text-generation-router",
- "thiserror",
+ "thiserror 1.0.69",
"tokenizers",
"tokio",
"tokio-stream",
@@ -4550,19 +4849,19 @@ dependencies = [
[[package]]
name = "text-generation-router-v3"
-version = "3.0.1-dev0"
+version = "3.2.3-dev0"
dependencies = [
"async-stream",
"async-trait",
"axum 0.7.9",
"axum-tracing-opentelemetry",
"base64 0.22.1",
- "clap 4.5.21",
+ "clap 4.5.32",
"criterion",
"futures",
"futures-util",
"grpc-metadata",
- "hf-hub",
+ "hf-hub 0.4.2",
"image",
"init-tracing-opentelemetry",
"itertools 0.13.0",
@@ -4577,14 +4876,15 @@ dependencies = [
"opentelemetry-otlp",
"prost 0.12.6",
"prost-build",
- "rand",
+ "rand 0.8.5",
"regex",
- "reqwest",
+ "reqwest 0.11.27",
+ "rustc-hash 2.1.1",
"serde",
"serde_json",
"slotmap",
"text-generation-router",
- "thiserror",
+ "thiserror 1.0.69",
"tokenizers",
"tokio",
"tokio-stream",
@@ -4614,7 +4914,16 @@ version = "1.0.69"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52"
dependencies = [
- "thiserror-impl",
+ "thiserror-impl 1.0.69",
+]
+
+[[package]]
+name = "thiserror"
+version = "2.0.12"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708"
+dependencies = [
+ "thiserror-impl 2.0.12",
]
[[package]]
@@ -4625,7 +4934,18 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
+]
+
+[[package]]
+name = "thiserror-impl"
+version = "2.0.12"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.100",
]
[[package]]
@@ -4651,9 +4971,9 @@ dependencies = [
[[package]]
name = "time"
-version = "0.3.36"
+version = "0.3.41"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885"
+checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40"
dependencies = [
"deranged",
"itoa",
@@ -4668,15 +4988,15 @@ dependencies = [
[[package]]
name = "time-core"
-version = "0.1.2"
+version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3"
+checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c"
[[package]]
name = "time-macros"
-version = "0.2.18"
+version = "0.2.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf"
+checksum = "3526739392ec93fd8b359c8e98514cb3e8e021beb4e5f597b00a0221f8ed8a49"
dependencies = [
"num-conv",
"time-core",
@@ -4704,15 +5024,15 @@ dependencies = [
[[package]]
name = "tokenizers"
-version = "0.20.3"
+version = "0.20.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "67b67c92f6d705e2a1d106fb0b28c696f9074901a9c656ee5d9f5de204c39bf7"
+checksum = "3b08cc37428a476fc9e20ac850132a513a2e1ce32b6a31addf2b74fa7033b905"
dependencies = [
"aho-corasick",
"derive_builder",
"esaxx-rs",
- "getrandom",
- "hf-hub",
+ "getrandom 0.2.15",
+ "hf-hub 0.3.2",
"indicatif",
"itertools 0.12.1",
"lazy_static",
@@ -4721,7 +5041,7 @@ dependencies = [
"monostate",
"onig",
"paste",
- "rand",
+ "rand 0.8.5",
"rayon",
"rayon-cond",
"regex",
@@ -4729,7 +5049,7 @@ dependencies = [
"serde",
"serde_json",
"spm_precompiled",
- "thiserror",
+ "thiserror 1.0.69",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",
@@ -4737,9 +5057,9 @@ dependencies = [
[[package]]
name = "tokio"
-version = "1.41.1"
+version = "1.44.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33"
+checksum = "f382da615b842244d4b8738c82ed1275e6c5dd90c459a30941cd07080b06c91a"
dependencies = [
"backtrace",
"bytes",
@@ -4765,13 +5085,13 @@ dependencies = [
[[package]]
name = "tokio-macros"
-version = "2.4.0"
+version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752"
+checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -4791,26 +5111,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f57eb36ecbe0fc510036adff84824dd3c24bb781e21bfa67b69d556aa85214f"
dependencies = [
"pin-project",
- "rand",
+ "rand 0.8.5",
"tokio",
]
[[package]]
name = "tokio-rustls"
-version = "0.26.0"
+version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4"
+checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b"
dependencies = [
- "rustls 0.23.17",
- "rustls-pki-types",
+ "rustls 0.23.25",
"tokio",
]
[[package]]
name = "tokio-stream"
-version = "0.1.16"
+version = "0.1.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4f4e6ce100d0eb49a2734f8c0812bcd324cf357d21810932c5df6b96ef2b86f1"
+checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047"
dependencies = [
"futures-core",
"pin-project-lite",
@@ -4819,9 +5138,9 @@ dependencies = [
[[package]]
name = "tokio-util"
-version = "0.7.12"
+version = "0.7.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a"
+checksum = "6b9590b93e6fcc1739458317cccd391ad3955e2bde8913edf6f95f9e65a8f034"
dependencies = [
"bytes",
"futures-core",
@@ -4833,9 +5152,9 @@ dependencies = [
[[package]]
name = "toml"
-version = "0.8.19"
+version = "0.8.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e"
+checksum = "cd87a5cdd6ffab733b2f74bc4fd7ee5fff6634124999ac278c35fc78c6120148"
dependencies = [
"serde",
"serde_spanned",
@@ -4854,11 +5173,11 @@ dependencies = [
[[package]]
name = "toml_edit"
-version = "0.22.22"
+version = "0.22.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5"
+checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474"
dependencies = [
- "indexmap 2.6.0",
+ "indexmap 2.8.0",
"serde",
"serde_spanned",
"toml_datetime",
@@ -4880,7 +5199,7 @@ dependencies = [
"h2 0.3.26",
"http 0.2.12",
"http-body 0.4.6",
- "hyper 0.14.31",
+ "hyper 0.14.32",
"hyper-timeout",
"percent-encoding",
"pin-project",
@@ -4907,7 +5226,7 @@ dependencies = [
"h2 0.3.26",
"http 0.2.12",
"http-body 0.4.6",
- "hyper 0.14.31",
+ "hyper 0.14.32",
"hyper-timeout",
"percent-encoding",
"pin-project",
@@ -4930,7 +5249,7 @@ dependencies = [
"proc-macro2",
"prost-build",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -4944,7 +5263,7 @@ dependencies = [
"indexmap 1.9.3",
"pin-project",
"pin-project-lite",
- "rand",
+ "rand 0.8.5",
"slab",
"tokio",
"tokio-util",
@@ -4955,14 +5274,14 @@ dependencies = [
[[package]]
name = "tower"
-version = "0.5.1"
+version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2873938d487c3cfb9aed7546dc9f2711d867c9f90c46b889989a2cb84eba6b4f"
+checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9"
dependencies = [
"futures-core",
"futures-util",
"pin-project-lite",
- "sync_wrapper 0.1.2",
+ "sync_wrapper 1.0.2",
"tokio",
"tower-layer",
"tower-service",
@@ -4975,9 +5294,9 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5"
dependencies = [
- "bitflags 2.6.0",
+ "bitflags 2.9.0",
"bytes",
- "http 1.1.0",
+ "http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"pin-project-lite",
@@ -4999,9 +5318,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
[[package]]
name = "tracing"
-version = "0.1.40"
+version = "0.1.41"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
+checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0"
dependencies = [
"log",
"pin-project-lite",
@@ -5011,20 +5330,20 @@ dependencies = [
[[package]]
name = "tracing-attributes"
-version = "0.1.27"
+version = "0.1.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
+checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
name = "tracing-core"
-version = "0.1.32"
+version = "0.1.33"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54"
+checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c"
dependencies = [
"once_cell",
"valuable",
@@ -5086,31 +5405,13 @@ dependencies = [
"web-time 0.2.4",
]
-[[package]]
-name = "tracing-opentelemetry"
-version = "0.25.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a9784ed4da7d921bc8df6963f8c80a0e4ce34ba6ba76668acadd3edbd985ff3b"
-dependencies = [
- "js-sys",
- "once_cell",
- "opentelemetry 0.24.0",
- "opentelemetry_sdk 0.24.1",
- "smallvec",
- "tracing",
- "tracing-core",
- "tracing-log 0.2.0",
- "tracing-subscriber",
- "web-time 1.1.0",
-]
-
[[package]]
name = "tracing-opentelemetry-instrumentation-sdk"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9920abb6a3ee3a2af7d30c9ff02900f8481935d36723c3da95cf807468218e8c"
dependencies = [
- "http 1.1.0",
+ "http 1.3.1",
"opentelemetry 0.21.0",
"tracing",
"tracing-opentelemetry 0.22.0",
@@ -5118,9 +5419,9 @@ dependencies = [
[[package]]
name = "tracing-serde"
-version = "0.1.3"
+version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "bc6b213177105856957181934e4920de57730fc69bf42c37ee5bb664d406d9e1"
+checksum = "704b1aeb7be0d0a84fc9828cae51dab5970fee5088f83d1dd7ee6f6246fc6ff1"
dependencies = [
"serde",
"tracing-core",
@@ -5128,9 +5429,9 @@ dependencies = [
[[package]]
name = "tracing-subscriber"
-version = "0.3.18"
+version = "0.3.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b"
+checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008"
dependencies = [
"matchers",
"nu-ansi-term",
@@ -5155,21 +5456,21 @@ checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]]
name = "typenum"
-version = "1.17.0"
+version = "1.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
+checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f"
[[package]]
name = "unicase"
-version = "2.8.0"
+version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7e51b68083f157f853b6379db119d1c1be0e6e4dec98101079dec41f6f5cf6df"
+checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
[[package]]
name = "unicode-ident"
-version = "1.0.14"
+version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83"
+checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
[[package]]
name = "unicode-normalization-alignments"
@@ -5217,9 +5518,9 @@ checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "unindent"
-version = "0.2.3"
+version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
+checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3"
[[package]]
name = "untrusted"
@@ -5246,18 +5547,19 @@ dependencies = [
"once_cell",
"rustls 0.22.4",
"rustls-pki-types",
- "rustls-webpki",
+ "rustls-webpki 0.102.8",
"serde",
"serde_json",
+ "socks",
"url",
"webpki-roots",
]
[[package]]
name = "url"
-version = "2.5.3"
+version = "2.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8d157f1b96d14500ffdc1f10ba712e780825526c03d9a49b4d0324b0d9113ada"
+checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60"
dependencies = [
"form_urlencoded",
"idna",
@@ -5294,7 +5596,7 @@ version = "4.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c5afb1a60e207dca502682537fefcfd9921e71d0b83e9576060f09abc6efab23"
dependencies = [
- "indexmap 2.6.0",
+ "indexmap 2.8.0",
"serde",
"serde_json",
"utoipa-gen",
@@ -5310,7 +5612,7 @@ dependencies = [
"proc-macro2",
"quote",
"regex",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -5331,24 +5633,35 @@ dependencies = [
[[package]]
name = "uuid"
-version = "1.11.0"
+version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a"
+checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9"
dependencies = [
- "getrandom",
- "rand",
+ "getrandom 0.3.2",
+ "rand 0.9.0",
"uuid-macro-internal",
]
[[package]]
name = "uuid-macro-internal"
-version = "1.11.0"
+version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6b91f57fe13a38d0ce9e28a03463d8d3c2468ed03d75375110ec71d93b449a08"
+checksum = "72dcd78c4f979627a754f5522cea6e6a25e55139056535fe6e69c506cd64a862"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
+]
+
+[[package]]
+name = "uuid-simd"
+version = "0.8.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "23b082222b4f6619906941c17eb2297fff4c2fb96cb60164170522942a200bd8"
+dependencies = [
+ "outref",
+ "uuid",
+ "vsimd",
]
[[package]]
@@ -5364,9 +5677,9 @@ dependencies = [
[[package]]
name = "valuable"
-version = "0.1.0"
+version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d"
+checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
[[package]]
name = "vcpkg"
@@ -5402,6 +5715,12 @@ version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
+[[package]]
+name = "vsimd"
+version = "0.8.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64"
+
[[package]]
name = "walkdir"
version = "2.5.0"
@@ -5428,48 +5747,58 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
-name = "wasm-bindgen"
-version = "0.2.95"
+name = "wasi"
+version = "0.14.2+wasi-0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e"
+checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3"
+dependencies = [
+ "wit-bindgen-rt",
+]
+
+[[package]]
+name = "wasm-bindgen"
+version = "0.2.100"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5"
dependencies = [
"cfg-if",
"once_cell",
+ "rustversion",
"wasm-bindgen-macro",
]
[[package]]
name = "wasm-bindgen-backend"
-version = "0.2.95"
+version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358"
+checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6"
dependencies = [
"bumpalo",
"log",
- "once_cell",
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-futures"
-version = "0.4.45"
+version = "0.4.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cc7ec4f8827a71586374db3e87abdb5a2bb3a15afed140221307c3ec06b1f63b"
+checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61"
dependencies = [
"cfg-if",
"js-sys",
+ "once_cell",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "wasm-bindgen-macro"
-version = "0.2.95"
+version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56"
+checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407"
dependencies = [
"quote",
"wasm-bindgen-macro-support",
@@ -5477,28 +5806,44 @@ dependencies = [
[[package]]
name = "wasm-bindgen-macro-support"
-version = "0.2.95"
+version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68"
+checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-shared"
-version = "0.2.95"
+version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d"
+checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d"
+dependencies = [
+ "unicode-ident",
+]
+
+[[package]]
+name = "wasm-streams"
+version = "0.4.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
+dependencies = [
+ "futures-util",
+ "js-sys",
+ "wasm-bindgen",
+ "wasm-bindgen-futures",
+ "web-sys",
+]
[[package]]
name = "web-sys"
-version = "0.3.72"
+version = "0.3.77"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112"
+checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2"
dependencies = [
"js-sys",
"wasm-bindgen",
@@ -5530,15 +5875,15 @@ version = "0.22.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed63aea5ce73d0ff405984102c42de94fc55a6b75765d621c65262469b3c9b53"
dependencies = [
- "ring 0.17.8",
+ "ring 0.17.14",
"untrusted 0.9.0",
]
[[package]]
name = "webpki-roots"
-version = "0.26.7"
+version = "0.26.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5d642ff16b7e79272ae451b7322067cdc17cadf68c23264be9d94a32319efe7e"
+checksum = "2210b291f7ea53617fbafcc4939f10914214ec15aace5ba62293a668f322c5c9"
dependencies = [
"rustls-pki-types",
]
@@ -5558,7 +5903,7 @@ dependencies = [
"either",
"home",
"once_cell",
- "rustix",
+ "rustix 0.38.44",
]
[[package]]
@@ -5611,6 +5956,41 @@ dependencies = [
"windows-targets 0.52.6",
]
+[[package]]
+name = "windows-link"
+version = "0.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38"
+
+[[package]]
+name = "windows-registry"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4286ad90ddb45071efd1a66dfa43eb02dd0dfbae1545ad6cc3c51cf34d7e8ba3"
+dependencies = [
+ "windows-result",
+ "windows-strings",
+ "windows-targets 0.53.0",
+]
+
+[[package]]
+name = "windows-result"
+version = "0.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c64fd11a4fd95df68efcfee5f44a294fe71b8bc6a91993e2791938abcc712252"
+dependencies = [
+ "windows-link",
+]
+
+[[package]]
+name = "windows-strings"
+version = "0.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "87fa48cc5d406560701792be122a10132491cff9d0aeb23583cc2dcafc847319"
+dependencies = [
+ "windows-link",
+]
+
[[package]]
name = "windows-sys"
version = "0.45.0"
@@ -5686,13 +6066,29 @@ dependencies = [
"windows_aarch64_gnullvm 0.52.6",
"windows_aarch64_msvc 0.52.6",
"windows_i686_gnu 0.52.6",
- "windows_i686_gnullvm",
+ "windows_i686_gnullvm 0.52.6",
"windows_i686_msvc 0.52.6",
"windows_x86_64_gnu 0.52.6",
"windows_x86_64_gnullvm 0.52.6",
"windows_x86_64_msvc 0.52.6",
]
+[[package]]
+name = "windows-targets"
+version = "0.53.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b1e4c7e8ceaaf9cb7d7507c974735728ab453b67ef8f18febdd7c11fe59dca8b"
+dependencies = [
+ "windows_aarch64_gnullvm 0.53.0",
+ "windows_aarch64_msvc 0.53.0",
+ "windows_i686_gnu 0.53.0",
+ "windows_i686_gnullvm 0.53.0",
+ "windows_i686_msvc 0.53.0",
+ "windows_x86_64_gnu 0.53.0",
+ "windows_x86_64_gnullvm 0.53.0",
+ "windows_x86_64_msvc 0.53.0",
+]
+
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.42.2"
@@ -5711,6 +6107,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
+[[package]]
+name = "windows_aarch64_gnullvm"
+version = "0.53.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764"
+
[[package]]
name = "windows_aarch64_msvc"
version = "0.42.2"
@@ -5729,6 +6131,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
+[[package]]
+name = "windows_aarch64_msvc"
+version = "0.53.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c"
+
[[package]]
name = "windows_i686_gnu"
version = "0.42.2"
@@ -5747,12 +6155,24 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
+[[package]]
+name = "windows_i686_gnu"
+version = "0.53.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3"
+
[[package]]
name = "windows_i686_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
+[[package]]
+name = "windows_i686_gnullvm"
+version = "0.53.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11"
+
[[package]]
name = "windows_i686_msvc"
version = "0.42.2"
@@ -5771,6 +6191,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
+[[package]]
+name = "windows_i686_msvc"
+version = "0.53.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d"
+
[[package]]
name = "windows_x86_64_gnu"
version = "0.42.2"
@@ -5789,6 +6215,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
+[[package]]
+name = "windows_x86_64_gnu"
+version = "0.53.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba"
+
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.42.2"
@@ -5807,6 +6239,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
+[[package]]
+name = "windows_x86_64_gnullvm"
+version = "0.53.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57"
+
[[package]]
name = "windows_x86_64_msvc"
version = "0.42.2"
@@ -5826,10 +6264,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
-name = "winnow"
-version = "0.6.20"
+name = "windows_x86_64_msvc"
+version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b"
+checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486"
+
+[[package]]
+name = "winnow"
+version = "0.7.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0e97b544156e9bebe1a0ffbc03484fc1ffe3100cbce3ffb17eac35f7cdd7ab36"
dependencies = [
"memchr",
]
@@ -5844,6 +6288,15 @@ dependencies = [
"windows-sys 0.48.0",
]
+[[package]]
+name = "wit-bindgen-rt"
+version = "0.39.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1"
+dependencies = [
+ "bitflags 2.9.0",
+]
+
[[package]]
name = "write16"
version = "1.0.0"
@@ -5856,17 +6309,11 @@ version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51"
-[[package]]
-name = "yansi"
-version = "1.0.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049"
-
[[package]]
name = "yoke"
-version = "0.7.4"
+version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6c5b1314b079b0930c31e3af543d8ee1757b1951ae1e1565ec704403a7240ca5"
+checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40"
dependencies = [
"serde",
"stable_deref_trait",
@@ -5876,13 +6323,13 @@ dependencies = [
[[package]]
name = "yoke-derive"
-version = "0.7.4"
+version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95"
+checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
"synstructure",
]
@@ -5892,8 +6339,16 @@ version = "0.7.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0"
dependencies = [
- "byteorder",
- "zerocopy-derive",
+ "zerocopy-derive 0.7.35",
+]
+
+[[package]]
+name = "zerocopy"
+version = "0.8.24"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2586fea28e186957ef732a5f8b3be2da217d65c5969d4b1e17f973ebbe876879"
+dependencies = [
+ "zerocopy-derive 0.8.24",
]
[[package]]
@@ -5904,27 +6359,38 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
+]
+
+[[package]]
+name = "zerocopy-derive"
+version = "0.8.24"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a996a8f63c5c4448cd959ac1bab0aaa3306ccfd060472f85943ee0750f0169be"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.100",
]
[[package]]
name = "zerofrom"
-version = "0.1.4"
+version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "91ec111ce797d0e0784a1116d0ddcdbea84322cd79e5d5ad173daeba4f93ab55"
+checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5"
dependencies = [
"zerofrom-derive",
]
[[package]]
name = "zerofrom-derive"
-version = "0.1.4"
+version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5"
+checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
"synstructure",
]
@@ -5953,7 +6419,7 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.89",
+ "syn 2.0.100",
]
[[package]]
@@ -5985,9 +6451,9 @@ dependencies = [
[[package]]
name = "zune-jpeg"
-version = "0.4.13"
+version = "0.4.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "16099418600b4d8f028622f73ff6e3deaabdff330fb9a2a131dea781ee8b0768"
+checksum = "99a5bab8d7dedf81405c4bb1f2b83ea057643d9cb28778cea9eecddeedd2e028"
dependencies = [
"zune-core",
]
diff --git a/Cargo.toml b/Cargo.toml
index d81551533..1bc736bab 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,26 +1,27 @@
[workspace]
members = [
- "benchmark",
- "backends/v2",
- "backends/v3",
- "backends/grpc-metadata",
- "backends/trtllm",
- "launcher",
- "router"
+ "benchmark",
+ "backends/v2",
+ "backends/v3",
+ "backends/grpc-metadata",
+ "backends/trtllm",
+ "backends/llamacpp",
+ "launcher",
+ "router"
]
default-members = [
- "benchmark",
- "backends/v2",
- "backends/v3",
- "backends/grpc-metadata",
- # "backends/trtllm",
- "launcher",
- "router"
+ "benchmark",
+ "backends/v2",
+ "backends/v3",
+ "backends/grpc-metadata",
+ # "backends/trtllm",
+ "launcher",
+ "router"
]
resolver = "2"
[workspace.package]
-version = "3.0.2-dev0"
+version = "3.2.3-dev0"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"
@@ -28,7 +29,7 @@ homepage = "https://github.com/huggingface/text-generation-inference"
[workspace.dependencies]
base64 = "0.22.0"
tokenizers = { version = "0.20.0", features = ["http"] }
-hf-hub = { version = "0.3.1", features = ["tokio"] }
+hf-hub = { version = "0.4.2", features = ["tokio"] }
metrics = { version = "0.23.0" }
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
minijinja = { version = "2.2.0", features = ["json"] }
diff --git a/Dockerfile b/Dockerfile
index 0c08d48f6..03840b971 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,5 +1,5 @@
# Rust builder
-FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
+FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@@ -45,21 +45,16 @@ RUN cargo build --profile release-opt --frozen
# Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
+WORKDIR /usr/src/
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
-ARG PYTORCH_VERSION=2.4.0
-
+ARG PYTORCH_VERSION=2.6
ARG PYTHON_VERSION=3.11
+
# Keep in sync with `server/pyproject.toml
-ARG CUDA_VERSION=12.4
-ARG MAMBA_VERSION=24.3.0-0
-ARG CUDA_CHANNEL=nvidia
-ARG INSTALL_CHANNEL=pytorch
# Automatically set by buildx
ARG TARGETPLATFORM
-ENV PATH /opt/conda/bin:$PATH
-
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
ca-certificates \
@@ -67,26 +62,12 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
curl \
git && \
rm -rf /var/lib/apt/lists/*
-
-# Install conda
-# translating Docker's TARGETPLATFORM into mamba arches
-RUN case ${TARGETPLATFORM} in \
- "linux/arm64") MAMBA_ARCH=aarch64 ;; \
- *) MAMBA_ARCH=x86_64 ;; \
- esac && \
- curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
-RUN chmod +x ~/mambaforge.sh && \
- bash ~/mambaforge.sh -b -p /opt/conda && \
- rm ~/mambaforge.sh
-
-# Install pytorch
-# On arm64 we exit with an error code
-RUN case ${TARGETPLATFORM} in \
- "linux/arm64") exit 1 ;; \
- *) /opt/conda/bin/conda update -y conda && \
- /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \
- esac && \
- /opt/conda/bin/conda clean -ya
+COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
+ENV PATH="$PATH:/root/.local/bin"
+RUN uv python install ${PYTHON_VERSION}
+RUN uv venv --python ${PYTHON_VERSION} && uv pip install torch==${PYTORCH_VERSION} torchvision pip setuptools packaging
+ENV VIRTUAL_ENV=/usr/src/.venv/
+ENV PATH="$PATH:/usr/src/.venv/bin/"
# CUDA kernels builder image
FROM pytorch-install AS kernel-builder
@@ -106,7 +87,7 @@ WORKDIR /usr/src
COPY server/Makefile-flash-att Makefile
# Build specific version of flash attention
-RUN make build-flash-attention
+RUN . .venv/bin/activate && make build-flash-attention
# Build Flash Attention v2 CUDA kernels
FROM kernel-builder AS flash-att-v2-builder
@@ -116,14 +97,14 @@ WORKDIR /usr/src
COPY server/Makefile-flash-att-v2 Makefile
# Build specific version of flash attention v2
-RUN make build-flash-attention-v2-cuda
+RUN . .venv/bin/activate && make build-flash-attention-v2-cuda
# Build Transformers exllama kernels
FROM kernel-builder AS exllama-kernels-builder
WORKDIR /usr/src
COPY server/exllama_kernels/ .
-RUN python setup.py build
+RUN . .venv/bin/activate && python setup.py build
# Build Transformers exllama kernels
FROM kernel-builder AS exllamav2-kernels-builder
@@ -131,54 +112,43 @@ WORKDIR /usr/src
COPY server/Makefile-exllamav2/ Makefile
# Build specific version of transformers
-RUN make build-exllamav2
+RUN . .venv/bin/activate && make build-exllamav2
# Build Transformers awq kernels
FROM kernel-builder AS awq-kernels-builder
WORKDIR /usr/src
COPY server/Makefile-awq Makefile
# Build specific version of transformers
-RUN make build-awq
-
-# Build eetq kernels
-FROM kernel-builder AS eetq-kernels-builder
-WORKDIR /usr/src
-COPY server/Makefile-eetq Makefile
-# Build specific version of transformers
-RUN make build-eetq
+RUN . .venv/bin/activate && make build-awq
# Build Lorax Punica kernels
FROM kernel-builder AS lorax-punica-builder
WORKDIR /usr/src
COPY server/Makefile-lorax-punica Makefile
# Build specific version of transformers
-RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica
+RUN . .venv/bin/activate && TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica
# Build Transformers CUDA kernels
FROM kernel-builder AS custom-kernels-builder
WORKDIR /usr/src
COPY server/custom_kernels/ .
# Build specific version of transformers
-RUN python setup.py build
+RUN . .venv/bin/activate && python setup.py build
# Build mamba kernels
FROM kernel-builder AS mamba-builder
WORKDIR /usr/src
COPY server/Makefile-selective-scan Makefile
-RUN make build-all
+RUN . .venv/bin/activate && make build-all
# Build flashinfer
FROM kernel-builder AS flashinfer-builder
WORKDIR /usr/src
COPY server/Makefile-flashinfer Makefile
-RUN make install-flashinfer
+RUN . .venv/bin/activate && make install-flashinfer
# Text Generation Inference base image
-FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS base
-
-# Conda env
-ENV PATH=/opt/conda/bin:$PATH \
- CONDA_PREFIX=/opt/conda
+FROM nvidia/cuda:12.4.0-base-ubuntu22.04 AS base
# Text Generation Inference base env
ENV HF_HOME=/data \
@@ -195,50 +165,61 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
git \
&& rm -rf /var/lib/apt/lists/*
-# Copy conda with PyTorch installed
-COPY --from=pytorch-install /opt/conda /opt/conda
-
-# Copy build artifacts from flash attention builder
-COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
-COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
-COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
-
-# Copy build artifacts from flash attention v2 builder
-COPY --from=flash-att-v2-builder /opt/conda/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so /opt/conda/lib/python3.11/site-packages
-
-# Copy build artifacts from custom kernels builder
-COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
-# Copy build artifacts from exllama kernels builder
-COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
-# Copy build artifacts from exllamav2 kernels builder
-COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
-# Copy build artifacts from awq kernels builder
-COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
-# Copy build artifacts from eetq kernels builder
-COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
-# Copy build artifacts from lorax punica kernels builder
-COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
-# Copy build artifacts from mamba builder
-COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /opt/conda/lib/python3.11/site-packages
-COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /opt/conda/lib/python3.11/site-packages
-COPY --from=flashinfer-builder /opt/conda/lib/python3.11/site-packages/flashinfer/ /opt/conda/lib/python3.11/site-packages/flashinfer/
-
+# RUN curl -LsSf https://astral.sh/uv/install.sh | sh
+# ENV PATH="$PATH:/root/.local/bin"
+COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
# Install flash-attention dependencies
-RUN pip install einops --no-cache-dir
+# RUN pip install einops --no-cache-dir
+
+# Copy env with PyTorch installed
+COPY --from=pytorch-install /usr/src/.venv /usr/src/.venv
+ENV PYTHON_VERSION=3.11
+RUN uv python install ${PYTHON_VERSION}
+ENV VIRTUAL_ENV=/usr/src/.venv/
+ENV PATH="$PATH:/usr/src/.venv/bin/"
# Install server
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
+ENV HF_KERNELS_CACHE=/kernels
RUN cd server && \
- make gen-server && \
- pip install -r requirements_cuda.txt && \
- pip install ".[attention, bnb, accelerate, compressed-tensors, marlin, moe, quantize, peft, outlines]" --no-cache-dir && \
- pip install nvidia-nccl-cu12==2.22.3
+ uv sync --frozen --extra gen --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --extra torch --no-install-project --active && \
+ make gen-server-raw && \
+ kernels download .
-ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
+RUN cd server && \
+ uv sync --frozen --extra gen --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --extra torch --active --python=${PYTHON_VERSION} && \
+ uv pip install nvidia-nccl-cu12==2.25.1 && \
+ pwd && \
+ text-generation-server --help
+
+# Copy build artifacts from flash attention builder
+COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
+COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
+COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
+
+# Copy build artifacts from flash attention v2 builder
+COPY --from=flash-att-v2-builder /usr/src/.venv/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so /usr/src/.venv/lib/python3.11/site-packages
+
+# Copy build artifacts from custom kernels builder
+COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
+# Copy build artifacts from exllama kernels builder
+COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
+# Copy build artifacts from exllamav2 kernels builder
+COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
+# Copy build artifacts from awq kernels builder
+COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
+# Copy build artifacts from lorax punica kernels builder
+COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
+# Copy build artifacts from mamba builder
+COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages
+COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages
+COPY --from=flashinfer-builder /usr/src/.venv/lib/python3.11/site-packages/flashinfer/ /usr/src/.venv/lib/python3.11/site-packages/flashinfer/
+
+
+# ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
# Required to find libpython within the rust binaries
-ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
# This is needed because exl2 tries to load flash-attn
# And fails with our builds.
ENV EXLLAMA_NO_FLASH_ATTN=1
@@ -273,5 +254,6 @@ FROM base
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh
+ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/root/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/"
ENTRYPOINT ["/tgi-entrypoint.sh"]
# CMD ["--json-output"]
diff --git a/Dockerfile.neuron b/Dockerfile.neuron
new file mode 100644
index 000000000..d22ca2228
--- /dev/null
+++ b/Dockerfile.neuron
@@ -0,0 +1,167 @@
+# Fetch and extract the TGI sources
+FROM alpine AS tgi
+RUN mkdir -p /tgi
+
+# Fetch the optimum-neuron sources directly to avoid relying on pypi deployments
+FROM alpine AS optimum-neuron
+RUN mkdir -p /optimum-neuron
+ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.1.0.tar.gz /optimum-neuron/sources.tar.gz
+RUN tar -C /optimum-neuron -xf /optimum-neuron/sources.tar.gz --strip-components=1
+
+# Build cargo components (adapted from TGI original Dockerfile)
+# Note: we cannot use the cargo-chef base image as it uses python 3.11
+FROM ubuntu:22.04 AS chef
+
+RUN apt-get update -y \
+ && apt-get install -y --no-install-recommends \
+ curl ca-certificates build-essential \
+ && rm -rf /var/lib/apt/lists/* \
+ && apt-get clean
+
+RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.1 --profile minimal -y
+ENV PATH="/root/.cargo/bin:${PATH}"
+RUN cargo install cargo-chef --locked
+
+WORKDIR /usr/src
+
+FROM chef AS planner
+COPY backends/neuron/Cargo.toml Cargo.toml
+COPY Cargo.lock Cargo.lock
+COPY rust-toolchain.toml rust-toolchain.toml
+COPY proto proto
+COPY router router
+COPY backends backends
+COPY launcher launcher
+RUN cargo chef prepare --recipe-path recipe.json
+
+FROM chef AS builder
+
+RUN apt-get update -y \
+ && apt-get install -y --no-install-recommends \
+ unzip python3-dev libssl-dev pkg-config \
+ && rm -rf /var/lib/apt/lists/* \
+ && apt-get clean
+
+RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
+ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
+ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
+ unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
+ rm -f $PROTOC_ZIP
+
+COPY backends/neuron/Cargo.toml Cargo.toml
+COPY --from=planner /usr/src/recipe.json recipe.json
+RUN cargo chef cook --release --recipe-path recipe.json
+
+COPY Cargo.lock Cargo.lock
+COPY rust-toolchain.toml rust-toolchain.toml
+COPY proto proto
+COPY router router
+COPY backends backends
+COPY launcher launcher
+RUN cargo build --release
+
+# Python base image
+FROM ubuntu:22.04 AS base
+
+RUN apt-get update -y \
+ && apt-get install -y --no-install-recommends \
+ python3-pip \
+ python3-setuptools \
+ python-is-python3 \
+ && rm -rf /var/lib/apt/lists/* \
+ && apt-get clean
+RUN pip3 --no-cache-dir install --upgrade pip
+
+# Python server build image
+FROM base AS pyserver
+
+RUN apt-get update -y \
+ && apt-get install -y --no-install-recommends \
+ make \
+ python3-venv \
+ && rm -rf /var/lib/apt/lists/* \
+ && apt-get clean
+
+RUN install -d /pyserver
+WORKDIR /pyserver
+COPY backends/neuron/server server
+COPY proto proto
+RUN pip3 install -r server/build-requirements.txt
+RUN VERBOSE=1 BUILDDIR=/pyserver/build PROTODIR=/pyserver/proto make -C server package
+
+# Neuron base image (used for deployment)
+FROM base AS neuron
+
+# Install system prerequisites
+RUN apt-get update -y \
+ && apt-get install -y --no-install-recommends \
+ gnupg2 \
+ wget \
+ python3-dev \
+ libexpat1 \
+ && rm -rf /var/lib/apt/lists/* \
+ && apt-get clean
+
+RUN echo "deb https://apt.repos.neuron.amazonaws.com jammy main" > /etc/apt/sources.list.d/neuron.list
+RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | apt-key add -
+
+# Install neuronx packages
+RUN apt-get update -y \
+ && apt-get install -y --no-install-recommends \
+ aws-neuronx-dkms=2.19.64.0 \
+ aws-neuronx-collectives=2.23.135.0-3e70920f2 \
+ aws-neuronx-runtime-lib=2.23.112.0-9b5179492 \
+ aws-neuronx-tools=2.20.204.0 \
+ libxml2 \
+ && rm -rf /var/lib/apt/lists/* \
+ && apt-get clean
+
+ENV PATH="/opt/bin/:/opt/aws/neuron/bin:${PATH}"
+
+# Install manually torch CPU version to avoid pulling CUDA
+RUN pip3 install \
+ torch==2.5.1 \
+ torchvision==0.20.1 \
+ --index-url https://download.pytorch.org/whl/cpu
+
+RUN pip3 install \
+ neuronx-cc==2.16.372.0 \
+ torch-neuronx==2.5.1.2.4.0 \
+ transformers-neuronx==0.13.322 \
+ neuronx-distributed==0.10.1 \
+ libneuronxla==2.1.681.0 \
+ --extra-index-url=https://pip.repos.neuron.amazonaws.com
+
+# Install HuggingFace packages
+RUN pip3 install \
+ hf_transfer huggingface_hub
+
+# Install optimum-neuron
+COPY --from=optimum-neuron /optimum-neuron optimum-neuron
+RUN pip3 install ./optimum-neuron
+
+# TGI base env
+ENV HUGGINGFACE_HUB_CACHE=/tmp \
+ HF_HUB_ENABLE_HF_TRANSFER=1 \
+ PORT=80
+
+# Disable color logs as they are not supported by CloudWatch
+ENV LOGURU_COLORIZE=NO
+ENV LOG_COLORIZE=0
+
+# Install router
+COPY --from=builder /usr/src/target/release/text-generation-router-v2 /usr/local/bin/text-generation-router
+# Install launcher
+COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
+# Install python server
+COPY --from=pyserver /pyserver/build/dist dist
+RUN pip install dist/text_generation_server*.tar.gz
+
+# Final image
+FROM neuron
+
+COPY backends/neuron/tgi_env.py /tgi_env.py
+COPY backends/neuron/tgi-entrypoint.sh /tgi-entrypoint.sh
+RUN chmod +x /tgi-entrypoint.sh
+
+ENTRYPOINT ["/tgi-entrypoint.sh"]
diff --git a/Dockerfile_amd b/Dockerfile_amd
index 7638947a5..e3e9efda8 100644
--- a/Dockerfile_amd
+++ b/Dockerfile_amd
@@ -1,5 +1,5 @@
# Rust builder
-FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
+FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@@ -41,262 +41,237 @@ COPY backends backends
COPY launcher launcher
RUN cargo build --profile release-opt --frozen
-# Text Generation Inference base image for RoCm
-FROM rocm/dev-ubuntu-22.04:6.2 AS base
+FROM rocm/dev-ubuntu-22.04:6.3.1-complete AS base
+ARG HIPBLASLT_BRANCH="4d40e36"
+ARG HIPBLAS_COMMON_BRANCH="7c1566b"
+ARG LEGACY_HIPBLASLT_OPTION=
+ARG RCCL_BRANCH="648a58d"
+ARG RCCL_REPO="https://github.com/ROCm/rccl"
+ARG TRITON_BRANCH="e5be006"
+ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
+ARG PYTORCH_BRANCH="3a585126"
+ARG PYTORCH_VISION_BRANCH="v0.19.1"
+ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
+ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
+ARG FA_BRANCH="b7d29fb"
+ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
+ARG AITER_BRANCH="21d47a9"
+ARG AITER_REPO="https://github.com/ROCm/aiter.git"
+
+ENV PATH=/opt/rocm/llvm/bin:$PATH
+ENV ROCM_PATH=/opt/rocm
+ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
+ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942
+ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
+
+ARG PYTHON_VERSION=3.11
+
+RUN mkdir -p /app
+WORKDIR /app
+ENV DEBIAN_FRONTEND=noninteractive
+
+# Install Python and other dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
- build-essential \
- ca-certificates \
- ccache \
- curl \
- git \
- make \
- libmsgpack-dev \
- libssl-dev \
- llvm-dev \
- g++ \
- # Needed to build VLLM & flash.
- rocthrust-dev \
- hipsparse-dev \
- hipblas-dev \
- hipcub-dev \
- rocblas-dev \
- hiprand-dev \
- hipfft-dev \
- rocrand-dev \
- miopen-hip-dev \
- hipsolver-dev \
- rccl-dev \
- cmake \
- python3.11-venv && \
- rm -rf /var/lib/apt/lists/*
+ build-essential \
+ ca-certificates \
+ ccache \
+ curl \
+ git \
+ ninja-build \
+ cmake \
+ software-properties-common \
+ python3.11-dev \
+ python3.11-venv && \
+ rm -rf /var/lib/apt/lists/*
-# Keep in sync with `server/pyproject.toml
-ARG MAMBA_VERSION=23.1.0-1
-ARG PYTHON_VERSION='3.11.10'
-# Automatically set by buildx
-ARG TARGETPLATFORM
-ENV PATH=/opt/conda/bin:$PATH
+COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
+ENV PATH="$PATH:/root/.local/bin"
+RUN uv python install ${PYTHON_VERSION}
+RUN uv venv --python ${PYTHON_VERSION} && uv pip install pip setuptools packaging
+ENV VIRTUAL_ENV=/usr/src/.venv/
+ENV PATH="$PATH:/usr/src/.venv/bin/"
-ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
+RUN . .venv/bin/activate && pip install -U packaging cmake ninja wheel setuptools pybind11 Cython
-# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
-# Install mamba
-# translating Docker's TARGETPLATFORM into mamba arches
-RUN case ${TARGETPLATFORM} in \
- "linux/arm64") MAMBA_ARCH=aarch64 ;; \
- *) MAMBA_ARCH=x86_64 ;; \
- esac && \
- curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
-RUN chmod +x ~/mambaforge.sh && \
- bash ~/mambaforge.sh -b -p /opt/conda && \
- mamba init && \
- rm ~/mambaforge.sh
-
-# RUN conda install intel::mkl-static intel::mkl-include
-# Install pytorch
-# On arm64 we exit with an error code
-RUN case ${TARGETPLATFORM} in \
- "linux/arm64") exit 1 ;; \
- *) /opt/conda/bin/conda update -y conda && \
- /opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
- esac && \
- /opt/conda/bin/conda clean -ya
-
-# Install flash-attention, torch dependencies
-RUN python3 -m pip install --upgrade pip && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/*
-
-RUN conda install mkl=2021
-ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/
-
-
-ARG COMMON_WORKDIR=/
-WORKDIR ${COMMON_WORKDIR}
-
-
-# Install HIPBLASLt
FROM base AS build_hipblaslt
-ARG HIPBLASLT_BRANCH="e6da924"
-RUN git clone https://github.com/ROCm/hipBLASLt.git \
- && cd hipBLASLt \
+ARG HIPBLASLT_BRANCH
+ARG HIPBLAS_COMMON_BRANCH
+# Set to "--legacy_hipblas_direct" for ROCm<=6.2
+ARG LEGACY_HIPBLASLT_OPTION
+RUN git clone https://github.com/ROCm/hipBLAS-common.git
+RUN . .venv/bin/activate && cd hipBLAS-common \
+ && git checkout ${HIPBLAS_COMMON_BRANCH} \
+ && mkdir build \
+ && cd build \
+ && cmake .. \
+ && make package \
+ && dpkg -i ./*.deb
+RUN git clone https://github.com/ROCm/hipBLASLt
+RUN . .venv/bin/activate && cd hipBLASLt \
&& git checkout ${HIPBLASLT_BRANCH} \
- && SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \
+ && ./install.sh -d --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \
&& cd build/release \
&& make package
+RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install
-FROM scratch AS export_hipblaslt
-ARG COMMON_WORKDIR
-COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb /
-
-# RCCL build stages
FROM base AS build_rccl
-ARG RCCL_BRANCH="rocm-6.2.0"
-RUN git clone https://github.com/ROCm/rccl \
- && cd rccl \
+ARG RCCL_BRANCH
+ARG RCCL_REPO
+RUN git clone ${RCCL_REPO}
+RUN . .venv/bin/activate && cd rccl \
&& git checkout ${RCCL_BRANCH} \
&& ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
-FROM scratch AS export_rccl
-ARG COMMON_WORKDIR
-COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb /
+RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install
-# Triton build stages
FROM base AS build_triton
-ARG TRITON_BRANCH="e192dba"
-ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
-RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \
- && cd triton \
+ARG TRITON_BRANCH
+ARG TRITON_REPO
+RUN git clone ${TRITON_REPO}
+RUN . .venv/bin/activate && cd triton \
&& git checkout ${TRITON_BRANCH} \
&& cd python \
&& python3 setup.py bdist_wheel --dist-dir=dist
-FROM scratch AS export_triton
-ARG COMMON_WORKDIR
-COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl /
+RUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install
-# # AMD-SMI build stages
FROM base AS build_amdsmi
-RUN cd /opt/rocm/share/amd_smi \
+RUN . .venv/bin/activate && cd /opt/rocm/share/amd_smi \
&& pip wheel . --wheel-dir=dist
-FROM scratch AS export_amdsmi
-COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl /
+RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install
+FROM base AS build_pytorch
+ARG PYTORCH_BRANCH
+ARG PYTORCH_VISION_BRANCH
+ARG PYTORCH_REPO
+ARG PYTORCH_VISION_REPO
+ARG FA_BRANCH
+ARG FA_REPO
+RUN git clone ${PYTORCH_REPO} pytorch
+RUN . .venv/bin/activate && cd pytorch && git checkout ${PYTORCH_BRANCH} && \
+ pip install -r requirements.txt && git submodule update --init --recursive \
+ && python3 tools/amd_build/build_amd.py \
+ && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \
+ && pip install dist/*.whl
+RUN git clone ${PYTORCH_VISION_REPO} vision
+RUN . .venv/bin/activate && cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
+ && python3 setup.py bdist_wheel --dist-dir=dist \
+ && pip install dist/*.whl
+RUN git clone ${FA_REPO}
+RUN . .venv/bin/activate && cd flash-attention \
+ && git checkout ${FA_BRANCH} \
+ && git submodule update --init \
+ && MAX_JOBS=64 GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist
+RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
+ && cp /app/vision/dist/*.whl /app/install \
+ && cp /app/flash-attention/dist/*.whl /app/install
-FROM base as build_pytorch
+FROM base AS final
+RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \
+ dpkg -i /install/*deb \
+ && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
+ && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status
+RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \
+ dpkg -i /install/*deb \
+ && sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
+ && sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status
+RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
+ . .venv/bin/activate && \
+ pip install /install/*.whl
+RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
+ . .venv/bin/activate && \
+ pip install /install/*.whl
+RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
+ . .venv/bin/activate && \
+ pip install /install/*.whl
-RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
- if ls /install/*.deb; then \
- dpkg -i /install/*.deb \
- && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
- && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
- fi
+ARG AITER_REPO
+ARG AITER_BRANCH
+RUN git clone --recursive ${AITER_REPO}
+RUN . .venv/bin/activate && cd aiter \
+ && git checkout ${AITER_BRANCH} \
+ && git submodule update --init --recursive \
+ && pip install -r requirements.txt \
+ && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter
-ARG BUILD_ENVIRONMENT=pytorch-linux-jammy-rocm6.2-py3.11
-ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
-
-# A commit to fix the output scaling factor issue in _scaled_mm
-# Not yet in 2.5.0-rc1
-ARG PYTORCH_BRANCH="cedc116"
-ARG PYTORCH_VISION_BRANCH="v0.19.1"
-ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
-
-RUN git clone ${PYTORCH_REPO} pytorch \
- && cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \
- && pip install -r requirements.txt --no-cache-dir \
- && python tools/amd_build/build_amd.py \
- && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist
-FROM scratch as export_pytorch
-ARG COMMON_WORKDIR
-COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl /
-
-FROM base AS install_deps
-
-ARG COMMON_WORKDIR
-
-# Install hipblaslt
-RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
- if ls /install/*.deb; then \
- dpkg -i /install/*.deb \
- && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
- && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
- fi
-
-RUN --mount=type=bind,from=export_rccl,src=/,target=/install \
- if ls /install/*.deb; then \
- dpkg -i /install/*.deb \
- # RCCL needs to be installed twice
- && dpkg -i /install/*.deb \
- && sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
- && sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status; \
- fi
-
-RUN --mount=type=bind,from=export_triton,src=/,target=/install \
- if ls /install/*.whl; then \
- # Preemptively uninstall to prevent pip same-version no-installs
- pip uninstall -y triton \
- && pip install /install/*.whl; \
- fi
-
-RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \
- # Preemptively uninstall to prevent pip same-version no-installs
- pip uninstall -y amdsmi \
- && pip install /install/*.whl;
-
-RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
- if ls /install/*.whl; then \
- # Preemptively uninstall to prevent pip same-version no-installs
- pip uninstall -y torch torchvision \
- && pip install /install/*.whl; \
- fi
-
-FROM install_deps AS kernel-builder
+RUN rm -rf /var/lib/apt/lists/*
+FROM final AS kernel-builder
# # Build vllm kernels
FROM kernel-builder AS vllm-builder
-WORKDIR /usr/src
COPY server/Makefile-vllm Makefile
+RUN . .venv/bin/activate && pip install setuptools_scm
# Build specific version of vllm
-RUN make build-vllm-rocm
-
-# Build Flash Attention v2 kernels
-FROM kernel-builder AS flash-att-v2-builder
-WORKDIR /usr/src
-
-COPY server/Makefile-flash-att-v2 Makefile
-
-# Build specific version of flash attention v2
-RUN make build-flash-attention-v2-rocm
+RUN . .venv/bin/activate && make build-vllm-rocm
# Build Transformers CUDA kernels (gpt-neox and bloom)
FROM kernel-builder AS custom-kernels-builder
-WORKDIR /usr/src
COPY server/custom_kernels/ .
-RUN python setup.py build
+RUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist
# Build exllama kernels
FROM kernel-builder AS exllama-kernels-builder
-WORKDIR /usr/src
COPY server/exllama_kernels/ .
-
-RUN python setup.py build
+RUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist
# Build exllama v2 kernels
FROM kernel-builder AS exllamav2-kernels-builder
-WORKDIR /usr/src
COPY server/exllamav2_kernels/ .
+RUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist
-RUN python setup.py build
+FROM kernel-builder AS marlin-kernels
+ENV MARLIN_KERNELS_BRANCH=v0.3.6
+ENV VLLM_TARGET_DEVICE=rocm
+RUN . .venv/bin/activate && git clone https://github.com/danieldk/marlin-kernels.git && \
+ cd marlin-kernels && \
+ git checkout ${MARLIN_KERNELS_BRANCH} && \
+ python3 setup.py bdist_wheel --dist-dir=dist
-FROM install_deps AS base-copy
+FROM kernel-builder AS moe-kernels
+ENV MOE_KERNELS_BRANCH=v0.8.2
+ENV VLLM_TARGET_DEVICE=rocm
+RUN . .venv/bin/activate && git clone https://github.com/danieldk/moe-kernels.git && \
+ cd moe-kernels && \
+ git checkout ${MOE_KERNELS_BRANCH} && \
+ python3 setup.py bdist_wheel --dist-dir=dist
+
+FROM final AS base-copy
# Text Generation Inference base env
ENV HF_HOME=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80
-# Copy builds artifacts from vllm builder
-COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
-
-# Copy build artifacts from flash attention v2 builder
-COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
-
-# Copy build artifacts from custom kernels builder
-COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
-
-# Copy build artifacts from exllama kernels builder
-COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
-
-# Copy build artifacts from exllamav2 kernels builder
-COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
+ENV VIRTUAL_ENV=/app/.venv/
+ENV PATH="$PATH:/app/.venv/bin/"
# Install server
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
RUN cd server && \
- make gen-server && \
- pip install -r requirements_rocm.txt && \
- pip install ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
+ uv pip install grpcio-tools mypy-protobuf && \
+ uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir && \
+ make gen-server-raw
+RUN cd server && \
+ pwd && \
+ text-generation-server --help
+
+RUN --mount=type=bind,from=vllm-builder,src=/app/vllm/dist,target=/install \
+ uv pip install /install/*.whl
+RUN --mount=type=bind,from=custom-kernels-builder,src=/app/dist,target=/install \
+ uv pip install /install/*.whl
+RUN --mount=type=bind,from=custom-kernels-builder,src=/app/dist,target=/install \
+ uv pip install /install/*.whl
+RUN --mount=type=bind,from=exllama-kernels-builder,src=/app/dist,target=/install \
+ uv pip install /install/*.whl
+RUN --mount=type=bind,from=exllamav2-kernels-builder,src=/app/dist,target=/install \
+ uv pip install /install/*.whl
+RUN --mount=type=bind,from=marlin-kernels,src=/app/marlin-kernels/dist,target=/install \
+ uv pip install /install/*.whl
+RUN --mount=type=bind,from=moe-kernels,src=/app/moe-kernels/dist,target=/install \
+ uv pip install /install/*.whl
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
@@ -304,7 +279,6 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/l
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
-ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
# AWS Sagemaker compatible image
FROM base AS sagemaker
@@ -335,4 +309,6 @@ COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh
ENTRYPOINT ["/tgi-entrypoint.sh"]
-CMD ["--json-output"]
+ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/root/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib"
+ENV PYTHONPATH=/app/.venv/lib/python3.11/site-packages
+# CMD ["--json-output"]
diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi
new file mode 100644
index 000000000..06073fe40
--- /dev/null
+++ b/Dockerfile_gaudi
@@ -0,0 +1,126 @@
+# Those arguments are required to build the image
+ARG HABANA_VERSION=1.20.0
+ARG PYTORCH_VERSION=2.6.0
+
+# Rust builder
+FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
+WORKDIR /usr/src
+
+ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
+
+FROM chef AS planner
+COPY Cargo.lock Cargo.lock
+COPY Cargo.toml Cargo.toml
+COPY rust-toolchain.toml rust-toolchain.toml
+COPY proto proto
+COPY benchmark benchmark
+COPY router router
+COPY backends backends
+COPY launcher launcher
+RUN cargo chef prepare --recipe-path recipe.json
+
+FROM chef AS builder
+
+ENV PYO3_PYTHON="/root/.local/bin/python" \
+ PYTHON_SYS_EXECUTABLE="/root/.local/bin/python" \
+ PYO3_PYTHON_VERSION="3.10"
+
+RUN curl -LsSf https://astral.sh/uv/install.sh | sh \
+ && . $HOME/.local/bin/env \
+ && uv python install 3.10 --default --preview \
+ && test -f /root/.local/bin/python || (echo "Python 3.10 not found at /root/.local/bin/python" && exit 1)
+
+RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
+ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
+ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
+ unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
+ rm -f $PROTOC_ZIP
+
+COPY --from=planner /usr/src/recipe.json recipe.json
+RUN cargo chef cook --profile release-opt --recipe-path recipe.json
+
+ARG GIT_SHA
+ARG DOCKER_LABEL
+
+COPY Cargo.toml Cargo.toml
+COPY rust-toolchain.toml rust-toolchain.toml
+COPY proto proto
+COPY benchmark benchmark
+COPY router router
+COPY backends backends
+COPY launcher launcher
+RUN cargo build --profile release-opt
+
+# Text Generation Inference base image
+ARG HABANA_VERSION
+ARG PYTORCH_VERSION
+
+FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytorch-installer-${PYTORCH_VERSION}:latest AS base
+
+ENV ATTENTION=default
+ENV PREFIX_CACHING=0
+ENV PREFILL_CHUNKING=0
+
+# Text Generation Inference base env
+ENV HF_HOME=/data \
+ HF_HUB_ENABLE_HF_TRANSFER=1 \
+ PORT=80
+
+# Assert that Python 3.10 is installed as the launcher is compiled with Python 3.10
+RUN python3.10 --version || (echo "Python 3.10 is not installed" && exit 1)
+
+# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
+RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
+ dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
+
+WORKDIR /usr/src
+
+RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
+ libssl-dev \
+ ca-certificates \
+ make \
+ curl \
+ git \
+ && rm -rf /var/lib/apt/lists/*
+
+# Install server
+COPY proto proto
+COPY backends/gaudi/server server
+COPY backends/gaudi/server/Makefile server/Makefile
+ARG HABANA_VERSION
+RUN cd server && \
+ make gen-server && \
+ pip install --no-deps -r requirements.txt && \
+ bash ./dill-0.3.8-patch.sh && \
+ pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
+ BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
+ pip install . --no-cache-dir
+RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git
+# Install benchmarker
+COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
+# Install router
+COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
+# Install launcher
+COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
+
+
+# AWS Sagemaker compatible image
+FROM base AS sagemaker
+
+COPY sagemaker-entrypoint.sh entrypoint.sh
+RUN chmod +x entrypoint.sh
+
+ENTRYPOINT ["./entrypoint.sh"]
+
+# Final image
+FROM base
+
+ENV HF_HUB_ENABLE_HF_TRANSFER 1
+ENV HABANA_VISIBLE_DEVICES all
+ENV OMPI_MCA_btl_vader_single_copy_mechanism NONE
+
+COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh
+RUN chmod +x /tgi-entrypoint.sh
+
+ENTRYPOINT ["/tgi-entrypoint.sh"]
+CMD ["--json-output"]
diff --git a/Dockerfile_intel b/Dockerfile_intel
index e024f31a5..b2a905ec9 100644
--- a/Dockerfile_intel
+++ b/Dockerfile_intel
@@ -1,6 +1,6 @@
ARG PLATFORM=xpu
-FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef
+FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@@ -45,7 +45,7 @@ RUN cargo build --profile release-opt --frozen
# Text Generation Inference base image for Intel
-FROM intel/oneapi-basekit:2024.2.1-0-devel-ubuntu22.04 AS xpu
+FROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04 AS xpu
USER root
@@ -87,7 +87,7 @@ RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https:/
RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d
-RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-pti-dev-0.9
+RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc libnl-genl-3-200
# Text Generation Inference base env
ENV HF_HOME=/data \
@@ -96,29 +96,28 @@ ENV HF_HOME=/data \
-WORKDIR /usr/src
-RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torch-2.5.0a0%2Bgite84e33f-cp311-cp311-linux_x86_64.whl --no-cache-dir
-RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torchaudio-2.5.0a0%2B56bc006-cp311-cp311-linux_x86_64.whl --no-cache-dir
-RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torchvision-0.20.0a0%2B8e8a208-cp311-cp311-linux_x86_64.whl --no-cache-dir
-RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.5.10%2Bgit9d489a8-cp311-cp311-linux_x86_64.whl --no-cache-dir
-RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/oneccl_bind_pt-2.5.0%2Bxpu-cp311-cp311-linux_x86_64.whl --no-cache-dir
-RUN pip install triton-xpu==3.0.0b2 --no-cache-dir
+WORKDIR /usr/src
+RUN pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/xpu
# Install server
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
+ENV UV_SYSTEM_PYTHON=1
RUN cd server && \
make gen-server && \
- pip install -r requirements_intel.txt && \
- pip install ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
+ pip install -U pip uv && \
+ uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
-ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/pti/0.9/lib:/opt/conda/lib
+ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib
ENV CCL_ZE_IPC_EXCHANGE=sockets
-#ENV TORCH_LLM_ALLREDUCE=1
-#ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
+ENV TORCH_LLM_ALLREDUCE=1
+ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
+ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0
+RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.6.0%2Bxpu-cp311-cp311-linux_x86_64.whl
+RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.6.10%2Bxpu-cp311-cp311-linux_x86_64.whl
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
@@ -158,7 +157,7 @@ ARG MAMBA_VERSION=23.1.0-1
ARG PYTHON_VERSION='3.11.10'
# Automatically set by buildx
ARG TARGETPLATFORM
-ENV PATH /opt/conda/bin:$PATH
+ENV PATH=/opt/conda/bin:$PATH
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba
@@ -181,22 +180,14 @@ RUN case ${TARGETPLATFORM} in \
RUN conda install -c conda-forge gperftools mkl
-
-RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.5.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
-RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.20.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
-RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
-
-RUN pip install triton py-libnuma
+RUN pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cpu
+RUN pip install triton==3.1.0 py-libnuma
WORKDIR /usr/src
-RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout b7b552baf64283b594665b8687430fe92990e497
-RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout v2.4.0+cpu+rc0
+RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/intel_extension_for_pytorch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl
+RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/oneccl_bind_pt-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl
-RUN sed -i 's/VERSION_MINOR 6/VERSION_MINOR 5/' intel-extension-for-pytorch/version.txt
-RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
-
-RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so
ENV CCL_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch
@@ -209,10 +200,11 @@ ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
+ENV UV_SYSTEM_PYTHON=1
RUN cd server && \
make gen-server && \
- pip install -r requirements_intel.txt && \
- pip install ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
+ pip install -U pip uv && \
+ uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
@@ -222,9 +214,9 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
FROM ${PLATFORM} AS final
-ENV ATTENTION=paged
-ENV PREFIX_CACHING=0
-ENV PREFILL_CHUNKING=0
+ENV ATTENTION=flashdecoding-ipex
+ENV PREFIX_CACHING=1
+ENV PREFILL_CHUNKING=1
ENV CUDA_GRAPHS=0
ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]
diff --git a/Dockerfile_llamacpp b/Dockerfile_llamacpp
new file mode 100644
index 000000000..291ae88cb
--- /dev/null
+++ b/Dockerfile_llamacpp
@@ -0,0 +1,88 @@
+FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu24.04 AS deps
+
+ARG llamacpp_version=b4827
+ARG llamacpp_cuda=OFF
+ARG llamacpp_native=ON
+ARG llamacpp_cpu_arm_arch=native
+ARG cuda_arch=75-real;80-real;86-real;89-real;90-real
+
+WORKDIR /opt/src
+
+ENV DEBIAN_FRONTEND=noninteractive
+RUN apt update && apt upgrade -y && apt install -y \
+ clang \
+ cmake \
+ curl \
+ git \
+ python3-dev \
+ libssl-dev \
+ pkg-config \
+ tar
+
+ADD https://github.com/ggml-org/llama.cpp/archive/refs/tags/${llamacpp_version}.tar.gz /opt/src/
+RUN mkdir -p llama.cpp \
+ && tar -xzf ${llamacpp_version}.tar.gz -C llama.cpp --strip-components=1 \
+ && cd llama.cpp \
+ && cmake -B build \
+ -DCMAKE_INSTALL_PREFIX=/usr \
+ -DCMAKE_INSTALL_LIBDIR=/usr/lib \
+ -DCMAKE_C_COMPILER=clang \
+ -DCMAKE_CXX_COMPILER=clang++ \
+ -DCMAKE_CUDA_ARCHITECTURES=${cuda_arch} \
+ -DGGML_CUDA=${llamacpp_cuda} \
+ -DGGML_NATIVE=${llamacpp_native} \
+ -DGGML_CPU_ARM_ARCH=${llamacpp_cpu_arm_arch} \
+ -DLLAMA_BUILD_COMMON=OFF \
+ -DLLAMA_BUILD_TESTS=OFF \
+ -DLLAMA_BUILD_EXAMPLES=OFF \
+ -DLLAMA_BUILD_SERVER=OFF \
+ && cmake --build build --parallel --config Release \
+ && cmake --install build
+
+WORKDIR /app
+COPY rust-toolchain.toml rust-toolchain.toml
+RUN curl -sSf https://sh.rustup.rs | sh -s -- --no-modify-path --default-toolchain 1.85.1 --profile minimal -y
+ENV PATH="/root/.cargo/bin:$PATH"
+RUN cargo install cargo-chef --locked
+
+FROM deps AS planner
+COPY . .
+RUN cargo chef prepare --recipe-path recipe.json
+
+FROM deps AS builder
+COPY --from=planner /app/recipe.json recipe.json
+RUN cargo chef cook \
+ --recipe-path recipe.json \
+ --profile release \
+ --package text-generation-router-llamacpp
+COPY . .
+RUN cargo build \
+ --profile release \
+ --package text-generation-router-llamacpp --frozen
+
+FROM nvidia/cuda:12.8.0-cudnn-runtime-ubuntu24.04
+WORKDIR /app
+
+ENV DEBIAN_FRONTEND=noninteractive
+RUN apt update && apt upgrade -y && apt install -y \
+ python3-venv \
+ python3-pip
+
+RUN python3 -m venv /venv
+ENV PATH="/venv/bin:$PATH"
+
+COPY backends/llamacpp/requirements.txt requirements.txt
+COPY --from=builder /opt/src/llama.cpp/gguf-py gguf-py
+COPY --from=builder /opt/src/llama.cpp/convert_hf_to_gguf.py /bin/
+
+RUN pip3 install --no-cache-dir \
+ -r requirements.txt \
+ -e gguf-py
+
+COPY --from=builder /usr/lib/libllama.so /usr/lib/
+COPY --from=builder /usr/lib/libggml*.so /usr/lib/
+COPY --from=builder /app/target/release/text-generation-router-llamacpp /usr/bin/
+
+ENV HF_HUB_ENABLE_HF_TRANSFER=1
+
+ENTRYPOINT ["text-generation-router-llamacpp"]
diff --git a/Dockerfile_trtllm b/Dockerfile_trtllm
index 3ccb0310b..7df2e3685 100644
--- a/Dockerfile_trtllm
+++ b/Dockerfile_trtllm
@@ -1,52 +1,56 @@
-ARG CUDA_ARCH_LIST="75-real;80-real;86-real;89-real;90-real"
-ARG OMPI_VERSION="4.1.6"
+ARG cuda_arch_list="75-real;80-real;86-real;89-real;90-real;100-real;120-real"
+ARG cuda_base=12.8.0
+ARG build_type=release
+ARG ompi_version=4.1.7
+ARG sccache_gha_enabled=off
+ARG actions_cache_url=""
+ARG actions_runtime_token=""
-# Build dependencies resolver stage
-FROM lukemathwalker/cargo-chef:latest AS chef
-WORKDIR /usr/src/text-generation-inference/backends/trtllm
-
-FROM chef AS planner
-COPY . .
-RUN cargo chef prepare --recipe-path recipe.json
# CUDA dependent dependencies resolver stage
-FROM nvidia/cuda:12.6.1-cudnn-devel-ubuntu22.04 AS cuda-builder
+FROM nvidia/cuda:${cuda_base}-cudnn-devel-ubuntu24.04 AS cuda-builder
-RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
- --mount=type=cache,target=/var/lib/apt,sharing=locked \
- apt update && apt install -y \
+RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
build-essential \
cmake \
curl \
- gcc \
- g++ \
+ gcc-14 \
+ g++-14 \
git \
git-lfs \
+ lld \
libssl-dev \
+ libucx-dev \
+ libasan8 \
+ libubsan1 \
ninja-build \
pkg-config \
+ pipx \
python3 \
python3-dev \
python3-setuptools \
tar \
- wget
+ wget --no-install-recommends && \
+ pipx ensurepath
ENV TGI_INSTALL_PREFIX=/usr/local/tgi
ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt
# Install OpenMPI
FROM cuda-builder AS mpi-builder
-ARG OMPI_VERSION
+WORKDIR /opt/src/mpi
-ENV OMPI_TARBALL_FILENAME="openmpi-$OMPI_VERSION.tar.bz2"
-RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILENAME" -P /opt/src && \
- mkdir /usr/src/mpi && \
- tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \
- cd /usr/src/mpi && \
+ARG ompi_version
+ENV OMPI_VERSION=${ompi_version}
+ENV OMPI_TARBALL_FILENAME=openmpi-${OMPI_VERSION}.tar.bz2
+ADD --checksum=sha256:54a33cb7ad81ff0976f15a6cc8003c3922f0f3d8ceed14e1813ef3603f22cd34 \
+ https://download.open-mpi.org/release/open-mpi/v4.1/${OMPI_TARBALL_FILENAME} .
+
+RUN tar --strip-components=1 -xf ${OMPI_TARBALL_FILENAME} &&\
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \
make -j all && \
make install && \
- rm -rf "/opt/src/$OMPI_TARBALL_FILENAME"
+ rm -rf ${OMPI_TARBALL_FILENAME}/..
# Install TensorRT
FROM cuda-builder AS trt-builder
@@ -58,38 +62,62 @@ RUN chmod +x /opt/install_tensorrt.sh && \
FROM cuda-builder AS tgi-builder
WORKDIR /usr/src/text-generation-inference
+# Scoped global args reuse
+ARG cuda_arch_list
+ARG build_type
+ARG sccache_gha_enabled
+ARG actions_cache_url
+ARG actions_runtime_token
+
# Install Rust
-RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \
- chmod -R a+w /root/.rustup && \
- chmod -R a+w /root/.cargo
-
ENV PATH="/root/.cargo/bin:$PATH"
-RUN cargo install cargo-chef
+RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.1 --profile minimal -y && \
+ chmod -R a+w /root/.rustup && \
+ chmod -R a+w /root/.cargo && \
+ cargo install sccache --locked
-# Cache dependencies
-COPY --from=planner /usr/src/text-generation-inference/backends/trtllm/recipe.json .
-RUN cargo chef cook --release --recipe-path recipe.json
-
-# Build actual TGI
-ARG CUDA_ARCH_LIST
-ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt:$CMAKE_PREFIX_PATH"
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
-ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig:$PKG_CONFIG_PATH"
+ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig"
+ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt"
-COPY . .
+ENV USE_LLD_LINKER=ON
+ENV CUDA_ARCH_LIST=${cuda_arch_list}
+
+# SCCACHE Specifics args - before finding a better, more generic, way...
+ENV SCCACHE_GHA_ENABLED=${sccache_gha_enabled}
+ENV ACTIONS_CACHE_URL=${actions_cache_url}
+ENV ACTIONS_RUNTIME_TOKEN=${actions_runtime_token}
+
+COPY Cargo.lock Cargo.lock
+COPY Cargo.toml Cargo.toml
+COPY rust-toolchain.toml rust-toolchain.toml
+COPY router router
+COPY backends backends
+COPY benchmark benchmark
+COPY launcher launcher
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
-RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \
- cd backends/trtllm && \
- CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release
-FROM nvidia/cuda:12.6.1-cudnn-runtime-ubuntu22.04 AS runtime
-RUN apt update && apt install -y python3-minimal python3-dev python3-pip && \
+ENV RUSTC_WRAPPER=sccache
+ENV CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX
+RUN export CC=gcc-14 \
+ export CXX=g++-14 \
+ export CMAKE_C_COMPILER_LAUNCHER=sccache && \
+ export CMAKE_CXX_COMPILER_LAUNCHER=sccache && \
+ export CMAKE_CUDA_COMPILER_LAUNCHER=sccache && \
+ mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \
+ cargo build --profile ${build_type} --package text-generation-backends-trtllm --bin text-generation-backends-trtllm && \
+ sccache --show-stats
+
+FROM nvidia/cuda:${cuda_base}-cudnn-runtime-ubuntu24.04 AS runtime
+RUN apt update && apt install -y libucx0 pipx python3-minimal python3-dev python3-pip python3-venv && \
rm -rf /var/lib/{apt,dpkg,cache,log}/ && \
- python3 -m pip install transformers tokenizers
+ pipx ensurepath && \
+ pipx install --include-deps transformers tokenizers
WORKDIR /usr/local/tgi/bin
+ENV PATH=/root/.local/share/pipx/venvs/transformers/bin/:$PATH
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/mpi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
ENV TOKENIZERS_PARALLELISM=false
ENV OMPI_MCA_plm_rsh_agent=""
@@ -99,10 +127,33 @@ COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
COPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
+# This is used only for the CI/CD
+FROM nvidia/cuda:${cuda_base}-cudnn-runtime-ubuntu24.04 AS ci-runtime
+RUN apt update && apt install -y libasan8 libubsan1 libucx0 pipx python3-minimal python3-dev python3-pip python3-venv && \
+ rm -rf /var/lib/{apt,dpkg,cache,log}/ && \
+ pipx ensurepath && \
+ pipx install --include-deps transformers tokenizers
+
+WORKDIR /usr/local/tgi/bin
+
+ENV PATH=/root/.local/share/pipx/venvs/transformers/bin/:$PATH
+ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/mpi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
+ENV TOKENIZERS_PARALLELISM=false
+ENV OMPI_MCA_plm_rsh_agent=""
+
+COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
+COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
+COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
+
+# Basically we copy from target/debug instead of target/release
+COPY --from=tgi-builder /usr/src/text-generation-inference/target/debug/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
+
+# This is the final image
FROM runtime
LABEL co.huggingface.vendor="Hugging Face Inc."
LABEL org.opencontainers.image.authors="hardware@hf.co"
+LABEL org.opencontainers.title="Text-Generation-Inference TensorRT-LLM Backend"
ENTRYPOINT ["./text-generation-launcher"]
CMD ["--executor-worker", "/usr/local/tgi/bin/executorWorker"]
diff --git a/Makefile b/Makefile
index 3068a06f4..2ecdd45ca 100644
--- a/Makefile
+++ b/Makefile
@@ -53,3 +53,6 @@ run-falcon-7b-instruct-quantize:
clean:
rm -rf target aml
+
+preview_doc:
+ doc-builder preview text-generation-inference docs/source --not_python_module
diff --git a/README.md b/README.md
index 6d3a9b124..ed7b4809c 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,7 @@
-
+
# Text Generation Inference
@@ -84,7 +84,7 @@ model=HuggingFaceH4/zephyr-7b-beta
volume=$PWD/data
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
-3.0.0 ghcr.io/huggingface/text-generation-inference:3.0.0 --model-id $model
+ ghcr.io/huggingface/text-generation-inference:3.2.3 --model-id $model
```
And then you can make requests like
@@ -121,7 +121,7 @@ curl localhost:8080/v1/chat/completions \
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
-**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.0-rocm --model-id $model` instead of the command above.
+**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.2.3-rocm --model-id $model` instead of the command above.
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
```
@@ -141,8 +141,8 @@ You have the option to utilize the `HF_TOKEN` environment variable for configuri
For example, if you want to serve the gated Llama V2 model variants:
1. Go to https://huggingface.co/settings/tokens
-2. Copy your cli READ token
-3. Export `HF_TOKEN=
`
+2. Copy your CLI READ token
+3. Export `HF_TOKEN=`
or with Docker:
@@ -151,13 +151,14 @@ model=meta-llama/Meta-Llama-3.1-8B-Instruct
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
token=
-docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.0.0 --model-id $model
+docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data \
+ ghcr.io/huggingface/text-generation-inference:3.2.3 --model-id $model
```
### A note on Shared Memory (shm)
[`NCCL`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html) is a communication framework used by
-`PyTorch` to do distributed training/inference. `text-generation-inference` make
+`PyTorch` to do distributed training/inference. `text-generation-inference` makes
use of `NCCL` to enable Tensor Parallelism to dramatically speed up inference for large language models.
In order to share data between the different devices of a `NCCL` group, `NCCL` might fall back to using the host memory if
@@ -196,7 +197,7 @@ Detailed blogpost by Adyen on TGI inner workings: [LLM inference at scale with T
You can also opt to install `text-generation-inference` locally.
-First clone the repository and change directoy into it:
+First clone the repository and change directory into it:
```shell
git clone https://github.com/huggingface/text-generation-inference
@@ -213,7 +214,7 @@ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
conda create -n text-generation-inference python=3.11
conda activate text-generation-inference
-#using pyton venv
+#using python venv
python3 -m venv .venv
source .venv/bin/activate
```
@@ -262,7 +263,8 @@ locally, which can take hours.
After that you can run TGI with `nix run`:
```shell
-nix run . -- --model-id meta-llama/Llama-3.1-8B-Instruct
+cd text-generation-inference
+nix run --extra-experimental-features nix-command --extra-experimental-features flakes . -- --model-id meta-llama/Llama-3.1-8B-Instruct
```
**Note:** when you are using Nix on a non-NixOS system, you have to [make some symlinks](https://danieldk.eu/Nix-CUDA-on-non-NixOS-systems#make-runopengl-driverlib-and-symlink-the-driver-library)
diff --git a/backends/gaudi/Makefile b/backends/gaudi/Makefile
new file mode 100644
index 000000000..f760f4d6e
--- /dev/null
+++ b/backends/gaudi/Makefile
@@ -0,0 +1,62 @@
+mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
+mkfile_dir := $(dir $(mkfile_path))
+root_dir := ${mkfile_dir}/../..
+
+HABANA_VERSION := 1.20.0
+PYTORCH_VERSION := 2.6.0
+
+.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
+
+image:
+ docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION)
+
+run-local-dev-container:
+ docker run -it \
+ --runtime=habana \
+ --ipc=host \
+ --cap-add=sys_nice \
+ --net=host \
+ -e HABANA_VISIBLE_DEVICES=all \
+ -e OMPI_MCA_btl_vader_single_copy_mechanism=none \
+ -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
+ -e HF_TOKEN=`cat /home/ubuntu/.cache/huggingface/token` \
+ -e LOG_LEVEL=debug \
+ -e PORT=8080 \
+ -v /home/ubuntu/.cache/huggingface:/data \
+ -v $(PWD):/text-generation-inference \
+ -w /text-generation-inference \
+ vault.habana.ai/gaudi-docker/$(HABANA_VERSION)/ubuntu22.04/habanalabs/pytorch-installer-$(PYTORCH_VERSION):latest
+
+install-dependencies:
+ pip install git+https://github.com/HabanaAI/DeepSpeed.git@$(HABANA_VERSION)
+ pip install outlines~=0.0.34
+ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
+
+install-server:
+ make -C ${root_dir}/backends/gaudi/server install PROTO_PATH=../../../proto/v3
+
+install-router:
+ make -C ${root_dir} install-router
+
+install-launcher:
+ make -C ${root_dir} install-launcher
+
+# use source to load the rust in path
+local-dev-install: install-dependencies
+ bash -c 'source "$$HOME/.cargo/env" && \
+ make install-server && \
+ make install-router && \
+ make install-launcher'
+
+# In order to run the integration tests, you need to first build the image (make -C backends/gaudi image)
+run-integration-tests:
+ uv pip install -r ${root_dir}/backends/gaudi/server/integration-tests/requirements.txt
+ DOCKER_VOLUME=${root_dir}/data \
+ HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
+ uv run pytest --durations=0 -sv ${root_dir}/backends/gaudi/server/integration-tests
+
+# This is used to capture the expected outputs for the integration tests offering an easy way to add more models to the integration tests
+capture-expected-outputs-for-integration-tests:
+ DOCKER_VOLUME=${root_dir}/data \
+ HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
+ uv run pytest --durations=0 -sv ${root_dir}/backends/gaudi/server/integration-tests/capture_expected_outputs.py
diff --git a/backends/gaudi/README.md b/backends/gaudi/README.md
new file mode 100644
index 000000000..ba890f0b1
--- /dev/null
+++ b/backends/gaudi/README.md
@@ -0,0 +1,142 @@
+# Text-generation-inference - Gaudi backend
+
+## Description
+
+This is the TGI backend for Intel Gaudi. This backend is composed of the tgi server optimized for Gaudi hardware.
+
+## Build your own image
+
+The simplest way to build TGI with the Gaudi backend is to use the provided `Makefile`:
+
+Option 1: From the project root directory:
+```bash
+make -C backends/gaudi image
+```
+
+Option 2: From the Gaudi backend directory:
+```bash
+cd backends/gaudi
+make image
+```
+
+You can now run the server with the following command:
+
+Option 1: Sharded:
+```bash
+model=meta-llama/Llama-3.1-8B-Instruct
+hf_token=$(cat ${HOME}/.cache/huggingface/token)
+volume=${HOME}/.cache/huggingface
+
+docker run --runtime=habana --ipc=host --cap-add=sys_nice \
+ -p 8080:80 -v $volume:/data \
+ -e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \
+ tgi-gaudi --model-id $model \
+ --sharded true --num-shard 8 \
+ --max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 8 --max-batch-prefill-tokens 2048
+```
+
+Option 2: Non-sharded:
+```bash
+model=meta-llama/Llama-3.1-8B-Instruct
+hf_token=$(cat ${HOME}/.cache/huggingface/token)
+volume=${HOME}/.cache/huggingface
+
+docker run --runtime=habana --ipc=host --cap-add=sys_nice \
+ -p 8080:80 -v $volume:/data \
+ -e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \
+ tgi-gaudi --model-id $model \
+ --max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 4 --max-batch-prefill-tokens 2048
+```
+
+## Contributing
+
+### Local Development
+
+This is useful if you want to run the server locally for better debugging.
+```bash
+make -C backends/gaudi run-local-dev-container
+```
+
+Then run the following command inside the container to install tgi for gaudi:
+```bash
+make -C backends/gaudi local-dev-install
+```
+
+Add rust to path:
+```bash
+. "$HOME/.cargo/env"
+```
+
+Option 1: Run the server (sharded model):
+```bash
+LOG_LEVEL=debug text-generation-launcher \
+ --model-id meta-llama/Llama-3.1-8B-Instruct \
+ --sharded true \
+ --num-shard 8 \
+ --max-input-tokens 512 \
+ --max-total-tokens 1024 \
+ --max-batch-size 8 \
+ --max-batch-prefill-tokens 2048
+```
+
+Option 2: Run the server (non-sharded model):
+```bash
+LOG_LEVEL=debug text-generation-launcher \
+ --model-id meta-llama/Llama-3.1-8B-Instruct \
+ --max-input-tokens 512 \
+ --max-total-tokens 1024 \
+ --max-batch-size 4 \
+ --max-batch-prefill-tokens 2048
+```
+
+You can then test the server with the following curl command from another terminal (can be outside the container):
+```bash
+curl 127.0.0.1:8080/generate \
+ -X POST \
+ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
+ -H 'Content-Type: application/json'
+```
+
+### Integration tests
+
+To run the integration tests, you need to first build the image:
+```bash
+make -C backends/gaudi image
+```
+
+Then run the following command to run the integration tests:
+```bash
+make -C backends/gaudi run-integration-tests
+```
+
+To capture the expected outputs for the integration tests, you can run the following command:
+```bash
+make -C backends/gaudi capture-expected-outputs-for-integration-tests
+```
+
+#### How the integration tests works
+The integration tests works as follows:
+
+1. Start a tgi server in a container, similar to the command:
+```bash
+docker run --runtime=habana --ipc=host --cap-add=sys_nice \
+ -p 8080:80 -v $volume:/data \
+ -e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \
+ tgi-gaudi --model-id $model \
+ --max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 4 --max-batch-prefill-tokens 2048
+```
+
+2. Do a /generate request to the server, similar to the command:
+```bash
+curl 127.0.0.1:8080/generate \
+ -X POST \
+ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
+ -H 'Content-Type: application/json'
+```
+
+3. Check the output of the server against the expected output:
+```python
+assert curl_output == expected_output
+```
+
+This is the repeated for a set of models and configurations.
diff --git a/backends/gaudi/examples/docker_commands/docker_commands.md b/backends/gaudi/examples/docker_commands/docker_commands.md
new file mode 100644
index 000000000..597012892
--- /dev/null
+++ b/backends/gaudi/examples/docker_commands/docker_commands.md
@@ -0,0 +1,283 @@
+# Examples of Docker Commands for Gaudi Backend
+
+This page gives a list of examples of docker run commands for some of the most popular models.
+
+> **Note:** The parameters are chosen for Gaudi2 hardware to maximize performance on this given hardware, please adjust the parameters based on your hardware. For example, if you are using Gaudi3, you may want to increase the batch size.
+
+## Default Precision (BF16)
+
+### Llama3.1-8B on 1 card (BF16)
+
+```bash
+model=meta-llama/Meta-Llama-3.1-8B-Instruct
+hf_token=YOUR_ACCESS_TOKEN
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+
+docker run -p 8080:80 \
+ --runtime=habana \
+ --cap-add=sys_nice \
+ --ipc=host \
+ -v $volume:/data \
+ -e HF_TOKEN=$hf_token \
+ -e MAX_TOTAL_TOKENS=2048 \
+ -e PREFILL_BATCH_BUCKET_SIZE=2 \
+ -e BATCH_BUCKET_SIZE=32 \
+ -e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
+ ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
+ --model-id $model \
+ --max-input-tokens 1024 --max-total-tokens 2048 \
+ --max-batch-prefill-tokens 2048 --max-batch-size 32 \
+ --max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
+```
+
+### Llama3.1-70B 8 cards (BF16)
+
+```bash
+model=meta-llama/Meta-Llama-3.1-70B-Instruct
+hf_token=YOUR_ACCESS_TOKEN
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+
+docker run -p 8080:80 \
+ --runtime=habana \
+ --cap-add=sys_nice \
+ --ipc=host \
+ -v $volume:/data \
+ -e HF_TOKEN=$hf_token \
+ -e MAX_TOTAL_TOKENS=2048 \
+ -e BATCH_BUCKET_SIZE=256 \
+ -e PREFILL_BATCH_BUCKET_SIZE=4 \
+ -e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
+ ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
+ --model-id $model \
+ --sharded true --num-shard 8 \
+ --max-input-tokens 1024 --max-total-tokens 2048 \
+ --max-batch-prefill-tokens 4096 --max-batch-size 256 \
+ --max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
+```
+
+### Llama2-7B on 1 Card (BF16)
+
+```bash
+model=meta-llama/Llama-2-7b-chat-hf
+hf_token=YOUR_ACCESS_TOKEN
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+
+docker run -p 8080:80 \
+ --runtime=habana \
+ --cap-add=sys_nice \
+ --ipc=host \
+ -v $volume:/data \
+ -e HF_TOKEN=$hf_token \
+ -e MAX_TOTAL_TOKENS=2048 \
+ -e PREFILL_BATCH_BUCKET_SIZE=2 \
+ -e BATCH_BUCKET_SIZE=32 \
+ -e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
+ ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
+ --model-id $model \
+ --max-input-tokens 1024 --max-total-tokens 2048 \
+ --max-batch-prefill-tokens 2048 --max-batch-size 32 \
+ --max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
+```
+
+### Llama2-70B on 8 cards (BF16)
+
+```bash
+model=meta-llama/Llama-2-70b-chat-hf
+hf_token=YOUR_ACCESS_TOKEN
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+
+docker run -p 8080:80 \
+ --runtime=habana \
+ --cap-add=sys_nice \
+ --ipc=host \
+ -v $volume:/data \
+ -e HF_TOKEN=$hf_token \
+ -e MAX_TOTAL_TOKENS=2048 \
+ -e BATCH_BUCKET_SIZE=256 \
+ -e PREFILL_BATCH_BUCKET_SIZE=4 \
+ -e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
+ ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
+ --model-id $model \
+ --sharded true --num-shard 8 \
+ --max-input-tokens 1024 --max-total-tokens 2048 \
+ --max-batch-prefill-tokens 4096 --max-batch-size 256 \
+ --max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
+```
+
+### Llava-v1.6-Mistral-7B on 1 card (BF16)
+
+```bash
+model=llava-hf/llava-v1.6-mistral-7b-hf
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+
+docker run -p 8080:80 \
+ --runtime=habana \
+ --cap-add=sys_nice \
+ --ipc=host \
+ -v $volume:/data \
+ -e PREFILL_BATCH_BUCKET_SIZE=1 \
+ -e BATCH_BUCKET_SIZE=1 \
+ ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
+ --model-id $model \
+ --max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
+ --max-total-tokens 8192 --max-batch-size 4
+```
+
+## FP8 Precision
+
+Please refer to the [FP8 Precision](https://huggingface.co/docs/text-generation-inference/backends/gaudi_new#how-to-use-different-precision-formats) section for more details. You need to measure the statistics of the model first before running the model in FP8 precision.
+
+## Llama3.1-8B on 1 Card (FP8)
+
+```bash
+model=meta-llama/Meta-Llama-3.1-8B-Instruct
+hf_token=YOUR_ACCESS_TOKEN
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+
+docker run -p 8080:80 \
+ --runtime=habana \
+ --cap-add=sys_nice \
+ --ipc=host \
+ -v $volume:/data \
+ -v $PWD/quantization_config:/usr/src/quantization_config \
+ -v $PWD/hqt_output:/usr/src/hqt_output \
+ -e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
+ -e HF_TOKEN=$hf_token \
+ -e MAX_TOTAL_TOKENS=2048 \
+ -e PREFILL_BATCH_BUCKET_SIZE=2 \
+ -e BATCH_BUCKET_SIZE=32 \
+ -e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
+ ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
+ --model-id $model \
+ --max-input-tokens 1024 --max-total-tokens 2048 \
+ --max-batch-prefill-tokens 2048 --max-batch-size 32 \
+ --max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
+```
+
+## Llama3.1-70B on 8 cards (FP8)
+
+```bash
+model=meta-llama/Meta-Llama-3.1-70B-Instruct
+hf_token=YOUR_ACCESS_TOKEN
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+
+docker run -p 8080:80 \
+ --runtime=habana \
+ --cap-add=sys_nice \
+ --ipc=host \
+ -v $volume:/data \
+ -v $PWD/quantization_config:/usr/src/quantization_config \
+ -v $PWD/hqt_output:/usr/src/hqt_output \
+ -e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
+ -e HF_TOKEN=$hf_token \
+ -e MAX_TOTAL_TOKENS=2048 \
+ -e BATCH_BUCKET_SIZE=256 \
+ -e PREFILL_BATCH_BUCKET_SIZE=4 \
+ -e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
+ ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
+ --model-id $model \
+ --sharded true --num-shard 8 \
+ --max-input-tokens 1024 --max-total-tokens 2048 \
+ --max-batch-prefill-tokens 4096 --max-batch-size 256 \
+ --max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
+```
+
+## Llama2-7B on 1 Card (FP8)
+
+```bash
+model=meta-llama/Llama-2-7b-chat-hf
+hf_token=YOUR_ACCESS_TOKEN
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+
+docker run -p 8080:80 \
+ --runtime=habana \
+ --cap-add=sys_nice \
+ --ipc=host \
+ -v $volume:/data \
+ -v $PWD/quantization_config:/usr/src/quantization_config \
+ -v $PWD/hqt_output:/usr/src/hqt_output \
+ -e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
+ -e HF_TOKEN=$hf_token \
+ -e MAX_TOTAL_TOKENS=2048 \
+ -e PREFILL_BATCH_BUCKET_SIZE=2 \
+ -e BATCH_BUCKET_SIZE=32 \
+ -e PAD_SEQUENCE_TO_MULTIPLE_OF=256 \
+ ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
+ --model-id $model \
+ --max-input-tokens 1024 --max-total-tokens 2048 \
+ --max-batch-prefill-tokens 2048 --max-batch-size 32 \
+ --max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
+```
+
+## Llama2-70B on 8 Cards (FP8)
+
+```bash
+model=meta-llama/Llama-2-70b-chat-hf
+hf_token=YOUR_ACCESS_TOKEN
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+
+docker run -p 8080:80 \
+ --runtime=habana \
+ --cap-add=sys_nice \
+ --ipc=host \
+ -v $volume:/data \
+ -v $PWD/quantization_config:/usr/src/quantization_config \
+ -v $PWD/hqt_output:/usr/src/hqt_output \
+ -e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
+ -e HF_TOKEN=$hf_token \
+ -e MAX_TOTAL_TOKENS=2048 \
+ -e BATCH_BUCKET_SIZE=256 \
+ -e PREFILL_BATCH_BUCKET_SIZE=4 \
+ -e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
+ ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
+ --model-id $model \
+ --sharded true --num-shard 8 \
+ --max-input-tokens 1024 --max-total-tokens 2048 \
+ --max-batch-prefill-tokens 4096 --max-batch-size 256 \
+ --max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
+```
+
+## Llava-v1.6-Mistral-7B on 1 Card (FP8)
+
+```bash
+model=llava-hf/llava-v1.6-mistral-7b-hf
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+
+docker run -p 8080:80 \
+ --runtime=habana \
+ --cap-add=sys_nice \
+ --ipc=host \
+ -v $volume:/data \
+ -v $PWD/quantization_config:/usr/src/quantization_config \
+ -v $PWD/hqt_output:/usr/src/hqt_output \
+ -e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
+ -e PREFILL_BATCH_BUCKET_SIZE=1 \
+ -e BATCH_BUCKET_SIZE=1 \
+ ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
+ --model-id $model \
+ --max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
+ --max-total-tokens 8192 --max-batch-size 4
+```
+
+## Llava-v1.6-Mistral-7B on 8 Cards (FP8)
+
+```bash
+model=llava-hf/llava-v1.6-mistral-7b-hf
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+
+docker run -p 8080:80 \
+ --runtime=habana \
+ --cap-add=sys_nice \
+ --ipc=host \
+ -v $volume:/data \
+ -v $PWD/quantization_config:/usr/src/quantization_config \
+ -v $PWD/hqt_output:/usr/src/hqt_output \
+ -e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
+ -e PREFILL_BATCH_BUCKET_SIZE=1 \
+ -e BATCH_BUCKET_SIZE=1 \
+ ghcr.io/huggingface/text-generation-inference:3.1.1-gaudi \
+ --model-id $model \
+ --sharded true --num-shard 8 \
+ --max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
+ --max-total-tokens 8192 --max-batch-size 4
+```
diff --git a/backends/gaudi/server/.gitignore b/backends/gaudi/server/.gitignore
new file mode 100644
index 000000000..576746eec
--- /dev/null
+++ b/backends/gaudi/server/.gitignore
@@ -0,0 +1,164 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+text_generation_server/__pycache__/
+text_generation_server/pb/__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+transformers
+safetensors
+flash-attention/
+flash-attention-v2/
+vllm/
+llm-awq/
+eetq/
+mamba/
diff --git a/backends/gaudi/server/Makefile b/backends/gaudi/server/Makefile
new file mode 100644
index 000000000..b5b843387
--- /dev/null
+++ b/backends/gaudi/server/Makefile
@@ -0,0 +1,38 @@
+include Makefile-flash-att
+include Makefile-flash-att-v2
+include Makefile-vllm
+include Makefile-awq
+include Makefile-eetq
+include Makefile-selective-scan
+
+PROTO_PATH ?= ../proto/v3
+
+unit-tests:
+ pytest -s -vv -m "not private" tests
+
+gen-server:
+ # Compile protos
+ pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir
+ mkdir text_generation_server/pb || true
+ python -m grpc_tools.protoc -I$(PROTO_PATH) --python_out=text_generation_server/pb \
+ --grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb $(PROTO_PATH)/generate.proto
+ find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
+ touch text_generation_server/pb/__init__.py
+
+install: gen-server
+ pip install pip --upgrade
+ pip install --no-deps -r requirements.txt
+ pip install -e "."
+
+run-dev:
+ SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
+
+install-poetry:
+ curl -sSL https://install.python-poetry.org | python3 -
+
+update-lock:
+ rm poetry.lock
+ poetry lock --no-update
+
+export-requirements:
+ poetry export -o requirements.txt --without-hashes
diff --git a/backends/gaudi/server/Makefile-awq b/backends/gaudi/server/Makefile-awq
new file mode 100644
index 000000000..4e074a133
--- /dev/null
+++ b/backends/gaudi/server/Makefile-awq
@@ -0,0 +1,15 @@
+# Fork that adds only the correct stream to this kernel in order
+# to make cuda graphs work.
+awq_commit := bd1dc2d5254345cc76ab71894651fb821275bdd4
+
+awq:
+ rm -rf llm-awq
+ git clone https://github.com/huggingface/llm-awq
+
+build-awq: awq
+ cd llm-awq/ && git fetch && git checkout $(awq_commit)
+ cd llm-awq/awq/kernels && python setup.py build
+
+install-awq: build-awq
+ pip uninstall awq_inference_engine -y || true
+ cd llm-awq/awq/kernels && python setup.py install
diff --git a/backends/gaudi/server/Makefile-eetq b/backends/gaudi/server/Makefile-eetq
new file mode 100644
index 000000000..726e47b57
--- /dev/null
+++ b/backends/gaudi/server/Makefile-eetq
@@ -0,0 +1,13 @@
+eetq_commit := 1657b1504faa359e2ce0ac02999439d7ac8c74c0
+
+eetq:
+ # Clone eetq
+ pip install packaging
+ git clone https://github.com/NetEase-FuXi/EETQ.git eetq
+
+build-eetq: eetq
+ cd eetq && git fetch && git checkout $(eetq_commit) && git submodule update --init --recursive
+ cd eetq && python setup.py build
+
+install-eetq: build-eetq
+ cd eetq && python setup.py install
diff --git a/backends/gaudi/server/Makefile-fbgemm b/backends/gaudi/server/Makefile-fbgemm
new file mode 100644
index 000000000..3b8061a1f
--- /dev/null
+++ b/backends/gaudi/server/Makefile-fbgemm
@@ -0,0 +1,15 @@
+fbgemm_commit := v0.8.0
+
+build-fbgemm:
+ @if [ ! -d "fbgemm" ]; then \
+ git clone https://github.com/pytorch/FBGEMM.git fbgemm; \
+ fi
+ cd fbgemm && git fetch && git checkout $(fbgemm_commit) && \
+ git submodule update --init --recursive && \
+ cd fbgemm_gpu && \
+ pip install -r requirements.txt && \
+ CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai build
+
+install-fbgemm: build-fbgemm
+ cd fbgemm/fbgemm_gpu && \
+ CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai install
diff --git a/backends/gaudi/server/Makefile-flash-att b/backends/gaudi/server/Makefile-flash-att
new file mode 100644
index 000000000..29e75bc48
--- /dev/null
+++ b/backends/gaudi/server/Makefile-flash-att
@@ -0,0 +1,12 @@
+flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec
+
+build-flash-attention:
+ if [ ! -d 'flash-attention' ]; then \
+ pip install -U packaging ninja --no-cache-dir && \
+ git clone https://github.com/HazyResearch/flash-attention.git; \
+ fi
+ cd flash-attention && git fetch && git checkout $(flash_att_commit) && \
+ MAX_JOBS=8 python setup.py build && cd csrc/layer_norm && python setup.py build && cd ../rotary && python setup.py build
+
+install-flash-attention: build-flash-attention
+ cd flash-attention && git checkout $(flash_att_commit) && MAX_JOBS=8 python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install
diff --git a/backends/gaudi/server/Makefile-flash-att-v2 b/backends/gaudi/server/Makefile-flash-att-v2
new file mode 100644
index 000000000..a9cdf7822
--- /dev/null
+++ b/backends/gaudi/server/Makefile-flash-att-v2
@@ -0,0 +1,21 @@
+flash_att_v2_commit_cuda := v2.6.1
+flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4
+
+build-flash-attention-v2-cuda:
+ pip install -U packaging wheel
+ pip install flash-attn==$(flash_att_v2_commit_cuda)
+
+install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
+ echo "Flash v2 installed"
+
+build-flash-attention-v2-rocm:
+ if [ ! -d 'flash-attention-v2' ]; then \
+ pip install -U packaging ninja --no-cache-dir && \
+ git clone https://github.com/mht-sharma/flash-attention.git flash-attention-v2 && \
+ cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \
+ git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \
+ fi
+
+install-flash-attention-v2-rocm: build-flash-attention-v2-rocm
+ cd flash-attention-v2 && \
+ GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install
diff --git a/backends/gaudi/server/Makefile-selective-scan b/backends/gaudi/server/Makefile-selective-scan
new file mode 100644
index 000000000..b93b517d6
--- /dev/null
+++ b/backends/gaudi/server/Makefile-selective-scan
@@ -0,0 +1,28 @@
+selective_scan_commit := 2a3704fd47ba817b415627b06fd796b971fdc137
+
+causal-conv1d:
+ rm -rf causal-conv1d
+ git clone https://github.com/Dao-AILab/causal-conv1d.git
+
+build-causal-conv1d: causal-conv1d
+ cd causal-conv1d/ && git checkout v1.1.1 # known latest working version tag
+ cd causal-conv1d/ && CAUSAL_CONV1D_FORCE_BUILD=TRUE python setup.py build
+
+install-causal-conv1d: build-causal-conv1d
+ pip uninstall causal-conv1d -y || true
+ cd causal-conv1d/ && pip install .
+
+# selective-scan dependends on causal-conv1d
+selective-scan:
+ rm -rf mamba
+ git clone https://github.com/state-spaces/mamba.git mamba
+
+build-selective-scan: selective-scan
+ cd mamba/ && git fetch && git checkout $(selective_scan_commit)
+ cd mamba && python setup.py build
+
+install-selective-scan: install-causal-conv1d build-selective-scan
+ pip uninstall selective-scan-cuda -y || true
+ cd mamba && pip install .
+
+build-all: build-causal-conv1d build-selective-scan
diff --git a/backends/gaudi/server/Makefile-vllm b/backends/gaudi/server/Makefile-vllm
new file mode 100644
index 000000000..18dcc4a0c
--- /dev/null
+++ b/backends/gaudi/server/Makefile-vllm
@@ -0,0 +1,23 @@
+commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
+commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247
+build-vllm-cuda:
+ if [ ! -d 'vllm' ]; then \
+ pip install -U ninja packaging --no-cache-dir && \
+ git clone https://github.com/Narsil/vllm.git vllm; \
+ fi
+ cd vllm && git fetch origin && git checkout $(commit_cuda) && python setup.py build
+
+install-vllm-cuda: build-vllm-cuda
+ cd vllm && git fetch origin && git checkout $(commit_cuda) && pip install -e .
+
+build-vllm-rocm:
+ if [ ! -d 'vllm' ]; then \
+ pip install -U ninja packaging --no-cache-dir && \
+ git clone https://github.com/mht-sharma/vllm.git vllm; \
+ fi
+ cd vllm && git fetch && git checkout $(commit_rocm) && \
+ PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
+
+install-vllm-rocm: build-vllm-rocm
+ cd vllm && git fetch && git checkout $(commit_rocm) && \
+ PYTORCH_ROCM_ARCH="gfx90a;gfx942" pip install -e .
diff --git a/backends/gaudi/server/README.md b/backends/gaudi/server/README.md
new file mode 100644
index 000000000..b8208f9ea
--- /dev/null
+++ b/backends/gaudi/server/README.md
@@ -0,0 +1,15 @@
+# Text Generation Inference Python gRPC Server
+
+A Python gRPC server for Text Generation Inference
+
+## Install
+
+```shell
+make install
+```
+
+## Run
+
+```shell
+make run-dev
+```
diff --git a/backends/gaudi/server/dill-0.3.7-patch.sh b/backends/gaudi/server/dill-0.3.7-patch.sh
new file mode 100644
index 000000000..5efd6c54b
--- /dev/null
+++ b/backends/gaudi/server/dill-0.3.7-patch.sh
@@ -0,0 +1,91 @@
+#!/bin/bash
+git clone -b dill-0.3.7 https://github.com/uqfoundation/dill.git
+pushd dill
+cat < dill-0.3.7.patch
+diff --git a/dill/_dill.py b/dill/_dill.py
+index d0cf543..f6eb662 100644
+--- a/dill/_dill.py
++++ b/dill/_dill.py
+@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered
+ XRangeType = range
+ from types import MappingProxyType as DictProxyType, new_class
+ from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError
+-import __main__ as _main_module
++class _LazyMainModule(object):
++ _module = None
++ @property
++ def module(self):
++ if self._module is None:
++ import __main__ as _m_module
++ self._module = _m_module
++ return self._module
++_main_module = _LazyMainModule()
+ import marshal
+ import gc
+ # import zlib
+@@ -353,7 +361,7 @@ class Pickler(StockPickler):
+ _fmode = kwds.pop('fmode', None)
+ _recurse = kwds.pop('recurse', None)
+ StockPickler.__init__(self, file, *args, **kwds)
+- self._main = _main_module
++ self._main = _main_module.module
+ self._diff_cache = {}
+ self._byref = settings['byref'] if _byref is None else _byref
+ self._strictio = False #_strictio
+@@ -435,12 +443,12 @@ class Unpickler(StockUnpickler):
+ settings = Pickler.settings
+ _ignore = kwds.pop('ignore', None)
+ StockUnpickler.__init__(self, *args, **kwds)
+- self._main = _main_module
++ self._main = _main_module.module
+ self._ignore = settings['ignore'] if _ignore is None else _ignore
+
+ def load(self): #NOTE: if settings change, need to update attributes
+ obj = StockUnpickler.load(self)
+- if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'):
++ if type(obj).__module__ == getattr(self._main, '__name__', '__main__'):
+ if not self._ignore:
+ # point obj class to main
+ try: obj.__class__ = getattr(self._main, type(obj).__name__)
+@@ -1194,11 +1202,11 @@ def save_module_dict(pickler, obj):
+ logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj
+ pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8'))
+ logger.trace(pickler, "# D1")
+- elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):
++ elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__):
+ logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj
+ pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general?
+ logger.trace(pickler, "# D3")
+- elif '__name__' in obj and obj != _main_module.__dict__ \\
++ elif '__name__' in obj and obj != _main_module.module.__dict__ \\
+ and type(obj['__name__']) is str \\
+ and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None):
+ logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj
+diff --git a/dill/session.py b/dill/session.py
+index 74234ab..1be8d89 100644
+--- a/dill/session.py
++++ b/dill/session.py
+@@ -233,7 +233,7 @@ def dump_module(
+ protocol = settings['protocol']
+ main = module
+ if main is None:
+- main = _main_module
++ main = _main_module.module
+ elif isinstance(main, str):
+ main = _import_module(main)
+ if not isinstance(main, ModuleType):
+@@ -501,7 +501,7 @@ def load_module(
+ pass
+ assert loaded is main
+ _restore_modules(unpickler, main)
+- if main is _main_module or main is module:
++ if main is _main_module.module or main is module:
+ return None
+ else:
+ return main
+
+EOF
+git apply dill-0.3.7.patch
+python -m pip install .
+popd
+rm -fr dill
diff --git a/backends/gaudi/server/dill-0.3.8-patch.sh b/backends/gaudi/server/dill-0.3.8-patch.sh
new file mode 100644
index 000000000..414790e7b
--- /dev/null
+++ b/backends/gaudi/server/dill-0.3.8-patch.sh
@@ -0,0 +1,91 @@
+#!/bin/bash
+git clone -b 0.3.8 https://github.com/uqfoundation/dill.git
+pushd dill
+cat < dill-0.3.8.patch
+diff --git a/dill/_dill.py b/dill/_dill.py
+index d42432f..1d251e6 100644
+--- a/dill/_dill.py
++++ b/dill/_dill.py
+@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered
+ XRangeType = range
+ from types import MappingProxyType as DictProxyType, new_class
+ from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError
+-import __main__ as _main_module
++class _LazyMainModule(object):
++ _module = None
++ @property
++ def module(self):
++ if self._module is None:
++ import __main__ as _m_module
++ self._module = _m_module
++ return self._module
++_main_module = _LazyMainModule()
+ import marshal
+ import gc
+ # import zlib
+@@ -355,7 +363,7 @@ class Pickler(StockPickler):
+ _fmode = kwds.pop('fmode', None)
+ _recurse = kwds.pop('recurse', None)
+ StockPickler.__init__(self, file, *args, **kwds)
+- self._main = _main_module
++ self._main = _main_module.module
+ self._diff_cache = {}
+ self._byref = settings['byref'] if _byref is None else _byref
+ self._strictio = False #_strictio
+@@ -437,12 +445,12 @@ class Unpickler(StockUnpickler):
+ settings = Pickler.settings
+ _ignore = kwds.pop('ignore', None)
+ StockUnpickler.__init__(self, *args, **kwds)
+- self._main = _main_module
++ self._main = _main_module.module
+ self._ignore = settings['ignore'] if _ignore is None else _ignore
+
+ def load(self): #NOTE: if settings change, need to update attributes
+ obj = StockUnpickler.load(self)
+- if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'):
++ if type(obj).__module__ == getattr(self._main, '__name__', '__main__'):
+ if not self._ignore:
+ # point obj class to main
+ try: obj.__class__ = getattr(self._main, type(obj).__name__)
+@@ -1199,11 +1207,11 @@ def save_module_dict(pickler, obj):
+ logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj
+ pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8'))
+ logger.trace(pickler, "# D1")
+- elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):
++ elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__):
+ logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj
+ pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general?
+ logger.trace(pickler, "# D3")
+- elif '__name__' in obj and obj != _main_module.__dict__ \\
++ elif '__name__' in obj and obj != _main_module.module.__dict__ \\
+ and type(obj['__name__']) is str \\
+ and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None):
+ logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj
+diff --git a/dill/session.py b/dill/session.py
+index e91068a..a921b43 100644
+--- a/dill/session.py
++++ b/dill/session.py
+@@ -233,7 +233,7 @@ def dump_module(
+ protocol = settings['protocol']
+ main = module
+ if main is None:
+- main = _main_module
++ main = _main_module.module
+ elif isinstance(main, str):
+ main = _import_module(main)
+ if not isinstance(main, ModuleType):
+@@ -501,7 +501,7 @@ def load_module(
+ pass
+ assert loaded is main
+ _restore_modules(unpickler, main)
+- if main is _main_module or main is module:
++ if main is _main_module.module or main is module:
+ return None
+ else:
+ return main
+
+EOF
+git apply dill-0.3.8.patch
+python -m pip install .
+popd
+rm -fr dill
diff --git a/backends/gaudi/server/integration-tests/capture_expected_outputs.py b/backends/gaudi/server/integration-tests/capture_expected_outputs.py
new file mode 100644
index 000000000..051b9d698
--- /dev/null
+++ b/backends/gaudi/server/integration-tests/capture_expected_outputs.py
@@ -0,0 +1,85 @@
+import json
+import os
+from typing import Dict, Any, Generator
+
+import pytest
+from test_model import TEST_CONFIGS
+
+UNKNOWN_CONFIGS = {
+ name: config
+ for name, config in TEST_CONFIGS.items()
+ if config["expected_greedy_output"] == "unknown"
+ or config["expected_batch_output"] == "unknown"
+}
+
+
+@pytest.fixture(scope="module", params=UNKNOWN_CONFIGS.keys())
+def test_config(request) -> Dict[str, Any]:
+ """Fixture that provides model configurations for testing."""
+ test_config = UNKNOWN_CONFIGS[request.param]
+ test_config["test_name"] = request.param
+ return test_config
+
+
+@pytest.fixture(scope="module")
+def test_name(test_config):
+ yield test_config["test_name"]
+
+
+@pytest.fixture(scope="module")
+def tgi_service(launcher, test_config, test_name) -> Generator:
+ """Fixture that provides a TGI service for testing."""
+ with launcher(test_config["model_id"], test_name) as service:
+ yield service
+
+
+@pytest.mark.asyncio
+async def test_capture_expected_outputs(tgi_service, test_config, test_name):
+ """Test that captures expected outputs for models with unknown outputs."""
+ print(f"Testing {test_name} with {test_config['model_id']}")
+
+ # Wait for service to be ready
+ await tgi_service.health(1000)
+ client = tgi_service.client
+
+ # Test single request (greedy)
+ print("Testing single request...")
+ response = await client.generate(
+ test_config["input"],
+ max_new_tokens=32,
+ )
+ greedy_output = response.generated_text
+
+ # Test multiple requests (batch)
+ print("Testing batch requests...")
+ responses = []
+ for _ in range(4):
+ response = await client.generate(
+ test_config["input"],
+ max_new_tokens=32,
+ )
+ responses.append(response.generated_text)
+
+ # Store results in a JSON file
+ output_file = "server/integration-tests/expected_outputs.json"
+ results = {}
+
+ # Try to load existing results if file exists
+ if os.path.exists(output_file):
+ with open(output_file, "r") as f:
+ results = json.load(f)
+
+ # Update results for this model
+ results[test_name] = {
+ "model_id": test_config["model_id"],
+ "input": test_config["input"],
+ "greedy_output": greedy_output,
+ "batch_outputs": responses,
+ "args": test_config["args"],
+ }
+
+ # Save updated results
+ with open(output_file, "w") as f:
+ json.dump(results, f, indent=2)
+
+ print(f"\nResults for {test_name} saved to {output_file}")
diff --git a/backends/gaudi/server/integration-tests/conftest.py b/backends/gaudi/server/integration-tests/conftest.py
new file mode 100644
index 000000000..c7daf70e0
--- /dev/null
+++ b/backends/gaudi/server/integration-tests/conftest.py
@@ -0,0 +1,292 @@
+import asyncio
+import contextlib
+import os
+import shlex
+import subprocess
+import sys
+import threading
+import time
+from tempfile import TemporaryDirectory
+from typing import List
+import socket
+
+import docker
+import pytest
+from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
+from docker.errors import NotFound
+from loguru import logger
+from test_model import TEST_CONFIGS
+from text_generation import AsyncClient
+from text_generation.types import Response
+
+# Use the latest image from the local docker build
+DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", "tgi-gaudi")
+DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", None)
+HF_TOKEN = os.getenv("HF_TOKEN", None)
+
+assert (
+ HF_TOKEN is not None
+), "HF_TOKEN is not set, please set it as some models are gated and thus the test will fail without it"
+
+if DOCKER_VOLUME is None:
+ logger.warning(
+ "DOCKER_VOLUME is not set, this will lead to the tests redownloading the models on each run, consider setting it to speed up testing"
+ )
+
+LOG_LEVEL = os.getenv("LOG_LEVEL", "info")
+
+BASE_ENV = {
+ "HF_HUB_ENABLE_HF_TRANSFER": "1",
+ "LOG_LEVEL": LOG_LEVEL,
+ "HF_TOKEN": os.getenv("HF_TOKEN", None),
+}
+
+
+HABANA_RUN_ARGS = {
+ "runtime": "habana",
+ "ipc_mode": "host",
+ "cap_add": ["sys_nice"],
+}
+
+logger.add(
+ sys.stderr,
+ format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}",
+ level="INFO",
+)
+
+
+def stream_container_logs(container, test_name):
+ """Stream container logs in a separate thread."""
+ try:
+ for log in container.logs(stream=True, follow=True):
+ print(
+ f"[TGI Server Logs - {test_name}] {log.decode('utf-8')}",
+ end="",
+ file=sys.stderr,
+ flush=True,
+ )
+ except Exception as e:
+ logger.error(f"Error streaming container logs: {str(e)}")
+
+
+class LauncherHandle:
+ def __init__(self, port: int):
+ self.client = AsyncClient(f"http://localhost:{port}", timeout=3600)
+
+ def _inner_health(self):
+ raise NotImplementedError
+
+ async def health(self, timeout: int = 60):
+ assert timeout > 0
+ start_time = time.time()
+ logger.info(f"Starting health check with timeout of {timeout}s")
+
+ for attempt in range(timeout):
+ if not self._inner_health():
+ logger.error("Launcher crashed during health check")
+ raise RuntimeError("Launcher crashed")
+
+ try:
+ await self.client.generate("test")
+ elapsed = time.time() - start_time
+ logger.info(f"Health check passed after {elapsed:.1f}s")
+ return
+ except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e:
+ if attempt == timeout - 1:
+ logger.error(f"Health check failed after {timeout}s: {str(e)}")
+ raise RuntimeError(f"Health check failed: {str(e)}")
+ if attempt % 10 == 0 and attempt != 0: # Only log every 10th attempt
+ logger.debug(
+ f"Connection attempt {attempt}/{timeout} failed: {str(e)}"
+ )
+ time.sleep(1)
+ except Exception as e:
+ logger.error(f"Unexpected error during health check: {str(e)}")
+ # Get full traceback for debugging
+ import traceback
+
+ logger.error(f"Full traceback:\n{traceback.format_exc()}")
+ raise
+
+
+class ContainerLauncherHandle(LauncherHandle):
+ def __init__(self, docker_client, container_name, port: int):
+ super(ContainerLauncherHandle, self).__init__(port)
+ self.docker_client = docker_client
+ self.container_name = container_name
+
+ def _inner_health(self) -> bool:
+ try:
+ container = self.docker_client.containers.get(self.container_name)
+ status = container.status
+ if status not in ["running", "created"]:
+ logger.warning(f"Container status is {status}")
+ # Get container logs for debugging
+ logs = container.logs().decode("utf-8")
+ logger.debug(f"Container logs:\n{logs}")
+ return status in ["running", "created"]
+ except Exception as e:
+ logger.error(f"Error checking container health: {str(e)}")
+ return False
+
+
+class ProcessLauncherHandle(LauncherHandle):
+ def __init__(self, process, port: int):
+ super(ProcessLauncherHandle, self).__init__(port)
+ self.process = process
+
+ def _inner_health(self) -> bool:
+ return self.process.poll() is None
+
+
+@pytest.fixture(scope="module")
+def data_volume():
+ tmpdir = TemporaryDirectory()
+ yield tmpdir.name
+ try:
+ # Cleanup the temporary directory using sudo as it contains root files created by the container
+ subprocess.run(shlex.split(f"sudo rm -rf {tmpdir.name}"), check=True)
+ except subprocess.CalledProcessError as e:
+ logger.error(f"Error cleaning up temporary directory: {str(e)}")
+
+
+@pytest.fixture(scope="module")
+def launcher(data_volume):
+ @contextlib.contextmanager
+ def docker_launcher(
+ model_id: str,
+ test_name: str,
+ ):
+ logger.info(
+ f"Starting docker launcher for model {model_id} and test {test_name}"
+ )
+
+ # Get a random available port
+ def get_free_port():
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind(("", 0))
+ s.listen(1)
+ port = s.getsockname()[1]
+ return port
+
+ port = get_free_port()
+ logger.debug(f"Using port {port}")
+
+ client = docker.from_env()
+
+ container_name = f"tgi-gaudi-test-{test_name.replace('/', '-')}"
+
+ try:
+ container = client.containers.get(container_name)
+ logger.info(
+ f"Stopping existing container {container_name} for test {test_name}"
+ )
+ container.stop()
+ container.wait()
+ except NotFound:
+ pass
+ except Exception as e:
+ logger.error(f"Error handling existing container: {str(e)}")
+
+ model_name = next(
+ name for name, cfg in TEST_CONFIGS.items() if cfg["model_id"] == model_id
+ )
+
+ tgi_args = TEST_CONFIGS[model_name]["args"].copy()
+
+ env = BASE_ENV.copy()
+
+ # Add model_id to env
+ env["MODEL_ID"] = model_id
+
+ # Add env config that is definied in the fixture parameter
+ if "env_config" in TEST_CONFIGS[model_name]:
+ env.update(TEST_CONFIGS[model_name]["env_config"].copy())
+
+ volumes = [f"{DOCKER_VOLUME}:/data"]
+ logger.debug(f"Using volume {volumes}")
+
+ try:
+ logger.info(f"Creating container with name {container_name}")
+
+ # Log equivalent docker run command for debugging, this is not actually executed
+ container = client.containers.run(
+ DOCKER_IMAGE,
+ command=tgi_args,
+ name=container_name,
+ environment=env,
+ detach=True,
+ volumes=volumes,
+ ports={"80/tcp": port},
+ **HABANA_RUN_ARGS,
+ )
+
+ logger.info(f"Container {container_name} started successfully")
+
+ # Start log streaming in a background thread
+ log_thread = threading.Thread(
+ target=stream_container_logs,
+ args=(container, test_name),
+ daemon=True, # This ensures the thread will be killed when the main program exits
+ )
+ log_thread.start()
+
+ # Add a small delay to allow container to initialize
+ time.sleep(2)
+
+ # Check container status after creation
+ status = container.status
+ logger.debug(f"Initial container status: {status}")
+ if status not in ["running", "created"]:
+ logs = container.logs().decode("utf-8")
+ logger.error(f"Container failed to start properly. Logs:\n{logs}")
+
+ yield ContainerLauncherHandle(client, container.name, port)
+
+ except Exception as e:
+ logger.error(f"Error starting container: {str(e)}")
+ # Get full traceback for debugging
+ import traceback
+
+ logger.error(f"Full traceback:\n{traceback.format_exc()}")
+ raise
+ finally:
+ try:
+ container = client.containers.get(container_name)
+ logger.info(f"Stopping container {container_name}")
+ container.stop()
+ container.wait()
+
+ container_output = container.logs().decode("utf-8")
+ print(container_output, file=sys.stderr)
+
+ container.remove()
+ logger.info(f"Container {container_name} removed successfully")
+ except NotFound:
+ pass
+ except Exception as e:
+ logger.warning(f"Error cleaning up container: {str(e)}")
+
+ return docker_launcher
+
+
+@pytest.fixture(scope="module")
+def generate_load():
+ async def generate_load_inner(
+ client: AsyncClient, prompt: str, max_new_tokens: int, n: int
+ ) -> List[Response]:
+ try:
+ futures = [
+ client.generate(
+ prompt,
+ max_new_tokens=max_new_tokens,
+ decoder_input_details=True,
+ )
+ for _ in range(n)
+ ]
+ return await asyncio.gather(*futures)
+ except Exception as e:
+ logger.error(f"Error generating load: {str(e)}")
+ raise
+
+ return generate_load_inner
diff --git a/backends/gaudi/server/integration-tests/pytest.ini b/backends/gaudi/server/integration-tests/pytest.ini
new file mode 100644
index 000000000..2f4c80e30
--- /dev/null
+++ b/backends/gaudi/server/integration-tests/pytest.ini
@@ -0,0 +1,2 @@
+[pytest]
+asyncio_mode = auto
diff --git a/backends/gaudi/server/integration-tests/requirements.txt b/backends/gaudi/server/integration-tests/requirements.txt
new file mode 100644
index 000000000..b67d2d8cc
--- /dev/null
+++ b/backends/gaudi/server/integration-tests/requirements.txt
@@ -0,0 +1,7 @@
+pytest >= 8.3.5
+pytest-asyncio >= 0.26.0
+docker >= 7.1.0
+Levenshtein >= 0.27.1
+loguru >= 0.7.3
+aiohttp >= 3.11.14
+text-generation
diff --git a/backends/gaudi/server/integration-tests/test_model.py b/backends/gaudi/server/integration-tests/test_model.py
new file mode 100644
index 000000000..cb2bf6a9f
--- /dev/null
+++ b/backends/gaudi/server/integration-tests/test_model.py
@@ -0,0 +1,276 @@
+from typing import Any, Dict
+
+from text_generation import AsyncClient
+import pytest
+from Levenshtein import distance as levenshtein_distance
+
+# The "args" config is not optimized for speed but only check that the inference is working for the different models architectures
+TEST_CONFIGS = {
+ "meta-llama/Llama-3.1-8B-Instruct-shared": {
+ "model_id": "meta-llama/Llama-3.1-8B-Instruct",
+ "input": "What is Deep Learning?",
+ "expected_greedy_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use",
+ "expected_batch_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use",
+ "args": [
+ "--sharded",
+ "true",
+ "--num-shard",
+ "8",
+ "--max-input-tokens",
+ "512",
+ "--max-total-tokens",
+ "1024",
+ "--max-batch-size",
+ "8",
+ "--max-batch-prefill-tokens",
+ "2048",
+ ],
+ },
+ "meta-llama/Llama-3.1-8B-Instruct": {
+ "model_id": "meta-llama/Llama-3.1-8B-Instruct",
+ "input": "What is Deep Learning?",
+ "expected_greedy_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of",
+ "expected_batch_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of",
+ "env_config": {},
+ "args": [
+ "--max-input-tokens",
+ "512",
+ "--max-total-tokens",
+ "1024",
+ "--max-batch-size",
+ "4",
+ "--max-batch-prefill-tokens",
+ "2048",
+ ],
+ },
+ "meta-llama/Llama-2-7b-chat-hf": {
+ "model_id": "meta-llama/Llama-2-7b-chat-hf",
+ "input": "What is Deep Learning?",
+ "expected_greedy_output": "\n\nDeep learning (also known as deep structured learning) is part of a broader family of machine learning techniques based on artificial neural networks\u2014specific",
+ "expected_batch_output": "\n\nDeep learning (also known as deep structured learning) is part of a broader family of machine learning techniques based on artificial neural networks\u2014specific",
+ "args": [
+ "--max-input-tokens",
+ "512",
+ "--max-total-tokens",
+ "1024",
+ "--max-batch-size",
+ "4",
+ "--max-batch-prefill-tokens",
+ "2048",
+ ],
+ },
+ "mistralai/Mistral-7B-Instruct-v0.3": {
+ "model_id": "mistralai/Mistral-7B-Instruct-v0.3",
+ "input": "What is Deep Learning?",
+ "expected_greedy_output": "\n\nDeep learning is a subset of machine learning in artificial intelligence (AI) that has networks capable of learning unsupervised from data that is unstructured",
+ "expected_batch_output": "\n\nDeep learning is a subset of machine learning in artificial intelligence (AI) that has networks capable of learning unsupervised from data that is unstructured",
+ "args": [
+ "--max-input-tokens",
+ "512",
+ "--max-total-tokens",
+ "1024",
+ "--max-batch-size",
+ "4",
+ "--max-batch-prefill-tokens",
+ "2048",
+ ],
+ },
+ "bigcode/starcoder2-3b": {
+ "model_id": "bigcode/starcoder2-3b",
+ "input": "What is Deep Learning?",
+ "expected_greedy_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to perform tasks.\n\nNeural networks are a type of machine learning algorithm that",
+ "expected_batch_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to perform tasks.\n\nNeural networks are a type of machine learning algorithm that",
+ "args": [
+ "--max-input-tokens",
+ "512",
+ "--max-total-tokens",
+ "1024",
+ "--max-batch-size",
+ "4",
+ "--max-batch-prefill-tokens",
+ "2048",
+ ],
+ },
+ "google/gemma-7b-it": {
+ "model_id": "google/gemma-7b-it",
+ "input": "What is Deep Learning?",
+ "expected_greedy_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to learn from large amounts of data. Neural networks are inspired by the structure and function of",
+ "expected_batch_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to learn from large amounts of data. Neural networks are inspired by the structure and function of",
+ "args": [
+ "--max-input-tokens",
+ "512",
+ "--max-total-tokens",
+ "1024",
+ "--max-batch-size",
+ "4",
+ "--max-batch-prefill-tokens",
+ "2048",
+ ],
+ },
+ "Qwen/Qwen2-0.5B-Instruct": {
+ "model_id": "Qwen/Qwen2-0.5B-Instruct",
+ "input": "What is Deep Learning?",
+ "expected_greedy_output": " Deep Learning is a type of machine learning that is based on the principles of artificial neural networks. It is a type of machine learning that is used to train models",
+ "expected_batch_output": " Deep Learning is a type of machine learning that is based on the principles of artificial neural networks. It is a type of machine learning that is used to train models",
+ "args": [
+ "--max-input-tokens",
+ "512",
+ "--max-total-tokens",
+ "1024",
+ "--max-batch-size",
+ "4",
+ "--max-batch-prefill-tokens",
+ "2048",
+ ],
+ },
+ "tiiuae/falcon-7b-instruct": {
+ "model_id": "tiiuae/falcon-7b-instruct",
+ "input": "What is Deep Learning?",
+ "expected_greedy_output": "\nDeep learning is a branch of machine learning that uses artificial neural networks to learn and make decisions. It is based on the concept of hierarchical learning, where a",
+ "expected_batch_output": "\nDeep learning is a branch of machine learning that uses artificial neural networks to learn and make decisions. It is based on the concept of hierarchical learning, where a",
+ "args": [
+ "--max-input-tokens",
+ "512",
+ "--max-total-tokens",
+ "1024",
+ "--max-batch-size",
+ "4",
+ ],
+ },
+ "microsoft/phi-1_5": {
+ "model_id": "microsoft/phi-1_5",
+ "input": "What is Deep Learning?",
+ "expected_greedy_output": "\n\nDeep Learning is a subfield of Machine Learning that focuses on building neural networks with multiple layers of interconnected nodes. These networks are designed to learn from large",
+ "expected_batch_output": "\n\nDeep Learning is a subfield of Machine Learning that focuses on building neural networks with multiple layers of interconnected nodes. These networks are designed to learn from large",
+ "args": [
+ "--max-input-tokens",
+ "512",
+ "--max-total-tokens",
+ "1024",
+ "--max-batch-size",
+ "4",
+ ],
+ },
+ "openai-community/gpt2": {
+ "model_id": "openai-community/gpt2",
+ "input": "What is Deep Learning?",
+ "expected_greedy_output": "\n\nDeep learning is a new field of research that has been around for a long time. It is a new field of research that has been around for a",
+ "expected_batch_output": "\n\nDeep learning is a new field of research that has been around for a long time. It is a new field of research that has been around for a",
+ "args": [
+ "--max-input-tokens",
+ "512",
+ "--max-total-tokens",
+ "1024",
+ "--max-batch-size",
+ "4",
+ ],
+ },
+ "facebook/opt-125m": {
+ "model_id": "facebook/opt-125m",
+ "input": "What is Deep Learning?",
+ "expected_greedy_output": "\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout",
+ "expected_batch_output": "\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout the Author\n\nAbout",
+ "args": [
+ "--max-input-tokens",
+ "512",
+ "--max-total-tokens",
+ "1024",
+ "--max-batch-size",
+ "4",
+ ],
+ },
+ "EleutherAI/gpt-j-6b": {
+ "model_id": "EleutherAI/gpt-j-6b",
+ "input": "What is Deep Learning?",
+ "expected_greedy_output": "\n\nDeep learning is a subset of machine learning that is based on the idea of neural networks. Neural networks are a type of artificial intelligence that is inspired by",
+ "expected_batch_output": "\n\nDeep learning is a subset of machine learning that is based on the idea of neural networks. Neural networks are a type of artificial intelligence that is inspired by",
+ "args": [
+ "--max-input-tokens",
+ "512",
+ "--max-total-tokens",
+ "1024",
+ "--max-batch-size",
+ "4",
+ ],
+ },
+}
+
+print(f"Testing {len(TEST_CONFIGS)} models")
+
+
+@pytest.fixture(scope="module", params=TEST_CONFIGS.keys())
+def test_config(request) -> Dict[str, Any]:
+ """Fixture that provides model configurations for testing."""
+ test_config = TEST_CONFIGS[request.param]
+ test_config["test_name"] = request.param
+ return test_config
+
+
+@pytest.fixture(scope="module")
+def model_id(test_config):
+ yield test_config["model_id"]
+
+
+@pytest.fixture(scope="module")
+def test_name(test_config):
+ yield test_config["test_name"]
+
+
+@pytest.fixture(scope="module")
+def expected_outputs(test_config):
+ return {
+ "greedy": test_config["expected_greedy_output"],
+ # "sampling": model_config["expected_sampling_output"],
+ "batch": test_config["expected_batch_output"],
+ }
+
+
+@pytest.fixture(scope="module")
+def input(test_config):
+ return test_config["input"]
+
+
+@pytest.fixture(scope="module")
+def tgi_service(launcher, model_id, test_name):
+ with launcher(model_id, test_name) as tgi_service:
+ yield tgi_service
+
+
+@pytest.fixture(scope="module")
+async def tgi_client(tgi_service) -> AsyncClient:
+ await tgi_service.health(1000)
+ return tgi_service.client
+
+
+@pytest.mark.asyncio
+async def test_model_single_request(
+ tgi_client: AsyncClient, expected_outputs: Dict[str, Any], input: str
+):
+ # Bounded greedy decoding without input
+ response = await tgi_client.generate(
+ input,
+ max_new_tokens=32,
+ )
+ assert response.details.generated_tokens == 32
+ assert response.generated_text == expected_outputs["greedy"]
+
+
+@pytest.mark.asyncio
+async def test_model_multiple_requests(
+ tgi_client, generate_load, expected_outputs, input
+):
+ num_requests = 4
+ responses = await generate_load(
+ tgi_client,
+ input,
+ max_new_tokens=32,
+ n=num_requests,
+ )
+
+ assert len(responses) == 4
+ expected = expected_outputs["batch"]
+ for r in responses:
+ assert r.details.generated_tokens == 32
+ # Compute the similarity with the expectation using the levenshtein distance
+ # We should not have more than two substitutions or additions
+ assert levenshtein_distance(r.generated_text, expected) < 3
diff --git a/backends/gaudi/server/poetry.lock b/backends/gaudi/server/poetry.lock
new file mode 100644
index 000000000..b9b2e1388
--- /dev/null
+++ b/backends/gaudi/server/poetry.lock
@@ -0,0 +1,3014 @@
+# This file is automatically @generated by Poetry 2.0.0 and should not be changed by hand.
+
+[[package]]
+name = "accelerate"
+version = "0.33.0"
+description = "Accelerate"
+optional = false
+python-versions = ">=3.8.0"
+groups = ["main"]
+files = [
+ {file = "accelerate-0.33.0-py3-none-any.whl", hash = "sha256:0a7f33d60ba09afabd028d4f0856dd19c5a734b7a596d637d9dd6e3d0eadbaf3"},
+ {file = "accelerate-0.33.0.tar.gz", hash = "sha256:11ba481ed6ea09191775df55ce464aeeba67a024bd0261a44b77b30fb439e26a"},
+]
+
+[package.dependencies]
+huggingface-hub = ">=0.21.0"
+numpy = ">=1.17,<2.0.0"
+packaging = ">=20.0"
+psutil = "*"
+pyyaml = "*"
+safetensors = ">=0.3.1"
+torch = ">=1.10.0"
+
+[package.extras]
+deepspeed = ["deepspeed (<=0.14.0)"]
+dev = ["bitsandbytes", "black (>=23.1,<24.0)", "datasets", "diffusers", "evaluate", "hf-doc-builder (>=0.3.0)", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "rich", "ruff (>=0.2.1,<0.3.0)", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"]
+quality = ["black (>=23.1,<24.0)", "hf-doc-builder (>=0.3.0)", "ruff (>=0.2.1,<0.3.0)"]
+rich = ["rich"]
+sagemaker = ["sagemaker"]
+test-dev = ["bitsandbytes", "datasets", "diffusers", "evaluate", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"]
+test-prod = ["parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist"]
+test-trackers = ["comet-ml", "dvclive", "tensorboard", "wandb"]
+testing = ["bitsandbytes", "datasets", "diffusers", "evaluate", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"]
+
+[[package]]
+name = "annotated-types"
+version = "0.7.0"
+description = "Reusable constraint types to use with typing.Annotated"
+optional = true
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"},
+ {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
+]
+
+[[package]]
+name = "attrs"
+version = "25.3.0"
+description = "Classes Without Boilerplate"
+optional = true
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3"},
+ {file = "attrs-25.3.0.tar.gz", hash = "sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b"},
+]
+
+[package.extras]
+benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
+cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
+dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit-uv", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
+docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier"]
+tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
+tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"]
+
+[[package]]
+name = "certifi"
+version = "2025.1.31"
+description = "Python package for providing Mozilla's CA Bundle."
+optional = false
+python-versions = ">=3.6"
+groups = ["main"]
+files = [
+ {file = "certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe"},
+ {file = "certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651"},
+]
+
+[[package]]
+name = "charset-normalizer"
+version = "3.4.1"
+description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "charset_normalizer-3.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:91b36a978b5ae0ee86c394f5a54d6ef44db1de0815eb43de826d41d21e4af3de"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7461baadb4dc00fd9e0acbe254e3d7d2112e7f92ced2adc96e54ef6501c5f176"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e218488cd232553829be0664c2292d3af2eeeb94b32bea483cf79ac6a694e037"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80ed5e856eb7f30115aaf94e4a08114ccc8813e6ed1b5efa74f9f82e8509858f"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b010a7a4fd316c3c484d482922d13044979e78d1861f0e0650423144c616a46a"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4532bff1b8421fd0a320463030c7520f56a79c9024a4e88f01c537316019005a"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d973f03c0cb71c5ed99037b870f2be986c3c05e63622c017ea9816881d2dd247"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3a3bd0dcd373514dcec91c411ddb9632c0d7d92aed7093b8c3bbb6d69ca74408"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:d9c3cdf5390dcd29aa8056d13e8e99526cda0305acc038b96b30352aff5ff2bb"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:2bdfe3ac2e1bbe5b59a1a63721eb3b95fc9b6817ae4a46debbb4e11f6232428d"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:eab677309cdb30d047996b36d34caeda1dc91149e4fdca0b1a039b3f79d9a807"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-win32.whl", hash = "sha256:c0429126cf75e16c4f0ad00ee0eae4242dc652290f940152ca8c75c3a4b6ee8f"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:9f0b8b1c6d84c8034a44893aba5e767bf9c7a211e313a9605d9c617d7083829f"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8bfa33f4f2672964266e940dd22a195989ba31669bd84629f05fab3ef4e2d125"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28bf57629c75e810b6ae989f03c0828d64d6b26a5e205535585f96093e405ed1"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f08ff5e948271dc7e18a35641d2f11a4cd8dfd5634f55228b691e62b37125eb3"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:234ac59ea147c59ee4da87a0c0f098e9c8d169f4dc2a159ef720f1a61bbe27cd"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd4ec41f914fa74ad1b8304bbc634b3de73d2a0889bd32076342a573e0779e00"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eea6ee1db730b3483adf394ea72f808b6e18cf3cb6454b4d86e04fa8c4327a12"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c96836c97b1238e9c9e3fe90844c947d5afbf4f4c92762679acfe19927d81d77"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4d86f7aff21ee58f26dcf5ae81a9addbd914115cdebcbb2217e4f0ed8982e146"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:09b5e6733cbd160dcc09589227187e242a30a49ca5cefa5a7edd3f9d19ed53fd"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:5777ee0881f9499ed0f71cc82cf873d9a0ca8af166dfa0af8ec4e675b7df48e6"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:237bdbe6159cff53b4f24f397d43c6336c6b0b42affbe857970cefbb620911c8"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-win32.whl", hash = "sha256:8417cb1f36cc0bc7eaba8ccb0e04d55f0ee52df06df3ad55259b9a323555fc8b"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:d7f50a1f8c450f3925cb367d011448c39239bb3eb4117c36a6d354794de4ce76"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:73d94b58ec7fecbc7366247d3b0b10a21681004153238750bb67bd9012414545"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dad3e487649f498dd991eeb901125411559b22e8d7ab25d3aeb1af367df5efd7"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c30197aa96e8eed02200a83fba2657b4c3acd0f0aa4bdc9f6c1af8e8962e0757"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2369eea1ee4a7610a860d88f268eb39b95cb588acd7235e02fd5a5601773d4fa"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc2722592d8998c870fa4e290c2eec2c1569b87fe58618e67d38b4665dfa680d"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffc9202a29ab3920fa812879e95a9e78b2465fd10be7fcbd042899695d75e616"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:804a4d582ba6e5b747c625bf1255e6b1507465494a40a2130978bda7b932c90b"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:0f55e69f030f7163dffe9fd0752b32f070566451afe180f99dbeeb81f511ad8d"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:c4c3e6da02df6fa1410a7680bd3f63d4f710232d3139089536310d027950696a"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:5df196eb874dae23dcfb968c83d4f8fdccb333330fe1fc278ac5ceeb101003a9"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e358e64305fe12299a08e08978f51fc21fac060dcfcddd95453eabe5b93ed0e1"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-win32.whl", hash = "sha256:9b23ca7ef998bc739bf6ffc077c2116917eabcc901f88da1b9856b210ef63f35"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:6ff8a4a60c227ad87030d76e99cd1698345d4491638dfa6673027c48b3cd395f"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:aabfa34badd18f1da5ec1bc2715cadc8dca465868a4e73a0173466b688f29dda"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22e14b5d70560b8dd51ec22863f370d1e595ac3d024cb8ad7d308b4cd95f8313"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8436c508b408b82d87dc5f62496973a1805cd46727c34440b0d29d8a2f50a6c9"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2d074908e1aecee37a7635990b2c6d504cd4766c7bc9fc86d63f9c09af3fa11b"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:955f8851919303c92343d2f66165294848d57e9bba6cf6e3625485a70a038d11"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:44ecbf16649486d4aebafeaa7ec4c9fed8b88101f4dd612dcaf65d5e815f837f"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0924e81d3d5e70f8126529951dac65c1010cdf117bb75eb02dd12339b57749dd"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2967f74ad52c3b98de4c3b32e1a44e32975e008a9cd2a8cc8966d6a5218c5cb2"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:c75cb2a3e389853835e84a2d8fb2b81a10645b503eca9bcb98df6b5a43eb8886"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:09b26ae6b1abf0d27570633b2b078a2a20419c99d66fb2823173d73f188ce601"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:fa88b843d6e211393a37219e6a1c1df99d35e8fd90446f1118f4216e307e48cd"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-win32.whl", hash = "sha256:eb8178fe3dba6450a3e024e95ac49ed3400e506fd4e9e5c32d30adda88cbd407"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-win_amd64.whl", hash = "sha256:b1ac5992a838106edb89654e0aebfc24f5848ae2547d22c2c3f66454daa11971"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f30bf9fd9be89ecb2360c7d94a711f00c09b976258846efe40db3d05828e8089"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:97f68b8d6831127e4787ad15e6757232e14e12060bec17091b85eb1486b91d8d"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7974a0b5ecd505609e3b19742b60cee7aa2aa2fb3151bc917e6e2646d7667dcf"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc54db6c8593ef7d4b2a331b58653356cf04f67c960f584edb7c3d8c97e8f39e"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:311f30128d7d333eebd7896965bfcfbd0065f1716ec92bd5638d7748eb6f936a"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:7d053096f67cd1241601111b698f5cad775f97ab25d81567d3f59219b5f1adbd"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:807f52c1f798eef6cf26beb819eeb8819b1622ddfeef9d0977a8502d4db6d534"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:dccbe65bd2f7f7ec22c4ff99ed56faa1e9f785482b9bbd7c717e26fd723a1d1e"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_s390x.whl", hash = "sha256:2fb9bd477fdea8684f78791a6de97a953c51831ee2981f8e4f583ff3b9d9687e"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:01732659ba9b5b873fc117534143e4feefecf3b2078b0a6a2e925271bb6f4cfa"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-win32.whl", hash = "sha256:7a4f97a081603d2050bfaffdefa5b02a9ec823f8348a572e39032caa8404a487"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:7b1bef6280950ee6c177b326508f86cad7ad4dff12454483b51d8b7d673a2c5d"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ecddf25bee22fe4fe3737a399d0d177d72bc22be6913acfab364b40bce1ba83c"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c60ca7339acd497a55b0ea5d506b2a2612afb2826560416f6894e8b5770d4a9"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b7b2d86dd06bfc2ade3312a83a5c364c7ec2e3498f8734282c6c3d4b07b346b8"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dd78cfcda14a1ef52584dbb008f7ac81c1328c0f58184bf9a84c49c605002da6"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e27f48bcd0957c6d4cb9d6fa6b61d192d0b13d5ef563e5f2ae35feafc0d179c"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:01ad647cdd609225c5350561d084b42ddf732f4eeefe6e678765636791e78b9a"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:619a609aa74ae43d90ed2e89bdd784765de0a25ca761b93e196d938b8fd1dbbd"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:89149166622f4db9b4b6a449256291dc87a99ee53151c74cbd82a53c8c2f6ccd"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:7709f51f5f7c853f0fb938bcd3bc59cdfdc5203635ffd18bf354f6967ea0f824"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:345b0426edd4e18138d6528aed636de7a9ed169b4aaf9d61a8c19e39d26838ca"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0907f11d019260cdc3f94fbdb23ff9125f6b5d1039b76003b5b0ac9d6a6c9d5b"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-win32.whl", hash = "sha256:ea0d8d539afa5eb2728aa1932a988a9a7af94f18582ffae4bc10b3fbdad0626e"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:329ce159e82018d646c7ac45b01a430369d526569ec08516081727a20e9e4af4"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b97e690a2118911e39b4042088092771b4ae3fc3aa86518f84b8cf6888dbdb41"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78baa6d91634dfb69ec52a463534bc0df05dbd546209b79a3880a34487f4b84f"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1a2bc9f351a75ef49d664206d51f8e5ede9da246602dc2d2726837620ea034b2"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:75832c08354f595c760a804588b9357d34ec00ba1c940c15e31e96d902093770"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0af291f4fe114be0280cdd29d533696a77b5b49cfde5467176ecab32353395c4"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0167ddc8ab6508fe81860a57dd472b2ef4060e8d378f0cc555707126830f2537"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2a75d49014d118e4198bcee5ee0a6f25856b29b12dbf7cd012791f8a6cc5c496"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:363e2f92b0f0174b2f8238240a1a30142e3db7b957a5dd5689b0e75fb717cc78"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ab36c8eb7e454e34e60eb55ca5d241a5d18b2c6244f6827a30e451c42410b5f7"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:4c0907b1928a36d5a998d72d64d8eaa7244989f7aaaf947500d3a800c83a3fd6"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:04432ad9479fa40ec0f387795ddad4437a2b50417c69fa275e212933519ff294"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-win32.whl", hash = "sha256:3bed14e9c89dcb10e8f3a29f9ccac4955aebe93c71ae803af79265c9ca5644c5"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:49402233c892a461407c512a19435d1ce275543138294f7ef013f0b63d5d3765"},
+ {file = "charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85"},
+ {file = "charset_normalizer-3.4.1.tar.gz", hash = "sha256:44251f18cd68a75b56585dd00dae26183e102cd5e0f9f1466e6df5da2ed64ea3"},
+]
+
+[[package]]
+name = "click"
+version = "8.1.8"
+description = "Composable command line interface toolkit"
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2"},
+ {file = "click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a"},
+]
+
+[package.dependencies]
+colorama = {version = "*", markers = "platform_system == \"Windows\""}
+
+[[package]]
+name = "cloudpickle"
+version = "3.1.1"
+description = "Pickler class to extend the standard pickle.Pickler functionality"
+optional = true
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "cloudpickle-3.1.1-py3-none-any.whl", hash = "sha256:c8c5a44295039331ee9dad40ba100a9c7297b6f988e50e87ccdf3765a668350e"},
+ {file = "cloudpickle-3.1.1.tar.gz", hash = "sha256:b216fa8ae4019d5482a8ac3c95d8f6346115d8835911fd4aefd1a445e4242c64"},
+]
+
+[[package]]
+name = "colorama"
+version = "0.4.6"
+description = "Cross-platform colored terminal text."
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
+groups = ["main", "dev"]
+files = [
+ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
+ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
+]
+markers = {main = "platform_system == \"Windows\" or sys_platform == \"win32\"", dev = "sys_platform == \"win32\""}
+
+[[package]]
+name = "deprecated"
+version = "1.2.18"
+description = "Python @deprecated decorator to deprecate old python classes, functions or methods."
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7"
+groups = ["main"]
+files = [
+ {file = "Deprecated-1.2.18-py2.py3-none-any.whl", hash = "sha256:bd5011788200372a32418f888e326a09ff80d0214bd961147cfed01b5c018eec"},
+ {file = "deprecated-1.2.18.tar.gz", hash = "sha256:422b6f6d859da6f2ef57857761bfb392480502a64c3028ca9bbe86085d72115d"},
+]
+
+[package.dependencies]
+wrapt = ">=1.10,<2"
+
+[package.extras]
+dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "setuptools", "tox"]
+
+[[package]]
+name = "diffusers"
+version = "0.31.0"
+description = "State-of-the-art diffusion in PyTorch and JAX."
+optional = false
+python-versions = ">=3.8.0"
+groups = ["main"]
+files = [
+ {file = "diffusers-0.31.0-py3-none-any.whl", hash = "sha256:cbc498ae63f4abfc7c3a07649cdcbee229ef2f9a9a1f0d19c9bbaf22f8d30c1f"},
+ {file = "diffusers-0.31.0.tar.gz", hash = "sha256:b1d01a73e45d43a0630c299173915dddd69fc50f2ae8f2ab5de4fd245eaed72f"},
+]
+
+[package.dependencies]
+filelock = "*"
+huggingface-hub = ">=0.23.2"
+importlib-metadata = "*"
+numpy = "*"
+Pillow = "*"
+regex = "!=2019.12.17"
+requests = "*"
+safetensors = ">=0.3.1"
+
+[package.extras]
+dev = ["GitPython (<3.1.19)", "Jinja2", "accelerate (>=0.31.0)", "compel (==0.1.8)", "datasets", "flax (>=0.4.1)", "hf-doc-builder (>=0.3.0)", "invisible-watermark (>=0.2.0)", "isort (>=5.5.4)", "jax (>=0.4.1)", "jaxlib (>=0.4.1)", "k-diffusion (>=0.0.12)", "librosa", "parameterized", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "pytest", "pytest-timeout", "pytest-xdist", "requests-mock (==1.10.0)", "ruff (==0.1.5)", "safetensors (>=0.3.1)", "scipy", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "torch (>=1.4,<2.5.0)", "torchvision", "transformers (>=4.41.2)", "urllib3 (<=2.0.0)"]
+docs = ["hf-doc-builder (>=0.3.0)"]
+flax = ["flax (>=0.4.1)", "jax (>=0.4.1)", "jaxlib (>=0.4.1)"]
+quality = ["hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<=2.0.0)"]
+test = ["GitPython (<3.1.19)", "Jinja2", "compel (==0.1.8)", "datasets", "invisible-watermark (>=0.2.0)", "k-diffusion (>=0.0.12)", "librosa", "parameterized", "pytest", "pytest-timeout", "pytest-xdist", "requests-mock (==1.10.0)", "safetensors (>=0.3.1)", "scipy", "sentencepiece (>=0.1.91,!=0.1.92)", "torchvision", "transformers (>=4.41.2)"]
+torch = ["accelerate (>=0.31.0)", "torch (>=1.4,<2.5.0)"]
+training = ["Jinja2", "accelerate (>=0.31.0)", "datasets", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "tensorboard"]
+
+[[package]]
+name = "diskcache"
+version = "5.6.3"
+description = "Disk Cache -- Disk and file backed persistent cache."
+optional = true
+python-versions = ">=3"
+groups = ["main"]
+files = [
+ {file = "diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19"},
+ {file = "diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc"},
+]
+
+[[package]]
+name = "exceptiongroup"
+version = "1.2.2"
+description = "Backport of PEP 654 (exception groups)"
+optional = false
+python-versions = ">=3.7"
+groups = ["dev"]
+markers = "python_version < \"3.11\""
+files = [
+ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"},
+ {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"},
+]
+
+[package.extras]
+test = ["pytest (>=6)"]
+
+[[package]]
+name = "filelock"
+version = "3.18.0"
+description = "A platform independent file lock."
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "filelock-3.18.0-py3-none-any.whl", hash = "sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de"},
+ {file = "filelock-3.18.0.tar.gz", hash = "sha256:adbc88eabb99d2fec8c9c1b229b171f18afa655400173ddc653d5d01501fb9f2"},
+]
+
+[package.extras]
+docs = ["furo (>=2024.8.6)", "sphinx (>=8.1.3)", "sphinx-autodoc-typehints (>=3)"]
+testing = ["covdefaults (>=2.3)", "coverage (>=7.6.10)", "diff-cover (>=9.2.1)", "pytest (>=8.3.4)", "pytest-asyncio (>=0.25.2)", "pytest-cov (>=6)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.28.1)"]
+typing = ["typing-extensions (>=4.12.2)"]
+
+[[package]]
+name = "fsspec"
+version = "2025.3.2"
+description = "File-system specification"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "fsspec-2025.3.2-py3-none-any.whl", hash = "sha256:2daf8dc3d1dfa65b6aa37748d112773a7a08416f6c70d96b264c96476ecaf711"},
+ {file = "fsspec-2025.3.2.tar.gz", hash = "sha256:e52c77ef398680bbd6a98c0e628fbc469491282981209907bbc8aea76a04fdc6"},
+]
+
+[package.extras]
+abfs = ["adlfs"]
+adl = ["adlfs"]
+arrow = ["pyarrow (>=1)"]
+dask = ["dask", "distributed"]
+dev = ["pre-commit", "ruff"]
+doc = ["numpydoc", "sphinx", "sphinx-design", "sphinx-rtd-theme", "yarl"]
+dropbox = ["dropbox", "dropboxdrivefs", "requests"]
+full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"]
+fuse = ["fusepy"]
+gcs = ["gcsfs"]
+git = ["pygit2"]
+github = ["requests"]
+gs = ["gcsfs"]
+gui = ["panel"]
+hdfs = ["pyarrow (>=1)"]
+http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)"]
+libarchive = ["libarchive-c"]
+oci = ["ocifs"]
+s3 = ["s3fs"]
+sftp = ["paramiko"]
+smb = ["smbprotocol"]
+ssh = ["paramiko"]
+test = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "numpy", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "requests"]
+test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask[dataframe,test]", "moto[server] (>4,<5)", "pytest-timeout", "xarray"]
+test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"]
+tqdm = ["tqdm"]
+
+[[package]]
+name = "googleapis-common-protos"
+version = "1.70.0"
+description = "Common protobufs used in Google APIs"
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "googleapis_common_protos-1.70.0-py3-none-any.whl", hash = "sha256:b8bfcca8c25a2bb253e0e0b0adaf8c00773e5e6af6fd92397576680b807e0fd8"},
+ {file = "googleapis_common_protos-1.70.0.tar.gz", hash = "sha256:0e1b44e0ea153e6594f9f394fef15193a68aaaea2d843f83e2742717ca753257"},
+]
+
+[package.dependencies]
+protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
+
+[package.extras]
+grpc = ["grpcio (>=1.44.0,<2.0.0)"]
+
+[[package]]
+name = "grpc-interceptor"
+version = "0.15.4"
+description = "Simplifies gRPC interceptors"
+optional = false
+python-versions = ">=3.7,<4.0"
+groups = ["main"]
+files = [
+ {file = "grpc-interceptor-0.15.4.tar.gz", hash = "sha256:1f45c0bcb58b6f332f37c637632247c9b02bc6af0fdceb7ba7ce8d2ebbfb0926"},
+ {file = "grpc_interceptor-0.15.4-py3-none-any.whl", hash = "sha256:0035f33228693ed3767ee49d937bac424318db173fef4d2d0170b3215f254d9d"},
+]
+
+[package.dependencies]
+grpcio = ">=1.49.1,<2.0.0"
+
+[package.extras]
+testing = ["protobuf (>=4.21.9)"]
+
+[[package]]
+name = "grpcio"
+version = "1.72.0rc1"
+description = "HTTP/2-based RPC framework"
+optional = false
+python-versions = ">=3.9"
+groups = ["main", "dev"]
+files = [
+ {file = "grpcio-1.72.0rc1-cp310-cp310-linux_armv7l.whl", hash = "sha256:db7db4b246a7fb21aeb70e7220be480948aa9c535eaa777ea0c840416ed8cac9"},
+ {file = "grpcio-1.72.0rc1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:baf028e61662fd320c18fb50070b6e330fa24b2b3a4d113f4d57b41e0f5b5873"},
+ {file = "grpcio-1.72.0rc1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:bf84cf17dfbf49ebe11b081b7a3c83b23625a80c979741e2e98b0ddb41080397"},
+ {file = "grpcio-1.72.0rc1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3fd6f8700d34754b32d13af234da2e413f408c8b741c8039f11beb06d53c3f6a"},
+ {file = "grpcio-1.72.0rc1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f05d243b8d814dd1c6fca19e4e0c5986fc70e2c3aa29e2c7c67e877e4c03ede6"},
+ {file = "grpcio-1.72.0rc1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:390a70394e2c315d7c480496db259ec16c00baeebf759c8967247269f0fee981"},
+ {file = "grpcio-1.72.0rc1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:b08973c62eda11343e7131d78635d50ae0c138a8f39eb817ca83cca842527d04"},
+ {file = "grpcio-1.72.0rc1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ce539397a258af1dee26118c40327004d023617bc99493baaf8e7938491f7361"},
+ {file = "grpcio-1.72.0rc1-cp310-cp310-win32.whl", hash = "sha256:4f97f628095bbdf6d4c2c15c1bc18f0514f90781528bc6082bb697ccc71d4f42"},
+ {file = "grpcio-1.72.0rc1-cp310-cp310-win_amd64.whl", hash = "sha256:dbcdf7a5463b61fca1586b54f7ea3c9dfd159f535224f457ae307f52d8d4a839"},
+ {file = "grpcio-1.72.0rc1-cp311-cp311-linux_armv7l.whl", hash = "sha256:23ebb3947783f10fec3e1d0b29b94db8e72f721900d1dd9c1d6db5876da69066"},
+ {file = "grpcio-1.72.0rc1-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:fd96b20846907ed4cd95bf1d628f16732f450114bde897eedb323fc3bc1eddb3"},
+ {file = "grpcio-1.72.0rc1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:6df1ba4a5f5793ae210699e1b1745f77a4ac17f73510fc36ee12c215f02523b4"},
+ {file = "grpcio-1.72.0rc1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a3398957c611f0af7cee4fdd34268b6664be8689eae0327440efb794e544908b"},
+ {file = "grpcio-1.72.0rc1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ef66029da9cbe94ba3047c1b04653e1d5096ca8d036eb6e24092f0e847d2c4f"},
+ {file = "grpcio-1.72.0rc1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6566e3e3458805381f8714492e8f559f082f8955ccd1c98d71f8afc0612dc841"},
+ {file = "grpcio-1.72.0rc1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:3c799bfa92450e95d3f1f9cc4b7d8cbefc8bd4356d3f6573d2fb5e698353192a"},
+ {file = "grpcio-1.72.0rc1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a251992531f3b16be3c013ec45a9caa69ecfe9b45335652d5681659f6d117233"},
+ {file = "grpcio-1.72.0rc1-cp311-cp311-win32.whl", hash = "sha256:c9e5f2c628dedf0886b774eee17e003a043941024e68ee2ebe76be6981a7baab"},
+ {file = "grpcio-1.72.0rc1-cp311-cp311-win_amd64.whl", hash = "sha256:8b9c0a84ff584da3f5c0cb04ee3d87c0bc70d41ab5a21d3b943963a94c622892"},
+ {file = "grpcio-1.72.0rc1-cp312-cp312-linux_armv7l.whl", hash = "sha256:188ac9d8cb05c250e212ba946a65a8541419bdfd803373d6b7fb8b10fe5ff991"},
+ {file = "grpcio-1.72.0rc1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:8bd956711dc21235bc78a70bf04a28b3f747c6576b9bb79362803707fec9f705"},
+ {file = "grpcio-1.72.0rc1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:b032b9cbb325e28ff847b6aae1df5a090aa49b682dc80c926b24a96de43c01aa"},
+ {file = "grpcio-1.72.0rc1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1ca12a4388a40eb0411264af291184e2cca38176996b591ac047844abd81d40b"},
+ {file = "grpcio-1.72.0rc1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7cefd52f392f4d6747b401f825901c48176737f7b03b17be0a0a638da194749"},
+ {file = "grpcio-1.72.0rc1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:1a24408fb051b70efa440b95f7e1acbb1c3067609934aa53a953d8d2cfc4d824"},
+ {file = "grpcio-1.72.0rc1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:c7b37608d14792d3dacb9aba55b96a17a074e139c4567b0ac5c1926302add910"},
+ {file = "grpcio-1.72.0rc1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:81ca42a96299ca617f3bc7b60660f15cabb98de6fce440ecd4d0640a5554345f"},
+ {file = "grpcio-1.72.0rc1-cp312-cp312-win32.whl", hash = "sha256:9ff2ef2a553d4edc8c620df3735b15a1e7dc05a60262e8c28445f2676fb09189"},
+ {file = "grpcio-1.72.0rc1-cp312-cp312-win_amd64.whl", hash = "sha256:3c9a6613662591c198d9e4e499f3336bc5c1c0e3fe3f0922cf48e74b37b3dcd1"},
+ {file = "grpcio-1.72.0rc1-cp313-cp313-linux_armv7l.whl", hash = "sha256:995e3e5c43cab6d0f1922b43b3c01a2624a4497ce91c3124e807497654301c59"},
+ {file = "grpcio-1.72.0rc1-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:8dfb0ff2ddd708dbecdffa37245b79aef707e789ffb0fc6a8be01608d982afcd"},
+ {file = "grpcio-1.72.0rc1-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:7e08eb53d6123995da63df90ce50e5b834de0a8ebfb1a3ac0890a2e246d2771c"},
+ {file = "grpcio-1.72.0rc1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:71cb52c0956fe7868692b490fda341a52d8187fab94e1136f5bd253c8e3560ac"},
+ {file = "grpcio-1.72.0rc1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dcf76ce8d4a6829f112ad88c4e6d528dbef922e01834d4a5cc3718bf599f7e84"},
+ {file = "grpcio-1.72.0rc1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:8852b6234a52b6b694a5f9a5a687d59127b3e71c8e345eebd6d483abbc412217"},
+ {file = "grpcio-1.72.0rc1-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:d1a0fee8420d9e453dc8cba1c7c067ca2d3054487cb6616ab8dad41f15e57465"},
+ {file = "grpcio-1.72.0rc1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:a13149f4fd3904093fa2dba484744dd7205f536650a533ab24dd95cca393c14c"},
+ {file = "grpcio-1.72.0rc1-cp313-cp313-win32.whl", hash = "sha256:cebe148511a1965363fc6aafd60a488fe9dc5d74dd92a59a8ecba66ddd53c573"},
+ {file = "grpcio-1.72.0rc1-cp313-cp313-win_amd64.whl", hash = "sha256:843352c352970a1df5bbf7da68d2770781f4bff2c85a4a0d20cc6eaaadf26e59"},
+ {file = "grpcio-1.72.0rc1-cp39-cp39-linux_armv7l.whl", hash = "sha256:2083c0cdff47ff7d4b093d05d703baeeef8db3b2c1f43c9f9d4288a99e444cdd"},
+ {file = "grpcio-1.72.0rc1-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:42df7e0f9d66f5c9b246d8e1da74605bce27b10dec20b6fc204edd6e7178da2d"},
+ {file = "grpcio-1.72.0rc1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:1190c2e4f221b5bd0e6eba3e44d6758ef48eeb2216dcb9734c158e8a5d8ce6a3"},
+ {file = "grpcio-1.72.0rc1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6d6c8d2ea63e1cdaaa81271e5c867fcd9732050324df372ff9d3163968be68c8"},
+ {file = "grpcio-1.72.0rc1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f6ee161b9d112232e5d6be437bf56383dca2334bd17e8b7a4a3f97f33722bdd"},
+ {file = "grpcio-1.72.0rc1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9abbdf945e3b151603d642f2bc7a637b87af2e3480ed047689bad9eb4fa9c712"},
+ {file = "grpcio-1.72.0rc1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:2edab5d26319a1fed695ec658efe3846b75e0c7f3a6202b042099c9b11dc10fd"},
+ {file = "grpcio-1.72.0rc1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:03b46e0041bee18a786ccef978bc29a26e4bd1b73a6ca0b21252387167843ff1"},
+ {file = "grpcio-1.72.0rc1-cp39-cp39-win32.whl", hash = "sha256:9b861cbfb63433e02b52f9971644095bec4a5fcd1e4d3f94e18cfad38f649d53"},
+ {file = "grpcio-1.72.0rc1-cp39-cp39-win_amd64.whl", hash = "sha256:2416792a567cba9f92bffc1a55ce0f2c8106956a2e32bfe8a22a8094a56b7108"},
+ {file = "grpcio-1.72.0rc1.tar.gz", hash = "sha256:221793dccd3332060f426975a041d319d6d57323d857d4afc25257ec4a5a67f3"},
+]
+
+[package.extras]
+protobuf = ["grpcio-tools (>=1.72.0rc1)"]
+
+[[package]]
+name = "grpcio-reflection"
+version = "1.71.0"
+description = "Standard Protobuf Reflection Service for gRPC"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "grpcio_reflection-1.71.0-py3-none-any.whl", hash = "sha256:8c88bdd9c92fcdd4d5df119997be05ecd0d7e10d377ec4a5072db507d2894612"},
+ {file = "grpcio_reflection-1.71.0.tar.gz", hash = "sha256:51504e977057ffabe66d1ed55557b15e969c42bb3a1f28ee45d730dd5f983bb5"},
+]
+
+[package.dependencies]
+grpcio = ">=1.71.0"
+protobuf = ">=5.26.1,<6.0dev"
+
+[[package]]
+name = "grpcio-status"
+version = "1.71.0"
+description = "Status proto mapping for gRPC"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "grpcio_status-1.71.0-py3-none-any.whl", hash = "sha256:843934ef8c09e3e858952887467f8256aac3910c55f077a359a65b2b3cde3e68"},
+ {file = "grpcio_status-1.71.0.tar.gz", hash = "sha256:11405fed67b68f406b3f3c7c5ae5104a79d2d309666d10d61b152e91d28fb968"},
+]
+
+[package.dependencies]
+googleapis-common-protos = ">=1.5.5"
+grpcio = ">=1.71.0"
+protobuf = ">=5.26.1,<6.0dev"
+
+[[package]]
+name = "grpcio-tools"
+version = "1.71.0"
+description = "Protobuf code generator for gRPC"
+optional = false
+python-versions = ">=3.9"
+groups = ["dev"]
+files = [
+ {file = "grpcio_tools-1.71.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:f4ad7f0d756546902597053d70b3af2606fbd70d7972876cd75c1e241d22ae00"},
+ {file = "grpcio_tools-1.71.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:64bdb291df61cf570b5256777ad5fe2b1db6d67bc46e55dc56a0a862722ae329"},
+ {file = "grpcio_tools-1.71.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:8dd9795e982d77a4b496f7278b943c2563d9afde2069cdee78c111a40cc4d675"},
+ {file = "grpcio_tools-1.71.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c1b5860c41a36b26fec4f52998f1a451d0525a5c9a4fb06b6ea3e9211abdb925"},
+ {file = "grpcio_tools-1.71.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3059c14035e5dc03d462f261e5900b9a077fd1a36976c3865b8507474520bad4"},
+ {file = "grpcio_tools-1.71.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:f360981b215b1d5aff9235b37e7e1826246e35bbac32a53e41d4e990a37b8f4c"},
+ {file = "grpcio_tools-1.71.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:bfe3888c3bbe16a5aa39409bc38744a31c0c3d2daa2b0095978c56e106c85b42"},
+ {file = "grpcio_tools-1.71.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:145985c0bf12131f0a1503e65763e0f060473f7f3928ed1ff3fb0e8aad5bc8ac"},
+ {file = "grpcio_tools-1.71.0-cp310-cp310-win32.whl", hash = "sha256:82c430edd939bb863550ee0fecf067d78feff828908a1b529bbe33cc57f2419c"},
+ {file = "grpcio_tools-1.71.0-cp310-cp310-win_amd64.whl", hash = "sha256:83e90724e3f02415c628e4ead1d6ffe063820aaaa078d9a39176793df958cd5a"},
+ {file = "grpcio_tools-1.71.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:1f19b16b49afa5d21473f49c0966dd430c88d089cd52ac02404d8cef67134efb"},
+ {file = "grpcio_tools-1.71.0-cp311-cp311-macosx_10_14_universal2.whl", hash = "sha256:459c8f5e00e390aecd5b89de67deb3ec7188a274bc6cb50e43cef35ab3a3f45d"},
+ {file = "grpcio_tools-1.71.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:edab7e6518de01196be37f96cb1e138c3819986bf5e2a6c9e1519b4d716b2f5a"},
+ {file = "grpcio_tools-1.71.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8b93b9f6adc7491d4c10144c0643409db298e5e63c997106a804f6f0248dbaf4"},
+ {file = "grpcio_tools-1.71.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ae5f2efa9e644c10bf1021600bfc099dfbd8e02b184d2d25dc31fcd6c2bc59e"},
+ {file = "grpcio_tools-1.71.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:65aa082f4435571d65d5ce07fc444f23c3eff4f3e34abef599ef8c9e1f6f360f"},
+ {file = "grpcio_tools-1.71.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:1331e726e08b7bdcbf2075fcf4b47dff07842b04845e6e220a08a4663e232d7f"},
+ {file = "grpcio_tools-1.71.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6693a7d3ba138b0e693b3d1f687cdd9db9e68976c3fa2b951c17a072fea8b583"},
+ {file = "grpcio_tools-1.71.0-cp311-cp311-win32.whl", hash = "sha256:6d11ed3ff7b6023b5c72a8654975324bb98c1092426ba5b481af406ff559df00"},
+ {file = "grpcio_tools-1.71.0-cp311-cp311-win_amd64.whl", hash = "sha256:072b2a5805ac97e4623b3aa8f7818275f3fb087f4aa131b0fce00471065f6eaa"},
+ {file = "grpcio_tools-1.71.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:61c0409d5bdac57a7bd0ce0ab01c1c916728fe4c8a03d77a25135ad481eb505c"},
+ {file = "grpcio_tools-1.71.0-cp312-cp312-macosx_10_14_universal2.whl", hash = "sha256:28784f39921d061d2164a9dcda5164a69d07bf29f91f0ea50b505958292312c9"},
+ {file = "grpcio_tools-1.71.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:192808cf553cedca73f0479cc61d5684ad61f24db7a5f3c4dfe1500342425866"},
+ {file = "grpcio_tools-1.71.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:989ee9da61098230d3d4c8f8f8e27c2de796f1ff21b1c90110e636d9acd9432b"},
+ {file = "grpcio_tools-1.71.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:541a756276c8a55dec991f6c0106ae20c8c8f5ce8d0bdbfcb01e2338d1a8192b"},
+ {file = "grpcio_tools-1.71.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:870c0097700d13c403e5517cb7750ab5b4a791ce3e71791c411a38c5468b64bd"},
+ {file = "grpcio_tools-1.71.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:abd57f615e88bf93c3c6fd31f923106e3beb12f8cd2df95b0d256fa07a7a0a57"},
+ {file = "grpcio_tools-1.71.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:753270e2d06d37e6d7af8967d1d059ec635ad215882041a36294f4e2fd502b2e"},
+ {file = "grpcio_tools-1.71.0-cp312-cp312-win32.whl", hash = "sha256:0e647794bd7138b8c215e86277a9711a95cf6a03ff6f9e555d54fdf7378b9f9d"},
+ {file = "grpcio_tools-1.71.0-cp312-cp312-win_amd64.whl", hash = "sha256:48debc879570972d28bfe98e4970eff25bb26da3f383e0e49829b2d2cd35ad87"},
+ {file = "grpcio_tools-1.71.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:9a78d07d6c301a25ef5ede962920a522556a1dfee1ccc05795994ceb867f766c"},
+ {file = "grpcio_tools-1.71.0-cp313-cp313-macosx_10_14_universal2.whl", hash = "sha256:580ac88141c9815557e63c9c04f5b1cdb19b4db8d0cb792b573354bde1ee8b12"},
+ {file = "grpcio_tools-1.71.0-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:f7c678e68ece0ae908ecae1c4314a0c2c7f83e26e281738b9609860cc2c82d96"},
+ {file = "grpcio_tools-1.71.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:56ecd6cc89b5e5eed1de5eb9cafce86c9c9043ee3840888cc464d16200290b53"},
+ {file = "grpcio_tools-1.71.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e52a041afc20ab2431d756b6295d727bd7adee813b21b06a3483f4a7a15ea15f"},
+ {file = "grpcio_tools-1.71.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:2a1712f12102b60c8d92779b89d0504e0d6f3a59f2b933e5622b8583f5c02992"},
+ {file = "grpcio_tools-1.71.0-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:41878cb7a75477e62fdd45e7e9155b3af1b7a5332844021e2511deaf99ac9e6c"},
+ {file = "grpcio_tools-1.71.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:682e958b476049ccc14c71bedf3f979bced01f6e0c04852efc5887841a32ad6b"},
+ {file = "grpcio_tools-1.71.0-cp313-cp313-win32.whl", hash = "sha256:0ccfb837152b7b858b9f26bb110b3ae8c46675d56130f6c2f03605c4f129be13"},
+ {file = "grpcio_tools-1.71.0-cp313-cp313-win_amd64.whl", hash = "sha256:ffff9bc5eacb34dd26b487194f7d44a3e64e752fc2cf049d798021bf25053b87"},
+ {file = "grpcio_tools-1.71.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:834959b6eceb85de5217a411aba1643b5f782798680c122202d6a06177226644"},
+ {file = "grpcio_tools-1.71.0-cp39-cp39-macosx_10_14_universal2.whl", hash = "sha256:e3ae9556e2a1cd70e7d7b0e0459c35af71d51a7dae4cf36075068011a69f13ec"},
+ {file = "grpcio_tools-1.71.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:77fe6db1334e0ce318b2cb4e70afa94e0c173ed1a533d37aea69ad9f61ae8ea9"},
+ {file = "grpcio_tools-1.71.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57e3e2544c306b60ef2d76570bac4e977be1ad548641c9eec130c3bc47e80141"},
+ {file = "grpcio_tools-1.71.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af39e245fa56f7f5c2fe86b7d6c1b78f395c07e54d5613cbdbb3c24769a92b6e"},
+ {file = "grpcio_tools-1.71.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:8f987d0053351217954543b174b0bddbf51d45b3cfcf8d6de97b0a43d264d753"},
+ {file = "grpcio_tools-1.71.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:8e6cdbba4dae7b37b0d25d074614be9936fb720144420f03d9f142a80be69ba2"},
+ {file = "grpcio_tools-1.71.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d3adc8b229e60c77bab5a5d62b415667133bd5ced7d59b5f71d6317c9143631e"},
+ {file = "grpcio_tools-1.71.0-cp39-cp39-win32.whl", hash = "sha256:f68334d28a267fabec6e70cb5986e9999cfbfd14db654094ddf9aedd804a293a"},
+ {file = "grpcio_tools-1.71.0-cp39-cp39-win_amd64.whl", hash = "sha256:1291a6136c07a86c3bb09f6c33f5cf227cc14956edd1b85cb572327a36e0aef8"},
+ {file = "grpcio_tools-1.71.0.tar.gz", hash = "sha256:38dba8e0d5e0fb23a034e09644fdc6ed862be2371887eee54901999e8f6792a8"},
+]
+
+[package.dependencies]
+grpcio = ">=1.71.0"
+protobuf = ">=5.26.1,<6.0dev"
+setuptools = "*"
+
+[[package]]
+name = "hf-transfer"
+version = "0.1.9"
+description = "Speed up file transfers with the Hugging Face Hub."
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "hf_transfer-0.1.9-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:6e94e8822da79573c9b6ae4d6b2f847c59a7a06c5327d7db20751b68538dc4f6"},
+ {file = "hf_transfer-0.1.9-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:3ebc4ab9023414880c8b1d3c38174d1c9989eb5022d37e814fa91a3060123eb0"},
+ {file = "hf_transfer-0.1.9-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8674026f21ed369aa2a0a4b46000aca850fc44cd2b54af33a172ce5325b4fc82"},
+ {file = "hf_transfer-0.1.9-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3a736dfbb2c84f5a2c975478ad200c0c8bfcb58a25a35db402678fb87ce17fa4"},
+ {file = "hf_transfer-0.1.9-cp313-cp313t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:504b8427fd785dd8546d53b9fafe6e436bd7a3adf76b9dce556507650a7b4567"},
+ {file = "hf_transfer-0.1.9-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2c7fc1b85f4d0f76e452765d7648c9f4bfd0aedb9ced2ae1ebfece2d8cfaf8e2"},
+ {file = "hf_transfer-0.1.9-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d991376f0eac70a60f0cbc95602aa708a6f7c8617f28b4945c1431d67b8e3c8"},
+ {file = "hf_transfer-0.1.9-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:e6ac4eddcd99575ed3735ed911ddf9d1697e2bd13aa3f0ad7e3904dd4863842e"},
+ {file = "hf_transfer-0.1.9-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:57fd9880da1ee0f47250f735f791fab788f0aa1ee36afc49f761349869c8b4d9"},
+ {file = "hf_transfer-0.1.9-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:5d561f0520f493c66b016d99ceabe69c23289aa90be38dd802d2aef279f15751"},
+ {file = "hf_transfer-0.1.9-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:a5b366d34cd449fe9b20ef25941e6eef0460a2f74e7389f02e673e1f88ebd538"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:e66acf91df4a8b72f60223059df3003062a5ae111757187ed1a06750a30e911b"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:8669dbcc7a3e2e8d61d42cd24da9c50d57770bd74b445c65123291ca842a7e7a"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8fd0167c4407a3bc4cdd0307e65ada2294ec04f1813d8a69a5243e379b22e9d8"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ee8b10afedcb75f71091bcc197c526a6ebf5c58bbbadb34fdeee6160f55f619f"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5828057e313de59300dd1abb489444bc452efe3f479d3c55b31a8f680936ba42"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fc6bd19e1cc177c66bdef15ef8636ad3bde79d5a4f608c158021153b4573509d"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cdca9bfb89e6f8f281890cc61a8aff2d3cecaff7e1a4d275574d96ca70098557"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:89a23f58b7b7effbc047b8ca286f131b17728c99a9f972723323003ffd1bb916"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:dc7fff1345980d6c0ebb92c811d24afa4b98b3e07ed070c8e38cc91fd80478c5"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:1a6bd16c667ebe89a069ca163060127a794fa3a3525292c900b8c8cc47985b0d"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:d2fde99d502093ade3ab1b53f80da18480e9902aa960dab7f74fb1b9e5bc5746"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-win32.whl", hash = "sha256:435cc3cdc8524ce57b074032b8fd76eed70a4224d2091232fa6a8cef8fd6803e"},
+ {file = "hf_transfer-0.1.9-cp38-abi3-win_amd64.whl", hash = "sha256:16f208fc678911c37e11aa7b586bc66a37d02e636208f18b6bc53d29b5df40ad"},
+ {file = "hf_transfer-0.1.9.tar.gz", hash = "sha256:035572865dab29d17e783fbf1e84cf1cb24f3fcf8f1b17db1cfc7fdf139f02bf"},
+]
+
+[[package]]
+name = "huggingface-hub"
+version = "0.30.2"
+description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
+optional = false
+python-versions = ">=3.8.0"
+groups = ["main"]
+files = [
+ {file = "huggingface_hub-0.30.2-py3-none-any.whl", hash = "sha256:68ff05969927058cfa41df4f2155d4bb48f5f54f719dd0390103eefa9b191e28"},
+ {file = "huggingface_hub-0.30.2.tar.gz", hash = "sha256:9a7897c5b6fd9dad3168a794a8998d6378210f5b9688d0dfc180b1a228dc2466"},
+]
+
+[package.dependencies]
+filelock = "*"
+fsspec = ">=2023.5.0"
+packaging = ">=20.9"
+pyyaml = ">=5.1"
+requests = "*"
+tqdm = ">=4.42.1"
+typing-extensions = ">=3.7.4.3"
+
+[package.extras]
+all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
+cli = ["InquirerPy (==0.3.4)"]
+dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
+fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
+hf-transfer = ["hf-transfer (>=0.1.4)"]
+hf-xet = ["hf-xet (>=0.1.4)"]
+inference = ["aiohttp"]
+quality = ["libcst (==1.4.0)", "mypy (==1.5.1)", "ruff (>=0.9.0)"]
+tensorflow = ["graphviz", "pydot", "tensorflow"]
+tensorflow-testing = ["keras (<3.0)", "tensorflow"]
+testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
+torch = ["safetensors[torch]", "torch"]
+typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"]
+
+[[package]]
+name = "idna"
+version = "3.10"
+description = "Internationalized Domain Names in Applications (IDNA)"
+optional = false
+python-versions = ">=3.6"
+groups = ["main"]
+files = [
+ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"},
+ {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"},
+]
+
+[package.extras]
+all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"]
+
+[[package]]
+name = "importlib-metadata"
+version = "8.6.1"
+description = "Read metadata from Python packages"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "importlib_metadata-8.6.1-py3-none-any.whl", hash = "sha256:02a89390c1e15fdfdc0d7c6b25cb3e62650d0494005c97d6f148bf5b9787525e"},
+ {file = "importlib_metadata-8.6.1.tar.gz", hash = "sha256:310b41d755445d74569f993ccfc22838295d9fe005425094fad953d7f15c8580"},
+]
+
+[package.dependencies]
+zipp = ">=3.20"
+
+[package.extras]
+check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"]
+cover = ["pytest-cov"]
+doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
+enabler = ["pytest-enabler (>=2.2)"]
+perf = ["ipython"]
+test = ["flufl.flake8", "importlib_resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-perf (>=0.9.2)"]
+type = ["pytest-mypy"]
+
+[[package]]
+name = "iniconfig"
+version = "2.1.0"
+description = "brain-dead simple config-ini parsing"
+optional = false
+python-versions = ">=3.8"
+groups = ["dev"]
+files = [
+ {file = "iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760"},
+ {file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"},
+]
+
+[[package]]
+name = "interegular"
+version = "0.3.3"
+description = "a regex intersection checker"
+optional = true
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "interegular-0.3.3-py37-none-any.whl", hash = "sha256:b0c07007d48c89d6d19f7204972d369b2a77222722e126b6aa63aa721dc3b19c"},
+ {file = "interegular-0.3.3.tar.gz", hash = "sha256:d9b697b21b34884711399ba0f0376914b81899ce670032486d0d048344a76600"},
+]
+
+[[package]]
+name = "jinja2"
+version = "3.1.6"
+description = "A very fast and expressive template engine."
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67"},
+ {file = "jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d"},
+]
+
+[package.dependencies]
+MarkupSafe = ">=2.0"
+
+[package.extras]
+i18n = ["Babel (>=2.7)"]
+
+[[package]]
+name = "joblib"
+version = "1.4.2"
+description = "Lightweight pipelining with Python functions"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6"},
+ {file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"},
+]
+
+[[package]]
+name = "jsonschema"
+version = "4.23.0"
+description = "An implementation of JSON Schema validation for Python"
+optional = true
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "jsonschema-4.23.0-py3-none-any.whl", hash = "sha256:fbadb6f8b144a8f8cf9f0b89ba94501d143e50411a1278633f56a7acf7fd5566"},
+ {file = "jsonschema-4.23.0.tar.gz", hash = "sha256:d71497fef26351a33265337fa77ffeb82423f3ea21283cd9467bb03999266bc4"},
+]
+
+[package.dependencies]
+attrs = ">=22.2.0"
+jsonschema-specifications = ">=2023.03.6"
+referencing = ">=0.28.4"
+rpds-py = ">=0.7.1"
+
+[package.extras]
+format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"]
+format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=24.6.0)"]
+
+[[package]]
+name = "jsonschema-specifications"
+version = "2024.10.1"
+description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "jsonschema_specifications-2024.10.1-py3-none-any.whl", hash = "sha256:a09a0680616357d9a0ecf05c12ad234479f549239d0f5b55f3deea67475da9bf"},
+ {file = "jsonschema_specifications-2024.10.1.tar.gz", hash = "sha256:0f38b83639958ce1152d02a7f062902c41c8fd20d558b0c34344292d417ae272"},
+]
+
+[package.dependencies]
+referencing = ">=0.31.0"
+
+[[package]]
+name = "lark"
+version = "1.2.2"
+description = "a modern parsing library"
+optional = true
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "lark-1.2.2-py3-none-any.whl", hash = "sha256:c2276486b02f0f1b90be155f2c8ba4a8e194d42775786db622faccd652d8e80c"},
+ {file = "lark-1.2.2.tar.gz", hash = "sha256:ca807d0162cd16cef15a8feecb862d7319e7a09bdb13aef927968e45040fed80"},
+]
+
+[package.extras]
+atomic-cache = ["atomicwrites"]
+interegular = ["interegular (>=0.3.1,<0.4.0)"]
+nearley = ["js2py"]
+regex = ["regex"]
+
+[[package]]
+name = "llvmlite"
+version = "0.43.0"
+description = "lightweight wrapper around basic LLVM functionality"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "llvmlite-0.43.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a289af9a1687c6cf463478f0fa8e8aa3b6fb813317b0d70bf1ed0759eab6f761"},
+ {file = "llvmlite-0.43.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6d4fd101f571a31acb1559ae1af30f30b1dc4b3186669f92ad780e17c81e91bc"},
+ {file = "llvmlite-0.43.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d434ec7e2ce3cc8f452d1cd9a28591745de022f931d67be688a737320dfcead"},
+ {file = "llvmlite-0.43.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6912a87782acdff6eb8bf01675ed01d60ca1f2551f8176a300a886f09e836a6a"},
+ {file = "llvmlite-0.43.0-cp310-cp310-win_amd64.whl", hash = "sha256:14f0e4bf2fd2d9a75a3534111e8ebeb08eda2f33e9bdd6dfa13282afacdde0ed"},
+ {file = "llvmlite-0.43.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3e8d0618cb9bfe40ac38a9633f2493d4d4e9fcc2f438d39a4e854f39cc0f5f98"},
+ {file = "llvmlite-0.43.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e0a9a1a39d4bf3517f2af9d23d479b4175ead205c592ceeb8b89af48a327ea57"},
+ {file = "llvmlite-0.43.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1da416ab53e4f7f3bc8d4eeba36d801cc1894b9fbfbf2022b29b6bad34a7df2"},
+ {file = "llvmlite-0.43.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:977525a1e5f4059316b183fb4fd34fa858c9eade31f165427a3977c95e3ee749"},
+ {file = "llvmlite-0.43.0-cp311-cp311-win_amd64.whl", hash = "sha256:d5bd550001d26450bd90777736c69d68c487d17bf371438f975229b2b8241a91"},
+ {file = "llvmlite-0.43.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f99b600aa7f65235a5a05d0b9a9f31150c390f31261f2a0ba678e26823ec38f7"},
+ {file = "llvmlite-0.43.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:35d80d61d0cda2d767f72de99450766250560399edc309da16937b93d3b676e7"},
+ {file = "llvmlite-0.43.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eccce86bba940bae0d8d48ed925f21dbb813519169246e2ab292b5092aba121f"},
+ {file = "llvmlite-0.43.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df6509e1507ca0760787a199d19439cc887bfd82226f5af746d6977bd9f66844"},
+ {file = "llvmlite-0.43.0-cp312-cp312-win_amd64.whl", hash = "sha256:7a2872ee80dcf6b5dbdc838763d26554c2a18aa833d31a2635bff16aafefb9c9"},
+ {file = "llvmlite-0.43.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9cd2a7376f7b3367019b664c21f0c61766219faa3b03731113ead75107f3b66c"},
+ {file = "llvmlite-0.43.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:18e9953c748b105668487b7c81a3e97b046d8abf95c4ddc0cd3c94f4e4651ae8"},
+ {file = "llvmlite-0.43.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:74937acd22dc11b33946b67dca7680e6d103d6e90eeaaaf932603bec6fe7b03a"},
+ {file = "llvmlite-0.43.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc9efc739cc6ed760f795806f67889923f7274276f0eb45092a1473e40d9b867"},
+ {file = "llvmlite-0.43.0-cp39-cp39-win_amd64.whl", hash = "sha256:47e147cdda9037f94b399bf03bfd8a6b6b1f2f90be94a454e3386f006455a9b4"},
+ {file = "llvmlite-0.43.0.tar.gz", hash = "sha256:ae2b5b5c3ef67354824fb75517c8db5fbe93bc02cd9671f3c62271626bc041d5"},
+]
+
+[[package]]
+name = "loguru"
+version = "0.7.3"
+description = "Python logging made (stupidly) simple"
+optional = false
+python-versions = "<4.0,>=3.5"
+groups = ["main"]
+files = [
+ {file = "loguru-0.7.3-py3-none-any.whl", hash = "sha256:31a33c10c8e1e10422bfd431aeb5d351c7cf7fa671e3c4df004162264b28220c"},
+ {file = "loguru-0.7.3.tar.gz", hash = "sha256:19480589e77d47b8d85b2c827ad95d49bf31b0dcde16593892eb51dd18706eb6"},
+]
+
+[package.dependencies]
+colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""}
+win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""}
+
+[package.extras]
+dev = ["Sphinx (==8.1.3)", "build (==1.2.2)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.5.0)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.13.0)", "mypy (==v1.4.1)", "myst-parser (==4.0.0)", "pre-commit (==4.0.1)", "pytest (==6.1.2)", "pytest (==8.3.2)", "pytest-cov (==2.12.1)", "pytest-cov (==5.0.0)", "pytest-cov (==6.0.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.1.0)", "sphinx-rtd-theme (==3.0.2)", "tox (==3.27.1)", "tox (==4.23.2)", "twine (==6.0.1)"]
+
+[[package]]
+name = "markdown-it-py"
+version = "3.0.0"
+description = "Python port of markdown-it. Markdown parsing, done right!"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"},
+ {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"},
+]
+
+[package.dependencies]
+mdurl = ">=0.1,<1.0"
+
+[package.extras]
+benchmarking = ["psutil", "pytest", "pytest-benchmark"]
+code-style = ["pre-commit (>=3.0,<4.0)"]
+compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"]
+linkify = ["linkify-it-py (>=1,<3)"]
+plugins = ["mdit-py-plugins"]
+profiling = ["gprof2dot"]
+rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"]
+testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
+
+[[package]]
+name = "markupsafe"
+version = "3.0.2"
+description = "Safely add untrusted strings to HTML/XML markup."
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "MarkupSafe-3.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7e94c425039cde14257288fd61dcfb01963e658efbc0ff54f5306b06054700f8"},
+ {file = "MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9e2d922824181480953426608b81967de705c3cef4d1af983af849d7bd619158"},
+ {file = "MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:38a9ef736c01fccdd6600705b09dc574584b89bea478200c5fbf112a6b0d5579"},
+ {file = "MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bbcb445fa71794da8f178f0f6d66789a28d7319071af7a496d4d507ed566270d"},
+ {file = "MarkupSafe-3.0.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57cb5a3cf367aeb1d316576250f65edec5bb3be939e9247ae594b4bcbc317dfb"},
+ {file = "MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:3809ede931876f5b2ec92eef964286840ed3540dadf803dd570c3b7e13141a3b"},
+ {file = "MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e07c3764494e3776c602c1e78e298937c3315ccc9043ead7e685b7f2b8d47b3c"},
+ {file = "MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b424c77b206d63d500bcb69fa55ed8d0e6a3774056bdc4839fc9298a7edca171"},
+ {file = "MarkupSafe-3.0.2-cp310-cp310-win32.whl", hash = "sha256:fcabf5ff6eea076f859677f5f0b6b5c1a51e70a376b0579e0eadef8db48c6b50"},
+ {file = "MarkupSafe-3.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:6af100e168aa82a50e186c82875a5893c5597a0c1ccdb0d8b40240b1f28b969a"},
+ {file = "MarkupSafe-3.0.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:9025b4018f3a1314059769c7bf15441064b2207cb3f065e6ea1e7359cb46db9d"},
+ {file = "MarkupSafe-3.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:93335ca3812df2f366e80509ae119189886b0f3c2b81325d39efdb84a1e2ae93"},
+ {file = "MarkupSafe-3.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2cb8438c3cbb25e220c2ab33bb226559e7afb3baec11c4f218ffa7308603c832"},
+ {file = "MarkupSafe-3.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a123e330ef0853c6e822384873bef7507557d8e4a082961e1defa947aa59ba84"},
+ {file = "MarkupSafe-3.0.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e084f686b92e5b83186b07e8a17fc09e38fff551f3602b249881fec658d3eca"},
+ {file = "MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d8213e09c917a951de9d09ecee036d5c7d36cb6cb7dbaece4c71a60d79fb9798"},
+ {file = "MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:5b02fb34468b6aaa40dfc198d813a641e3a63b98c2b05a16b9f80b7ec314185e"},
+ {file = "MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0bff5e0ae4ef2e1ae4fdf2dfd5b76c75e5c2fa4132d05fc1b0dabcd20c7e28c4"},
+ {file = "MarkupSafe-3.0.2-cp311-cp311-win32.whl", hash = "sha256:6c89876f41da747c8d3677a2b540fb32ef5715f97b66eeb0c6b66f5e3ef6f59d"},
+ {file = "MarkupSafe-3.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:70a87b411535ccad5ef2f1df5136506a10775d267e197e4cf531ced10537bd6b"},
+ {file = "MarkupSafe-3.0.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:9778bd8ab0a994ebf6f84c2b949e65736d5575320a17ae8984a77fab08db94cf"},
+ {file = "MarkupSafe-3.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:846ade7b71e3536c4e56b386c2a47adf5741d2d8b94ec9dc3e92e5e1ee1e2225"},
+ {file = "MarkupSafe-3.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c99d261bd2d5f6b59325c92c73df481e05e57f19837bdca8413b9eac4bd8028"},
+ {file = "MarkupSafe-3.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e17c96c14e19278594aa4841ec148115f9c7615a47382ecb6b82bd8fea3ab0c8"},
+ {file = "MarkupSafe-3.0.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:88416bd1e65dcea10bc7569faacb2c20ce071dd1f87539ca2ab364bf6231393c"},
+ {file = "MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2181e67807fc2fa785d0592dc2d6206c019b9502410671cc905d132a92866557"},
+ {file = "MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:52305740fe773d09cffb16f8ed0427942901f00adedac82ec8b67752f58a1b22"},
+ {file = "MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ad10d3ded218f1039f11a75f8091880239651b52e9bb592ca27de44eed242a48"},
+ {file = "MarkupSafe-3.0.2-cp312-cp312-win32.whl", hash = "sha256:0f4ca02bea9a23221c0182836703cbf8930c5e9454bacce27e767509fa286a30"},
+ {file = "MarkupSafe-3.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:8e06879fc22a25ca47312fbe7c8264eb0b662f6db27cb2d3bbbc74b1df4b9b87"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ba9527cdd4c926ed0760bc301f6728ef34d841f405abf9d4f959c478421e4efd"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f8b3d067f2e40fe93e1ccdd6b2e1d16c43140e76f02fb1319a05cf2b79d99430"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:569511d3b58c8791ab4c2e1285575265991e6d8f8700c7be0e88f86cb0672094"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15ab75ef81add55874e7ab7055e9c397312385bd9ced94920f2802310c930396"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f3818cb119498c0678015754eba762e0d61e5b52d34c8b13d770f0719f7b1d79"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:cdb82a876c47801bb54a690c5ae105a46b392ac6099881cdfb9f6e95e4014c6a"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:cabc348d87e913db6ab4aa100f01b08f481097838bdddf7c7a84b7575b7309ca"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:444dcda765c8a838eaae23112db52f1efaf750daddb2d9ca300bcae1039adc5c"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313-win32.whl", hash = "sha256:bcf3e58998965654fdaff38e58584d8937aa3096ab5354d493c77d1fdd66d7a1"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:e6a2a455bd412959b57a172ce6328d2dd1f01cb2135efda2e4576e8a23fa3b0f"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:b5a6b3ada725cea8a5e634536b1b01c30bcdcd7f9c6fff4151548d5bf6b3a36c"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:a904af0a6162c73e3edcb969eeeb53a63ceeb5d8cf642fade7d39e7963a22ddb"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4aa4e5faecf353ed117801a068ebab7b7e09ffb6e1d5e412dc852e0da018126c"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0ef13eaeee5b615fb07c9a7dadb38eac06a0608b41570d8ade51c56539e509d"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d16a81a06776313e817c951135cf7340a3e91e8c1ff2fac444cfd75fffa04afe"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:6381026f158fdb7c72a168278597a5e3a5222e83ea18f543112b2662a9b699c5"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:3d79d162e7be8f996986c064d1c7c817f6df3a77fe3d6859f6f9e7be4b8c213a"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:131a3c7689c85f5ad20f9f6fb1b866f402c445b220c19fe4308c0b147ccd2ad9"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313t-win32.whl", hash = "sha256:ba8062ed2cf21c07a9e295d5b8a2a5ce678b913b45fdf68c32d95d6c1291e0b6"},
+ {file = "MarkupSafe-3.0.2-cp313-cp313t-win_amd64.whl", hash = "sha256:e444a31f8db13eb18ada366ab3cf45fd4b31e4db1236a4448f68778c1d1a5a2f"},
+ {file = "MarkupSafe-3.0.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:eaa0a10b7f72326f1372a713e73c3f739b524b3af41feb43e4921cb529f5929a"},
+ {file = "MarkupSafe-3.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:48032821bbdf20f5799ff537c7ac3d1fba0ba032cfc06194faffa8cda8b560ff"},
+ {file = "MarkupSafe-3.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a9d3f5f0901fdec14d8d2f66ef7d035f2157240a433441719ac9a3fba440b13"},
+ {file = "MarkupSafe-3.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88b49a3b9ff31e19998750c38e030fc7bb937398b1f78cfa599aaef92d693144"},
+ {file = "MarkupSafe-3.0.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cfad01eed2c2e0c01fd0ecd2ef42c492f7f93902e39a42fc9ee1692961443a29"},
+ {file = "MarkupSafe-3.0.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:1225beacc926f536dc82e45f8a4d68502949dc67eea90eab715dea3a21c1b5f0"},
+ {file = "MarkupSafe-3.0.2-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:3169b1eefae027567d1ce6ee7cae382c57fe26e82775f460f0b2778beaad66c0"},
+ {file = "MarkupSafe-3.0.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:eb7972a85c54febfb25b5c4b4f3af4dcc731994c7da0d8a0b4a6eb0640e1d178"},
+ {file = "MarkupSafe-3.0.2-cp39-cp39-win32.whl", hash = "sha256:8c4e8c3ce11e1f92f6536ff07154f9d49677ebaaafc32db9db4620bc11ed480f"},
+ {file = "MarkupSafe-3.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6e296a513ca3d94054c2c881cc913116e90fd030ad1c656b3869762b754f5f8a"},
+ {file = "markupsafe-3.0.2.tar.gz", hash = "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0"},
+]
+
+[[package]]
+name = "mdurl"
+version = "0.1.2"
+description = "Markdown URL utilities"
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"},
+ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
+]
+
+[[package]]
+name = "mpmath"
+version = "1.3.0"
+description = "Python library for arbitrary-precision floating-point arithmetic"
+optional = false
+python-versions = "*"
+groups = ["main"]
+files = [
+ {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"},
+ {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"},
+]
+
+[package.extras]
+develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"]
+docs = ["sphinx"]
+gmpy = ["gmpy2 (>=2.1.0a4)"]
+tests = ["pytest (>=4.6)"]
+
+[[package]]
+name = "nest-asyncio"
+version = "1.6.0"
+description = "Patch asyncio to allow nested event loops"
+optional = true
+python-versions = ">=3.5"
+groups = ["main"]
+files = [
+ {file = "nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c"},
+ {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"},
+]
+
+[[package]]
+name = "networkx"
+version = "3.2.1"
+description = "Python package for creating and manipulating graphs and networks"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "networkx-3.2.1-py3-none-any.whl", hash = "sha256:f18c69adc97877c42332c170849c96cefa91881c99a7cb3e95b7c659ebdc1ec2"},
+ {file = "networkx-3.2.1.tar.gz", hash = "sha256:9f1bb5cf3409bf324e0a722c20bdb4c20ee39bf1c30ce8ae499c8502b0b5e0c6"},
+]
+
+[package.extras]
+default = ["matplotlib (>=3.5)", "numpy (>=1.22)", "pandas (>=1.4)", "scipy (>=1.9,!=1.11.0,!=1.11.1)"]
+developer = ["changelist (==0.4)", "mypy (>=1.1)", "pre-commit (>=3.2)", "rtoml"]
+doc = ["nb2plots (>=0.7)", "nbconvert (<7.9)", "numpydoc (>=1.6)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.14)", "sphinx (>=7)", "sphinx-gallery (>=0.14)", "texext (>=0.6.7)"]
+extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"]
+test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"]
+
+[[package]]
+name = "numba"
+version = "0.60.0"
+description = "compiling Python code using LLVM"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "numba-0.60.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5d761de835cd38fb400d2c26bb103a2726f548dc30368853121d66201672e651"},
+ {file = "numba-0.60.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:159e618ef213fba758837f9837fb402bbe65326e60ba0633dbe6c7f274d42c1b"},
+ {file = "numba-0.60.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1527dc578b95c7c4ff248792ec33d097ba6bef9eda466c948b68dfc995c25781"},
+ {file = "numba-0.60.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fe0b28abb8d70f8160798f4de9d486143200f34458d34c4a214114e445d7124e"},
+ {file = "numba-0.60.0-cp310-cp310-win_amd64.whl", hash = "sha256:19407ced081d7e2e4b8d8c36aa57b7452e0283871c296e12d798852bc7d7f198"},
+ {file = "numba-0.60.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a17b70fc9e380ee29c42717e8cc0bfaa5556c416d94f9aa96ba13acb41bdece8"},
+ {file = "numba-0.60.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3fb02b344a2a80efa6f677aa5c40cd5dd452e1b35f8d1c2af0dfd9ada9978e4b"},
+ {file = "numba-0.60.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5f4fde652ea604ea3c86508a3fb31556a6157b2c76c8b51b1d45eb40c8598703"},
+ {file = "numba-0.60.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4142d7ac0210cc86432b818338a2bc368dc773a2f5cf1e32ff7c5b378bd63ee8"},
+ {file = "numba-0.60.0-cp311-cp311-win_amd64.whl", hash = "sha256:cac02c041e9b5bc8cf8f2034ff6f0dbafccd1ae9590dc146b3a02a45e53af4e2"},
+ {file = "numba-0.60.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d7da4098db31182fc5ffe4bc42c6f24cd7d1cb8a14b59fd755bfee32e34b8404"},
+ {file = "numba-0.60.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:38d6ea4c1f56417076ecf8fc327c831ae793282e0ff51080c5094cb726507b1c"},
+ {file = "numba-0.60.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:62908d29fb6a3229c242e981ca27e32a6e606cc253fc9e8faeb0e48760de241e"},
+ {file = "numba-0.60.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0ebaa91538e996f708f1ab30ef4d3ddc344b64b5227b67a57aa74f401bb68b9d"},
+ {file = "numba-0.60.0-cp312-cp312-win_amd64.whl", hash = "sha256:f75262e8fe7fa96db1dca93d53a194a38c46da28b112b8a4aca168f0df860347"},
+ {file = "numba-0.60.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:01ef4cd7d83abe087d644eaa3d95831b777aa21d441a23703d649e06b8e06b74"},
+ {file = "numba-0.60.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:819a3dfd4630d95fd574036f99e47212a1af41cbcb019bf8afac63ff56834449"},
+ {file = "numba-0.60.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0b983bd6ad82fe868493012487f34eae8bf7dd94654951404114f23c3466d34b"},
+ {file = "numba-0.60.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c151748cd269ddeab66334bd754817ffc0cabd9433acb0f551697e5151917d25"},
+ {file = "numba-0.60.0-cp39-cp39-win_amd64.whl", hash = "sha256:3031547a015710140e8c87226b4cfe927cac199835e5bf7d4fe5cb64e814e3ab"},
+ {file = "numba-0.60.0.tar.gz", hash = "sha256:5df6158e5584eece5fc83294b949fd30b9f1125df7708862205217e068aabf16"},
+]
+
+[package.dependencies]
+llvmlite = "==0.43.*"
+numpy = ">=1.22,<2.1"
+
+[[package]]
+name = "numpy"
+version = "1.26.4"
+description = "Fundamental package for array computing in Python"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"},
+ {file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"},
+ {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"},
+ {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"},
+ {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"},
+ {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"},
+ {file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"},
+ {file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"},
+ {file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"},
+ {file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"},
+ {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"},
+ {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"},
+ {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"},
+ {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"},
+ {file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"},
+ {file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"},
+ {file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"},
+ {file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"},
+ {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"},
+ {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"},
+ {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"},
+ {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"},
+ {file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"},
+ {file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"},
+ {file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"},
+ {file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"},
+ {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"},
+ {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"},
+ {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"},
+ {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"},
+ {file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"},
+ {file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"},
+ {file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"},
+ {file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"},
+ {file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"},
+ {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"},
+]
+
+[[package]]
+name = "nvidia-cublas-cu12"
+version = "12.4.5.8"
+description = "CUBLAS native runtime libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3"},
+ {file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b"},
+ {file = "nvidia_cublas_cu12-12.4.5.8-py3-none-win_amd64.whl", hash = "sha256:5a796786da89203a0657eda402bcdcec6180254a8ac22d72213abc42069522dc"},
+]
+
+[[package]]
+name = "nvidia-cuda-cupti-cu12"
+version = "12.4.127"
+description = "CUDA profiling tools runtime libs."
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a"},
+ {file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb"},
+ {file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:5688d203301ab051449a2b1cb6690fbe90d2b372f411521c86018b950f3d7922"},
+]
+
+[[package]]
+name = "nvidia-cuda-nvrtc-cu12"
+version = "12.4.127"
+description = "NVRTC native runtime libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0eedf14185e04b76aa05b1fea04133e59f465b6f960c0cbf4e37c3cb6b0ea198"},
+ {file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338"},
+ {file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:a961b2f1d5f17b14867c619ceb99ef6fcec12e46612711bcec78eb05068a60ec"},
+]
+
+[[package]]
+name = "nvidia-cuda-runtime-cu12"
+version = "12.4.127"
+description = "CUDA Runtime native Libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3"},
+ {file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5"},
+ {file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:09c2e35f48359752dfa822c09918211844a3d93c100a715d79b59591130c5e1e"},
+]
+
+[[package]]
+name = "nvidia-cudnn-cu12"
+version = "9.1.0.70"
+description = "cuDNN runtime libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f"},
+ {file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a"},
+]
+
+[package.dependencies]
+nvidia-cublas-cu12 = "*"
+
+[[package]]
+name = "nvidia-cufft-cu12"
+version = "11.2.1.3"
+description = "CUFFT native runtime libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399"},
+ {file = "nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9"},
+ {file = "nvidia_cufft_cu12-11.2.1.3-py3-none-win_amd64.whl", hash = "sha256:d802f4954291101186078ccbe22fc285a902136f974d369540fd4a5333d1440b"},
+]
+
+[package.dependencies]
+nvidia-nvjitlink-cu12 = "*"
+
+[[package]]
+name = "nvidia-curand-cu12"
+version = "10.3.5.147"
+description = "CURAND native runtime libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1f173f09e3e3c76ab084aba0de819c49e56614feae5c12f69883f4ae9bb5fad9"},
+ {file = "nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b"},
+ {file = "nvidia_curand_cu12-10.3.5.147-py3-none-win_amd64.whl", hash = "sha256:f307cc191f96efe9e8f05a87096abc20d08845a841889ef78cb06924437f6771"},
+]
+
+[[package]]
+name = "nvidia-cusolver-cu12"
+version = "11.6.1.9"
+description = "CUDA solver native runtime libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e"},
+ {file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260"},
+ {file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-win_amd64.whl", hash = "sha256:e77314c9d7b694fcebc84f58989f3aa4fb4cb442f12ca1a9bde50f5e8f6d1b9c"},
+]
+
+[package.dependencies]
+nvidia-cublas-cu12 = "*"
+nvidia-cusparse-cu12 = "*"
+nvidia-nvjitlink-cu12 = "*"
+
+[[package]]
+name = "nvidia-cusparse-cu12"
+version = "12.3.1.170"
+description = "CUSPARSE native runtime libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3"},
+ {file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1"},
+ {file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-win_amd64.whl", hash = "sha256:9bc90fb087bc7b4c15641521f31c0371e9a612fc2ba12c338d3ae032e6b6797f"},
+]
+
+[package.dependencies]
+nvidia-nvjitlink-cu12 = "*"
+
+[[package]]
+name = "nvidia-cusparselt-cu12"
+version = "0.6.2"
+description = "NVIDIA cuSPARSELt"
+optional = false
+python-versions = "*"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:067a7f6d03ea0d4841c85f0c6f1991c5dda98211f6302cb83a4ab234ee95bef8"},
+ {file = "nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:df2c24502fd76ebafe7457dbc4716b2fec071aabaed4fb7691a201cde03704d9"},
+ {file = "nvidia_cusparselt_cu12-0.6.2-py3-none-win_amd64.whl", hash = "sha256:0057c91d230703924c0422feabe4ce768841f9b4b44d28586b6f6d2eb86fbe70"},
+]
+
+[[package]]
+name = "nvidia-nccl-cu12"
+version = "2.21.5"
+description = "NVIDIA Collective Communication Library (NCCL) Runtime"
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0"},
+]
+
+[[package]]
+name = "nvidia-nvjitlink-cu12"
+version = "12.4.127"
+description = "Nvidia JIT LTO Library"
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"},
+ {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},
+ {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"},
+]
+
+[[package]]
+name = "nvidia-nvtx-cu12"
+version = "12.4.127"
+description = "NVIDIA Tools Extension"
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7959ad635db13edf4fc65c06a6e9f9e55fc2f92596db928d169c0bb031e88ef3"},
+ {file = "nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a"},
+ {file = "nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485"},
+]
+
+[[package]]
+name = "opentelemetry-api"
+version = "1.32.0"
+description = "OpenTelemetry Python API"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_api-1.32.0-py3-none-any.whl", hash = "sha256:15df743c765078611f376037b0d9111ec5c1febf2ec9440cdd919370faa1ce55"},
+ {file = "opentelemetry_api-1.32.0.tar.gz", hash = "sha256:2623280c916f9b19cad0aa4280cb171265f19fd2909b0d47e4f06f7c83b02cb5"},
+]
+
+[package.dependencies]
+deprecated = ">=1.2.6"
+importlib-metadata = ">=6.0,<8.7.0"
+
+[[package]]
+name = "opentelemetry-exporter-otlp"
+version = "1.32.0"
+description = "OpenTelemetry Collector Exporters"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_exporter_otlp-1.32.0-py3-none-any.whl", hash = "sha256:8b563bee30f05415fb51e075eb6461cdaa7bcef1cc79917cfd79caf12e5bb548"},
+ {file = "opentelemetry_exporter_otlp-1.32.0.tar.gz", hash = "sha256:4c66681f8acd95dce44966842182e3690e77256e5791ceb34b76ea1c34b20463"},
+]
+
+[package.dependencies]
+opentelemetry-exporter-otlp-proto-grpc = "1.32.0"
+opentelemetry-exporter-otlp-proto-http = "1.32.0"
+
+[[package]]
+name = "opentelemetry-exporter-otlp-proto-common"
+version = "1.32.0"
+description = "OpenTelemetry Protobuf encoding"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_exporter_otlp_proto_common-1.32.0-py3-none-any.whl", hash = "sha256:277a63a18768b3b460d082a489f6f80d4ae2c1e6b185bb701c6bd4e91405e4bd"},
+ {file = "opentelemetry_exporter_otlp_proto_common-1.32.0.tar.gz", hash = "sha256:2bca672f2a279c4f517115e635c0cc1269d07b2982a36681c521f7e56179a222"},
+]
+
+[package.dependencies]
+opentelemetry-proto = "1.32.0"
+
+[[package]]
+name = "opentelemetry-exporter-otlp-proto-grpc"
+version = "1.32.0"
+description = "OpenTelemetry Collector Protobuf over gRPC Exporter"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_exporter_otlp_proto_grpc-1.32.0-py3-none-any.whl", hash = "sha256:85b7c42bebe48ef55866793a3123ebf357dcaf629d961b27067025fd60104dbe"},
+ {file = "opentelemetry_exporter_otlp_proto_grpc-1.32.0.tar.gz", hash = "sha256:c069c5d5f429a46fb1001f38191730939f593789c847648e4cea26dc8b6018a8"},
+]
+
+[package.dependencies]
+deprecated = ">=1.2.6"
+googleapis-common-protos = ">=1.52,<2.0"
+grpcio = {version = ">=1.63.2,<2.0.0", markers = "python_version < \"3.13\""}
+opentelemetry-api = ">=1.15,<2.0"
+opentelemetry-exporter-otlp-proto-common = "1.32.0"
+opentelemetry-proto = "1.32.0"
+opentelemetry-sdk = ">=1.32.0,<1.33.0"
+
+[[package]]
+name = "opentelemetry-exporter-otlp-proto-http"
+version = "1.32.0"
+description = "OpenTelemetry Collector Protobuf over HTTP Exporter"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_exporter_otlp_proto_http-1.32.0-py3-none-any.whl", hash = "sha256:e2ffecd6d2220eaf1291a46339f109bc0a57ee7c4c6abb8174df418bf00ce01f"},
+ {file = "opentelemetry_exporter_otlp_proto_http-1.32.0.tar.gz", hash = "sha256:a5dfd94603da86e313e4f4fb8d181fd3b64a7c2a9c7b408c3653d2b1bc68d14f"},
+]
+
+[package.dependencies]
+deprecated = ">=1.2.6"
+googleapis-common-protos = ">=1.52,<2.0"
+opentelemetry-api = ">=1.15,<2.0"
+opentelemetry-exporter-otlp-proto-common = "1.32.0"
+opentelemetry-proto = "1.32.0"
+opentelemetry-sdk = ">=1.32.0,<1.33.0"
+requests = ">=2.7,<3.0"
+
+[[package]]
+name = "opentelemetry-instrumentation"
+version = "0.53b0"
+description = "Instrumentation Tools & Auto Instrumentation for OpenTelemetry Python"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_instrumentation-0.53b0-py3-none-any.whl", hash = "sha256:70600778fd567c9c5fbfca181378ae179c0dec3ff613171707d3d77c360ff105"},
+ {file = "opentelemetry_instrumentation-0.53b0.tar.gz", hash = "sha256:f2c21d71a3cdf28c656e3d90d247ee7558fb9b0239b3d9e9190266499dbed9d2"},
+]
+
+[package.dependencies]
+opentelemetry-api = ">=1.4,<2.0"
+opentelemetry-semantic-conventions = "0.53b0"
+packaging = ">=18.0"
+wrapt = ">=1.0.0,<2.0.0"
+
+[[package]]
+name = "opentelemetry-instrumentation-grpc"
+version = "0.53b0"
+description = "OpenTelemetry gRPC instrumentation"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_instrumentation_grpc-0.53b0-py3-none-any.whl", hash = "sha256:bd44f113c58fd66614b07bd9b8115ec311389ec58ef7e48a06581e302971c3f4"},
+ {file = "opentelemetry_instrumentation_grpc-0.53b0.tar.gz", hash = "sha256:a95b752e0782e7b503379de1c64a5afa2c7c1cd8196fa5f2b5c090d01c15e517"},
+]
+
+[package.dependencies]
+opentelemetry-api = ">=1.12,<2.0"
+opentelemetry-instrumentation = "0.53b0"
+opentelemetry-semantic-conventions = "0.53b0"
+wrapt = ">=1.0.0,<2.0.0"
+
+[package.extras]
+instruments = ["grpcio (>=1.42.0)"]
+
+[[package]]
+name = "opentelemetry-proto"
+version = "1.32.0"
+description = "OpenTelemetry Python Proto"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_proto-1.32.0-py3-none-any.whl", hash = "sha256:f699269dc037e18fba05442580a8682c9fbd0f4c7f5addfed82c44be0c53c5ff"},
+ {file = "opentelemetry_proto-1.32.0.tar.gz", hash = "sha256:f8b70ae52f4ef8a4e4c0760e87c9071e07ece2618c080d4839bef44c0156cd44"},
+]
+
+[package.dependencies]
+protobuf = ">=5.0,<6.0"
+
+[[package]]
+name = "opentelemetry-sdk"
+version = "1.32.0"
+description = "OpenTelemetry Python SDK"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_sdk-1.32.0-py3-none-any.whl", hash = "sha256:ed252d035c22a15536c1f603ca089298daab60850fc2f5ddfa95d95cc1c043ea"},
+ {file = "opentelemetry_sdk-1.32.0.tar.gz", hash = "sha256:5ff07fb371d1ab1189fa7047702e2e888b5403c5efcbb18083cae0d5aa5f58d2"},
+]
+
+[package.dependencies]
+opentelemetry-api = "1.32.0"
+opentelemetry-semantic-conventions = "0.53b0"
+typing-extensions = ">=3.7.4"
+
+[[package]]
+name = "opentelemetry-semantic-conventions"
+version = "0.53b0"
+description = "OpenTelemetry Semantic Conventions"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "opentelemetry_semantic_conventions-0.53b0-py3-none-any.whl", hash = "sha256:561da89f766ab51615c0e72b12329e0a1bc16945dbd62c8646ffc74e36a1edff"},
+ {file = "opentelemetry_semantic_conventions-0.53b0.tar.gz", hash = "sha256:05b7908e1da62d72f9bf717ed25c72f566fe005a2dd260c61b11e025f2552cf6"},
+]
+
+[package.dependencies]
+deprecated = ">=1.2.6"
+opentelemetry-api = "1.32.0"
+
+[[package]]
+name = "optimum"
+version = "1.24.0"
+description = "Optimum Library is an extension of the Hugging Face Transformers library, providing a framework to integrate third-party libraries from Hardware Partners and interface with their specific functionality."
+optional = false
+python-versions = ">=3.9.0"
+groups = ["main"]
+files = [
+ {file = "optimum-1.24.0-py3-none-any.whl", hash = "sha256:196776949183cd3a56a15097a02be41e6f37aa92d824bd053de89c39ee6b0087"},
+ {file = "optimum-1.24.0.tar.gz", hash = "sha256:b502a2afbf78bb73370ebb1eff07b93108a1b386116e87eb17e882d210150551"},
+]
+
+[package.dependencies]
+huggingface-hub = ">=0.8.0"
+numpy = "*"
+packaging = "*"
+torch = ">=1.11"
+transformers = ">=4.29"
+
+[package.extras]
+amd = ["optimum-amd"]
+benchmark = ["evaluate (>=0.2.0)", "optuna", "scikit-learn", "seqeval", "torchvision", "tqdm"]
+dev = ["Pillow", "accelerate", "black (>=23.1,<24.0)", "einops", "parameterized", "pytest (<=8.0.0)", "pytest-xdist", "requests", "rjieba", "ruff (==0.1.5)", "sacremoses", "scikit-learn", "sentencepiece", "timm", "torchaudio", "torchvision"]
+diffusers = ["diffusers"]
+doc-build = ["accelerate"]
+exporters = ["onnx", "onnxruntime", "timm", "transformers (>=4.36,<4.49.0)"]
+exporters-gpu = ["onnx", "onnxruntime-gpu", "timm", "transformers (>=4.36,<4.49.0)"]
+exporters-tf = ["datasets (<=2.16)", "h5py", "numpy (<1.24.0)", "onnx", "onnxruntime", "tensorflow (>=2.4,<=2.12.1)", "tf2onnx", "timm", "transformers (>=4.36,<4.38)"]
+furiosa = ["optimum-furiosa"]
+graphcore = ["optimum-graphcore"]
+habana = ["optimum-habana", "transformers (>=4.45.0,<4.46.0)"]
+intel = ["optimum-intel (>=1.18.0)"]
+ipex = ["optimum-intel[ipex] (>=1.18.0)"]
+neural-compressor = ["optimum-intel[neural-compressor] (>=1.18.0)"]
+neuron = ["optimum-neuron[neuron] (>=0.0.20)", "transformers (>=4.36.2,<4.42.0)"]
+neuronx = ["optimum-neuron[neuronx] (>=0.0.20)", "transformers (>=4.36.2,<4.42.0)"]
+nncf = ["optimum-intel[nncf] (>=1.18.0)"]
+onnxruntime = ["datasets (>=1.2.1)", "evaluate", "onnx", "onnxruntime (>=1.11.0)", "protobuf (>=3.20.1)", "transformers (>=4.36,<4.49.0)"]
+onnxruntime-gpu = ["datasets (>=1.2.1)", "evaluate", "onnx", "onnxruntime-gpu (>=1.11.0)", "protobuf (>=3.20.1)", "transformers (>=4.36,<4.49.0)"]
+onnxruntime-training = ["accelerate", "datasets (>=1.2.1)", "evaluate", "onnxruntime-training (>=1.11.0)", "protobuf (>=3.20.1)", "torch-ort", "transformers (>=4.36,<4.49.0)"]
+openvino = ["optimum-intel[openvino] (>=1.18.0)"]
+quality = ["black (>=23.1,<24.0)", "ruff (==0.1.5)"]
+quanto = ["optimum-quanto (>=0.2.4)"]
+tests = ["Pillow", "accelerate", "einops", "parameterized", "pytest (<=8.0.0)", "pytest-xdist", "requests", "rjieba", "sacremoses", "scikit-learn", "sentencepiece", "timm", "torchaudio", "torchvision"]
+
+[[package]]
+name = "optimum-habana"
+version = "1.17.0"
+description = "Optimum Habana is the interface between the Hugging Face Transformers and Diffusers libraries and Habana's Gaudi processor (HPU). It provides a set of tools enabling easy model loading, training and inference on single- and multi-HPU settings for different downstream tasks."
+optional = false
+python-versions = "*"
+groups = ["main"]
+files = [
+ {file = "optimum_habana-1.17.0-py3-none-any.whl", hash = "sha256:4f1008c7e84248b62778c8d5f79443237026a5a281b50ee67c7db211c5ca7d2a"},
+ {file = "optimum_habana-1.17.0.tar.gz", hash = "sha256:634adaa775c5c1694a164bdec46133e9712676237e996aadf3392c820573ce92"},
+]
+
+[package.dependencies]
+accelerate = ">=0.33.0,<0.34.0"
+diffusers = ">=0.31.0,<0.32.0"
+huggingface_hub = ">=0.24.7"
+optimum = "*"
+sentence-transformers = "3.3.1"
+torch = "*"
+transformers = ">=4.49.0,<4.50.0"
+
+[package.extras]
+quality = ["hf_doc_builder", "ruff"]
+tests = ["GitPython", "datasets", "optuna", "parameterized", "peft", "psutil", "pytest (<8.0.0)", "safetensors", "scipy", "sentencepiece", "timm", "timm", "torchsde"]
+
+[[package]]
+name = "outlines"
+version = "0.0.36"
+description = "Probabilistic Generative Model Programming"
+optional = true
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "outlines-0.0.36-py3-none-any.whl", hash = "sha256:afa02ca5c449c47731fa06af66d13c2f5ee8b30f8b82b4db90e08215d6f111d1"},
+ {file = "outlines-0.0.36.tar.gz", hash = "sha256:3cffb43143548cd78c6061990feb461cffd5479999391b8390471ea839c2d46e"},
+]
+
+[package.dependencies]
+cloudpickle = "*"
+diskcache = "*"
+interegular = "*"
+jinja2 = "*"
+joblib = "*"
+jsonschema = "*"
+lark = "*"
+nest-asyncio = "*"
+numba = "*"
+numpy = "*"
+pydantic = ">=2.0"
+referencing = "*"
+requests = "*"
+scipy = "*"
+torch = ">=2.1.0"
+transformers = "*"
+
+[package.extras]
+serve = ["fastapi", "pydantic (>=2.0)", "ray (==2.9.0)", "uvicorn", "vllm (>=0.3.0)"]
+test = ["accelerate", "beartype (<0.16.0)", "coverage[toml] (>=5.1)", "datasets", "diff-cover", "huggingface-hub", "llama-cpp-python", "openai (>=1.0.0)", "pre-commit", "pytest", "pytest-benchmark", "pytest-cov", "pytest-mock", "responses", "transformers"]
+
+[[package]]
+name = "packaging"
+version = "24.2"
+description = "Core utilities for Python packages"
+optional = false
+python-versions = ">=3.8"
+groups = ["main", "dev"]
+files = [
+ {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"},
+ {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"},
+]
+
+[[package]]
+name = "peft"
+version = "0.15.1"
+description = "Parameter-Efficient Fine-Tuning (PEFT)"
+optional = false
+python-versions = ">=3.9.0"
+groups = ["main"]
+files = [
+ {file = "peft-0.15.1-py3-none-any.whl", hash = "sha256:5fb3960beb518f00668f2cdc53424a5cc495c78281697821ce24609c90ca0a10"},
+ {file = "peft-0.15.1.tar.gz", hash = "sha256:e4c65af70683a9ef3baf1ab450710f1eb7181f369ef6172ca8bf15bf4ae6ff71"},
+]
+
+[package.dependencies]
+accelerate = ">=0.21.0"
+huggingface_hub = ">=0.25.0"
+numpy = ">=1.17"
+packaging = ">=20.0"
+psutil = "*"
+pyyaml = "*"
+safetensors = "*"
+torch = ">=1.13.0"
+tqdm = "*"
+transformers = "*"
+
+[package.extras]
+dev = ["black", "black", "hf-doc-builder", "hf-doc-builder", "ruff (>=0.9.2,<0.10.0)"]
+docs-specific = ["black", "hf-doc-builder"]
+quality = ["black", "hf-doc-builder", "ruff (>=0.9.2,<0.10.0)"]
+test = ["black", "black", "datasets", "diffusers", "hf-doc-builder", "hf-doc-builder", "parameterized", "protobuf", "pytest", "pytest-cov", "pytest-xdist", "ruff (>=0.9.2,<0.10.0)", "scipy", "sentencepiece"]
+
+[[package]]
+name = "pillow"
+version = "11.2.1"
+description = "Python Imaging Library (Fork)"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "pillow-11.2.1-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:d57a75d53922fc20c165016a20d9c44f73305e67c351bbc60d1adaf662e74047"},
+ {file = "pillow-11.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:127bf6ac4a5b58b3d32fc8289656f77f80567d65660bc46f72c0d77e6600cc95"},
+ {file = "pillow-11.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b4ba4be812c7a40280629e55ae0b14a0aafa150dd6451297562e1764808bbe61"},
+ {file = "pillow-11.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8bd62331e5032bc396a93609982a9ab6b411c05078a52f5fe3cc59234a3abd1"},
+ {file = "pillow-11.2.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:562d11134c97a62fe3af29581f083033179f7ff435f78392565a1ad2d1c2c45c"},
+ {file = "pillow-11.2.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:c97209e85b5be259994eb5b69ff50c5d20cca0f458ef9abd835e262d9d88b39d"},
+ {file = "pillow-11.2.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:0c3e6d0f59171dfa2e25d7116217543310908dfa2770aa64b8f87605f8cacc97"},
+ {file = "pillow-11.2.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:cc1c3bc53befb6096b84165956e886b1729634a799e9d6329a0c512ab651e579"},
+ {file = "pillow-11.2.1-cp310-cp310-win32.whl", hash = "sha256:312c77b7f07ab2139924d2639860e084ec2a13e72af54d4f08ac843a5fc9c79d"},
+ {file = "pillow-11.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:9bc7ae48b8057a611e5fe9f853baa88093b9a76303937449397899385da06fad"},
+ {file = "pillow-11.2.1-cp310-cp310-win_arm64.whl", hash = "sha256:2728567e249cdd939f6cc3d1f049595c66e4187f3c34078cbc0a7d21c47482d2"},
+ {file = "pillow-11.2.1-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:35ca289f712ccfc699508c4658a1d14652e8033e9b69839edf83cbdd0ba39e70"},
+ {file = "pillow-11.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e0409af9f829f87a2dfb7e259f78f317a5351f2045158be321fd135973fff7bf"},
+ {file = "pillow-11.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4e5c5edee874dce4f653dbe59db7c73a600119fbea8d31f53423586ee2aafd7"},
+ {file = "pillow-11.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b93a07e76d13bff9444f1a029e0af2964e654bfc2e2c2d46bfd080df5ad5f3d8"},
+ {file = "pillow-11.2.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:e6def7eed9e7fa90fde255afaf08060dc4b343bbe524a8f69bdd2a2f0018f600"},
+ {file = "pillow-11.2.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:8f4f3724c068be008c08257207210c138d5f3731af6c155a81c2b09a9eb3a788"},
+ {file = "pillow-11.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:a0a6709b47019dff32e678bc12c63008311b82b9327613f534e496dacaefb71e"},
+ {file = "pillow-11.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f6b0c664ccb879109ee3ca702a9272d877f4fcd21e5eb63c26422fd6e415365e"},
+ {file = "pillow-11.2.1-cp311-cp311-win32.whl", hash = "sha256:cc5d875d56e49f112b6def6813c4e3d3036d269c008bf8aef72cd08d20ca6df6"},
+ {file = "pillow-11.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:0f5c7eda47bf8e3c8a283762cab94e496ba977a420868cb819159980b6709193"},
+ {file = "pillow-11.2.1-cp311-cp311-win_arm64.whl", hash = "sha256:4d375eb838755f2528ac8cbc926c3e31cc49ca4ad0cf79cff48b20e30634a4a7"},
+ {file = "pillow-11.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:78afba22027b4accef10dbd5eed84425930ba41b3ea0a86fa8d20baaf19d807f"},
+ {file = "pillow-11.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:78092232a4ab376a35d68c4e6d5e00dfd73454bd12b230420025fbe178ee3b0b"},
+ {file = "pillow-11.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25a5f306095c6780c52e6bbb6109624b95c5b18e40aab1c3041da3e9e0cd3e2d"},
+ {file = "pillow-11.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c7b29dbd4281923a2bfe562acb734cee96bbb129e96e6972d315ed9f232bef4"},
+ {file = "pillow-11.2.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:3e645b020f3209a0181a418bffe7b4a93171eef6c4ef6cc20980b30bebf17b7d"},
+ {file = "pillow-11.2.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b2dbea1012ccb784a65349f57bbc93730b96e85b42e9bf7b01ef40443db720b4"},
+ {file = "pillow-11.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:da3104c57bbd72948d75f6a9389e6727d2ab6333c3617f0a89d72d4940aa0443"},
+ {file = "pillow-11.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:598174aef4589af795f66f9caab87ba4ff860ce08cd5bb447c6fc553ffee603c"},
+ {file = "pillow-11.2.1-cp312-cp312-win32.whl", hash = "sha256:1d535df14716e7f8776b9e7fee118576d65572b4aad3ed639be9e4fa88a1cad3"},
+ {file = "pillow-11.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:14e33b28bf17c7a38eede290f77db7c664e4eb01f7869e37fa98a5aa95978941"},
+ {file = "pillow-11.2.1-cp312-cp312-win_arm64.whl", hash = "sha256:21e1470ac9e5739ff880c211fc3af01e3ae505859392bf65458c224d0bf283eb"},
+ {file = "pillow-11.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:fdec757fea0b793056419bca3e9932eb2b0ceec90ef4813ea4c1e072c389eb28"},
+ {file = "pillow-11.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:b0e130705d568e2f43a17bcbe74d90958e8a16263868a12c3e0d9c8162690830"},
+ {file = "pillow-11.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7bdb5e09068332578214cadd9c05e3d64d99e0e87591be22a324bdbc18925be0"},
+ {file = "pillow-11.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d189ba1bebfbc0c0e529159631ec72bb9e9bc041f01ec6d3233d6d82eb823bc1"},
+ {file = "pillow-11.2.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:191955c55d8a712fab8934a42bfefbf99dd0b5875078240943f913bb66d46d9f"},
+ {file = "pillow-11.2.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:ad275964d52e2243430472fc5d2c2334b4fc3ff9c16cb0a19254e25efa03a155"},
+ {file = "pillow-11.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:750f96efe0597382660d8b53e90dd1dd44568a8edb51cb7f9d5d918b80d4de14"},
+ {file = "pillow-11.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:fe15238d3798788d00716637b3d4e7bb6bde18b26e5d08335a96e88564a36b6b"},
+ {file = "pillow-11.2.1-cp313-cp313-win32.whl", hash = "sha256:3fe735ced9a607fee4f481423a9c36701a39719252a9bb251679635f99d0f7d2"},
+ {file = "pillow-11.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:74ee3d7ecb3f3c05459ba95eed5efa28d6092d751ce9bf20e3e253a4e497e691"},
+ {file = "pillow-11.2.1-cp313-cp313-win_arm64.whl", hash = "sha256:5119225c622403afb4b44bad4c1ca6c1f98eed79db8d3bc6e4e160fc6339d66c"},
+ {file = "pillow-11.2.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:8ce2e8411c7aaef53e6bb29fe98f28cd4fbd9a1d9be2eeea434331aac0536b22"},
+ {file = "pillow-11.2.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:9ee66787e095127116d91dea2143db65c7bb1e232f617aa5957c0d9d2a3f23a7"},
+ {file = "pillow-11.2.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9622e3b6c1d8b551b6e6f21873bdcc55762b4b2126633014cea1803368a9aa16"},
+ {file = "pillow-11.2.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63b5dff3a68f371ea06025a1a6966c9a1e1ee452fc8020c2cd0ea41b83e9037b"},
+ {file = "pillow-11.2.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:31df6e2d3d8fc99f993fd253e97fae451a8db2e7207acf97859732273e108406"},
+ {file = "pillow-11.2.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:062b7a42d672c45a70fa1f8b43d1d38ff76b63421cbbe7f88146b39e8a558d91"},
+ {file = "pillow-11.2.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:4eb92eca2711ef8be42fd3f67533765d9fd043b8c80db204f16c8ea62ee1a751"},
+ {file = "pillow-11.2.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f91ebf30830a48c825590aede79376cb40f110b387c17ee9bd59932c961044f9"},
+ {file = "pillow-11.2.1-cp313-cp313t-win32.whl", hash = "sha256:e0b55f27f584ed623221cfe995c912c61606be8513bfa0e07d2c674b4516d9dd"},
+ {file = "pillow-11.2.1-cp313-cp313t-win_amd64.whl", hash = "sha256:36d6b82164c39ce5482f649b437382c0fb2395eabc1e2b1702a6deb8ad647d6e"},
+ {file = "pillow-11.2.1-cp313-cp313t-win_arm64.whl", hash = "sha256:225c832a13326e34f212d2072982bb1adb210e0cc0b153e688743018c94a2681"},
+ {file = "pillow-11.2.1-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:7491cf8a79b8eb867d419648fff2f83cb0b3891c8b36da92cc7f1931d46108c8"},
+ {file = "pillow-11.2.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8b02d8f9cb83c52578a0b4beadba92e37d83a4ef11570a8688bbf43f4ca50909"},
+ {file = "pillow-11.2.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:014ca0050c85003620526b0ac1ac53f56fc93af128f7546623cc8e31875ab928"},
+ {file = "pillow-11.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3692b68c87096ac6308296d96354eddd25f98740c9d2ab54e1549d6c8aea9d79"},
+ {file = "pillow-11.2.1-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:f781dcb0bc9929adc77bad571b8621ecb1e4cdef86e940fe2e5b5ee24fd33b35"},
+ {file = "pillow-11.2.1-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:2b490402c96f907a166615e9a5afacf2519e28295f157ec3a2bb9bd57de638cb"},
+ {file = "pillow-11.2.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:dd6b20b93b3ccc9c1b597999209e4bc5cf2853f9ee66e3fc9a400a78733ffc9a"},
+ {file = "pillow-11.2.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:4b835d89c08a6c2ee7781b8dd0a30209a8012b5f09c0a665b65b0eb3560b6f36"},
+ {file = "pillow-11.2.1-cp39-cp39-win32.whl", hash = "sha256:b10428b3416d4f9c61f94b494681280be7686bda15898a3a9e08eb66a6d92d67"},
+ {file = "pillow-11.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:6ebce70c3f486acf7591a3d73431fa504a4e18a9b97ff27f5f47b7368e4b9dd1"},
+ {file = "pillow-11.2.1-cp39-cp39-win_arm64.whl", hash = "sha256:c27476257b2fdcd7872d54cfd119b3a9ce4610fb85c8e32b70b42e3680a29a1e"},
+ {file = "pillow-11.2.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:9b7b0d4fd2635f54ad82785d56bc0d94f147096493a79985d0ab57aedd563156"},
+ {file = "pillow-11.2.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:aa442755e31c64037aa7c1cb186e0b369f8416c567381852c63444dd666fb772"},
+ {file = "pillow-11.2.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f0d3348c95b766f54b76116d53d4cb171b52992a1027e7ca50c81b43b9d9e363"},
+ {file = "pillow-11.2.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85d27ea4c889342f7e35f6d56e7e1cb345632ad592e8c51b693d7b7556043ce0"},
+ {file = "pillow-11.2.1-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:bf2c33d6791c598142f00c9c4c7d47f6476731c31081331664eb26d6ab583e01"},
+ {file = "pillow-11.2.1-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e616e7154c37669fc1dfc14584f11e284e05d1c650e1c0f972f281c4ccc53193"},
+ {file = "pillow-11.2.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:39ad2e0f424394e3aebc40168845fee52df1394a4673a6ee512d840d14ab3013"},
+ {file = "pillow-11.2.1-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:80f1df8dbe9572b4b7abdfa17eb5d78dd620b1d55d9e25f834efdbee872d3aed"},
+ {file = "pillow-11.2.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:ea926cfbc3957090becbcbbb65ad177161a2ff2ad578b5a6ec9bb1e1cd78753c"},
+ {file = "pillow-11.2.1-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:738db0e0941ca0376804d4de6a782c005245264edaa253ffce24e5a15cbdc7bd"},
+ {file = "pillow-11.2.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9db98ab6565c69082ec9b0d4e40dd9f6181dab0dd236d26f7a50b8b9bfbd5076"},
+ {file = "pillow-11.2.1-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:036e53f4170e270ddb8797d4c590e6dd14d28e15c7da375c18978045f7e6c37b"},
+ {file = "pillow-11.2.1-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:14f73f7c291279bd65fda51ee87affd7c1e097709f7fdd0188957a16c264601f"},
+ {file = "pillow-11.2.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:208653868d5c9ecc2b327f9b9ef34e0e42a4cdd172c2988fd81d62d2bc9bc044"},
+ {file = "pillow-11.2.1.tar.gz", hash = "sha256:a64dd61998416367b7ef979b73d3a85853ba9bec4c2925f74e588879a58716b6"},
+]
+
+[package.extras]
+docs = ["furo", "olefile", "sphinx (>=8.2)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinxext-opengraph"]
+fpx = ["olefile"]
+mic = ["olefile"]
+test-arrow = ["pyarrow"]
+tests = ["check-manifest", "coverage (>=7.4.2)", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout", "trove-classifiers (>=2024.10.12)"]
+typing = ["typing-extensions"]
+xmp = ["defusedxml"]
+
+[[package]]
+name = "pluggy"
+version = "1.5.0"
+description = "plugin and hook calling mechanisms for python"
+optional = false
+python-versions = ">=3.8"
+groups = ["dev"]
+files = [
+ {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"},
+ {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"},
+]
+
+[package.extras]
+dev = ["pre-commit", "tox"]
+testing = ["pytest", "pytest-benchmark"]
+
+[[package]]
+name = "prometheus-client"
+version = "0.21.1"
+description = "Python client for the Prometheus monitoring system."
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "prometheus_client-0.21.1-py3-none-any.whl", hash = "sha256:594b45c410d6f4f8888940fe80b5cc2521b305a1fafe1c58609ef715a001f301"},
+ {file = "prometheus_client-0.21.1.tar.gz", hash = "sha256:252505a722ac04b0456be05c05f75f45d760c2911ffc45f2a06bcaed9f3ae3fb"},
+]
+
+[package.extras]
+twisted = ["twisted"]
+
+[[package]]
+name = "protobuf"
+version = "5.29.4"
+description = ""
+optional = false
+python-versions = ">=3.8"
+groups = ["main", "dev"]
+files = [
+ {file = "protobuf-5.29.4-cp310-abi3-win32.whl", hash = "sha256:13eb236f8eb9ec34e63fc8b1d6efd2777d062fa6aaa68268fb67cf77f6839ad7"},
+ {file = "protobuf-5.29.4-cp310-abi3-win_amd64.whl", hash = "sha256:bcefcdf3976233f8a502d265eb65ea740c989bacc6c30a58290ed0e519eb4b8d"},
+ {file = "protobuf-5.29.4-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:307ecba1d852ec237e9ba668e087326a67564ef83e45a0189a772ede9e854dd0"},
+ {file = "protobuf-5.29.4-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:aec4962f9ea93c431d5714ed1be1c93f13e1a8618e70035ba2b0564d9e633f2e"},
+ {file = "protobuf-5.29.4-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:d7d3f7d1d5a66ed4942d4fefb12ac4b14a29028b209d4bfb25c68ae172059922"},
+ {file = "protobuf-5.29.4-cp38-cp38-win32.whl", hash = "sha256:1832f0515b62d12d8e6ffc078d7e9eb06969aa6dc13c13e1036e39d73bebc2de"},
+ {file = "protobuf-5.29.4-cp38-cp38-win_amd64.whl", hash = "sha256:476cb7b14914c780605a8cf62e38c2a85f8caff2e28a6a0bad827ec7d6c85d68"},
+ {file = "protobuf-5.29.4-cp39-cp39-win32.whl", hash = "sha256:fd32223020cb25a2cc100366f1dedc904e2d71d9322403224cdde5fdced0dabe"},
+ {file = "protobuf-5.29.4-cp39-cp39-win_amd64.whl", hash = "sha256:678974e1e3a9b975b8bc2447fca458db5f93a2fb6b0c8db46b6675b5b5346812"},
+ {file = "protobuf-5.29.4-py3-none-any.whl", hash = "sha256:3fde11b505e1597f71b875ef2fc52062b6a9740e5f7c8997ce878b6009145862"},
+ {file = "protobuf-5.29.4.tar.gz", hash = "sha256:4f1dfcd7997b31ef8f53ec82781ff434a28bf71d9102ddde14d076adcfc78c99"},
+]
+
+[[package]]
+name = "psutil"
+version = "7.0.0"
+description = "Cross-platform lib for process and system monitoring in Python. NOTE: the syntax of this script MUST be kept compatible with Python 2.7."
+optional = false
+python-versions = ">=3.6"
+groups = ["main"]
+files = [
+ {file = "psutil-7.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:101d71dc322e3cffd7cea0650b09b3d08b8e7c4109dd6809fe452dfd00e58b25"},
+ {file = "psutil-7.0.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:39db632f6bb862eeccf56660871433e111b6ea58f2caea825571951d4b6aa3da"},
+ {file = "psutil-7.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fcee592b4c6f146991ca55919ea3d1f8926497a713ed7faaf8225e174581e91"},
+ {file = "psutil-7.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b1388a4f6875d7e2aff5c4ca1cc16c545ed41dd8bb596cefea80111db353a34"},
+ {file = "psutil-7.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5f098451abc2828f7dc6b58d44b532b22f2088f4999a937557b603ce72b1993"},
+ {file = "psutil-7.0.0-cp36-cp36m-win32.whl", hash = "sha256:84df4eb63e16849689f76b1ffcb36db7b8de703d1bc1fe41773db487621b6c17"},
+ {file = "psutil-7.0.0-cp36-cp36m-win_amd64.whl", hash = "sha256:1e744154a6580bc968a0195fd25e80432d3afec619daf145b9e5ba16cc1d688e"},
+ {file = "psutil-7.0.0-cp37-abi3-win32.whl", hash = "sha256:ba3fcef7523064a6c9da440fc4d6bd07da93ac726b5733c29027d7dc95b39d99"},
+ {file = "psutil-7.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:4cf3d4eb1aa9b348dec30105c55cd9b7d4629285735a102beb4441e38db90553"},
+ {file = "psutil-7.0.0.tar.gz", hash = "sha256:7be9c3eba38beccb6495ea33afd982a44074b78f28c434a1f51cc07fd315c456"},
+]
+
+[package.extras]
+dev = ["abi3audit", "black (==24.10.0)", "check-manifest", "coverage", "packaging", "pylint", "pyperf", "pypinfo", "pytest", "pytest-cov", "pytest-xdist", "requests", "rstcheck", "ruff", "setuptools", "sphinx", "sphinx_rtd_theme", "toml-sort", "twine", "virtualenv", "vulture", "wheel"]
+test = ["pytest", "pytest-xdist", "setuptools"]
+
+[[package]]
+name = "py-cpuinfo"
+version = "9.0.0"
+description = "Get CPU info with pure Python"
+optional = false
+python-versions = "*"
+groups = ["main"]
+files = [
+ {file = "py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690"},
+ {file = "py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5"},
+]
+
+[[package]]
+name = "pydantic"
+version = "2.11.3"
+description = "Data validation using Python type hints"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "pydantic-2.11.3-py3-none-any.whl", hash = "sha256:a082753436a07f9ba1289c6ffa01cd93db3548776088aa917cc43b63f68fa60f"},
+ {file = "pydantic-2.11.3.tar.gz", hash = "sha256:7471657138c16adad9322fe3070c0116dd6c3ad8d649300e3cbdfe91f4db4ec3"},
+]
+
+[package.dependencies]
+annotated-types = ">=0.6.0"
+pydantic-core = "2.33.1"
+typing-extensions = ">=4.12.2"
+typing-inspection = ">=0.4.0"
+
+[package.extras]
+email = ["email-validator (>=2.0.0)"]
+timezone = ["tzdata"]
+
+[[package]]
+name = "pydantic-core"
+version = "2.33.1"
+description = "Core functionality for Pydantic validation and serialization"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "pydantic_core-2.33.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3077cfdb6125cc8dab61b155fdd714663e401f0e6883f9632118ec12cf42df26"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8ffab8b2908d152e74862d276cf5017c81a2f3719f14e8e3e8d6b83fda863927"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5183e4f6a2d468787243ebcd70cf4098c247e60d73fb7d68d5bc1e1beaa0c4db"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:398a38d323f37714023be1e0285765f0a27243a8b1506b7b7de87b647b517e48"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:87d3776f0001b43acebfa86f8c64019c043b55cc5a6a2e313d728b5c95b46969"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c566dd9c5f63d22226409553531f89de0cac55397f2ab8d97d6f06cfce6d947e"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0d5f3acc81452c56895e90643a625302bd6be351e7010664151cc55b7b97f89"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d3a07fadec2a13274a8d861d3d37c61e97a816beae717efccaa4b36dfcaadcde"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:f99aeda58dce827f76963ee87a0ebe75e648c72ff9ba1174a253f6744f518f65"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:902dbc832141aa0ec374f4310f1e4e7febeebc3256f00dc359a9ac3f264a45dc"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fe44d56aa0b00d66640aa84a3cbe80b7a3ccdc6f0b1ca71090696a6d4777c091"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-win32.whl", hash = "sha256:ed3eb16d51257c763539bde21e011092f127a2202692afaeaccb50db55a31383"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-win_amd64.whl", hash = "sha256:694ad99a7f6718c1a498dc170ca430687a39894a60327f548e02a9c7ee4b6504"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:6e966fc3caaf9f1d96b349b0341c70c8d6573bf1bac7261f7b0ba88f96c56c24"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bfd0adeee563d59c598ceabddf2c92eec77abcb3f4a391b19aa7366170bd9e30"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91815221101ad3c6b507804178a7bb5cb7b2ead9ecd600041669c8d805ebd595"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9fea9c1869bb4742d174a57b4700c6dadea951df8b06de40c2fedb4f02931c2e"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d20eb4861329bb2484c021b9d9a977566ab16d84000a57e28061151c62b349a"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0fb935c5591573ae3201640579f30128ccc10739b45663f93c06796854405505"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c964fd24e6166420d18fb53996d8c9fd6eac9bf5ae3ec3d03015be4414ce497f"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:681d65e9011f7392db5aa002b7423cc442d6a673c635668c227c6c8d0e5a4f77"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e100c52f7355a48413e2999bfb4e139d2977a904495441b374f3d4fb4a170961"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:048831bd363490be79acdd3232f74a0e9951b11b2b4cc058aeb72b22fdc3abe1"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:bdc84017d28459c00db6f918a7272a5190bec3090058334e43a76afb279eac7c"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-win32.whl", hash = "sha256:32cd11c5914d1179df70406427097c7dcde19fddf1418c787540f4b730289896"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-win_amd64.whl", hash = "sha256:2ea62419ba8c397e7da28a9170a16219d310d2cf4970dbc65c32faf20d828c83"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-win_arm64.whl", hash = "sha256:fc903512177361e868bc1f5b80ac8c8a6e05fcdd574a5fb5ffeac5a9982b9e89"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:1293d7febb995e9d3ec3ea09caf1a26214eec45b0f29f6074abb004723fc1de8"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:99b56acd433386c8f20be5c4000786d1e7ca0523c8eefc995d14d79c7a081498"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35a5ec3fa8c2fe6c53e1b2ccc2454398f95d5393ab398478f53e1afbbeb4d939"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b172f7b9d2f3abc0efd12e3386f7e48b576ef309544ac3a63e5e9cdd2e24585d"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9097b9f17f91eea659b9ec58148c0747ec354a42f7389b9d50701610d86f812e"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cc77ec5b7e2118b152b0d886c7514a4653bcb58c6b1d760134a9fab915f777b3"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5e3d15245b08fa4a84cefc6c9222e6f37c98111c8679fbd94aa145f9a0ae23d"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ef99779001d7ac2e2461d8ab55d3373fe7315caefdbecd8ced75304ae5a6fc6b"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:fc6bf8869e193855e8d91d91f6bf59699a5cdfaa47a404e278e776dd7f168b39"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:b1caa0bc2741b043db7823843e1bde8aaa58a55a58fda06083b0569f8b45693a"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ec259f62538e8bf364903a7d0d0239447059f9434b284f5536e8402b7dd198db"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-win32.whl", hash = "sha256:e14f369c98a7c15772b9da98987f58e2b509a93235582838bd0d1d8c08b68fda"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-win_amd64.whl", hash = "sha256:1c607801d85e2e123357b3893f82c97a42856192997b95b4d8325deb1cd0c5f4"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-win_arm64.whl", hash = "sha256:8d13f0276806ee722e70a1c93da19748594f19ac4299c7e41237fc791d1861ea"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:70af6a21237b53d1fe7b9325b20e65cbf2f0a848cf77bed492b029139701e66a"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:282b3fe1bbbe5ae35224a0dbd05aed9ccabccd241e8e6b60370484234b456266"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b315e596282bbb5822d0c7ee9d255595bd7506d1cb20c2911a4da0b970187d3"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1dfae24cf9921875ca0ca6a8ecb4bb2f13c855794ed0d468d6abbec6e6dcd44a"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6dd8ecfde08d8bfadaea669e83c63939af76f4cf5538a72597016edfa3fad516"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2f593494876eae852dc98c43c6f260f45abdbfeec9e4324e31a481d948214764"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:948b73114f47fd7016088e5186d13faf5e1b2fe83f5e320e371f035557fd264d"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e11f3864eb516af21b01e25fac915a82e9ddad3bb0fb9e95a246067398b435a4"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:549150be302428b56fdad0c23c2741dcdb5572413776826c965619a25d9c6bde"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:495bc156026efafd9ef2d82372bd38afce78ddd82bf28ef5276c469e57c0c83e"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ec79de2a8680b1a67a07490bddf9636d5c2fab609ba8c57597e855fa5fa4dacd"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-win32.whl", hash = "sha256:ee12a7be1742f81b8a65b36c6921022301d466b82d80315d215c4c691724986f"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-win_amd64.whl", hash = "sha256:ede9b407e39949d2afc46385ce6bd6e11588660c26f80576c11c958e6647bc40"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-win_arm64.whl", hash = "sha256:aa687a23d4b7871a00e03ca96a09cad0f28f443690d300500603bd0adba4b523"},
+ {file = "pydantic_core-2.33.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:401d7b76e1000d0dd5538e6381d28febdcacb097c8d340dde7d7fc6e13e9f95d"},
+ {file = "pydantic_core-2.33.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7aeb055a42d734c0255c9e489ac67e75397d59c6fbe60d155851e9782f276a9c"},
+ {file = "pydantic_core-2.33.1-cp313-cp313t-win_amd64.whl", hash = "sha256:338ea9b73e6e109f15ab439e62cb3b78aa752c7fd9536794112e14bee02c8d18"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:5ab77f45d33d264de66e1884fca158bc920cb5e27fd0764a72f72f5756ae8bdb"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e7aaba1b4b03aaea7bb59e1b5856d734be011d3e6d98f5bcaa98cb30f375f2ad"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7fb66263e9ba8fea2aa85e1e5578980d127fb37d7f2e292773e7bc3a38fb0c7b"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3f2648b9262607a7fb41d782cc263b48032ff7a03a835581abbf7a3bec62bcf5"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:723c5630c4259400818b4ad096735a829074601805d07f8cafc366d95786d331"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d100e3ae783d2167782391e0c1c7a20a31f55f8015f3293647544df3f9c67824"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:177d50460bc976a0369920b6c744d927b0ecb8606fb56858ff542560251b19e5"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a3edde68d1a1f9af1273b2fe798997b33f90308fb6d44d8550c89fc6a3647cf6"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a62c3c3ef6a7e2c45f7853b10b5bc4ddefd6ee3cd31024754a1a5842da7d598d"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-musllinux_1_1_armv7l.whl", hash = "sha256:c91dbb0ab683fa0cd64a6e81907c8ff41d6497c346890e26b23de7ee55353f96"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9f466e8bf0a62dc43e068c12166281c2eca72121dd2adc1040f3aa1e21ef8599"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-win32.whl", hash = "sha256:ab0277cedb698749caada82e5d099dc9fed3f906a30d4c382d1a21725777a1e5"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-win_amd64.whl", hash = "sha256:5773da0ee2d17136b1f1c6fbde543398d452a6ad2a7b54ea1033e2daa739b8d2"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5c834f54f8f4640fd7e4b193f80eb25a0602bba9e19b3cd2fc7ffe8199f5ae02"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:049e0de24cf23766f12cc5cc71d8abc07d4a9deb9061b334b62093dedc7cb068"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a28239037b3d6f16916a4c831a5a0eadf856bdd6d2e92c10a0da3a59eadcf3e"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d3da303ab5f378a268fa7d45f37d7d85c3ec19769f28d2cc0c61826a8de21fe"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:25626fb37b3c543818c14821afe0fd3830bc327a43953bc88db924b68c5723f1"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:3ab2d36e20fbfcce8f02d73c33a8a7362980cff717926bbae030b93ae46b56c7"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:2f9284e11c751b003fd4215ad92d325d92c9cb19ee6729ebd87e3250072cdcde"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:048c01eee07d37cbd066fc512b9d8b5ea88ceeb4e629ab94b3e56965ad655add"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:5ccd429694cf26af7997595d627dd2637e7932214486f55b8a357edaac9dae8c"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3a371dc00282c4b84246509a5ddc808e61b9864aa1eae9ecc92bb1268b82db4a"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:f59295ecc75a1788af8ba92f2e8c6eeaa5a94c22fc4d151e8d9638814f85c8fc"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08530b8ac922003033f399128505f513e30ca770527cc8bbacf75a84fcc2c74b"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bae370459da6a5466978c0eacf90690cb57ec9d533f8e63e564ef3822bfa04fe"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e3de2777e3b9f4d603112f78006f4ae0acb936e95f06da6cb1a45fbad6bdb4b5"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:3a64e81e8cba118e108d7126362ea30e021291b7805d47e4896e52c791be2761"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:52928d8c1b6bda03cc6d811e8923dffc87a2d3c8b3bfd2ce16471c7147a24850"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:1b30d92c9412beb5ac6b10a3eb7ef92ccb14e3f2a8d7732e2d739f58b3aa7544"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:f995719707e0e29f0f41a8aa3bcea6e761a36c9136104d3189eafb83f5cec5e5"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:7edbc454a29fc6aeae1e1eecba4f07b63b8d76e76a748532233c4c167b4cb9ea"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:ad05b683963f69a1d5d2c2bdab1274a31221ca737dbbceaa32bcb67359453cdd"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df6a94bf9452c6da9b5d76ed229a5683d0306ccb91cca8e1eea883189780d568"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7965c13b3967909a09ecc91f21d09cfc4576bf78140b988904e94f130f188396"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3f1fdb790440a34f6ecf7679e1863b825cb5ffde858a9197f851168ed08371e5"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:5277aec8d879f8d05168fdd17ae811dd313b8ff894aeeaf7cd34ad28b4d77e33"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:8ab581d3530611897d863d1a649fb0644b860286b4718db919bfd51ece41f10b"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:0483847fa9ad5e3412265c1bd72aad35235512d9ce9d27d81a56d935ef489672"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:de9e06abe3cc5ec6a2d5f75bc99b0bdca4f5c719a5b34026f8c57efbdecd2ee3"},
+ {file = "pydantic_core-2.33.1.tar.gz", hash = "sha256:bcc9c6fdb0ced789245b02b7d6603e17d1563064ddcfc36f046b61c0c05dd9df"},
+]
+
+[package.dependencies]
+typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
+
+[[package]]
+name = "pygments"
+version = "2.19.1"
+description = "Pygments is a syntax highlighting package written in Python."
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c"},
+ {file = "pygments-2.19.1.tar.gz", hash = "sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f"},
+]
+
+[package.extras]
+windows-terminal = ["colorama (>=0.4.6)"]
+
+[[package]]
+name = "pytest"
+version = "8.3.5"
+description = "pytest: simple powerful testing with Python"
+optional = false
+python-versions = ">=3.8"
+groups = ["dev"]
+files = [
+ {file = "pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820"},
+ {file = "pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845"},
+]
+
+[package.dependencies]
+colorama = {version = "*", markers = "sys_platform == \"win32\""}
+exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
+iniconfig = "*"
+packaging = "*"
+pluggy = ">=1.5,<2"
+tomli = {version = ">=1", markers = "python_version < \"3.11\""}
+
+[package.extras]
+dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
+
+[[package]]
+name = "pyyaml"
+version = "6.0.2"
+description = "YAML parser and emitter for Python"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"},
+ {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"},
+ {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237"},
+ {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b"},
+ {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed"},
+ {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180"},
+ {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68"},
+ {file = "PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99"},
+ {file = "PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e"},
+ {file = "PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774"},
+ {file = "PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee"},
+ {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c"},
+ {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317"},
+ {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85"},
+ {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4"},
+ {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e"},
+ {file = "PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5"},
+ {file = "PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44"},
+ {file = "PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab"},
+ {file = "PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725"},
+ {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5"},
+ {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425"},
+ {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476"},
+ {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48"},
+ {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b"},
+ {file = "PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4"},
+ {file = "PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8"},
+ {file = "PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba"},
+ {file = "PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1"},
+ {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133"},
+ {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484"},
+ {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5"},
+ {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc"},
+ {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652"},
+ {file = "PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183"},
+ {file = "PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563"},
+ {file = "PyYAML-6.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a"},
+ {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5"},
+ {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d"},
+ {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083"},
+ {file = "PyYAML-6.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706"},
+ {file = "PyYAML-6.0.2-cp38-cp38-win32.whl", hash = "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a"},
+ {file = "PyYAML-6.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff"},
+ {file = "PyYAML-6.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d"},
+ {file = "PyYAML-6.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f"},
+ {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290"},
+ {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12"},
+ {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19"},
+ {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e"},
+ {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725"},
+ {file = "PyYAML-6.0.2-cp39-cp39-win32.whl", hash = "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631"},
+ {file = "PyYAML-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8"},
+ {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"},
+]
+
+[[package]]
+name = "referencing"
+version = "0.36.2"
+description = "JSON Referencing + Python"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "referencing-0.36.2-py3-none-any.whl", hash = "sha256:e8699adbbf8b5c7de96d8ffa0eb5c158b3beafce084968e2ea8bb08c6794dcd0"},
+ {file = "referencing-0.36.2.tar.gz", hash = "sha256:df2e89862cd09deabbdba16944cc3f10feb6b3e6f18e902f7cc25609a34775aa"},
+]
+
+[package.dependencies]
+attrs = ">=22.2.0"
+rpds-py = ">=0.7.0"
+typing-extensions = {version = ">=4.4.0", markers = "python_version < \"3.13\""}
+
+[[package]]
+name = "regex"
+version = "2024.11.6"
+description = "Alternative regular expression module, to replace re."
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "regex-2024.11.6-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ff590880083d60acc0433f9c3f713c51f7ac6ebb9adf889c79a261ecf541aa91"},
+ {file = "regex-2024.11.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:658f90550f38270639e83ce492f27d2c8d2cd63805c65a13a14d36ca126753f0"},
+ {file = "regex-2024.11.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:164d8b7b3b4bcb2068b97428060b2a53be050085ef94eca7f240e7947f1b080e"},
+ {file = "regex-2024.11.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3660c82f209655a06b587d55e723f0b813d3a7db2e32e5e7dc64ac2a9e86fde"},
+ {file = "regex-2024.11.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d22326fcdef5e08c154280b71163ced384b428343ae16a5ab2b3354aed12436e"},
+ {file = "regex-2024.11.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f1ac758ef6aebfc8943560194e9fd0fa18bcb34d89fd8bd2af18183afd8da3a2"},
+ {file = "regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:997d6a487ff00807ba810e0f8332c18b4eb8d29463cfb7c820dc4b6e7562d0cf"},
+ {file = "regex-2024.11.6-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02a02d2bb04fec86ad61f3ea7f49c015a0681bf76abb9857f945d26159d2968c"},
+ {file = "regex-2024.11.6-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f02f93b92358ee3f78660e43b4b0091229260c5d5c408d17d60bf26b6c900e86"},
+ {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:06eb1be98df10e81ebaded73fcd51989dcf534e3c753466e4b60c4697a003b67"},
+ {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:040df6fe1a5504eb0f04f048e6d09cd7c7110fef851d7c567a6b6e09942feb7d"},
+ {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:fdabbfc59f2c6edba2a6622c647b716e34e8e3867e0ab975412c5c2f79b82da2"},
+ {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:8447d2d39b5abe381419319f942de20b7ecd60ce86f16a23b0698f22e1b70008"},
+ {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:da8f5fc57d1933de22a9e23eec290a0d8a5927a5370d24bda9a6abe50683fe62"},
+ {file = "regex-2024.11.6-cp310-cp310-win32.whl", hash = "sha256:b489578720afb782f6ccf2840920f3a32e31ba28a4b162e13900c3e6bd3f930e"},
+ {file = "regex-2024.11.6-cp310-cp310-win_amd64.whl", hash = "sha256:5071b2093e793357c9d8b2929dfc13ac5f0a6c650559503bb81189d0a3814519"},
+ {file = "regex-2024.11.6-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5478c6962ad548b54a591778e93cd7c456a7a29f8eca9c49e4f9a806dcc5d638"},
+ {file = "regex-2024.11.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2c89a8cc122b25ce6945f0423dc1352cb9593c68abd19223eebbd4e56612c5b7"},
+ {file = "regex-2024.11.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:94d87b689cdd831934fa3ce16cc15cd65748e6d689f5d2b8f4f4df2065c9fa20"},
+ {file = "regex-2024.11.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1062b39a0a2b75a9c694f7a08e7183a80c63c0d62b301418ffd9c35f55aaa114"},
+ {file = "regex-2024.11.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:167ed4852351d8a750da48712c3930b031f6efdaa0f22fa1933716bfcd6bf4a3"},
+ {file = "regex-2024.11.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2d548dafee61f06ebdb584080621f3e0c23fff312f0de1afc776e2a2ba99a74f"},
+ {file = "regex-2024.11.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f2a19f302cd1ce5dd01a9099aaa19cae6173306d1302a43b627f62e21cf18ac0"},
+ {file = "regex-2024.11.6-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bec9931dfb61ddd8ef2ebc05646293812cb6b16b60cf7c9511a832b6f1854b55"},
+ {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9714398225f299aa85267fd222f7142fcb5c769e73d7733344efc46f2ef5cf89"},
+ {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:202eb32e89f60fc147a41e55cb086db2a3f8cb82f9a9a88440dcfc5d37faae8d"},
+ {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:4181b814e56078e9b00427ca358ec44333765f5ca1b45597ec7446d3a1ef6e34"},
+ {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:068376da5a7e4da51968ce4c122a7cd31afaaec4fccc7856c92f63876e57b51d"},
+ {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ac10f2c4184420d881a3475fb2c6f4d95d53a8d50209a2500723d831036f7c45"},
+ {file = "regex-2024.11.6-cp311-cp311-win32.whl", hash = "sha256:c36f9b6f5f8649bb251a5f3f66564438977b7ef8386a52460ae77e6070d309d9"},
+ {file = "regex-2024.11.6-cp311-cp311-win_amd64.whl", hash = "sha256:02e28184be537f0e75c1f9b2f8847dc51e08e6e171c6bde130b2687e0c33cf60"},
+ {file = "regex-2024.11.6-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:52fb28f528778f184f870b7cf8f225f5eef0a8f6e3778529bdd40c7b3920796a"},
+ {file = "regex-2024.11.6-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fdd6028445d2460f33136c55eeb1f601ab06d74cb3347132e1c24250187500d9"},
+ {file = "regex-2024.11.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:805e6b60c54bf766b251e94526ebad60b7de0c70f70a4e6210ee2891acb70bf2"},
+ {file = "regex-2024.11.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b85c2530be953a890eaffde05485238f07029600e8f098cdf1848d414a8b45e4"},
+ {file = "regex-2024.11.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bb26437975da7dc36b7efad18aa9dd4ea569d2357ae6b783bf1118dabd9ea577"},
+ {file = "regex-2024.11.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:abfa5080c374a76a251ba60683242bc17eeb2c9818d0d30117b4486be10c59d3"},
+ {file = "regex-2024.11.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b7fa6606c2881c1db9479b0eaa11ed5dfa11c8d60a474ff0e095099f39d98e"},
+ {file = "regex-2024.11.6-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0c32f75920cf99fe6b6c539c399a4a128452eaf1af27f39bce8909c9a3fd8cbe"},
+ {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:982e6d21414e78e1f51cf595d7f321dcd14de1f2881c5dc6a6e23bbbbd68435e"},
+ {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:a7c2155f790e2fb448faed6dd241386719802296ec588a8b9051c1f5c481bc29"},
+ {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:149f5008d286636e48cd0b1dd65018548944e495b0265b45e1bffecce1ef7f39"},
+ {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:e5364a4502efca094731680e80009632ad6624084aff9a23ce8c8c6820de3e51"},
+ {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:0a86e7eeca091c09e021db8eb72d54751e527fa47b8d5787caf96d9831bd02ad"},
+ {file = "regex-2024.11.6-cp312-cp312-win32.whl", hash = "sha256:32f9a4c643baad4efa81d549c2aadefaeba12249b2adc5af541759237eee1c54"},
+ {file = "regex-2024.11.6-cp312-cp312-win_amd64.whl", hash = "sha256:a93c194e2df18f7d264092dc8539b8ffb86b45b899ab976aa15d48214138e81b"},
+ {file = "regex-2024.11.6-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a6ba92c0bcdf96cbf43a12c717eae4bc98325ca3730f6b130ffa2e3c3c723d84"},
+ {file = "regex-2024.11.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:525eab0b789891ac3be914d36893bdf972d483fe66551f79d3e27146191a37d4"},
+ {file = "regex-2024.11.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:086a27a0b4ca227941700e0b31425e7a28ef1ae8e5e05a33826e17e47fbfdba0"},
+ {file = "regex-2024.11.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bde01f35767c4a7899b7eb6e823b125a64de314a8ee9791367c9a34d56af18d0"},
+ {file = "regex-2024.11.6-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b583904576650166b3d920d2bcce13971f6f9e9a396c673187f49811b2769dc7"},
+ {file = "regex-2024.11.6-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1c4de13f06a0d54fa0d5ab1b7138bfa0d883220965a29616e3ea61b35d5f5fc7"},
+ {file = "regex-2024.11.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3cde6e9f2580eb1665965ce9bf17ff4952f34f5b126beb509fee8f4e994f143c"},
+ {file = "regex-2024.11.6-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d7f453dca13f40a02b79636a339c5b62b670141e63efd511d3f8f73fba162b3"},
+ {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:59dfe1ed21aea057a65c6b586afd2a945de04fc7db3de0a6e3ed5397ad491b07"},
+ {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b97c1e0bd37c5cd7902e65f410779d39eeda155800b65fc4d04cc432efa9bc6e"},
+ {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f9d1e379028e0fc2ae3654bac3cbbef81bf3fd571272a42d56c24007979bafb6"},
+ {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:13291b39131e2d002a7940fb176e120bec5145f3aeb7621be6534e46251912c4"},
+ {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4f51f88c126370dcec4908576c5a627220da6c09d0bff31cfa89f2523843316d"},
+ {file = "regex-2024.11.6-cp313-cp313-win32.whl", hash = "sha256:63b13cfd72e9601125027202cad74995ab26921d8cd935c25f09c630436348ff"},
+ {file = "regex-2024.11.6-cp313-cp313-win_amd64.whl", hash = "sha256:2b3361af3198667e99927da8b84c1b010752fa4b1115ee30beaa332cabc3ef1a"},
+ {file = "regex-2024.11.6-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:3a51ccc315653ba012774efca4f23d1d2a8a8f278a6072e29c7147eee7da446b"},
+ {file = "regex-2024.11.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ad182d02e40de7459b73155deb8996bbd8e96852267879396fb274e8700190e3"},
+ {file = "regex-2024.11.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ba9b72e5643641b7d41fa1f6d5abda2c9a263ae835b917348fc3c928182ad467"},
+ {file = "regex-2024.11.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40291b1b89ca6ad8d3f2b82782cc33807f1406cf68c8d440861da6304d8ffbbd"},
+ {file = "regex-2024.11.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cdf58d0e516ee426a48f7b2c03a332a4114420716d55769ff7108c37a09951bf"},
+ {file = "regex-2024.11.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a36fdf2af13c2b14738f6e973aba563623cb77d753bbbd8d414d18bfaa3105dd"},
+ {file = "regex-2024.11.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1cee317bfc014c2419a76bcc87f071405e3966da434e03e13beb45f8aced1a6"},
+ {file = "regex-2024.11.6-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50153825ee016b91549962f970d6a4442fa106832e14c918acd1c8e479916c4f"},
+ {file = "regex-2024.11.6-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ea1bfda2f7162605f6e8178223576856b3d791109f15ea99a9f95c16a7636fb5"},
+ {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:df951c5f4a1b1910f1a99ff42c473ff60f8225baa1cdd3539fe2819d9543e9df"},
+ {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:072623554418a9911446278f16ecb398fb3b540147a7828c06e2011fa531e773"},
+ {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:f654882311409afb1d780b940234208a252322c24a93b442ca714d119e68086c"},
+ {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:89d75e7293d2b3e674db7d4d9b1bee7f8f3d1609428e293771d1a962617150cc"},
+ {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:f65557897fc977a44ab205ea871b690adaef6b9da6afda4790a2484b04293a5f"},
+ {file = "regex-2024.11.6-cp38-cp38-win32.whl", hash = "sha256:6f44ec28b1f858c98d3036ad5d7d0bfc568bdd7a74f9c24e25f41ef1ebfd81a4"},
+ {file = "regex-2024.11.6-cp38-cp38-win_amd64.whl", hash = "sha256:bb8f74f2f10dbf13a0be8de623ba4f9491faf58c24064f32b65679b021ed0001"},
+ {file = "regex-2024.11.6-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5704e174f8ccab2026bd2f1ab6c510345ae8eac818b613d7d73e785f1310f839"},
+ {file = "regex-2024.11.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:220902c3c5cc6af55d4fe19ead504de80eb91f786dc102fbd74894b1551f095e"},
+ {file = "regex-2024.11.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5e7e351589da0850c125f1600a4c4ba3c722efefe16b297de54300f08d734fbf"},
+ {file = "regex-2024.11.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5056b185ca113c88e18223183aa1a50e66507769c9640a6ff75859619d73957b"},
+ {file = "regex-2024.11.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e34b51b650b23ed3354b5a07aab37034d9f923db2a40519139af34f485f77d0"},
+ {file = "regex-2024.11.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5670bce7b200273eee1840ef307bfa07cda90b38ae56e9a6ebcc9f50da9c469b"},
+ {file = "regex-2024.11.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:08986dce1339bc932923e7d1232ce9881499a0e02925f7402fb7c982515419ef"},
+ {file = "regex-2024.11.6-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:93c0b12d3d3bc25af4ebbf38f9ee780a487e8bf6954c115b9f015822d3bb8e48"},
+ {file = "regex-2024.11.6-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:764e71f22ab3b305e7f4c21f1a97e1526a25ebdd22513e251cf376760213da13"},
+ {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:f056bf21105c2515c32372bbc057f43eb02aae2fda61052e2f7622c801f0b4e2"},
+ {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:69ab78f848845569401469da20df3e081e6b5a11cb086de3eed1d48f5ed57c95"},
+ {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:86fddba590aad9208e2fa8b43b4c098bb0ec74f15718bb6a704e3c63e2cef3e9"},
+ {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:684d7a212682996d21ca12ef3c17353c021fe9de6049e19ac8481ec35574a70f"},
+ {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:a03e02f48cd1abbd9f3b7e3586d97c8f7a9721c436f51a5245b3b9483044480b"},
+ {file = "regex-2024.11.6-cp39-cp39-win32.whl", hash = "sha256:41758407fc32d5c3c5de163888068cfee69cb4c2be844e7ac517a52770f9af57"},
+ {file = "regex-2024.11.6-cp39-cp39-win_amd64.whl", hash = "sha256:b2837718570f95dd41675328e111345f9b7095d821bac435aac173ac80b19983"},
+ {file = "regex-2024.11.6.tar.gz", hash = "sha256:7ab159b063c52a0333c884e4679f8d7a85112ee3078fe3d9004b2dd875585519"},
+]
+
+[[package]]
+name = "requests"
+version = "2.32.3"
+description = "Python HTTP for Humans."
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"},
+ {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"},
+]
+
+[package.dependencies]
+certifi = ">=2017.4.17"
+charset-normalizer = ">=2,<4"
+idna = ">=2.5,<4"
+urllib3 = ">=1.21.1,<3"
+
+[package.extras]
+socks = ["PySocks (>=1.5.6,!=1.5.7)"]
+use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
+
+[[package]]
+name = "rich"
+version = "14.0.0"
+description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
+optional = false
+python-versions = ">=3.8.0"
+groups = ["main"]
+files = [
+ {file = "rich-14.0.0-py3-none-any.whl", hash = "sha256:1c9491e1951aac09caffd42f448ee3d04e58923ffe14993f6e83068dc395d7e0"},
+ {file = "rich-14.0.0.tar.gz", hash = "sha256:82f1bc23a6a21ebca4ae0c45af9bdbc492ed20231dcb63f297d6d1021a9d5725"},
+]
+
+[package.dependencies]
+markdown-it-py = ">=2.2.0"
+pygments = ">=2.13.0,<3.0.0"
+typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.11\""}
+
+[package.extras]
+jupyter = ["ipywidgets (>=7.5.1,<9)"]
+
+[[package]]
+name = "rpds-py"
+version = "0.24.0"
+description = "Python bindings to Rust's persistent data structures (rpds)"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "rpds_py-0.24.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:006f4342fe729a368c6df36578d7a348c7c716be1da0a1a0f86e3021f8e98724"},
+ {file = "rpds_py-0.24.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2d53747da70a4e4b17f559569d5f9506420966083a31c5fbd84e764461c4444b"},
+ {file = "rpds_py-0.24.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8acd55bd5b071156bae57b555f5d33697998752673b9de554dd82f5b5352727"},
+ {file = "rpds_py-0.24.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7e80d375134ddb04231a53800503752093dbb65dad8dabacce2c84cccc78e964"},
+ {file = "rpds_py-0.24.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60748789e028d2a46fc1c70750454f83c6bdd0d05db50f5ae83e2db500b34da5"},
+ {file = "rpds_py-0.24.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6e1daf5bf6c2be39654beae83ee6b9a12347cb5aced9a29eecf12a2d25fff664"},
+ {file = "rpds_py-0.24.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1b221c2457d92a1fb3c97bee9095c874144d196f47c038462ae6e4a14436f7bc"},
+ {file = "rpds_py-0.24.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:66420986c9afff67ef0c5d1e4cdc2d0e5262f53ad11e4f90e5e22448df485bf0"},
+ {file = "rpds_py-0.24.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:43dba99f00f1d37b2a0265a259592d05fcc8e7c19d140fe51c6e6f16faabeb1f"},
+ {file = "rpds_py-0.24.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:a88c0d17d039333a41d9bf4616bd062f0bd7aa0edeb6cafe00a2fc2a804e944f"},
+ {file = "rpds_py-0.24.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:cc31e13ce212e14a539d430428cd365e74f8b2d534f8bc22dd4c9c55b277b875"},
+ {file = "rpds_py-0.24.0-cp310-cp310-win32.whl", hash = "sha256:fc2c1e1b00f88317d9de6b2c2b39b012ebbfe35fe5e7bef980fd2a91f6100a07"},
+ {file = "rpds_py-0.24.0-cp310-cp310-win_amd64.whl", hash = "sha256:c0145295ca415668420ad142ee42189f78d27af806fcf1f32a18e51d47dd2052"},
+ {file = "rpds_py-0.24.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:2d3ee4615df36ab8eb16c2507b11e764dcc11fd350bbf4da16d09cda11fcedef"},
+ {file = "rpds_py-0.24.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e13ae74a8a3a0c2f22f450f773e35f893484fcfacb00bb4344a7e0f4f48e1f97"},
+ {file = "rpds_py-0.24.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf86f72d705fc2ef776bb7dd9e5fbba79d7e1f3e258bf9377f8204ad0fc1c51e"},
+ {file = "rpds_py-0.24.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c43583ea8517ed2e780a345dd9960896afc1327e8cf3ac8239c167530397440d"},
+ {file = "rpds_py-0.24.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4cd031e63bc5f05bdcda120646a0d32f6d729486d0067f09d79c8db5368f4586"},
+ {file = "rpds_py-0.24.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:34d90ad8c045df9a4259c47d2e16a3f21fdb396665c94520dbfe8766e62187a4"},
+ {file = "rpds_py-0.24.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e838bf2bb0b91ee67bf2b889a1a841e5ecac06dd7a2b1ef4e6151e2ce155c7ae"},
+ {file = "rpds_py-0.24.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:04ecf5c1ff4d589987b4d9882872f80ba13da7d42427234fce8f22efb43133bc"},
+ {file = "rpds_py-0.24.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:630d3d8ea77eabd6cbcd2ea712e1c5cecb5b558d39547ac988351195db433f6c"},
+ {file = "rpds_py-0.24.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:ebcb786b9ff30b994d5969213a8430cbb984cdd7ea9fd6df06663194bd3c450c"},
+ {file = "rpds_py-0.24.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:174e46569968ddbbeb8a806d9922f17cd2b524aa753b468f35b97ff9c19cb718"},
+ {file = "rpds_py-0.24.0-cp311-cp311-win32.whl", hash = "sha256:5ef877fa3bbfb40b388a5ae1cb00636a624690dcb9a29a65267054c9ea86d88a"},
+ {file = "rpds_py-0.24.0-cp311-cp311-win_amd64.whl", hash = "sha256:e274f62cbd274359eff63e5c7e7274c913e8e09620f6a57aae66744b3df046d6"},
+ {file = "rpds_py-0.24.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:d8551e733626afec514b5d15befabea0dd70a343a9f23322860c4f16a9430205"},
+ {file = "rpds_py-0.24.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0e374c0ce0ca82e5b67cd61fb964077d40ec177dd2c4eda67dba130de09085c7"},
+ {file = "rpds_py-0.24.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d69d003296df4840bd445a5d15fa5b6ff6ac40496f956a221c4d1f6f7b4bc4d9"},
+ {file = "rpds_py-0.24.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8212ff58ac6dfde49946bea57474a386cca3f7706fc72c25b772b9ca4af6b79e"},
+ {file = "rpds_py-0.24.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:528927e63a70b4d5f3f5ccc1fa988a35456eb5d15f804d276709c33fc2f19bda"},
+ {file = "rpds_py-0.24.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a824d2c7a703ba6daaca848f9c3d5cb93af0505be505de70e7e66829affd676e"},
+ {file = "rpds_py-0.24.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44d51febb7a114293ffd56c6cf4736cb31cd68c0fddd6aa303ed09ea5a48e029"},
+ {file = "rpds_py-0.24.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3fab5f4a2c64a8fb64fc13b3d139848817a64d467dd6ed60dcdd6b479e7febc9"},
+ {file = "rpds_py-0.24.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:9be4f99bee42ac107870c61dfdb294d912bf81c3c6d45538aad7aecab468b6b7"},
+ {file = "rpds_py-0.24.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:564c96b6076a98215af52f55efa90d8419cc2ef45d99e314fddefe816bc24f91"},
+ {file = "rpds_py-0.24.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:75a810b7664c17f24bf2ffd7f92416c00ec84b49bb68e6a0d93e542406336b56"},
+ {file = "rpds_py-0.24.0-cp312-cp312-win32.whl", hash = "sha256:f6016bd950be4dcd047b7475fdf55fb1e1f59fc7403f387be0e8123e4a576d30"},
+ {file = "rpds_py-0.24.0-cp312-cp312-win_amd64.whl", hash = "sha256:998c01b8e71cf051c28f5d6f1187abbdf5cf45fc0efce5da6c06447cba997034"},
+ {file = "rpds_py-0.24.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:3d2d8e4508e15fc05b31285c4b00ddf2e0eb94259c2dc896771966a163122a0c"},
+ {file = "rpds_py-0.24.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0f00c16e089282ad68a3820fd0c831c35d3194b7cdc31d6e469511d9bffc535c"},
+ {file = "rpds_py-0.24.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:951cc481c0c395c4a08639a469d53b7d4afa252529a085418b82a6b43c45c240"},
+ {file = "rpds_py-0.24.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c9ca89938dff18828a328af41ffdf3902405a19f4131c88e22e776a8e228c5a8"},
+ {file = "rpds_py-0.24.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ed0ef550042a8dbcd657dfb284a8ee00f0ba269d3f2286b0493b15a5694f9fe8"},
+ {file = "rpds_py-0.24.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2b2356688e5d958c4d5cb964af865bea84db29971d3e563fb78e46e20fe1848b"},
+ {file = "rpds_py-0.24.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78884d155fd15d9f64f5d6124b486f3d3f7fd7cd71a78e9670a0f6f6ca06fb2d"},
+ {file = "rpds_py-0.24.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6a4a535013aeeef13c5532f802708cecae8d66c282babb5cd916379b72110cf7"},
+ {file = "rpds_py-0.24.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:84e0566f15cf4d769dade9b366b7b87c959be472c92dffb70462dd0844d7cbad"},
+ {file = "rpds_py-0.24.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:823e74ab6fbaa028ec89615ff6acb409e90ff45580c45920d4dfdddb069f2120"},
+ {file = "rpds_py-0.24.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:c61a2cb0085c8783906b2f8b1f16a7e65777823c7f4d0a6aaffe26dc0d358dd9"},
+ {file = "rpds_py-0.24.0-cp313-cp313-win32.whl", hash = "sha256:60d9b630c8025b9458a9d114e3af579a2c54bd32df601c4581bd054e85258143"},
+ {file = "rpds_py-0.24.0-cp313-cp313-win_amd64.whl", hash = "sha256:6eea559077d29486c68218178ea946263b87f1c41ae7f996b1f30a983c476a5a"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:d09dc82af2d3c17e7dd17120b202a79b578d79f2b5424bda209d9966efeed114"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:5fc13b44de6419d1e7a7e592a4885b323fbc2f46e1f22151e3a8ed3b8b920405"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c347a20d79cedc0a7bd51c4d4b7dbc613ca4e65a756b5c3e57ec84bd43505b47"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:20f2712bd1cc26a3cc16c5a1bfee9ed1abc33d4cdf1aabd297fe0eb724df4272"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aad911555286884be1e427ef0dc0ba3929e6821cbeca2194b13dc415a462c7fd"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0aeb3329c1721c43c58cae274d7d2ca85c1690d89485d9c63a006cb79a85771a"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a0f156e9509cee987283abd2296ec816225145a13ed0391df8f71bf1d789e2d"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:aa6800adc8204ce898c8a424303969b7aa6a5e4ad2789c13f8648739830323b7"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:a18fc371e900a21d7392517c6f60fe859e802547309e94313cd8181ad9db004d"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:9168764133fd919f8dcca2ead66de0105f4ef5659cbb4fa044f7014bed9a1797"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:5f6e3cec44ba05ee5cbdebe92d052f69b63ae792e7d05f1020ac5e964394080c"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-win32.whl", hash = "sha256:8ebc7e65ca4b111d928b669713865f021b7773350eeac4a31d3e70144297baba"},
+ {file = "rpds_py-0.24.0-cp313-cp313t-win_amd64.whl", hash = "sha256:675269d407a257b8c00a6b58205b72eec8231656506c56fd429d924ca00bb350"},
+ {file = "rpds_py-0.24.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:a36b452abbf29f68527cf52e181fced56685731c86b52e852053e38d8b60bc8d"},
+ {file = "rpds_py-0.24.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8b3b397eefecec8e8e39fa65c630ef70a24b09141a6f9fc17b3c3a50bed6b50e"},
+ {file = "rpds_py-0.24.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cdabcd3beb2a6dca7027007473d8ef1c3b053347c76f685f5f060a00327b8b65"},
+ {file = "rpds_py-0.24.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5db385bacd0c43f24be92b60c857cf760b7f10d8234f4bd4be67b5b20a7c0b6b"},
+ {file = "rpds_py-0.24.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8097b3422d020ff1c44effc40ae58e67d93e60d540a65649d2cdaf9466030791"},
+ {file = "rpds_py-0.24.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:493fe54318bed7d124ce272fc36adbf59d46729659b2c792e87c3b95649cdee9"},
+ {file = "rpds_py-0.24.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8aa362811ccdc1f8dadcc916c6d47e554169ab79559319ae9fae7d7752d0d60c"},
+ {file = "rpds_py-0.24.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d8f9a6e7fd5434817526815f09ea27f2746c4a51ee11bb3439065f5fc754db58"},
+ {file = "rpds_py-0.24.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:8205ee14463248d3349131bb8099efe15cd3ce83b8ef3ace63c7e976998e7124"},
+ {file = "rpds_py-0.24.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:921ae54f9ecba3b6325df425cf72c074cd469dea843fb5743a26ca7fb2ccb149"},
+ {file = "rpds_py-0.24.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:32bab0a56eac685828e00cc2f5d1200c548f8bc11f2e44abf311d6b548ce2e45"},
+ {file = "rpds_py-0.24.0-cp39-cp39-win32.whl", hash = "sha256:f5c0ed12926dec1dfe7d645333ea59cf93f4d07750986a586f511c0bc61fe103"},
+ {file = "rpds_py-0.24.0-cp39-cp39-win_amd64.whl", hash = "sha256:afc6e35f344490faa8276b5f2f7cbf71f88bc2cda4328e00553bd451728c571f"},
+ {file = "rpds_py-0.24.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:619ca56a5468f933d940e1bf431c6f4e13bef8e688698b067ae68eb4f9b30e3a"},
+ {file = "rpds_py-0.24.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:4b28e5122829181de1898c2c97f81c0b3246d49f585f22743a1246420bb8d399"},
+ {file = "rpds_py-0.24.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8e5ab32cf9eb3647450bc74eb201b27c185d3857276162c101c0f8c6374e098"},
+ {file = "rpds_py-0.24.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:208b3a70a98cf3710e97cabdc308a51cd4f28aa6e7bb11de3d56cd8b74bab98d"},
+ {file = "rpds_py-0.24.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bbc4362e06f950c62cad3d4abf1191021b2ffaf0b31ac230fbf0526453eee75e"},
+ {file = "rpds_py-0.24.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ebea2821cdb5f9fef44933617be76185b80150632736f3d76e54829ab4a3b4d1"},
+ {file = "rpds_py-0.24.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b9a4df06c35465ef4d81799999bba810c68d29972bf1c31db61bfdb81dd9d5bb"},
+ {file = "rpds_py-0.24.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d3aa13bdf38630da298f2e0d77aca967b200b8cc1473ea05248f6c5e9c9bdb44"},
+ {file = "rpds_py-0.24.0-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:041f00419e1da7a03c46042453598479f45be3d787eb837af382bfc169c0db33"},
+ {file = "rpds_py-0.24.0-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:d8754d872a5dfc3c5bf9c0e059e8107451364a30d9fd50f1f1a85c4fb9481164"},
+ {file = "rpds_py-0.24.0-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:896c41007931217a343eff197c34513c154267636c8056fb409eafd494c3dcdc"},
+ {file = "rpds_py-0.24.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:92558d37d872e808944c3c96d0423b8604879a3d1c86fdad508d7ed91ea547d5"},
+ {file = "rpds_py-0.24.0-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f9e0057a509e096e47c87f753136c9b10d7a91842d8042c2ee6866899a717c0d"},
+ {file = "rpds_py-0.24.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:d6e109a454412ab82979c5b1b3aee0604eca4bbf9a02693bb9df027af2bfa91a"},
+ {file = "rpds_py-0.24.0-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc1c892b1ec1f8cbd5da8de287577b455e388d9c328ad592eabbdcb6fc93bee5"},
+ {file = "rpds_py-0.24.0-pp311-pypy311_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9c39438c55983d48f4bb3487734d040e22dad200dab22c41e331cee145e7a50d"},
+ {file = "rpds_py-0.24.0-pp311-pypy311_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9d7e8ce990ae17dda686f7e82fd41a055c668e13ddcf058e7fb5e9da20b57793"},
+ {file = "rpds_py-0.24.0-pp311-pypy311_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9ea7f4174d2e4194289cb0c4e172d83e79a6404297ff95f2875cf9ac9bced8ba"},
+ {file = "rpds_py-0.24.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb2954155bb8f63bb19d56d80e5e5320b61d71084617ed89efedb861a684baea"},
+ {file = "rpds_py-0.24.0-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:04f2b712a2206e13800a8136b07aaedc23af3facab84918e7aa89e4be0260032"},
+ {file = "rpds_py-0.24.0-pp311-pypy311_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:eda5c1e2a715a4cbbca2d6d304988460942551e4e5e3b7457b50943cd741626d"},
+ {file = "rpds_py-0.24.0-pp311-pypy311_pp73-musllinux_1_2_i686.whl", hash = "sha256:9abc80fe8c1f87218db116016de575a7998ab1629078c90840e8d11ab423ee25"},
+ {file = "rpds_py-0.24.0-pp311-pypy311_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:6a727fd083009bc83eb83d6950f0c32b3c94c8b80a9b667c87f4bd1274ca30ba"},
+ {file = "rpds_py-0.24.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:e0f3ef95795efcd3b2ec3fe0a5bcfb5dadf5e3996ea2117427e524d4fbf309c6"},
+ {file = "rpds_py-0.24.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:2c13777ecdbbba2077670285dd1fe50828c8742f6a4119dbef6f83ea13ad10fb"},
+ {file = "rpds_py-0.24.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79e8d804c2ccd618417e96720ad5cd076a86fa3f8cb310ea386a3e6229bae7d1"},
+ {file = "rpds_py-0.24.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fd822f019ccccd75c832deb7aa040bb02d70a92eb15a2f16c7987b7ad4ee8d83"},
+ {file = "rpds_py-0.24.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0047638c3aa0dbcd0ab99ed1e549bbf0e142c9ecc173b6492868432d8989a046"},
+ {file = "rpds_py-0.24.0-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a5b66d1b201cc71bc3081bc2f1fc36b0c1f268b773e03bbc39066651b9e18391"},
+ {file = "rpds_py-0.24.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dbcbb6db5582ea33ce46a5d20a5793134b5365110d84df4e30b9d37c6fd40ad3"},
+ {file = "rpds_py-0.24.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:63981feca3f110ed132fd217bf7768ee8ed738a55549883628ee3da75bb9cb78"},
+ {file = "rpds_py-0.24.0-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:3a55fc10fdcbf1a4bd3c018eea422c52cf08700cf99c28b5cb10fe97ab77a0d3"},
+ {file = "rpds_py-0.24.0-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:c30ff468163a48535ee7e9bf21bd14c7a81147c0e58a36c1078289a8ca7af0bd"},
+ {file = "rpds_py-0.24.0-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:369d9c6d4c714e36d4a03957b4783217a3ccd1e222cdd67d464a3a479fc17796"},
+ {file = "rpds_py-0.24.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:24795c099453e3721fda5d8ddd45f5dfcc8e5a547ce7b8e9da06fecc3832e26f"},
+ {file = "rpds_py-0.24.0.tar.gz", hash = "sha256:772cc1b2cd963e7e17e6cc55fe0371fb9c704d63e44cacec7b9b7f523b78919e"},
+]
+
+[[package]]
+name = "safetensors"
+version = "0.5.3"
+description = ""
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "safetensors-0.5.3-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:bd20eb133db8ed15b40110b7c00c6df51655a2998132193de2f75f72d99c7073"},
+ {file = "safetensors-0.5.3-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:21d01c14ff6c415c485616b8b0bf961c46b3b343ca59110d38d744e577f9cce7"},
+ {file = "safetensors-0.5.3-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11bce6164887cd491ca75c2326a113ba934be596e22b28b1742ce27b1d076467"},
+ {file = "safetensors-0.5.3-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4a243be3590bc3301c821da7a18d87224ef35cbd3e5f5727e4e0728b8172411e"},
+ {file = "safetensors-0.5.3-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8bd84b12b1670a6f8e50f01e28156422a2bc07fb16fc4e98bded13039d688a0d"},
+ {file = "safetensors-0.5.3-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:391ac8cab7c829452175f871fcaf414aa1e292b5448bd02620f675a7f3e7abb9"},
+ {file = "safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cead1fa41fc54b1e61089fa57452e8834f798cb1dc7a09ba3524f1eb08e0317a"},
+ {file = "safetensors-0.5.3-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1077f3e94182d72618357b04b5ced540ceb71c8a813d3319f1aba448e68a770d"},
+ {file = "safetensors-0.5.3-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:799021e78287bac619c7b3f3606730a22da4cda27759ddf55d37c8db7511c74b"},
+ {file = "safetensors-0.5.3-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:df26da01aaac504334644e1b7642fa000bfec820e7cef83aeac4e355e03195ff"},
+ {file = "safetensors-0.5.3-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:32c3ef2d7af8b9f52ff685ed0bc43913cdcde135089ae322ee576de93eae5135"},
+ {file = "safetensors-0.5.3-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:37f1521be045e56fc2b54c606d4455573e717b2d887c579ee1dbba5f868ece04"},
+ {file = "safetensors-0.5.3-cp38-abi3-win32.whl", hash = "sha256:cfc0ec0846dcf6763b0ed3d1846ff36008c6e7290683b61616c4b040f6a54ace"},
+ {file = "safetensors-0.5.3-cp38-abi3-win_amd64.whl", hash = "sha256:836cbbc320b47e80acd40e44c8682db0e8ad7123209f69b093def21ec7cafd11"},
+ {file = "safetensors-0.5.3.tar.gz", hash = "sha256:b6b0d6ecacec39a4fdd99cc19f4576f5219ce858e6fd8dbe7609df0b8dc56965"},
+]
+
+[package.extras]
+all = ["safetensors[jax]", "safetensors[numpy]", "safetensors[paddlepaddle]", "safetensors[pinned-tf]", "safetensors[quality]", "safetensors[testing]", "safetensors[torch]"]
+dev = ["safetensors[all]"]
+jax = ["flax (>=0.6.3)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "safetensors[numpy]"]
+mlx = ["mlx (>=0.0.9)"]
+numpy = ["numpy (>=1.21.6)"]
+paddlepaddle = ["paddlepaddle (>=2.4.1)", "safetensors[numpy]"]
+pinned-tf = ["safetensors[numpy]", "tensorflow (==2.18.0)"]
+quality = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "isort (>=5.5.4)"]
+tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"]
+testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools-rust (>=1.5.2)"]
+torch = ["safetensors[numpy]", "torch (>=1.10)"]
+
+[[package]]
+name = "scikit-learn"
+version = "1.6.1"
+description = "A set of python modules for machine learning and data mining"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "scikit_learn-1.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d056391530ccd1e501056160e3c9673b4da4805eb67eb2bdf4e983e1f9c9204e"},
+ {file = "scikit_learn-1.6.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:0c8d036eb937dbb568c6242fa598d551d88fb4399c0344d95c001980ec1c7d36"},
+ {file = "scikit_learn-1.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8634c4bd21a2a813e0a7e3900464e6d593162a29dd35d25bdf0103b3fce60ed5"},
+ {file = "scikit_learn-1.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:775da975a471c4f6f467725dff0ced5c7ac7bda5e9316b260225b48475279a1b"},
+ {file = "scikit_learn-1.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:8a600c31592bd7dab31e1c61b9bbd6dea1b3433e67d264d17ce1017dbdce8002"},
+ {file = "scikit_learn-1.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:72abc587c75234935e97d09aa4913a82f7b03ee0b74111dcc2881cba3c5a7b33"},
+ {file = "scikit_learn-1.6.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b3b00cdc8f1317b5f33191df1386c0befd16625f49d979fe77a8d44cae82410d"},
+ {file = "scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dc4765af3386811c3ca21638f63b9cf5ecf66261cc4815c1db3f1e7dc7b79db2"},
+ {file = "scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:25fc636bdaf1cc2f4a124a116312d837148b5e10872147bdaf4887926b8c03d8"},
+ {file = "scikit_learn-1.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:fa909b1a36e000a03c382aade0bd2063fd5680ff8b8e501660c0f59f021a6415"},
+ {file = "scikit_learn-1.6.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:926f207c804104677af4857b2c609940b743d04c4c35ce0ddc8ff4f053cddc1b"},
+ {file = "scikit_learn-1.6.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:2c2cae262064e6a9b77eee1c8e768fc46aa0b8338c6a8297b9b6759720ec0ff2"},
+ {file = "scikit_learn-1.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1061b7c028a8663fb9a1a1baf9317b64a257fcb036dae5c8752b2abef31d136f"},
+ {file = "scikit_learn-1.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e69fab4ebfc9c9b580a7a80111b43d214ab06250f8a7ef590a4edf72464dd86"},
+ {file = "scikit_learn-1.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:70b1d7e85b1c96383f872a519b3375f92f14731e279a7b4c6cfd650cf5dffc52"},
+ {file = "scikit_learn-1.6.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:2ffa1e9e25b3d93990e74a4be2c2fc61ee5af85811562f1288d5d055880c4322"},
+ {file = "scikit_learn-1.6.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:dc5cf3d68c5a20ad6d571584c0750ec641cc46aeef1c1507be51300e6003a7e1"},
+ {file = "scikit_learn-1.6.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c06beb2e839ecc641366000ca84f3cf6fa9faa1777e29cf0c04be6e4d096a348"},
+ {file = "scikit_learn-1.6.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8ca8cb270fee8f1f76fa9bfd5c3507d60c6438bbee5687f81042e2bb98e5a97"},
+ {file = "scikit_learn-1.6.1-cp313-cp313-win_amd64.whl", hash = "sha256:7a1c43c8ec9fde528d664d947dc4c0789be4077a3647f232869f41d9bf50e0fb"},
+ {file = "scikit_learn-1.6.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:a17c1dea1d56dcda2fac315712f3651a1fea86565b64b48fa1bc090249cbf236"},
+ {file = "scikit_learn-1.6.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:6a7aa5f9908f0f28f4edaa6963c0a6183f1911e63a69aa03782f0d924c830a35"},
+ {file = "scikit_learn-1.6.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0650e730afb87402baa88afbf31c07b84c98272622aaba002559b614600ca691"},
+ {file = "scikit_learn-1.6.1-cp313-cp313t-win_amd64.whl", hash = "sha256:3f59fe08dc03ea158605170eb52b22a105f238a5d512c4470ddeca71feae8e5f"},
+ {file = "scikit_learn-1.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6849dd3234e87f55dce1db34c89a810b489ead832aaf4d4550b7ea85628be6c1"},
+ {file = "scikit_learn-1.6.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:e7be3fa5d2eb9be7d77c3734ff1d599151bb523674be9b834e8da6abe132f44e"},
+ {file = "scikit_learn-1.6.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:44a17798172df1d3c1065e8fcf9019183f06c87609b49a124ebdf57ae6cb0107"},
+ {file = "scikit_learn-1.6.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8b7a3b86e411e4bce21186e1c180d792f3d99223dcfa3b4f597ecc92fa1a422"},
+ {file = "scikit_learn-1.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:7a73d457070e3318e32bdb3aa79a8d990474f19035464dfd8bede2883ab5dc3b"},
+ {file = "scikit_learn-1.6.1.tar.gz", hash = "sha256:b4fc2525eca2c69a59260f583c56a7557c6ccdf8deafdba6e060f94c1c59738e"},
+]
+
+[package.dependencies]
+joblib = ">=1.2.0"
+numpy = ">=1.19.5"
+scipy = ">=1.6.0"
+threadpoolctl = ">=3.1.0"
+
+[package.extras]
+benchmark = ["matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "pandas (>=1.1.5)"]
+build = ["cython (>=3.0.10)", "meson-python (>=0.16.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)"]
+docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pydata-sphinx-theme (>=0.15.3)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=7.3.7)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-design (>=0.6.0)", "sphinx-gallery (>=0.17.1)", "sphinx-prompt (>=1.4.0)", "sphinx-remove-toctrees (>=1.0.0.post1)", "sphinxcontrib-sass (>=0.3.4)", "sphinxext-opengraph (>=0.9.1)", "towncrier (>=24.8.0)"]
+examples = ["matplotlib (>=3.3.4)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)"]
+install = ["joblib (>=1.2.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)", "threadpoolctl (>=3.1.0)"]
+maintenance = ["conda-lock (==2.5.6)"]
+tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.5.1)", "scikit-image (>=0.17.2)"]
+
+[[package]]
+name = "scipy"
+version = "1.13.1"
+description = "Fundamental algorithms for scientific computing in Python"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "scipy-1.13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca"},
+ {file = "scipy-1.13.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f"},
+ {file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989"},
+ {file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f"},
+ {file = "scipy-1.13.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94"},
+ {file = "scipy-1.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54"},
+ {file = "scipy-1.13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9"},
+ {file = "scipy-1.13.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326"},
+ {file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299"},
+ {file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa"},
+ {file = "scipy-1.13.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59"},
+ {file = "scipy-1.13.1-cp311-cp311-win_amd64.whl", hash = "sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b"},
+ {file = "scipy-1.13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1"},
+ {file = "scipy-1.13.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d"},
+ {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627"},
+ {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884"},
+ {file = "scipy-1.13.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16"},
+ {file = "scipy-1.13.1-cp312-cp312-win_amd64.whl", hash = "sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949"},
+ {file = "scipy-1.13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5"},
+ {file = "scipy-1.13.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24"},
+ {file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004"},
+ {file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d"},
+ {file = "scipy-1.13.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c"},
+ {file = "scipy-1.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2"},
+ {file = "scipy-1.13.1.tar.gz", hash = "sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c"},
+]
+
+[package.dependencies]
+numpy = ">=1.22.4,<2.3"
+
+[package.extras]
+dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"]
+doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.12.0)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"]
+test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
+
+[[package]]
+name = "sentence-transformers"
+version = "3.3.1"
+description = "State-of-the-Art Text Embeddings"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "sentence_transformers-3.3.1-py3-none-any.whl", hash = "sha256:abffcc79dab37b7d18d21a26d5914223dd42239cfe18cb5e111c66c54b658ae7"},
+ {file = "sentence_transformers-3.3.1.tar.gz", hash = "sha256:9635dbfb11c6b01d036b9cfcee29f7716ab64cf2407ad9f403a2e607da2ac48b"},
+]
+
+[package.dependencies]
+huggingface-hub = ">=0.20.0"
+Pillow = "*"
+scikit-learn = "*"
+scipy = "*"
+torch = ">=1.11.0"
+tqdm = "*"
+transformers = ">=4.41.0,<5.0.0"
+
+[package.extras]
+dev = ["accelerate (>=0.20.3)", "datasets", "peft", "pre-commit", "pytest", "pytest-cov"]
+onnx = ["optimum[onnxruntime] (>=1.23.1)"]
+onnx-gpu = ["optimum[onnxruntime-gpu] (>=1.23.1)"]
+openvino = ["optimum-intel[openvino] (>=1.20.0)"]
+train = ["accelerate (>=0.20.3)", "datasets"]
+
+[[package]]
+name = "sentencepiece"
+version = "0.2.0"
+description = "SentencePiece python wrapper"
+optional = false
+python-versions = "*"
+groups = ["main"]
+files = [
+ {file = "sentencepiece-0.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:188779e1298a1c8b8253c7d3ad729cb0a9891e5cef5e5d07ce4592c54869e227"},
+ {file = "sentencepiece-0.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bed9cf85b296fa2b76fc2547b9cbb691a523864cebaee86304c43a7b4cb1b452"},
+ {file = "sentencepiece-0.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d7b67e724bead13f18db6e1d10b6bbdc454af574d70efbb36f27d90387be1ca3"},
+ {file = "sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2fde4b08cfe237be4484c6c7c2e2c75fb862cfeab6bd5449ce4caeafd97b767a"},
+ {file = "sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c378492056202d1c48a4979650981635fd97875a00eabb1f00c6a236b013b5e"},
+ {file = "sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1380ce6540a368de2ef6d7e6ba14ba8f3258df650d39ba7d833b79ee68a52040"},
+ {file = "sentencepiece-0.2.0-cp310-cp310-win32.whl", hash = "sha256:a1151d6a6dd4b43e552394aed0edfe9292820272f0194bd56c7c1660a0c06c3d"},
+ {file = "sentencepiece-0.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:d490142b0521ef22bc1085f061d922a2a6666175bb6b42e588ff95c0db6819b2"},
+ {file = "sentencepiece-0.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:17982700c4f6dbb55fa3594f3d7e5dd1c8659a274af3738e33c987d2a27c9d5c"},
+ {file = "sentencepiece-0.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7c867012c0e8bcd5bdad0f791609101cb5c66acb303ab3270218d6debc68a65e"},
+ {file = "sentencepiece-0.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7fd6071249c74f779c5b27183295b9202f8dedb68034e716784364443879eaa6"},
+ {file = "sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27f90c55a65013cbb8f4d7aab0599bf925cde4adc67ae43a0d323677b5a1c6cb"},
+ {file = "sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b293734059ef656dcd65be62ff771507bea8fed0a711b6733976e1ed3add4553"},
+ {file = "sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e58b47f933aca74c6a60a79dcb21d5b9e47416256c795c2d58d55cec27f9551d"},
+ {file = "sentencepiece-0.2.0-cp311-cp311-win32.whl", hash = "sha256:c581258cf346b327c62c4f1cebd32691826306f6a41d8c4bec43b010dee08e75"},
+ {file = "sentencepiece-0.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:0993dbc665f4113017892f1b87c3904a44d0640eda510abcacdfb07f74286d36"},
+ {file = "sentencepiece-0.2.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ea5f536e32ea8ec96086ee00d7a4a131ce583a1b18d130711707c10e69601cb2"},
+ {file = "sentencepiece-0.2.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0cb51f53b6aae3c36bafe41e86167c71af8370a039f542c43b0cce5ef24a68c"},
+ {file = "sentencepiece-0.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3212121805afc58d8b00ab4e7dd1f8f76c203ddb9dc94aa4079618a31cf5da0f"},
+ {file = "sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a3149e3066c2a75e0d68a43eb632d7ae728c7925b517f4c05c40f6f7280ce08"},
+ {file = "sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:632f3594d3e7ac8b367bca204cb3fd05a01d5b21455acd097ea4c0e30e2f63d7"},
+ {file = "sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f295105c6bdbb05bd5e1b0cafbd78ff95036f5d3641e7949455a3f4e5e7c3109"},
+ {file = "sentencepiece-0.2.0-cp312-cp312-win32.whl", hash = "sha256:fb89f811e5efd18bab141afc3fea3de141c3f69f3fe9e898f710ae7fe3aab251"},
+ {file = "sentencepiece-0.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:7a673a72aab81fef5ebe755c6e0cc60087d1f3a4700835d40537183c1703a45f"},
+ {file = "sentencepiece-0.2.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4547683f330289ec4f093027bfeb87f9ef023b2eb6f879fdc4a8187c7e0ffb90"},
+ {file = "sentencepiece-0.2.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7cd6175f7eaec7142d2bf6f6597ce7db4c9ac89acf93fcdb17410c3a8b781eeb"},
+ {file = "sentencepiece-0.2.0-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:859ba1acde782609a0910a26a60e16c191a82bf39b5621107552c0cd79fad00f"},
+ {file = "sentencepiece-0.2.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcbbef6cc277f8f18f36959e305f10b1c620442d75addc79c21d7073ae581b50"},
+ {file = "sentencepiece-0.2.0-cp36-cp36m-win32.whl", hash = "sha256:536b934e244829e3fe6c4f198652cd82da48adb9aa145c9f00889542726dee3d"},
+ {file = "sentencepiece-0.2.0-cp36-cp36m-win_amd64.whl", hash = "sha256:0a91aaa3c769b52440df56fafda683b3aa48e3f2169cf7ee5b8c8454a7f3ae9b"},
+ {file = "sentencepiece-0.2.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:787e480ca4c1d08c9985a7eb1eae4345c107729c99e9b5a9a00f2575fc7d4b4b"},
+ {file = "sentencepiece-0.2.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4d158189eb2ecffea3a51edf6d25e110b3678ec47f1a40f2d541eafbd8f6250"},
+ {file = "sentencepiece-0.2.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1e5ca43013e8935f25457a4fca47e315780172c3e821b4b13a890668911c792"},
+ {file = "sentencepiece-0.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7140d9e5a74a0908493bb4a13f1f16a401297bd755ada4c707e842fbf6f0f5bf"},
+ {file = "sentencepiece-0.2.0-cp37-cp37m-win32.whl", hash = "sha256:6cf333625234f247ab357b0bd9836638405ea9082e1543d5b8408f014979dcbf"},
+ {file = "sentencepiece-0.2.0-cp37-cp37m-win_amd64.whl", hash = "sha256:ff88712338b01031910e8e61e7239aff3ce8869ee31a47df63cb38aadd591bea"},
+ {file = "sentencepiece-0.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:20813a68d4c221b1849c62c30e1281ea81687894d894b8d4a0f4677d9311e0f5"},
+ {file = "sentencepiece-0.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:926ef920ae2e8182db31d3f5d081ada57804e3e1d3a8c4ef8b117f9d9fb5a945"},
+ {file = "sentencepiece-0.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:89f65f69636b7e9c015b79dff9c9985a9bc7d19ded6f79ef9f1ec920fdd73ecf"},
+ {file = "sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f67eae0dbe6f2d7d6ba50a354623d787c99965f068b81e145d53240198021b0"},
+ {file = "sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:98501e075f35dd1a1d5a20f65be26839fcb1938752ec61539af008a5aa6f510b"},
+ {file = "sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3d1d2cc4882e8d6a1adf9d5927d7716f80617fc693385661caff21888972269"},
+ {file = "sentencepiece-0.2.0-cp38-cp38-win32.whl", hash = "sha256:b99a308a2e5e569031ab164b74e6fab0b6f37dfb493c32f7816225f4d411a6dd"},
+ {file = "sentencepiece-0.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:cdb701eec783d3ec86b7cd4c763adad8eaf6b46db37ee1c36e5e6c44b3fe1b5f"},
+ {file = "sentencepiece-0.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1e0f9c4d0a6b0af59b613175f019916e28ade076e21242fd5be24340d8a2f64a"},
+ {file = "sentencepiece-0.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:298f21cc1366eb60311aedba3169d30f885c363ddbf44214b0a587d2908141ad"},
+ {file = "sentencepiece-0.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3f1ec95aa1e5dab11f37ac7eff190493fd87770f7a8b81ebc9dd768d1a3c8704"},
+ {file = "sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b06b70af54daa4b4904cbb90b4eb6d35c9f3252fdc86c9c32d5afd4d30118d8"},
+ {file = "sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:22e37bac44dd6603388cb598c64ff7a76e41ca774646f21c23aadfbf5a2228ab"},
+ {file = "sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0461324897735512a32d222e3d886e24ad6a499761952b6bda2a9ee6e4313ea5"},
+ {file = "sentencepiece-0.2.0-cp39-cp39-win32.whl", hash = "sha256:38aed822fb76435fa1f12185f10465a94ab9e51d5e8a9159e9a540ce926f0ffd"},
+ {file = "sentencepiece-0.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:d8cf876516548b5a1d6ac4745d8b554f5c07891d55da557925e5c13ff0b4e6ad"},
+ {file = "sentencepiece-0.2.0.tar.gz", hash = "sha256:a52c19171daaf2e697dc6cbe67684e0fa341b1248966f6aebb541de654d15843"},
+]
+
+[[package]]
+name = "setuptools"
+version = "78.1.0"
+description = "Easily download, build, install, upgrade, and uninstall Python packages"
+optional = false
+python-versions = ">=3.9"
+groups = ["main", "dev"]
+files = [
+ {file = "setuptools-78.1.0-py3-none-any.whl", hash = "sha256:3e386e96793c8702ae83d17b853fb93d3e09ef82ec62722e61da5cd22376dcd8"},
+ {file = "setuptools-78.1.0.tar.gz", hash = "sha256:18fd474d4a82a5f83dac888df697af65afa82dec7323d09c3e37d1f14288da54"},
+]
+markers = {main = "python_version >= \"3.12\""}
+
+[package.extras]
+check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.8.0)"]
+core = ["importlib_metadata (>=6)", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"]
+cover = ["pytest-cov"]
+doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
+enabler = ["pytest-enabler (>=2.2)"]
+test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
+type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.14.*)", "pytest-mypy"]
+
+[[package]]
+name = "shellingham"
+version = "1.5.4"
+description = "Tool to Detect Surrounding Shell"
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686"},
+ {file = "shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de"},
+]
+
+[[package]]
+name = "sympy"
+version = "1.13.1"
+description = "Computer algebra system (CAS) in Python"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8"},
+ {file = "sympy-1.13.1.tar.gz", hash = "sha256:9cebf7e04ff162015ce31c9c6c9144daa34a93bd082f54fd8f12deca4f47515f"},
+]
+
+[package.dependencies]
+mpmath = ">=1.1.0,<1.4"
+
+[package.extras]
+dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"]
+
+[[package]]
+name = "threadpoolctl"
+version = "3.6.0"
+description = "threadpoolctl"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb"},
+ {file = "threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e"},
+]
+
+[[package]]
+name = "tokenizers"
+version = "0.21.1"
+description = ""
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "tokenizers-0.21.1-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:e78e413e9e668ad790a29456e677d9d3aa50a9ad311a40905d6861ba7692cf41"},
+ {file = "tokenizers-0.21.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:cd51cd0a91ecc801633829fcd1fda9cf8682ed3477c6243b9a095539de4aecf3"},
+ {file = "tokenizers-0.21.1-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28da6b72d4fb14ee200a1bd386ff74ade8992d7f725f2bde2c495a9a98cf4d9f"},
+ {file = "tokenizers-0.21.1-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:34d8cfde551c9916cb92014e040806122295a6800914bab5865deb85623931cf"},
+ {file = "tokenizers-0.21.1-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aaa852d23e125b73d283c98f007e06d4595732104b65402f46e8ef24b588d9f8"},
+ {file = "tokenizers-0.21.1-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a21a15d5c8e603331b8a59548bbe113564136dc0f5ad8306dd5033459a226da0"},
+ {file = "tokenizers-0.21.1-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2fdbd4c067c60a0ac7eca14b6bd18a5bebace54eb757c706b47ea93204f7a37c"},
+ {file = "tokenizers-0.21.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2dd9a0061e403546f7377df940e866c3e678d7d4e9643d0461ea442b4f89e61a"},
+ {file = "tokenizers-0.21.1-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:db9484aeb2e200c43b915a1a0150ea885e35f357a5a8fabf7373af333dcc8dbf"},
+ {file = "tokenizers-0.21.1-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:ed248ab5279e601a30a4d67bdb897ecbe955a50f1e7bb62bd99f07dd11c2f5b6"},
+ {file = "tokenizers-0.21.1-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:9ac78b12e541d4ce67b4dfd970e44c060a2147b9b2a21f509566d556a509c67d"},
+ {file = "tokenizers-0.21.1-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:e5a69c1a4496b81a5ee5d2c1f3f7fbdf95e90a0196101b0ee89ed9956b8a168f"},
+ {file = "tokenizers-0.21.1-cp39-abi3-win32.whl", hash = "sha256:1039a3a5734944e09de1d48761ade94e00d0fa760c0e0551151d4dd851ba63e3"},
+ {file = "tokenizers-0.21.1-cp39-abi3-win_amd64.whl", hash = "sha256:0f0dcbcc9f6e13e675a66d7a5f2f225a736745ce484c1a4e07476a89ccdad382"},
+ {file = "tokenizers-0.21.1.tar.gz", hash = "sha256:a1bb04dc5b448985f86ecd4b05407f5a8d97cb2c0532199b2a302a604a0165ab"},
+]
+
+[package.dependencies]
+huggingface-hub = ">=0.16.4,<1.0"
+
+[package.extras]
+dev = ["tokenizers[testing]"]
+docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"]
+testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests", "ruff"]
+
+[[package]]
+name = "tomli"
+version = "2.2.1"
+description = "A lil' TOML parser"
+optional = false
+python-versions = ">=3.8"
+groups = ["dev"]
+markers = "python_version < \"3.11\""
+files = [
+ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
+ {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"},
+ {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a"},
+ {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee"},
+ {file = "tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e"},
+ {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4"},
+ {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106"},
+ {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8"},
+ {file = "tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff"},
+ {file = "tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b"},
+ {file = "tomli-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea"},
+ {file = "tomli-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8"},
+ {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192"},
+ {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222"},
+ {file = "tomli-2.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77"},
+ {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6"},
+ {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd"},
+ {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e"},
+ {file = "tomli-2.2.1-cp312-cp312-win32.whl", hash = "sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98"},
+ {file = "tomli-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4"},
+ {file = "tomli-2.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f4039b9cbc3048b2416cc57ab3bda989a6fcf9b36cf8937f01a6e731b64f80d7"},
+ {file = "tomli-2.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:286f0ca2ffeeb5b9bd4fcc8d6c330534323ec51b2f52da063b11c502da16f30c"},
+ {file = "tomli-2.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a92ef1a44547e894e2a17d24e7557a5e85a9e1d0048b0b5e7541f76c5032cb13"},
+ {file = "tomli-2.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9316dc65bed1684c9a98ee68759ceaed29d229e985297003e494aa825ebb0281"},
+ {file = "tomli-2.2.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e85e99945e688e32d5a35c1ff38ed0b3f41f43fad8df0bdf79f72b2ba7bc5272"},
+ {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ac065718db92ca818f8d6141b5f66369833d4a80a9d74435a268c52bdfa73140"},
+ {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:d920f33822747519673ee656a4b6ac33e382eca9d331c87770faa3eef562aeb2"},
+ {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a198f10c4d1b1375d7687bc25294306e551bf1abfa4eace6650070a5c1ae2744"},
+ {file = "tomli-2.2.1-cp313-cp313-win32.whl", hash = "sha256:d3f5614314d758649ab2ab3a62d4f2004c825922f9e370b29416484086b264ec"},
+ {file = "tomli-2.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:a38aa0308e754b0e3c67e344754dff64999ff9b513e691d0e786265c93583c69"},
+ {file = "tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc"},
+ {file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"},
+]
+
+[[package]]
+name = "torch"
+version = "2.6.0"
+description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
+optional = false
+python-versions = ">=3.9.0"
+groups = ["main"]
+files = [
+ {file = "torch-2.6.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:6860df13d9911ac158f4c44031609700e1eba07916fff62e21e6ffa0a9e01961"},
+ {file = "torch-2.6.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:c4f103a49830ce4c7561ef4434cc7926e5a5fe4e5eb100c19ab36ea1e2b634ab"},
+ {file = "torch-2.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:56eeaf2ecac90da5d9e35f7f35eb286da82673ec3c582e310a8d1631a1c02341"},
+ {file = "torch-2.6.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:09e06f9949e1a0518c5b09fe95295bc9661f219d9ecb6f9893e5123e10696628"},
+ {file = "torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:7979834102cd5b7a43cc64e87f2f3b14bd0e1458f06e9f88ffa386d07c7446e1"},
+ {file = "torch-2.6.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:ccbd0320411fe1a3b3fec7b4d3185aa7d0c52adac94480ab024b5c8f74a0bf1d"},
+ {file = "torch-2.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:46763dcb051180ce1ed23d1891d9b1598e07d051ce4c9d14307029809c4d64f7"},
+ {file = "torch-2.6.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:94fc63b3b4bedd327af588696559f68c264440e2503cc9e6954019473d74ae21"},
+ {file = "torch-2.6.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:2bb8987f3bb1ef2675897034402373ddfc8f5ef0e156e2d8cfc47cacafdda4a9"},
+ {file = "torch-2.6.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:b789069020c5588c70d5c2158ac0aa23fd24a028f34a8b4fcb8fcb4d7efcf5fb"},
+ {file = "torch-2.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:7e1448426d0ba3620408218b50aa6ada88aeae34f7a239ba5431f6c8774b1239"},
+ {file = "torch-2.6.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:9a610afe216a85a8b9bc9f8365ed561535c93e804c2a317ef7fabcc5deda0989"},
+ {file = "torch-2.6.0-cp313-cp313-manylinux1_x86_64.whl", hash = "sha256:4874a73507a300a5d089ceaff616a569e7bb7c613c56f37f63ec3ffac65259cf"},
+ {file = "torch-2.6.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:a0d5e1b9874c1a6c25556840ab8920569a7a4137afa8a63a32cee0bc7d89bd4b"},
+ {file = "torch-2.6.0-cp313-cp313-win_amd64.whl", hash = "sha256:510c73251bee9ba02ae1cb6c9d4ee0907b3ce6020e62784e2d7598e0cfa4d6cc"},
+ {file = "torch-2.6.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:ff96f4038f8af9f7ec4231710ed4549da1bdebad95923953a25045dcf6fd87e2"},
+ {file = "torch-2.6.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:9ea955317cfcd3852b1402b62af258ce735c2edeee42ca9419b6bc889e5ae053"},
+ {file = "torch-2.6.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:bb2c6c3e65049f081940f5ab15c9136c7de40d3f01192541c920a07c7c585b7e"},
+ {file = "torch-2.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:683410f97984103148e31b38a8631acf31c3034c020c0f4d26171e7626d8317a"},
+ {file = "torch-2.6.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:265f70de5fd45b864d924b64be1797f86e76c8e48a02c2a3a6fc7ec247d2226c"},
+]
+
+[package.dependencies]
+filelock = "*"
+fsspec = "*"
+jinja2 = "*"
+networkx = "*"
+nvidia-cublas-cu12 = {version = "12.4.5.8", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cuda-cupti-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cuda-nvrtc-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cuda-runtime-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cudnn-cu12 = {version = "9.1.0.70", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cufft-cu12 = {version = "11.2.1.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-curand-cu12 = {version = "10.3.5.147", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cusolver-cu12 = {version = "11.6.1.9", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cusparse-cu12 = {version = "12.3.1.170", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cusparselt-cu12 = {version = "0.6.2", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-nccl-cu12 = {version = "2.21.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-nvjitlink-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-nvtx-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+setuptools = {version = "*", markers = "python_version >= \"3.12\""}
+sympy = {version = "1.13.1", markers = "python_version >= \"3.9\""}
+triton = {version = "3.2.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+typing-extensions = ">=4.10.0"
+
+[package.extras]
+opt-einsum = ["opt-einsum (>=3.3)"]
+optree = ["optree (>=0.13.0)"]
+
+[[package]]
+name = "tqdm"
+version = "4.67.1"
+description = "Fast, Extensible Progress Meter"
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2"},
+ {file = "tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2"},
+]
+
+[package.dependencies]
+colorama = {version = "*", markers = "platform_system == \"Windows\""}
+
+[package.extras]
+dev = ["nbval", "pytest (>=6)", "pytest-asyncio (>=0.24)", "pytest-cov", "pytest-timeout"]
+discord = ["requests"]
+notebook = ["ipywidgets (>=6)"]
+slack = ["slack-sdk"]
+telegram = ["requests"]
+
+[[package]]
+name = "transformers"
+version = "4.49.0"
+description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
+optional = false
+python-versions = ">=3.9.0"
+groups = ["main"]
+files = [
+ {file = "transformers-4.49.0-py3-none-any.whl", hash = "sha256:6b4fded1c5fee04d384b1014495b4235a2b53c87503d7d592423c06128cbbe03"},
+ {file = "transformers-4.49.0.tar.gz", hash = "sha256:7e40e640b5b8dc3f48743f5f5adbdce3660c82baafbd3afdfc04143cdbd2089e"},
+]
+
+[package.dependencies]
+filelock = "*"
+huggingface-hub = ">=0.26.0,<1.0"
+numpy = ">=1.17"
+packaging = ">=20.0"
+pyyaml = ">=5.1"
+regex = "!=2019.12.17"
+requests = "*"
+safetensors = ">=0.4.1"
+tokenizers = ">=0.21,<0.22"
+tqdm = ">=4.27"
+
+[package.extras]
+accelerate = ["accelerate (>=0.26.0)"]
+agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=2.0)"]
+all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av", "codecarbon (>=2.8.1)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision"]
+audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
+benchmark = ["optimum-benchmark (>=0.3.0)"]
+codecarbon = ["codecarbon (>=2.8.1)"]
+deepspeed = ["accelerate (>=0.26.0)", "deepspeed (>=0.9.3)"]
+deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
+dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
+dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.21,<0.22)", "urllib3 (<2.0.0)"]
+dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "libcst", "librosa", "nltk (<=3.8.1)", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
+flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"]
+flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
+ftfy = ["ftfy"]
+integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"]
+ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"]
+modelcreation = ["cookiecutter (==1.7.3)"]
+natten = ["natten (>=0.14.6,<0.15.0)"]
+onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"]
+onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
+optuna = ["optuna"]
+quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "libcst", "rich", "ruff (==0.5.1)", "urllib3 (<2.0.0)"]
+ray = ["ray[tune] (>=2.7.0)"]
+retrieval = ["datasets (!=2.5.0)", "faiss-cpu"]
+ruff = ["ruff (==0.5.1)"]
+sagemaker = ["sagemaker (>=2.31.0)"]
+sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"]
+serving = ["fastapi", "pydantic", "starlette", "uvicorn"]
+sigopt = ["sigopt"]
+sklearn = ["scikit-learn"]
+speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
+testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
+tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"]
+tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"]
+tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
+tiktoken = ["blobfile", "tiktoken"]
+timm = ["timm (<=1.0.11)"]
+tokenizers = ["tokenizers (>=0.21,<0.22)"]
+torch = ["accelerate (>=0.26.0)", "torch (>=2.0)"]
+torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
+torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"]
+torchhub = ["filelock", "huggingface-hub (>=0.26.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "tqdm (>=4.27)"]
+video = ["av"]
+vision = ["Pillow (>=10.0.1,<=15.0)"]
+
+[[package]]
+name = "triton"
+version = "3.2.0"
+description = "A language and compiler for custom Deep Learning operations"
+optional = false
+python-versions = "*"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3e54983cd51875855da7c68ec05c05cf8bb08df361b1d5b69e05e40b0c9bd62"},
+ {file = "triton-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8009a1fb093ee8546495e96731336a33fb8856a38e45bb4ab6affd6dbc3ba220"},
+ {file = "triton-3.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d9b215efc1c26fa7eefb9a157915c92d52e000d2bf83e5f69704047e63f125c"},
+ {file = "triton-3.2.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5dfa23ba84541d7c0a531dfce76d8bcd19159d50a4a8b14ad01e91734a5c1b0"},
+ {file = "triton-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:30ceed0eff2c4a73b14eb63e052992f44bbdf175f3fad21e1ac8097a772de7ee"},
+]
+
+[package.extras]
+build = ["cmake (>=3.20)", "lit"]
+tests = ["autopep8", "flake8", "isort", "llnl-hatchet", "numpy", "pytest", "scipy (>=1.7.1)"]
+tutorials = ["matplotlib", "pandas", "tabulate"]
+
+[[package]]
+name = "typer"
+version = "0.15.2"
+description = "Typer, build great CLIs. Easy to code. Based on Python type hints."
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "typer-0.15.2-py3-none-any.whl", hash = "sha256:46a499c6107d645a9c13f7ee46c5d5096cae6f5fc57dd11eccbbb9ae3e44ddfc"},
+ {file = "typer-0.15.2.tar.gz", hash = "sha256:ab2fab47533a813c49fe1f16b1a370fd5819099c00b119e0633df65f22144ba5"},
+]
+
+[package.dependencies]
+click = ">=8.0.0"
+rich = ">=10.11.0"
+shellingham = ">=1.3.0"
+typing-extensions = ">=3.7.4.3"
+
+[[package]]
+name = "typing-extensions"
+version = "4.13.2"
+description = "Backported and Experimental Type Hints for Python 3.8+"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "typing_extensions-4.13.2-py3-none-any.whl", hash = "sha256:a439e7c04b49fec3e5d3e2beaa21755cadbbdc391694e28ccdd36ca4a1408f8c"},
+ {file = "typing_extensions-4.13.2.tar.gz", hash = "sha256:e6c81219bd689f51865d9e372991c540bda33a0379d5573cddb9a3a23f7caaef"},
+]
+
+[[package]]
+name = "typing-inspection"
+version = "0.4.0"
+description = "Runtime typing introspection tools"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "typing_inspection-0.4.0-py3-none-any.whl", hash = "sha256:50e72559fcd2a6367a19f7a7e610e6afcb9fac940c650290eed893d61386832f"},
+ {file = "typing_inspection-0.4.0.tar.gz", hash = "sha256:9765c87de36671694a67904bf2c96e395be9c6439bb6c87b5142569dcdd65122"},
+]
+
+[package.dependencies]
+typing-extensions = ">=4.12.0"
+
+[[package]]
+name = "urllib3"
+version = "2.4.0"
+description = "HTTP library with thread-safe connection pooling, file post, and more."
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "urllib3-2.4.0-py3-none-any.whl", hash = "sha256:4e16665048960a0900c702d4a66415956a584919c03361cac9f1df5c5dd7e813"},
+ {file = "urllib3-2.4.0.tar.gz", hash = "sha256:414bc6535b787febd7567804cc015fee39daab8ad86268f1310a9250697de466"},
+]
+
+[package.extras]
+brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
+h2 = ["h2 (>=4,<5)"]
+socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
+zstd = ["zstandard (>=0.18.0)"]
+
+[[package]]
+name = "win32-setctime"
+version = "1.2.0"
+description = "A small Python utility to set file creation time on Windows"
+optional = false
+python-versions = ">=3.5"
+groups = ["main"]
+markers = "sys_platform == \"win32\""
+files = [
+ {file = "win32_setctime-1.2.0-py3-none-any.whl", hash = "sha256:95d644c4e708aba81dc3704a116d8cbc974d70b3bdb8be1d150e36be6e9d1390"},
+ {file = "win32_setctime-1.2.0.tar.gz", hash = "sha256:ae1fdf948f5640aae05c511ade119313fb6a30d7eabe25fef9764dca5873c4c0"},
+]
+
+[package.extras]
+dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"]
+
+[[package]]
+name = "wrapt"
+version = "1.17.2"
+description = "Module for decorators, wrappers and monkey patching."
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "wrapt-1.17.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3d57c572081fed831ad2d26fd430d565b76aa277ed1d30ff4d40670b1c0dd984"},
+ {file = "wrapt-1.17.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5e251054542ae57ac7f3fba5d10bfff615b6c2fb09abeb37d2f1463f841ae22"},
+ {file = "wrapt-1.17.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:80dd7db6a7cb57ffbc279c4394246414ec99537ae81ffd702443335a61dbf3a7"},
+ {file = "wrapt-1.17.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a6e821770cf99cc586d33833b2ff32faebdbe886bd6322395606cf55153246c"},
+ {file = "wrapt-1.17.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b60fb58b90c6d63779cb0c0c54eeb38941bae3ecf7a73c764c52c88c2dcb9d72"},
+ {file = "wrapt-1.17.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b870b5df5b71d8c3359d21be8f0d6c485fa0ebdb6477dda51a1ea54a9b558061"},
+ {file = "wrapt-1.17.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4011d137b9955791f9084749cba9a367c68d50ab8d11d64c50ba1688c9b457f2"},
+ {file = "wrapt-1.17.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:1473400e5b2733e58b396a04eb7f35f541e1fb976d0c0724d0223dd607e0f74c"},
+ {file = "wrapt-1.17.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3cedbfa9c940fdad3e6e941db7138e26ce8aad38ab5fe9dcfadfed9db7a54e62"},
+ {file = "wrapt-1.17.2-cp310-cp310-win32.whl", hash = "sha256:582530701bff1dec6779efa00c516496968edd851fba224fbd86e46cc6b73563"},
+ {file = "wrapt-1.17.2-cp310-cp310-win_amd64.whl", hash = "sha256:58705da316756681ad3c9c73fd15499aa4d8c69f9fd38dc8a35e06c12468582f"},
+ {file = "wrapt-1.17.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ff04ef6eec3eee8a5efef2401495967a916feaa353643defcc03fc74fe213b58"},
+ {file = "wrapt-1.17.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4db983e7bca53819efdbd64590ee96c9213894272c776966ca6306b73e4affda"},
+ {file = "wrapt-1.17.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9abc77a4ce4c6f2a3168ff34b1da9b0f311a8f1cfd694ec96b0603dff1c79438"},
+ {file = "wrapt-1.17.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b929ac182f5ace000d459c59c2c9c33047e20e935f8e39371fa6e3b85d56f4a"},
+ {file = "wrapt-1.17.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f09b286faeff3c750a879d336fb6d8713206fc97af3adc14def0cdd349df6000"},
+ {file = "wrapt-1.17.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a7ed2d9d039bd41e889f6fb9364554052ca21ce823580f6a07c4ec245c1f5d6"},
+ {file = "wrapt-1.17.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:129a150f5c445165ff941fc02ee27df65940fcb8a22a61828b1853c98763a64b"},
+ {file = "wrapt-1.17.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1fb5699e4464afe5c7e65fa51d4f99e0b2eadcc176e4aa33600a3df7801d6662"},
+ {file = "wrapt-1.17.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9a2bce789a5ea90e51a02dfcc39e31b7f1e662bc3317979aa7e5538e3a034f72"},
+ {file = "wrapt-1.17.2-cp311-cp311-win32.whl", hash = "sha256:4afd5814270fdf6380616b321fd31435a462019d834f83c8611a0ce7484c7317"},
+ {file = "wrapt-1.17.2-cp311-cp311-win_amd64.whl", hash = "sha256:acc130bc0375999da18e3d19e5a86403667ac0c4042a094fefb7eec8ebac7cf3"},
+ {file = "wrapt-1.17.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:d5e2439eecc762cd85e7bd37161d4714aa03a33c5ba884e26c81559817ca0925"},
+ {file = "wrapt-1.17.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3fc7cb4c1c744f8c05cd5f9438a3caa6ab94ce8344e952d7c45a8ed59dd88392"},
+ {file = "wrapt-1.17.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8fdbdb757d5390f7c675e558fd3186d590973244fab0c5fe63d373ade3e99d40"},
+ {file = "wrapt-1.17.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5bb1d0dbf99411f3d871deb6faa9aabb9d4e744d67dcaaa05399af89d847a91d"},
+ {file = "wrapt-1.17.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d18a4865f46b8579d44e4fe1e2bcbc6472ad83d98e22a26c963d46e4c125ef0b"},
+ {file = "wrapt-1.17.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc570b5f14a79734437cb7b0500376b6b791153314986074486e0b0fa8d71d98"},
+ {file = "wrapt-1.17.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6d9187b01bebc3875bac9b087948a2bccefe464a7d8f627cf6e48b1bbae30f82"},
+ {file = "wrapt-1.17.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:9e8659775f1adf02eb1e6f109751268e493c73716ca5761f8acb695e52a756ae"},
+ {file = "wrapt-1.17.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e8b2816ebef96d83657b56306152a93909a83f23994f4b30ad4573b00bd11bb9"},
+ {file = "wrapt-1.17.2-cp312-cp312-win32.whl", hash = "sha256:468090021f391fe0056ad3e807e3d9034e0fd01adcd3bdfba977b6fdf4213ea9"},
+ {file = "wrapt-1.17.2-cp312-cp312-win_amd64.whl", hash = "sha256:ec89ed91f2fa8e3f52ae53cd3cf640d6feff92ba90d62236a81e4e563ac0e991"},
+ {file = "wrapt-1.17.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:6ed6ffac43aecfe6d86ec5b74b06a5be33d5bb9243d055141e8cabb12aa08125"},
+ {file = "wrapt-1.17.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:35621ae4c00e056adb0009f8e86e28eb4a41a4bfa8f9bfa9fca7d343fe94f998"},
+ {file = "wrapt-1.17.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a604bf7a053f8362d27eb9fefd2097f82600b856d5abe996d623babd067b1ab5"},
+ {file = "wrapt-1.17.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5cbabee4f083b6b4cd282f5b817a867cf0b1028c54d445b7ec7cfe6505057cf8"},
+ {file = "wrapt-1.17.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:49703ce2ddc220df165bd2962f8e03b84c89fee2d65e1c24a7defff6f988f4d6"},
+ {file = "wrapt-1.17.2-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8112e52c5822fc4253f3901b676c55ddf288614dc7011634e2719718eaa187dc"},
+ {file = "wrapt-1.17.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9fee687dce376205d9a494e9c121e27183b2a3df18037f89d69bd7b35bcf59e2"},
+ {file = "wrapt-1.17.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:18983c537e04d11cf027fbb60a1e8dfd5190e2b60cc27bc0808e653e7b218d1b"},
+ {file = "wrapt-1.17.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:703919b1633412ab54bcf920ab388735832fdcb9f9a00ae49387f0fe67dad504"},
+ {file = "wrapt-1.17.2-cp313-cp313-win32.whl", hash = "sha256:abbb9e76177c35d4e8568e58650aa6926040d6a9f6f03435b7a522bf1c487f9a"},
+ {file = "wrapt-1.17.2-cp313-cp313-win_amd64.whl", hash = "sha256:69606d7bb691b50a4240ce6b22ebb319c1cfb164e5f6569835058196e0f3a845"},
+ {file = "wrapt-1.17.2-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:4a721d3c943dae44f8e243b380cb645a709ba5bd35d3ad27bc2ed947e9c68192"},
+ {file = "wrapt-1.17.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:766d8bbefcb9e00c3ac3b000d9acc51f1b399513f44d77dfe0eb026ad7c9a19b"},
+ {file = "wrapt-1.17.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e496a8ce2c256da1eb98bd15803a79bee00fc351f5dfb9ea82594a3f058309e0"},
+ {file = "wrapt-1.17.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40d615e4fe22f4ad3528448c193b218e077656ca9ccb22ce2cb20db730f8d306"},
+ {file = "wrapt-1.17.2-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a5aaeff38654462bc4b09023918b7f21790efb807f54c000a39d41d69cf552cb"},
+ {file = "wrapt-1.17.2-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a7d15bbd2bc99e92e39f49a04653062ee6085c0e18b3b7512a4f2fe91f2d681"},
+ {file = "wrapt-1.17.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:e3890b508a23299083e065f435a492b5435eba6e304a7114d2f919d400888cc6"},
+ {file = "wrapt-1.17.2-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:8c8b293cd65ad716d13d8dd3624e42e5a19cc2a2f1acc74b30c2c13f15cb61a6"},
+ {file = "wrapt-1.17.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:4c82b8785d98cdd9fed4cac84d765d234ed3251bd6afe34cb7ac523cb93e8b4f"},
+ {file = "wrapt-1.17.2-cp313-cp313t-win32.whl", hash = "sha256:13e6afb7fe71fe7485a4550a8844cc9ffbe263c0f1a1eea569bc7091d4898555"},
+ {file = "wrapt-1.17.2-cp313-cp313t-win_amd64.whl", hash = "sha256:eaf675418ed6b3b31c7a989fd007fa7c3be66ce14e5c3b27336383604c9da85c"},
+ {file = "wrapt-1.17.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5c803c401ea1c1c18de70a06a6f79fcc9c5acfc79133e9869e730ad7f8ad8ef9"},
+ {file = "wrapt-1.17.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f917c1180fdb8623c2b75a99192f4025e412597c50b2ac870f156de8fb101119"},
+ {file = "wrapt-1.17.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ecc840861360ba9d176d413a5489b9a0aff6d6303d7e733e2c4623cfa26904a6"},
+ {file = "wrapt-1.17.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb87745b2e6dc56361bfde481d5a378dc314b252a98d7dd19a651a3fa58f24a9"},
+ {file = "wrapt-1.17.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:58455b79ec2661c3600e65c0a716955adc2410f7383755d537584b0de41b1d8a"},
+ {file = "wrapt-1.17.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4e42a40a5e164cbfdb7b386c966a588b1047558a990981ace551ed7e12ca9c2"},
+ {file = "wrapt-1.17.2-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:91bd7d1773e64019f9288b7a5101f3ae50d3d8e6b1de7edee9c2ccc1d32f0c0a"},
+ {file = "wrapt-1.17.2-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:bb90fb8bda722a1b9d48ac1e6c38f923ea757b3baf8ebd0c82e09c5c1a0e7a04"},
+ {file = "wrapt-1.17.2-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:08e7ce672e35efa54c5024936e559469436f8b8096253404faeb54d2a878416f"},
+ {file = "wrapt-1.17.2-cp38-cp38-win32.whl", hash = "sha256:410a92fefd2e0e10d26210e1dfb4a876ddaf8439ef60d6434f21ef8d87efc5b7"},
+ {file = "wrapt-1.17.2-cp38-cp38-win_amd64.whl", hash = "sha256:95c658736ec15602da0ed73f312d410117723914a5c91a14ee4cdd72f1d790b3"},
+ {file = "wrapt-1.17.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:99039fa9e6306880572915728d7f6c24a86ec57b0a83f6b2491e1d8ab0235b9a"},
+ {file = "wrapt-1.17.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2696993ee1eebd20b8e4ee4356483c4cb696066ddc24bd70bcbb80fa56ff9061"},
+ {file = "wrapt-1.17.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:612dff5db80beef9e649c6d803a8d50c409082f1fedc9dbcdfde2983b2025b82"},
+ {file = "wrapt-1.17.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:62c2caa1585c82b3f7a7ab56afef7b3602021d6da34fbc1cf234ff139fed3cd9"},
+ {file = "wrapt-1.17.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c958bcfd59bacc2d0249dcfe575e71da54f9dcf4a8bdf89c4cb9a68a1170d73f"},
+ {file = "wrapt-1.17.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc78a84e2dfbc27afe4b2bd7c80c8db9bca75cc5b85df52bfe634596a1da846b"},
+ {file = "wrapt-1.17.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:ba0f0eb61ef00ea10e00eb53a9129501f52385c44853dbd6c4ad3f403603083f"},
+ {file = "wrapt-1.17.2-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:1e1fe0e6ab7775fd842bc39e86f6dcfc4507ab0ffe206093e76d61cde37225c8"},
+ {file = "wrapt-1.17.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:c86563182421896d73858e08e1db93afdd2b947a70064b813d515d66549e15f9"},
+ {file = "wrapt-1.17.2-cp39-cp39-win32.whl", hash = "sha256:f393cda562f79828f38a819f4788641ac7c4085f30f1ce1a68672baa686482bb"},
+ {file = "wrapt-1.17.2-cp39-cp39-win_amd64.whl", hash = "sha256:36ccae62f64235cf8ddb682073a60519426fdd4725524ae38874adf72b5f2aeb"},
+ {file = "wrapt-1.17.2-py3-none-any.whl", hash = "sha256:b18f2d1533a71f069c7f82d524a52599053d4c7166e9dd374ae2136b7f40f7c8"},
+ {file = "wrapt-1.17.2.tar.gz", hash = "sha256:41388e9d4d1522446fe79d3213196bd9e3b301a336965b9e27ca2788ebd122f3"},
+]
+
+[[package]]
+name = "zipp"
+version = "3.21.0"
+description = "Backport of pathlib-compatible object wrapper for zip files"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "zipp-3.21.0-py3-none-any.whl", hash = "sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931"},
+ {file = "zipp-3.21.0.tar.gz", hash = "sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4"},
+]
+
+[package.extras]
+check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"]
+cover = ["pytest-cov"]
+doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
+enabler = ["pytest-enabler (>=2.2)"]
+test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"]
+type = ["pytest-mypy"]
+
+[metadata]
+lock-version = "2.1"
+python-versions = ">=3.9,<3.13"
+content-hash = "cb3921d3df77dd5a7c9c7f09fcdf4f4f61b307b04e5bd58c52cfb299ae053da3"
diff --git a/backends/gaudi/server/pyproject.toml b/backends/gaudi/server/pyproject.toml
new file mode 100644
index 000000000..3f2676cbe
--- /dev/null
+++ b/backends/gaudi/server/pyproject.toml
@@ -0,0 +1,45 @@
+[tool.poetry]
+name = "text-generation-server"
+version = "2.0.4"
+description = "Text Generation Inference Python gRPC Server"
+authors = ["Olivier Dehaene "]
+
+[tool.poetry.scripts]
+text-generation-server = 'text_generation_server.cli:app'
+
+[tool.poetry.dependencies]
+python = ">=3.9,<3.13"
+protobuf = "^5.0"
+grpcio = "^1.71.1"
+grpcio-status = "*"
+grpcio-reflection = "*"
+grpc-interceptor = "^0.15.0"
+typer = "^0.15.0"
+loguru = "^0.7.3"
+opentelemetry-api = "^1.32.0"
+opentelemetry-exporter-otlp = "^1.32.0"
+opentelemetry-instrumentation-grpc = "^0.53b0"
+hf-transfer = "^0.1.9"
+sentencepiece = "^0.2.0"
+peft = "^0.15"
+optimum-habana = "1.17"
+transformers = "^4.49"
+numpy = "^1.26"
+accelerate = "^0.33"
+outlines= { version = "^0.0.36", optional = true }
+prometheus-client = "^0.21.1"
+py-cpuinfo = "^9.0.0"
+
+[tool.poetry.group.dev.dependencies]
+grpcio-tools = "*"
+pytest = "^8.3.5"
+
+[tool.pytest.ini_options]
+markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
+
+[build-system]
+requires = ["poetry-core>=1.0.0"]
+build-backend = "poetry.core.masonry.api"
+
+[tool.poetry.requires-plugins]
+poetry-plugin-export = ">=1.8"
diff --git a/backends/gaudi/server/requirements.txt b/backends/gaudi/server/requirements.txt
new file mode 100644
index 000000000..1a5d767f8
--- /dev/null
+++ b/backends/gaudi/server/requirements.txt
@@ -0,0 +1,101 @@
+accelerate==0.33.0 ; python_version >= "3.9" and python_version < "3.13"
+annotated-types==0.7.0 ; python_version >= "3.9" and python_version < "3.13"
+attrs==25.3.0 ; python_version >= "3.9" and python_version < "3.13"
+certifi==2025.1.31 ; python_version >= "3.9" and python_version < "3.13"
+charset-normalizer==3.4.1 ; python_version >= "3.9" and python_version < "3.13"
+click==8.1.8 ; python_version >= "3.9" and python_version < "3.13"
+cloudpickle==3.1.1 ; python_version >= "3.9" and python_version < "3.13"
+colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Windows" or python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
+deprecated==1.2.18 ; python_version >= "3.9" and python_version < "3.13"
+diffusers==0.31.0 ; python_version >= "3.9" and python_version < "3.13"
+diskcache==5.6.3 ; python_version >= "3.9" and python_version < "3.13"
+filelock==3.18.0 ; python_version >= "3.9" and python_version < "3.13"
+fsspec==2025.3.2 ; python_version >= "3.9" and python_version < "3.13"
+googleapis-common-protos==1.70.0 ; python_version >= "3.9" and python_version < "3.13"
+grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
+grpcio-reflection==1.71.0 ; python_version >= "3.9" and python_version < "3.13"
+grpcio-status==1.71.0 ; python_version >= "3.9" and python_version < "3.13"
+grpcio==1.72.0rc1 ; python_version >= "3.9" and python_version < "3.13"
+hf-transfer==0.1.9 ; python_version >= "3.9" and python_version < "3.13"
+huggingface-hub==0.30.2 ; python_version >= "3.9" and python_version < "3.13"
+idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
+importlib-metadata==8.6.1 ; python_version >= "3.9" and python_version < "3.13"
+interegular==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
+jinja2==3.1.6 ; python_version >= "3.9" and python_version < "3.13"
+joblib==1.4.2 ; python_version >= "3.9" and python_version < "3.13"
+jsonschema-specifications==2024.10.1 ; python_version >= "3.9" and python_version < "3.13"
+jsonschema==4.23.0 ; python_version >= "3.9" and python_version < "3.13"
+lark==1.2.2 ; python_version >= "3.9" and python_version < "3.13"
+llvmlite==0.43.0 ; python_version >= "3.9" and python_version < "3.13"
+loguru==0.7.3 ; python_version >= "3.9" and python_version < "3.13"
+markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
+markupsafe==3.0.2 ; python_version >= "3.9" and python_version < "3.13"
+mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13"
+mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
+nest-asyncio==1.6.0 ; python_version >= "3.9" and python_version < "3.13"
+networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
+numba==0.60.0 ; python_version >= "3.9" and python_version < "3.13"
+numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
+nvidia-cublas-cu12==12.4.5.8 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
+nvidia-cuda-cupti-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
+nvidia-cuda-nvrtc-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
+nvidia-cuda-runtime-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
+nvidia-cudnn-cu12==9.1.0.70 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
+nvidia-cufft-cu12==11.2.1.3 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
+nvidia-curand-cu12==10.3.5.147 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
+nvidia-cusolver-cu12==11.6.1.9 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
+nvidia-cusparse-cu12==12.3.1.170 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
+nvidia-cusparselt-cu12==0.6.2 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
+nvidia-nccl-cu12==2.21.5 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
+nvidia-nvjitlink-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
+nvidia-nvtx-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
+opentelemetry-api==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
+opentelemetry-exporter-otlp-proto-common==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
+opentelemetry-exporter-otlp-proto-grpc==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
+opentelemetry-exporter-otlp-proto-http==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
+opentelemetry-exporter-otlp==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
+opentelemetry-instrumentation-grpc==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
+opentelemetry-instrumentation==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
+opentelemetry-proto==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
+opentelemetry-sdk==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
+opentelemetry-semantic-conventions==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
+optimum-habana==1.17.0 ; python_version >= "3.9" and python_version < "3.13"
+optimum==1.24.0 ; python_version >= "3.9" and python_version < "3.13"
+outlines==0.0.36 ; python_version >= "3.9" and python_version < "3.13"
+packaging==24.2 ; python_version >= "3.9" and python_version < "3.13"
+peft==0.15.1 ; python_version >= "3.9" and python_version < "3.13"
+pillow==11.2.1 ; python_version >= "3.9" and python_version < "3.13"
+prometheus-client==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
+protobuf==5.29.4 ; python_version >= "3.9" and python_version < "3.13"
+psutil==7.0.0 ; python_version >= "3.9" and python_version < "3.13"
+py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
+pydantic-core==2.33.1 ; python_version >= "3.9" and python_version < "3.13"
+pydantic==2.11.3 ; python_version >= "3.9" and python_version < "3.13"
+pygments==2.19.1 ; python_version >= "3.9" and python_version < "3.13"
+pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
+referencing==0.36.2 ; python_version >= "3.9" and python_version < "3.13"
+regex==2024.11.6 ; python_version >= "3.9" and python_version < "3.13"
+requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
+rich==14.0.0 ; python_version >= "3.9" and python_version < "3.13"
+rpds-py==0.24.0 ; python_version >= "3.9" and python_version < "3.13"
+safetensors==0.5.3 ; python_version >= "3.9" and python_version < "3.13"
+scikit-learn==1.6.1 ; python_version >= "3.9" and python_version < "3.13"
+scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
+sentence-transformers==3.3.1 ; python_version >= "3.9" and python_version < "3.13"
+sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
+setuptools==78.1.0 ; python_version >= "3.12" and python_version < "3.13"
+shellingham==1.5.4 ; python_version >= "3.9" and python_version < "3.13"
+sympy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
+threadpoolctl==3.6.0 ; python_version >= "3.9" and python_version < "3.13"
+tokenizers==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
+torch==2.6.0 ; python_version >= "3.9" and python_version < "3.13"
+tqdm==4.67.1 ; python_version >= "3.9" and python_version < "3.13"
+transformers==4.49.0 ; python_version >= "3.9" and python_version < "3.13"
+triton==3.2.0 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
+typer==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
+typing-extensions==4.13.2 ; python_version >= "3.9" and python_version < "3.13"
+typing-inspection==0.4.0 ; python_version >= "3.9" and python_version < "3.13"
+urllib3==2.4.0 ; python_version >= "3.9" and python_version < "3.13"
+win32-setctime==1.2.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
+wrapt==1.17.2 ; python_version >= "3.9" and python_version < "3.13"
+zipp==3.21.0 ; python_version >= "3.9" and python_version < "3.13"
diff --git a/.devcontainer/Dockerfile.trtllm b/backends/gaudi/server/text_generation_server/__init__.py
similarity index 100%
rename from .devcontainer/Dockerfile.trtllm
rename to backends/gaudi/server/text_generation_server/__init__.py
diff --git a/backends/gaudi/server/text_generation_server/adapters/__init__.py b/backends/gaudi/server/text_generation_server/adapters/__init__.py
new file mode 100644
index 000000000..8697cb9ee
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/adapters/__init__.py
@@ -0,0 +1,13 @@
+# Origin: https://github.com/predibase/lorax
+# Path: lorax/server/lorax_server/adapters/__init__.py
+# License: Apache License Version 2.0, January 2004
+
+from text_generation_server.adapters.weights import (
+ AdapterBatchData,
+ AdapterBatchMetadata,
+)
+
+__all__ = [
+ "AdapterBatchData",
+ "AdapterBatchMetadata",
+]
diff --git a/backends/gaudi/server/text_generation_server/adapters/config.py b/backends/gaudi/server/text_generation_server/adapters/config.py
new file mode 100644
index 000000000..b7e270900
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/adapters/config.py
@@ -0,0 +1,30 @@
+# Origin: https://github.com/predibase/lorax
+# Path: lorax/server/lorax_server/adapters/config.py
+# License: Apache License Version 2.0, January 2004
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import Dict, Set, Tuple
+
+import torch
+
+from text_generation_server.adapters.weights import AdapterWeights
+
+
+@dataclass
+class ModuleMap:
+ module_name: str
+ module_weights: Dict[str, Tuple[torch.Tensor, str]]
+
+
+@dataclass
+class AdapterConfig(ABC):
+ base_model_name_or_path: str
+
+ @abstractmethod
+ def map_weights_for_model(
+ self,
+ adapter_weights: Dict[int, AdapterWeights],
+ weight_names: Tuple[str],
+ ) -> Tuple[ModuleMap, Set[str]]:
+ pass
diff --git a/backends/gaudi/server/text_generation_server/adapters/lora.py b/backends/gaudi/server/text_generation_server/adapters/lora.py
new file mode 100644
index 000000000..a00338e7c
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/adapters/lora.py
@@ -0,0 +1,471 @@
+# Origin: https://github.com/predibase/lorax
+# Path: lorax/server/lorax_server/adapters/lora.py
+# License: Apache License Version 2.0, January 2004
+
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Set, Tuple, Type, Union
+
+import torch
+from peft import LoraConfig as _LoraConfig
+from torch.distributed import ProcessGroup
+
+from text_generation_server.adapters.config import AdapterConfig, ModuleMap
+
+from text_generation_server.adapters.weights import (
+ AdapterBatchMetadata,
+ AdapterWeights,
+ BatchAdapterWeights,
+)
+from text_generation_server.utils.sgmv import (
+ BGMV_MAX_RANK,
+ MAX_RANK_CUSTOM,
+ get_tmp_tensors,
+ orient_for_rank,
+ pad_rank,
+ use_cutlass_shrink,
+)
+
+
+def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
+ block_size = size // world_size
+ start = offset + rank * block_size
+ stop = offset + (rank + 1) * block_size
+ return start, stop
+
+
+def shard_on_dim(
+ t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
+):
+ world_size = process_group.size()
+ rank = process_group.rank()
+
+ size = t.shape[dim]
+ start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
+
+ if dim == 0:
+ tensor = t[start:stop]
+ elif dim == 1:
+ tensor = t[:, start:stop]
+ else:
+ raise NotImplementedError("Let's make that generic when needed")
+
+ return tensor
+
+
+def shard_lora_weights(
+ weights_a: List[torch.Tensor],
+ weights_b: List[torch.Tensor],
+ split_dim: int,
+ process_group: ProcessGroup,
+) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
+ # [hidden_size, r]
+ weights_a = [
+ shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a
+ ]
+
+ # [r, hidden_size]
+ weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b]
+
+ return weights_a, weights_b
+
+
+@dataclass
+class LoraConfig(AdapterConfig):
+ r: int
+ target_modules: Optional[Union[List[str], str]]
+ fan_in_fan_out: bool
+ lora_alpha: int
+ use_rslora: bool
+
+ def map_weights_for_model(
+ self,
+ adapter_weights: Dict[int, AdapterWeights],
+ weight_names: Tuple[str],
+ ) -> Tuple[ModuleMap, Set[str]]:
+ adapter_weight_names = set()
+ module_map = {}
+ for weight_name in weight_names:
+ lora_a_name = f"base_model.model.{weight_name}.lora_A.weight"
+ lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
+ if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights:
+ continue
+
+ module_map[weight_name] = {
+ "lora_A": (adapter_weights[lora_a_name], lora_a_name),
+ "lora_B": (adapter_weights[lora_b_name], lora_b_name),
+ }
+ adapter_weight_names.add(lora_a_name)
+ adapter_weight_names.add(lora_b_name)
+ return module_map, adapter_weight_names
+
+ @classmethod
+ def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
+ hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
+ return cls(
+ base_model_name_or_path=hf_config.base_model_name_or_path,
+ r=hf_config.r,
+ target_modules=hf_config.target_modules,
+ fan_in_fan_out=hf_config.fan_in_fan_out,
+ lora_alpha=hf_config.lora_alpha,
+ use_rslora=(
+ hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False
+ ),
+ )
+
+
+class LoraWeights(AdapterWeights):
+ """LoRA weights for a single adapter merged across all layers."""
+
+ def __init__(
+ self,
+ weights_a: List[torch.Tensor],
+ weights_b: List[torch.Tensor],
+ adapter_config: LoraConfig,
+ ):
+ self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
+ self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
+
+ self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
+ self._is_transposed = False
+
+ # [num_layers, hidden_size, r]
+ weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
+ self._weights_a = torch.stack(weights_a)
+
+ # [num_layers, r, hidden_size]
+ self._weights_b = torch.stack(weights_b)
+
+ self.adapter_config = adapter_config
+
+ @property
+ def weights_a(self) -> torch.Tensor:
+ if self._is_transposed:
+ self._transpose_weights()
+ return self._weights_a
+
+ @property
+ def weights_b(self) -> torch.Tensor:
+ if self._is_transposed:
+ self._transpose_weights()
+ return self._weights_b
+
+ @property
+ def weights_a_t(self) -> torch.Tensor:
+ if not self._is_transposed:
+ self._transpose_weights()
+ return self._weights_a
+
+ @property
+ def weights_b_t(self) -> torch.Tensor:
+ if not self._is_transposed:
+ self._transpose_weights()
+ return self._weights_b
+
+ def _transpose_weights(self):
+ if self._use_cutlass_shrink:
+ # If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation
+ self._weights_a = self._weights_a.transpose(1, 2).contiguous()
+ self._weights_b = self._weights_b.transpose(1, 2).contiguous()
+ self._is_transposed = not self._is_transposed
+
+ @classmethod
+ def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
+ return [BatchLoraWeights]
+
+ # prepare pre-loaded lora weights for use in the model.
+ #
+ # this method processes and organizes lora weights for a specific layer type across all layers:
+ # - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor.
+ # - retrieves weights from `module_map` based on the `layer_type`.
+ # - processes `nlayers` number of layers.
+ # - converts weights to the specified `dtype`.
+ # - shards weights across `world_size` number of processes using the `process_group`.
+ # - maps weights to specific layers using `target_to_layer`.
+ # - tracks `unused_weight_names` to identify any unused weights.
+ #
+ # the method handles weight transposition, scaling, and padding to ensure compatibility
+ # with SGMV or BGMV operations.
+ @classmethod
+ def prepare_weights(
+ cls,
+ config: LoraConfig,
+ module_map: Dict[str, Dict],
+ layer_type: str,
+ unused_weight_names: Set[str],
+ nlayers: int,
+ dtype: torch.dtype,
+ world_size: int,
+ process_group: ProcessGroup,
+ target_to_layer: Dict[str, Tuple[str, torch.Tensor]],
+ ) -> Optional[AdapterWeights]:
+ lora_a_list = [None] * nlayers
+ lora_b_list = [None] * nlayers
+
+ for layer_id in range(nlayers):
+ key = (layer_id, layer_type)
+ weight_name, layer = target_to_layer[key]
+ base_weight = layer.base_layer.linear.weight
+ base_device = base_weight.device
+
+ if weight_name not in module_map:
+ # There is no LoRA weight for this layer type in the adapter
+ return None
+
+ lora_a, lora_a_name = module_map[weight_name]["lora_A"]
+ lora_a = lora_a.to(base_device, dtype)
+
+ lora_b, lora_b_name = module_map[weight_name]["lora_B"]
+ lora_b = lora_b.to(base_device, dtype)
+
+ scale = get_scaling_factor(
+ config.lora_alpha,
+ config.r,
+ uses_rslora=config.use_rslora,
+ )
+
+ unused_weight_names.discard(lora_a_name)
+ unused_weight_names.discard(lora_b_name)
+
+ # Merge scaling factor into lora_b due to associativity of matrix multiplication:
+ # (A * B) * C = A * (B * C)
+ lora_a_list[layer_id] = lora_a.transpose(0, 1)
+ lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
+
+ # pad lora ranks to be compatible with sgmv
+ lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list]
+ lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list]
+
+ if lora_a_list:
+ # update rank if it was padded
+ padded_rank = lora_a_list[0].size(1)
+ config.r = padded_rank
+
+ return LoraWeights(
+ *shard_lora_weights(
+ weights_a=lora_a_list,
+ weights_b=lora_b_list,
+ split_dim=0 if layer_type in {"o_proj", "down_proj", "lm_head"} else 1,
+ process_group=process_group,
+ ),
+ config,
+ )
+
+
+@dataclass
+class RankSegments:
+ rank: int
+
+ lora_a_ptr: torch.Tensor
+ lora_b_ptr: torch.Tensor
+
+ # prefill (sgmv)
+ tmp_shrink: torch.Tensor
+ tmp_expand: torch.Tensor
+ segment_starts: torch.Tensor
+ segment_ends: torch.Tensor
+
+ # decode (bgmv)
+ indices: torch.Tensor
+
+
+@dataclass
+class BatchLoraWeights(BatchAdapterWeights):
+ lora_a: Dict[int, torch.Tensor]
+ lora_b: Dict[int, torch.Tensor]
+ adapter_index_configs: Dict[int, LoraConfig]
+ rank_data: Dict[int, RankSegments]
+ use_sgmv: bool
+
+ def has_adapter(self, adapter_index: int) -> bool:
+ return adapter_index in self.adapter_index_configs
+
+ def can_vectorize(self, pg: ProcessGroup) -> bool:
+ return all(
+ rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
+ for rank_data in self.rank_data.values()
+ )
+
+ @classmethod
+ def load(
+ self,
+ adapter_weights: Dict[int, AdapterWeights],
+ meta: AdapterBatchMetadata,
+ prefill: bool,
+ prefill_head_indices: Optional[torch.Tensor],
+ ) -> Optional["BatchLoraWeights"]:
+ adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
+ adapter_weights = {
+ k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
+ }
+ if not adapter_weights:
+ return None
+
+ first_weights = next(iter(adapter_weights.values()))
+ device = first_weights.weights_a.device
+ segment_indices = meta.segment_indices
+
+ lora_a = {
+ idx: adapter_weights[idx].weights_a
+ for idx in segment_indices
+ if idx in adapter_weights
+ }
+ lora_b = {
+ idx: adapter_weights[idx].weights_b
+ for idx in segment_indices
+ if idx in adapter_weights
+ }
+
+ max_rank = max(
+ (
+ adapter_weights[idx].lora_a_r
+ for idx in segment_indices
+ if idx in adapter_weights
+ ),
+ default=0,
+ )
+
+ if prefill or max_rank > BGMV_MAX_RANK:
+ use_sgmv = True
+ lora_a_ptr = torch.tensor(
+ [
+ (
+ adapter_weights[idx].weights_a.data_ptr()
+ if idx in adapter_weights
+ else 0
+ )
+ for idx in segment_indices
+ ],
+ dtype=torch.int64,
+ device=device,
+ )
+ lora_b_ptr = torch.tensor(
+ [
+ (
+ adapter_weights[idx].weights_b.data_ptr()
+ if idx in adapter_weights
+ else 0
+ )
+ for idx in segment_indices
+ ],
+ dtype=torch.int64,
+ device=device,
+ )
+ else:
+ use_sgmv = False
+ lora_a_ptr = torch.tensor(
+ [
+ (
+ adapter_weights[idx].weights_a_t.data_ptr()
+ if idx in adapter_weights
+ else 0
+ )
+ for idx in segment_indices
+ ],
+ dtype=torch.int64,
+ device=device,
+ )
+ lora_b_ptr = torch.tensor(
+ [
+ (
+ adapter_weights[idx].weights_b_t.data_ptr()
+ if idx in adapter_weights
+ else 0
+ )
+ for idx in segment_indices
+ ],
+ dtype=torch.int64,
+ device=device,
+ )
+
+ adapter_index_configs = {
+ idx: adapter_weights[idx].adapter_config
+ for idx in segment_indices
+ if idx in adapter_weights
+ }
+
+ adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}
+
+ rank_indices = defaultdict(list)
+ for segment_idx, adapter_idx in enumerate(segment_indices):
+ if adapter_idx not in adapter_weights:
+ continue
+ rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
+
+ if prefill_head_indices is not None:
+ j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
+ for head_index in prefill_head_indices:
+ # j cannot go out of bounds as that would mean there are tokens without corresponding adapters
+ if head_index < meta.adapter_segments[j]:
+ prefill_head_segment_ends[-1] += 1
+ else:
+ prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
+ prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
+ j += 1
+
+ rank_data = {}
+ for rank, indices in rank_indices.items():
+ tmp_shrink = None
+ tmp_expand = None
+ segment_starts = None
+ segment_ends = None
+ batch_indices = None
+
+ if use_sgmv:
+ lora_a_ptr_indices = lora_a_ptr[indices]
+ tmp_shrink, tmp_expand = get_tmp_tensors(
+ lora_a_ptr_indices.size(0), rank, device
+ )
+ segment_starts = meta.adapter_segments[indices]
+ segment_ends = meta.adapter_segments[[i + 1 for i in indices]]
+ if prefill_head_indices is not None:
+ for i, segment_index in enumerate(indices):
+ segment_starts[i] = prefill_head_segment_starts[segment_index]
+ segment_ends[i] = prefill_head_segment_ends[segment_index]
+ else:
+ rank_indices = set(indices)
+ batch_indices = [
+ adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
+ ]
+ batch_indices = [
+ idx if idx in rank_indices else -1 for idx in batch_indices
+ ]
+ batch_indices = torch.tensor(
+ batch_indices, dtype=torch.int64, device=device
+ )
+
+ rank_data[rank] = RankSegments(
+ rank=rank,
+ tmp_shrink=tmp_shrink,
+ tmp_expand=tmp_expand,
+ lora_a_ptr=lora_a_ptr[indices],
+ lora_b_ptr=lora_b_ptr[indices],
+ segment_starts=segment_starts,
+ segment_ends=segment_ends,
+ indices=batch_indices,
+ )
+
+ return BatchLoraWeights(
+ lora_a=lora_a,
+ lora_b=lora_b,
+ adapter_index_configs=adapter_index_configs,
+ rank_data=rank_data,
+ use_sgmv=use_sgmv,
+ )
+
+
+def get_scaling_factor(
+ lora_alpha: int,
+ r: int,
+ uses_rslora: bool = False,
+) -> float:
+ """Computes the scaling factor for the lora weights."""
+ if uses_rslora:
+ return lora_alpha / (r**0.5)
+ return lora_alpha / r
+
+
+def _convert_lora(v: AdapterWeights) -> AdapterWeights:
+ if hasattr(v, "lora_weights"):
+ return v.lora_weights
+ return v
diff --git a/backends/gaudi/server/text_generation_server/adapters/weights.py b/backends/gaudi/server/text_generation_server/adapters/weights.py
new file mode 100644
index 000000000..da75dbcdf
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/adapters/weights.py
@@ -0,0 +1,146 @@
+# Origin: https://github.com/predibase/lorax
+# Path: lorax/server/lorax_server/adapters/weights.py
+# License: Apache License Version 2.0, January 2004
+
+from abc import ABC, abstractclassmethod
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Set, Type
+
+import torch
+
+
+@dataclass
+class AdapterBatchMetadata:
+ # [batch_size]
+ adapter_indices: torch.Tensor
+
+ # [num_adapters]
+ adapter_set: Set[int]
+
+ # [num_segments + 1]
+ adapter_segments: torch.Tensor
+
+ # [num_segments]
+ # maps from segment index to adapter index, i.e.:
+ # segment_indices[s] == adapter_indices[i]
+ segment_indices: List[int]
+
+
+class AdapterWeights(ABC):
+ @abstractclassmethod
+ def get_batch_types(cls) -> List[Type["BatchAdapterWeights"]]:
+ pass
+
+ @property
+ def speculative_tokens(self) -> int:
+ return 0
+
+
+class BatchAdapterWeights(ABC):
+ @abstractclassmethod
+ def has_adapter(self, adapter_index: int) -> bool:
+ pass
+
+ @abstractclassmethod
+ def load(
+ cls,
+ adapter_weights: Dict[int, AdapterWeights],
+ meta: "AdapterBatchMetadata",
+ prefill: bool,
+ prefill_head_indices: torch.Tensor,
+ ) -> Optional["BatchAdapterWeights"]:
+ pass
+
+
+class LayerAdapterWeights:
+ """Adapter weights that apply to a particular layer."""
+
+ def __init__(self):
+ self.adapter_weights: Dict[int, AdapterWeights] = {}
+
+ def add_adapter(self, adapter_idx: int, weights: AdapterWeights):
+ self.adapter_weights[adapter_idx] = weights
+
+ def remove_adapter(self, adapter_idx: int):
+ if adapter_idx not in self.adapter_weights:
+ return
+ del self.adapter_weights[adapter_idx]
+
+ def is_empty(self) -> bool:
+ return len(self.adapter_weights) == 0
+
+ def get_data(
+ self,
+ meta: AdapterBatchMetadata,
+ prefill: bool,
+ prefill_head_indices: Optional[torch.Tensor],
+ ) -> Dict[str, BatchAdapterWeights]:
+ # bucket adapters by batch class
+ adapter_batch_types: Dict[
+ Type[BatchAdapterWeights], Dict[int, AdapterWeights]
+ ] = defaultdict(dict)
+ for adapter_index, adapter_weights in self.adapter_weights.items():
+ for batch_type in adapter_weights.get_batch_types():
+ adapter_batch_types[batch_type][adapter_index] = adapter_weights
+
+ batch_data = {}
+ for batch_type, adapter_weights in adapter_batch_types.items():
+ batched_weights = batch_type.load(
+ adapter_weights, meta, prefill, prefill_head_indices
+ )
+ if batched_weights is not None:
+ batch_data = batched_weights
+ return batch_data
+
+
+@dataclass
+class AdapterBatchData:
+ meta: AdapterBatchMetadata
+
+ # layer type -> adapter type -> batch weight data
+ data: Dict[str, Dict[str, BatchAdapterWeights]]
+
+ prefill: bool
+
+ @staticmethod
+ def from_meta(
+ meta: AdapterBatchMetadata,
+ weights: Dict[str, LayerAdapterWeights],
+ prefill: bool,
+ prefill_head_indices: Optional[torch.Tensor],
+ ) -> "AdapterBatchData":
+ data = {}
+ for k, v in weights.items():
+ if v.is_empty():
+ continue
+ data[k] = v.get_data(
+ meta, prefill, prefill_head_indices if k == "lm_head" else None
+ )
+ return AdapterBatchData(meta=meta, data=data, prefill=prefill)
+
+ def ranks(self) -> Set[int]:
+ # TODO(travis): refactor to be less coupled to lora implementation
+ ranks = set()
+ for lora_data in self.data.values():
+ if lora_data is None:
+ continue
+
+ for rank_data in lora_data.rank_data.values():
+ ranks.add(rank_data.rank)
+
+ return ranks
+
+ def layer_names(self) -> Set[str]:
+ return set(self.data.keys())
+
+ def adapter_keys(self) -> Set[str]:
+ adapter_keys = set()
+ for layer_data in self.data.values():
+ adapter_keys.update(layer_data.keys())
+ return adapter_keys
+
+ @property
+ def max_rank(self) -> int:
+ ranks = self.ranks()
+ return max(ranks) if len(ranks) > 0 else 0
diff --git a/backends/gaudi/server/text_generation_server/cache.py b/backends/gaudi/server/text_generation_server/cache.py
new file mode 100644
index 000000000..4504733e5
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/cache.py
@@ -0,0 +1,34 @@
+import torch
+
+from typing import Dict, Optional, TypeVar
+
+from text_generation_server.models.types import Batch
+
+B = TypeVar("B", bound=Batch)
+
+
+class Cache:
+ def __init__(self):
+ self.cache: Dict[int, B] = {}
+
+ def pop(self, batch_id: int) -> Optional[B]:
+ return self.cache.pop(batch_id, None)
+
+ def set(self, entry: B):
+ if entry is not None:
+ self.cache[entry.batch_id] = entry
+
+ def delete(self, batch_id: int):
+ batch = self.pop(batch_id)
+ if batch is not None:
+ del batch
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ def clear(self):
+ keys = list(self.cache.keys())
+ for k in keys:
+ self.delete(k)
+
+ def __len__(self):
+ return len(self.cache.keys())
diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py
new file mode 100644
index 000000000..53837ef71
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/cli.py
@@ -0,0 +1,426 @@
+import os
+import psutil
+import signal
+import sys
+import typer
+
+from pathlib import Path
+from loguru import logger
+from typing import Optional
+from enum import Enum
+from huggingface_hub import hf_hub_download
+from text_generation_server.utils.adapter import parse_lora_adapters
+
+
+app = typer.Typer()
+
+
+class Quantization(str, Enum):
+ gptq = "gptq"
+ awq = "awq"
+ fp8 = "fp8"
+
+
+class Dtype(str, Enum):
+ float16 = "float16"
+ bloat16 = "bfloat16"
+
+
+@app.command()
+def serve(
+ model_id: str,
+ revision: Optional[str] = None,
+ sharded: bool = False,
+ quantize: Optional[Quantization] = None,
+ speculate: Optional[int] = None,
+ dtype: Optional[Dtype] = None,
+ trust_remote_code: bool = False,
+ uds_path: Path = "/tmp/text-generation-server",
+ logger_level: str = "INFO",
+ json_output: bool = False,
+ otlp_endpoint: Optional[str] = None,
+ otlp_service_name: str = "text-generation-inference.server",
+ max_input_tokens: Optional[int] = None,
+):
+ if sharded:
+ # assert (
+ # os.getenv("RANK", None) is not None
+ # ), "RANK must be set when sharded is True"
+ assert (
+ os.getenv("WORLD_SIZE", None) is not None
+ ), "WORLD_SIZE must be set when sharded is True"
+ assert (
+ os.getenv("MASTER_ADDR", None) is not None
+ ), "MASTER_ADDR must be set when sharded is True"
+ assert (
+ os.getenv("MASTER_PORT", None) is not None
+ ), "MASTER_PORT must be set when sharded is True"
+
+ # Remove default handler
+ logger.remove()
+ logger.add(
+ sys.stdout,
+ format="{message}",
+ filter="text_generation_server",
+ level=logger_level,
+ serialize=json_output,
+ backtrace=True,
+ diagnose=False,
+ )
+
+ # Import here after the logger is added to log potential import exceptions
+ from text_generation_server import server
+ from text_generation_server.tracing import setup_tracing
+
+ # Setup OpenTelemetry distributed tracing
+ if otlp_endpoint is not None:
+ setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
+
+ lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS"))
+
+ # TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
+ # and warn the user
+ if lora_adapters:
+ logger.warning("LoRA adapters enabled (experimental feature).")
+
+ if "CUDA_GRAPHS" in os.environ:
+ logger.warning(
+ "LoRA adapters incompatible with CUDA Graphs. Disabling CUDA Graphs."
+ )
+ global CUDA_GRAPHS
+ CUDA_GRAPHS = None
+
+ # Downgrade enum into str for easier management later on
+ quantize = None if quantize is None else quantize.value
+ dtype = "bfloat16" if dtype is None else dtype.value
+ logger.info(f"quantize={quantize}")
+ if dtype is not None and quantize not in {
+ None,
+ "bitsandbytes",
+ "bitsandbytes-nf4",
+ "bitsandbytes-fp4",
+ "gptq",
+ "awq",
+ "fp8",
+ }:
+ raise RuntimeError(
+ "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
+ )
+
+ logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype))
+
+ if sharded and os.getenv("ATTENTION", "default") not in {"paged"}:
+ tgi_file = Path(__file__).resolve().parent / "tgi_service.py"
+ num_shard = int(os.getenv("WORLD_SIZE", "1"))
+ logger.info("CLI SHARDED = {}".format(num_shard))
+ import subprocess
+
+ cmd = (
+ f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}"
+ )
+ cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}"
+ cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}"
+ cmd += f" --quantize {quantize} --max_input_tokens {max_input_tokens}"
+ if speculate is not None:
+ cmd += f"--speculate {speculate}"
+ logger.info("CLI server start deepspeed ={} ".format(cmd))
+ sys.stdout.flush()
+ sys.stderr.flush()
+ with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc:
+ do_terminate = False
+ current_handler = signal.getsignal(signal.SIGTERM)
+
+ def terminate_handler(sig, frame):
+ nonlocal do_terminate
+ do_terminate = True
+ if callable(current_handler):
+ current_handler(sig, frame)
+
+ signal.signal(signal.SIGTERM, terminate_handler)
+
+ finished = False
+ while not finished:
+ try:
+ if do_terminate:
+ parent = psutil.Process(proc.pid)
+ all_procs = parent.children(recursive=True) + [parent]
+ for p in all_procs:
+ try:
+ p.terminate()
+ except psutil.NoSuchProcess:
+ pass
+ _, alive = psutil.wait_procs(all_procs, timeout=30)
+ for p in alive:
+ p.kill()
+
+ do_terminate = False
+
+ proc.wait(timeout=3)
+ except subprocess.TimeoutExpired:
+ pass
+ else:
+ finished = True
+
+ sys.stdout.flush()
+ sys.stderr.flush()
+ if proc.returncode != 0:
+ logger.error(f"{cmd} exited with status = {proc.returncode}")
+ return proc.returncode
+ else:
+ server.serve(
+ model_id,
+ lora_adapters,
+ revision,
+ sharded,
+ quantize,
+ speculate,
+ dtype,
+ trust_remote_code,
+ uds_path,
+ max_input_tokens,
+ )
+
+
+@app.command()
+def download_weights(
+ model_id: str,
+ revision: Optional[str] = None,
+ extension: str = ".safetensors",
+ auto_convert: bool = True,
+ logger_level: str = "INFO",
+ json_output: bool = False,
+ trust_remote_code: bool = False,
+ merge_lora: bool = False,
+):
+ # Remove default handler
+ logger.remove()
+ logger.add(
+ sys.stdout,
+ format="{message}",
+ filter="text_generation_server",
+ level=logger_level,
+ serialize=json_output,
+ backtrace=True,
+ diagnose=False,
+ )
+
+ # Import here after the logger is added to log potential import exceptions
+ from text_generation_server import utils
+
+ # Test if files were already download
+ try:
+ utils.weight_files(model_id, revision, extension)
+ logger.info("Files are already present on the host. " "Skipping download.")
+ return
+ # Local files not found
+ except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
+ pass
+
+ is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
+ "WEIGHTS_CACHE_OVERRIDE", None
+ ) is not None
+
+ if not is_local_model:
+ # TODO: maybe reverse the default value of merge_lora?
+ # currently by default we don't merge the weights with the base model
+ if merge_lora:
+ try:
+ hf_hub_download(
+ model_id, revision=revision, filename="adapter_config.json"
+ )
+ utils.download_and_unload_peft(
+ model_id, revision, trust_remote_code=trust_remote_code
+ )
+ is_local_model = True
+ utils.weight_files(model_id, revision, extension)
+ return
+ except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
+ pass
+ else:
+ try:
+ utils.peft.download_peft(
+ model_id, revision, trust_remote_code=trust_remote_code
+ )
+ except Exception:
+ pass
+
+ try:
+ import json
+
+ config = hf_hub_download(
+ model_id, revision=revision, filename="config.json"
+ )
+ with open(config, "r") as f:
+ config = json.load(f)
+
+ base_model_id = config.get("base_model_name_or_path", None)
+ if base_model_id and base_model_id != model_id:
+ try:
+ logger.info(f"Downloading parent model {base_model_id}")
+ download_weights(
+ model_id=base_model_id,
+ revision="main",
+ extension=extension,
+ auto_convert=auto_convert,
+ logger_level=logger_level,
+ json_output=json_output,
+ trust_remote_code=trust_remote_code,
+ )
+ except Exception:
+ pass
+ except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
+ pass
+
+ # Try to download weights from the hub
+ try:
+ filenames = utils.weight_hub_files(model_id, revision, extension)
+ utils.download_weights(filenames, model_id, revision)
+ # Successfully downloaded weights
+ return
+
+ # No weights found on the hub with this extension
+ except utils.EntryNotFoundError as e:
+ # Check if we want to automatically convert to safetensors or if we can use .bin weights instead
+ if not extension == ".safetensors" or not auto_convert:
+ raise e
+
+ elif (Path(model_id) / "adapter_config.json").exists():
+ # Try to load as a local PEFT model
+ try:
+ utils.download_and_unload_peft(
+ model_id, revision, trust_remote_code=trust_remote_code
+ )
+ utils.weight_files(model_id, revision, extension)
+ return
+ except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
+ pass
+ elif (Path(model_id) / "config.json").exists():
+ # Try to load as a local Medusa model
+ try:
+ import json
+
+ config = Path(model_id) / "config.json"
+ with open(config, "r") as f:
+ config = json.load(f)
+
+ base_model_id = config.get("base_model_name_or_path", None)
+ if base_model_id:
+ try:
+ logger.info(f"Downloading parent model {base_model_id}")
+ download_weights(
+ model_id=base_model_id,
+ revision="main",
+ extension=extension,
+ auto_convert=auto_convert,
+ logger_level=logger_level,
+ json_output=json_output,
+ trust_remote_code=trust_remote_code,
+ )
+ except Exception:
+ pass
+ except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
+ pass
+
+ # Try to see if there are local pytorch weights
+ try:
+ # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
+ try:
+ local_pt_files = utils.weight_files(model_id, revision, ".bin")
+ except Exception:
+ local_pt_files = utils.weight_files(model_id, revision, ".pt")
+
+ # No local pytorch weights
+ except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
+ if extension == ".safetensors":
+ logger.warning(
+ f"No safetensors weights found for model {model_id} at revision {revision}. "
+ f"Downloading PyTorch weights."
+ )
+
+ # Try to see if there are pytorch weights on the hub
+ pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
+ # Download pytorch weights
+ local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
+
+ if auto_convert:
+ if not trust_remote_code:
+ logger.warning(
+ "🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because "
+ "Pickle files are unsafe and can essentially contain remote code execution!"
+ "Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety",
+ )
+
+ logger.warning(
+ f"No safetensors weights found for model {model_id} at revision {revision}. "
+ f"Converting PyTorch weights to safetensors."
+ )
+
+ # Safetensors final filenames
+ local_st_files = [
+ p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
+ for p in local_pt_files
+ ]
+ try:
+ import transformers
+ import json
+
+ if is_local_model:
+ config_filename = os.path.join(model_id, "config.json")
+ else:
+ config_filename = hf_hub_download(
+ model_id, revision=revision, filename="config.json"
+ )
+ with open(config_filename, "r") as f:
+ config = json.load(f)
+ architecture = config["architectures"][0]
+
+ class_ = getattr(transformers, architecture)
+
+ # Name for this varible depends on transformers version.
+ discard_names = getattr(class_, "_tied_weights_keys", [])
+
+ except Exception:
+ discard_names = []
+ # Convert pytorch weights to safetensors
+ utils.convert_files(local_pt_files, local_st_files, discard_names)
+
+
+@app.command()
+def quantize(
+ model_id: str,
+ output_dir: str,
+ revision: Optional[str] = None,
+ logger_level: str = "INFO",
+ json_output: bool = False,
+ trust_remote_code: bool = False,
+ upload_to_model_id: Optional[str] = None,
+ percdamp: float = 0.01,
+ act_order: bool = False,
+ groupsize: int = 128,
+):
+ if revision is None:
+ revision = "main"
+ download_weights(
+ model_id=model_id,
+ revision=revision,
+ logger_level=logger_level,
+ json_output=json_output,
+ )
+ from text_generation_server.layers.gptq.quantize import quantize
+
+ quantize(
+ model_id=model_id,
+ bits=4,
+ groupsize=groupsize,
+ output_dir=output_dir,
+ revision=revision,
+ trust_remote_code=trust_remote_code,
+ upload_to_model_id=upload_to_model_id,
+ percdamp=percdamp,
+ act_order=act_order,
+ sym=True,
+ )
+
+
+if __name__ == "__main__":
+ app()
diff --git a/backends/gaudi/server/text_generation_server/habana_quantization_env.py b/backends/gaudi/server/text_generation_server/habana_quantization_env.py
new file mode 100644
index 000000000..b03b7e266
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/habana_quantization_env.py
@@ -0,0 +1,53 @@
+# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
+
+import os
+import habana_frameworks.torch as htorch
+
+quant_config = os.getenv("QUANT_CONFIG", "")
+is_quantization_enabled = quant_config != ""
+
+if is_quantization_enabled:
+ os.environ.setdefault("ENABLE_EXPERIMENTAL_FLAGS", "true")
+ os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true")
+ os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false")
+ os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false")
+ os.environ.setdefault("UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av")
+ os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")
+
+
+def patch_scoped_linear_all_reduce(model):
+ from deepspeed.module_inject.layers import LinearAllreduce
+ from optimum.habana.transformers.models.modeling_all_models import (
+ ScopedLinearAllReduce,
+ )
+
+ for name, module in model.named_children():
+ if type(module) is LinearAllreduce:
+ SL = ScopedLinearAllReduce(mod=module)
+ setattr(model, name, SL)
+ patch_scoped_linear_all_reduce(module)
+
+
+def setup_quantization(model):
+ if is_quantization_enabled:
+ htorch.core.quantization._mark_params_as_const(model)
+ htorch.core.quantization._check_params_as_const(model)
+ htorch.core.hpu_initialize(model)
+ return model
+
+
+def prepare_model_for_quantization(model):
+ if is_quantization_enabled:
+ if model.config.model_type in [
+ "llama",
+ "falcon",
+ "qwen2",
+ "starcoder2",
+ "gemma",
+ ]:
+ patch_scoped_linear_all_reduce(model)
+ from neural_compressor.torch.quantization import FP8Config, convert
+
+ config = FP8Config.from_json_file(quant_config)
+ model = convert(model, config)
+ return model
diff --git a/backends/gaudi/server/text_generation_server/interceptor.py b/backends/gaudi/server/text_generation_server/interceptor.py
new file mode 100644
index 000000000..47f33cd0b
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/interceptor.py
@@ -0,0 +1,45 @@
+# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
+
+import torch
+import grpc
+
+from google.rpc import status_pb2, code_pb2
+from grpc_status import rpc_status
+from grpc_interceptor.server import AsyncServerInterceptor
+from loguru import logger
+from typing import Callable, Any
+import traceback
+import os
+
+
+class ExceptionInterceptor(AsyncServerInterceptor):
+ async def intercept(
+ self,
+ method: Callable,
+ request_or_iterator: Any,
+ context: grpc.ServicerContext,
+ method_name: str,
+ ) -> Any:
+ try:
+ response = method(request_or_iterator, context)
+ return await response
+ except Exception as err:
+ trace = " " + traceback.format_exc() if os.environ.get("DUMP_STACK") else ""
+ method_name = method_name.split("/")[-1]
+ logger.exception(f"Method {method_name} encountered an error.")
+
+ # Runtime Error cannot be recovered from
+ if isinstance(err, RuntimeError):
+ exit(1)
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ from .utils.debug import dbg_trace
+
+ dbg_trace("EXCEPTION", traceback.format_exc())
+ await context.abort_with_status(
+ rpc_status.to_status(
+ status_pb2.Status(code=code_pb2.INTERNAL, message=str(err) + trace)
+ )
+ )
diff --git a/backends/gaudi/server/text_generation_server/layers/__init__.py b/backends/gaudi/server/text_generation_server/layers/__init__.py
new file mode 100644
index 000000000..0000ca915
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/__init__.py
@@ -0,0 +1,34 @@
+from text_generation_server.layers.tensor_parallel import (
+ TensorParallelColumnLinear,
+ TensorParallelRowLinear,
+ TensorParallelEmbedding,
+)
+from text_generation_server.layers.linear import (
+ get_linear,
+ FastLinear,
+)
+from text_generation_server.layers.speculative import SpeculativeHead
+
+# Just to add the `load` methods.
+from text_generation_server.layers.layernorm import load_layer_norm
+from text_generation_server.layers.conv import load_conv2d
+
+from text_generation_server.layers.lora import (
+ LoraLinear,
+ TensorParallelMultiAdapterLinear,
+ TensorParallelAdapterRowLinear,
+)
+
+__all__ = [
+ "get_linear",
+ "FastLinear",
+ "TensorParallelColumnLinear",
+ "TensorParallelRowLinear",
+ "TensorParallelEmbedding",
+ "SpeculativeHead",
+ "LoraLinear",
+ "TensorParallelMultiAdapterLinear",
+ "TensorParallelAdapterRowLinear",
+ "load_layer_norm",
+ "load_conv2d",
+]
diff --git a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py
new file mode 100644
index 000000000..9ba9f6e08
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py
@@ -0,0 +1,28 @@
+from .common import (
+ Seqlen,
+ HPUPagedAttentionMetadata,
+ trim_attn_metadata,
+ trim_seqlen_metadata,
+)
+
+from .hpu import (
+ SUPPORTS_WINDOWING,
+ attention,
+ paged_attention,
+)
+
+
+# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
+from .kv_cache import KVCache, get_kv_scales
+
+__all__ = [
+ "attention",
+ "get_kv_scales",
+ "paged_attention",
+ "SUPPORTS_WINDOWING",
+ "KVCache",
+ "Seqlen",
+ "HPUPagedAttentionMetadata",
+ "trim_seqlen_metadata",
+ "trim_attn_metadata",
+]
diff --git a/backends/gaudi/server/text_generation_server/layers/attention/common.py b/backends/gaudi/server/text_generation_server/layers/attention/common.py
new file mode 100644
index 000000000..8ec9fb461
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/attention/common.py
@@ -0,0 +1,147 @@
+from dataclasses import dataclass
+import torch
+from typing import Optional, List, Dict
+import collections
+
+_TYPE_CACHE = {}
+
+
+@dataclass
+class HPUPagedAttentionMetadata:
+ """Metadata for PagedAttention."""
+
+ block_list: Optional[torch.Tensor]
+ block_mapping: Optional[torch.Tensor]
+ block_usage: Optional[torch.Tensor]
+ block_scales: Optional[torch.Tensor]
+ block_groups: Optional[torch.Tensor]
+ attn_bias: Optional[torch.Tensor]
+
+
+def subtuple(
+ obj: object,
+ typename: str,
+ to_copy: List[str],
+ to_override: Optional[Dict[str, object]] = None,
+):
+ if obj is None:
+ return None
+ if to_override is None:
+ to_override = {}
+ fields = set(to_copy) | set(to_override.keys())
+ if isinstance(obj, dict):
+ values = {key: obj[key] for key in fields if key in obj}
+ else:
+ values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
+ if typename not in _TYPE_CACHE:
+ _TYPE_CACHE[typename] = collections.namedtuple(typename, " ".join(fields))
+ return _TYPE_CACHE[typename](**values)
+
+
+def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
+ # NOTE(kzawora): To anyone working on this in the future:
+ # Trimming metadata is required when using HPUGraphs.
+ # Attention metadata is going to be hashed by PT bridge, and
+ # appropriate HPUGraphs will be matched based on all inputs' hash.
+
+ # Before you put more keys in here, make sure you know their
+ # value type and make sure you know how it's going to be hashed.
+ # You can find that information in input_hash function
+ # in habana_frameworks/torch/hpu/graphs.py. You can also hash
+ # it manually with torch.hpu.graphs.input_hash(attention_metadata)
+
+ # If you use primitive types here - they will get hashed based
+ # on their value. You *will* get lots of excessive graph captures
+ # (and an OOM eventually) if you decide to put something like
+ # seq_len int here.
+ # If you absolutely need a scalar, put it in a tensor. Tensors
+ # get hashed using their metadata, not their values:
+ # input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
+ # input_hash(123) != input_hash(321)
+ # input_hash("abc") != input_hash("cba")
+ attention_metadata = subtuple(
+ metadata,
+ "TrimmedAttentionMetadata",
+ [
+ "block_list",
+ "block_mapping",
+ "block_usage",
+ "block_scales",
+ "block_groups",
+ "attn_bias",
+ ],
+ )
+ return attention_metadata
+
+
+@dataclass
+class Seqlen:
+ input_lengths: torch.Tensor
+ cache_lengths: torch.Tensor
+ cu_seqlen_q: Optional[torch.Tensor]
+ cu_seqlen_k: Optional[torch.Tensor]
+
+ def __init__(
+ self,
+ input_lengths,
+ cache_lengths,
+ cu_seqlen_q=None,
+ ):
+ self.input_lengths = input_lengths
+ self.cache_lengths = cache_lengths
+ device = self.input_lengths.device
+ shape = self.input_lengths.shape
+ if cu_seqlen_q is None:
+ cu_seqlen_q = torch.arange(
+ shape[0] + 1,
+ device=device,
+ dtype=torch.int32,
+ )
+ cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
+
+ # cuda graphs don't like this and this is necessary to clamp within mistral
+ # Although FA2 might not want the clamping
+ # cu_seqlen_k[0] = 0
+ total = self.input_lengths + self.cache_lengths
+ torch.cumsum(total, -1, out=cu_seqlen_k[1:])
+
+ self.cu_seqlen_q = cu_seqlen_q
+ self.cu_seqlen_k = cu_seqlen_k
+
+ def clamp(self, max):
+ # Flash decoding doesn't need to clamp
+ return self
+
+
+def trim_seqlen_metadata(metadata: Seqlen) -> object:
+ # NOTE(kzawora): To anyone working on this in the future:
+ # Trimming metadata is required when using HPUGraphs.
+ # Attention metadata is going to be hashed by PT bridge, and
+ # appropriate HPUGraphs will be matched based on all inputs' hash.
+
+ # Before you put more keys in here, make sure you know their
+ # value type and make sure you know how it's going to be hashed.
+ # You can find that information in input_hash function
+ # in habana_frameworks/torch/hpu/graphs.py. You can also hash
+ # it manually with torch.hpu.graphs.input_hash(attention_metadata)
+
+ # If you use primitive types here - they will get hashed based
+ # on their value. You *will* get lots of excessive graph captures
+ # (and an OOM eventually) if you decide to put something like
+ # seq_len int here.
+ # If you absolutely need a scalar, put it in a tensor. Tensors
+ # get hashed using their metadata, not their values:
+ # input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
+ # input_hash(123) != input_hash(321)
+ # input_hash("abc") != input_hash("cba")
+ attention_metadata = subtuple(
+ metadata,
+ "TrimmedSeqlen",
+ [
+ "input_lengths",
+ "cache_lengths",
+ "cu_seqlen_q",
+ "cu_seqlen_k",
+ ],
+ )
+ return attention_metadata
diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py
new file mode 100644
index 000000000..f34e93abc
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py
@@ -0,0 +1,95 @@
+import torch
+from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
+from typing import Optional
+from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
+from vllm_hpu_extension import ops
+from vllm_hpu_extension.utils import Matmul
+from habana_frameworks.torch.hpex.kernels import FusedSDPA
+from vllm_hpu_extension.utils import ModuleFusedSDPA
+import os
+
+SUPPORTS_WINDOWING = False
+
+
+def fetch_from_cache(cache, blocks):
+ if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
+ return cache[: blocks.size(0)]
+ else:
+ return cache.index_select(0, blocks)
+
+
+def attention(
+ *,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ kv_cache: KVCache,
+ kv_scales: KVScales,
+ seqlen: Seqlen,
+ softmax_scale: float,
+ window_size_left: int = -1,
+ causal: bool = True,
+ softcap: Optional[float] = None,
+):
+ fsdpa_op = ModuleFusedSDPA(FusedSDPA)
+ bs = seqlen.input_lengths.shape[0]
+ _, head_num, head_size = query.shape
+ _, kv_head_num, head_size = key.shape
+ query = query.view(bs, -1, head_num, head_size).transpose(1, 2)
+ key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
+ value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
+ attn_output = fsdpa_op(
+ query,
+ key,
+ value,
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=causal,
+ scale=softmax_scale,
+ softmax_mode="None",
+ recompute_mode=None,
+ valid_sequence_lengths=seqlen.input_lengths,
+ padding_side="left",
+ )
+ attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
+ return attn_output
+
+
+def paged_attention(
+ query: torch.Tensor,
+ kv_cache: KVCache,
+ kv_head_mapping: torch.Tensor,
+ softmax_scale: float,
+ seqlen: Seqlen,
+ *,
+ kv_scales: KVScales,
+ softcap: Optional[float] = None,
+ hpu_attention_meta: HPUPagedAttentionMetadata,
+):
+ batch_size, head_num, head_size = query.shape
+ output = ops.flat_pa(
+ query=query.view(batch_size, 1, head_num * head_size),
+ key_cache=kv_cache.key,
+ value_cache=kv_cache.value,
+ block_list=hpu_attention_meta.block_list,
+ block_mapping=hpu_attention_meta.block_mapping,
+ block_bias=hpu_attention_meta.attn_bias,
+ block_scales=hpu_attention_meta.block_scales,
+ block_groups=hpu_attention_meta.block_groups,
+ scale=softmax_scale,
+ matmul_qk_op=Matmul(),
+ matmul_av_op=Matmul(),
+ batch2block_matmul_op=Matmul(),
+ block2batch_matmul_op=Matmul(),
+ keys_fetch_func=fetch_from_cache,
+ values_fetch_func=fetch_from_cache,
+ )
+ # Reshape the output tensor.
+ return output.view(batch_size, head_num, head_size)
+
+
+__all__ = [
+ "SUPPORTS_WINDOWING",
+ "attention",
+ "paged_attention",
+]
diff --git a/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py
new file mode 100644
index 000000000..d238cdb97
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py
@@ -0,0 +1,139 @@
+from typing import Tuple
+from dataclasses import dataclass, field
+
+import torch
+
+from text_generation_server.models.globals import BLOCK_SIZE
+from text_generation_server.utils.weights import Weights
+from vllm_hpu_extension import cache_ops
+
+
+@dataclass
+class KVScales:
+ """
+ Key-value scales for FP8 KV cache.
+
+ This data class stores key and value scales both as a GPU tensor and
+ as a GPU float. This inconvenience is necessary because some functions
+ (e.g. scaling kernels) take scales as a GPU tensor, whereas others
+ (e.g. flashinfer) take scales as a CPU scalar.
+ """
+
+ key_scale: torch.Tensor
+ value_scale: torch.Tensor
+ key_scale_cpu: float = field(init=False)
+ value_scale_cpu: float = field(init=False)
+
+ def __post_init__(self):
+ if self.key_scale.numel() != 1 or self.value_scale.numel() != 1:
+ raise ValueError("Key and value scales must be scalar tensors.")
+
+ self.key_scale_cpu = self.key_scale.item()
+ self.value_scale_cpu = self.value_scale.item()
+
+
+class KVCache:
+ """
+ Key-value cache for attention layers.
+ """
+
+ kv_cache: Tuple[torch.Tensor, torch.Tensor]
+
+ def __init__(
+ self,
+ *,
+ num_blocks: int,
+ num_heads: int,
+ head_size: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ ):
+ """Construct the key-value cache for a layer."""
+ ## TODO FP8 kv cache support
+
+ self.kv_cache = (
+ torch.zeros(
+ (num_blocks, BLOCK_SIZE, num_heads, head_size),
+ dtype=dtype,
+ device=device,
+ ),
+ torch.zeros(
+ (num_blocks, BLOCK_SIZE, num_heads, head_size),
+ dtype=dtype,
+ device=device,
+ ),
+ )
+
+ @property
+ def dtype(self):
+ """Get the data type of the cache."""
+ return self.kv_cache[0].dtype
+
+ @property
+ def key(self):
+ """Get the key cache."""
+
+ return self.kv_cache[0]
+
+ @property
+ def value(self):
+ """Get the value cache."""
+
+ return self.kv_cache[1]
+
+ def store(
+ self,
+ *,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ slots: torch.Tensor,
+ kv_scales: KVScales,
+ ):
+ """Store the key and value at the given slots."""
+ ## TODO FP8 kv cache support
+
+ key_cache = self.kv_cache[0]
+ value_cache = self.kv_cache[1]
+
+ paged_reshape_and_cache(
+ key,
+ value,
+ key_cache,
+ value_cache,
+ slots,
+ kv_scales.key_scale_cpu,
+ kv_scales.value_scale_cpu,
+ )
+
+
+def paged_reshape_and_cache(
+ key: torch.Tensor,
+ value: torch.Tensor,
+ key_cache: torch.Tensor,
+ value_cache: torch.Tensor,
+ slots: torch.Tensor,
+ k_scale: float = 1.0,
+ v_scale: float = 1.0,
+):
+ block_idx = slots // BLOCK_SIZE
+ block_offset = slots % BLOCK_SIZE
+ cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
+ cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)
+
+
+def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
+ """Load KV cache scales."""
+
+ key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device)
+ value_scale = key_scale
+ if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor(
+ f"{prefix}.v_scale"
+ ):
+ key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float()
+ value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float()
+ elif weights.has_tensor(f"{prefix}.kv_scale"):
+ # Fall back to older more coarse-grained scale when available.
+ key_scale = weights.get_tensor(f"{prefix}.kv_scale").float()
+ value_scale = key_scale
+
+ return KVScales(key_scale=key_scale, value_scale=value_scale)
diff --git a/backends/gaudi/server/text_generation_server/layers/awq/conversion_utils.py b/backends/gaudi/server/text_generation_server/layers/awq/conversion_utils.py
new file mode 100644
index 000000000..b19eafbbe
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/awq/conversion_utils.py
@@ -0,0 +1,97 @@
+import torch
+from typing import List
+
+
+AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
+REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
+
+
+def pack(imatrix: torch.Tensor, direction: str = "column"):
+ """
+ Packs a 4-bit integer matrix into a packed 32-bit integer matrix.
+ Args:
+ imatrix (torch.Tensor): matrix of integers
+ direction (str): direction of packing, either "column" or "row"
+ Returns:
+ qmatrix (torch.Tensor): packed matrix of integers
+ """
+ shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=imatrix.device)
+
+ imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
+
+ if direction == "column":
+ imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4))
+ qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1)
+
+ elif direction == "row":
+ imatrix = imatrix.view(imatrix.shape[0] // (32 // 4), (32 // 4), -1)
+ qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1)
+
+ qmatrix = qmatrix.to(torch.int32)
+
+ return qmatrix
+
+
+def unpack(qmatrix: torch.Tensor, direction: str = "column"):
+ """
+ Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix.
+ Args:
+ qmatrix (torch.Tensor): matrix of packed integers
+ direction (str): direction of unpacking, either "column" or "row"
+ Returns:
+ imatrix (torch.Tensor): matrix of integers
+ """
+ shifts = torch.arange(0, 32, 4, device=qmatrix.device)
+
+ if direction == "column":
+ imatrix = torch.bitwise_right_shift(
+ qmatrix[:, :, None], shifts[None, None, :]
+ ).view(qmatrix.shape[0], -1)
+
+ elif direction == "row":
+ imatrix = torch.bitwise_right_shift(
+ qmatrix[:, None, :], shifts[None, :, None]
+ ).view(-1, qmatrix.shape[-1])
+
+ imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
+
+ return imatrix
+
+
+def apply_order(
+ imatrix: torch.Tensor,
+ direction: str = "column",
+ order: List[int] = AWQ_PACK_ORDER,
+):
+ """
+ Applies the order to a 4-bit integer matrix.
+ Args:
+ imatrix (torch.Tensor): matrix of integers
+ direction (str): direction of applying order, either "column" or "row"
+ order (List[int]): order to apply, default is AWQ_PACK_ORDER
+ Returns:
+ imatrix (torch.Tensor): matrix of integers
+ """
+ if direction == "column":
+ imatrix = imatrix.view(-1, (32 // 4))[:, order].view(imatrix.shape)
+ elif direction == "row":
+ imatrix = imatrix.view((32 // 4), -1)[order, :].view(imatrix.shape)
+
+ return imatrix
+
+
+def fast_awq_to_gptq(qweight, qzeros):
+ # awq uses column packing for both weights and zeros
+ izeros = unpack(qzeros, direction="column")
+ iweights = unpack(qweight, direction="column")
+
+ # Reverse the order of the iweight and izeros tensors
+ izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER)
+ iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER)
+ # Subtract 1 from the izeros tensor (gptq adds 1 to the zeros)
+ izeros = izeros - 1
+ # exllama uses row packing for weights and column packing for zeros
+ qzeros = pack(izeros, direction="column")
+ qweight = pack(iweights, direction="row")
+
+ return qweight, qzeros
diff --git a/backends/gaudi/server/text_generation_server/layers/awq/quantize/__init__.py b/backends/gaudi/server/text_generation_server/layers/awq/quantize/__init__.py
new file mode 100644
index 000000000..856d7c281
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/awq/quantize/__init__.py
@@ -0,0 +1,3 @@
+from .hpu import WQLinear
+
+__all__ = ["WQLinear"]
diff --git a/backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py b/backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py
new file mode 100644
index 000000000..3af0131b3
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py
@@ -0,0 +1,134 @@
+from typing import Optional
+import torch
+import torch.nn as nn
+
+try:
+ import habana_frameworks.torch.hpu # noqa: F401
+
+ convert_from_uint4 = torch.ops.hpu.convert_from_uint4
+except Exception as e:
+ hpu_import_exception = e
+
+ def error_raiser_hpu(*args, **kwargs):
+ raise ValueError(
+ f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
+ )
+
+ convert_from_uint4 = error_raiser_hpu
+
+AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
+
+
+def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
+ shifts = torch.arange(0, 32, bits, device=qzeros.device)
+
+ # unpacking columnwise
+ iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
+ torch.int8 # smallest dtype available
+ )
+ iweights = iweights.view(iweights.shape[0], -1)
+
+ # unpacking columnwise
+ if qzeros is not None:
+ izeros = torch.bitwise_right_shift(
+ qzeros[:, :, None], shifts[None, None, :]
+ ).to(
+ torch.int8 # smallest dtype available
+ )
+ izeros = izeros.view(izeros.shape[0], -1)
+ else:
+ izeros = qzeros
+
+ return iweights, izeros
+
+
+def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
+ reverse_order_tensor = torch.arange(
+ iweights.shape[-1],
+ dtype=torch.int32,
+ device=izeros.device,
+ )
+ reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
+ reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
+ reverse_order_tensor = reverse_order_tensor.view(-1)
+
+ if izeros is not None:
+ izeros = izeros[:, reverse_order_tensor]
+ iweights = iweights[:, reverse_order_tensor]
+
+ return iweights, izeros
+
+
+def unpack_weight_and_zeros(qweight, qzeros, bits):
+ # Unpack the qweight and qzeros tensors
+ iweight, izeros = unpack_awq(qweight, qzeros, bits)
+ # Reverse the order of the iweight and izeros tensors
+ iweight, izeros = reverse_awq_order(iweight, izeros, bits)
+
+ # overflow checks
+ iweight = torch.bitwise_and(iweight, (2**bits) - 1)
+ izeros = torch.bitwise_and(izeros, (2**bits) - 1)
+
+ return iweight, izeros
+
+
+def pack_tensor(input, bits=4):
+ normal = input.to(torch.int32)
+ q = torch.zeros(
+ (normal.shape[0], normal.shape[1] // 32 * bits),
+ dtype=torch.int32,
+ device=input.device,
+ )
+ i = 0
+ col = 0
+ while col < q.shape[1]:
+ for j in range(i, i + (32 // bits)):
+ q[:, col] |= normal[:, j] << (bits * (j - i))
+ i += 32 // bits
+ col += 1
+ q = q.to(torch.int32)
+ return q
+
+
+class WQLinear(nn.Module):
+ def __init__(
+ self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
+ ):
+ super().__init__()
+
+ if w_bit not in [4]:
+ raise NotImplementedError("Only 4-bit are supported for now.")
+
+ self.in_features = qweight.shape[0]
+ self.out_features = qweight.shape[1] * 32 // w_bit
+
+ self.w_bit = w_bit
+ self.group_size = group_size if group_size != -1 else self.in_features
+ # quick sanity check (make sure aligment)
+ assert self.in_features % self.group_size == 0
+ assert self.out_features % (32 // self.w_bit) == 0
+
+ self.qweight = qweight
+ self.qzeros = qzeros
+ self.scales = scales
+ self.bias = bias
+ self._preprocessing()
+
+ def _preprocessing(self):
+ device = self.qweight.device
+ weight, zeros = unpack_weight_and_zeros(
+ self.qweight.cpu(), self.qzeros.cpu(), self.w_bit
+ )
+ self.qweight = pack_tensor(weight).to(device)
+ self.qzeros = pack_tensor(zeros).to(device)
+
+ @torch.no_grad()
+ def forward(self, x):
+ out_shape = x.shape[:-1] + (self.out_features,)
+ x = x.reshape(-1, x.shape[-1])
+ weights = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)
+ outputs = torch.matmul(x, weights)
+
+ outputs = outputs + self.bias if self.bias is not None else outputs
+ outputs = outputs.reshape(out_shape)
+ return outputs
diff --git a/backends/gaudi/server/text_generation_server/layers/bnb.py b/backends/gaudi/server/text_generation_server/layers/bnb.py
new file mode 100644
index 000000000..791d9b6d8
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/bnb.py
@@ -0,0 +1,124 @@
+from dataclasses import dataclass
+
+import bitsandbytes as bnb
+import torch
+from bitsandbytes.nn import Int8Params, Params4bit
+from text_generation_server.utils.weights import UnquantizedWeight
+
+
+@dataclass
+class BNBWeight(UnquantizedWeight):
+ weight: torch.Tensor
+
+ def get_linear(self, bias: torch.Tensor):
+ return Linear8bitLt(self.weight, bias, has_fp16_weights=False, threshold=6.0)
+
+
+class Linear8bitLt(torch.nn.Module):
+ def __init__(
+ self,
+ weight,
+ bias,
+ has_fp16_weights=True,
+ memory_efficient_backward=False,
+ threshold=0.0,
+ index=None,
+ ):
+ super().__init__()
+ assert (
+ not memory_efficient_backward
+ ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
+ self.state = bnb.MatmulLtState()
+ self.index = index
+
+ # Necessary for stacked layers
+ self.state.threshold = threshold
+ self.state.has_fp16_weights = has_fp16_weights
+ self.state.memory_efficient_backward = memory_efficient_backward
+ if threshold > 0.0 and not has_fp16_weights:
+ self.state.use_pool = True
+
+ self.weight = Int8Params(
+ weight.data,
+ has_fp16_weights=has_fp16_weights,
+ requires_grad=has_fp16_weights,
+ )
+ self.weight.cuda(weight.device)
+ self.bias = bias
+
+ def init_8bit_state(self):
+ self.state.CB = self.weight.CB
+ self.state.SCB = self.weight.SCB
+ self.weight.CB = None
+ self.weight.SCB = None
+
+ def forward(self, x: torch.Tensor):
+ self.state.is_training = self.training
+ if self.weight.CB is not None:
+ self.init_8bit_state()
+
+ # weights are cast automatically as Int8Params, but the bias has to be cast manually
+ if self.bias is not None and self.bias.dtype != x.dtype:
+ self.bias.data = self.bias.data.to(x.dtype)
+
+ out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
+
+ if not self.state.has_fp16_weights:
+ if self.state.CB is not None and self.state.CxB is not None:
+ # we converted 8-bit row major to turing/ampere format in the first inference pass
+ # we no longer need the row-major weight
+ del self.state.CB
+ self.weight.data = self.state.CxB
+ return out
+
+
+@dataclass
+class BNBFP4Weight(UnquantizedWeight):
+ weight: torch.Tensor
+
+ def get_linear(self, bias: torch.Tensor):
+ return Linear4bit(self.weight, bias, quant_type="fp4")
+
+
+@dataclass
+class BNBNF4Weight(UnquantizedWeight):
+ weight: torch.Tensor
+
+ def get_linear(self, bias: torch.Tensor):
+ return Linear4bit(self.weight, bias, quant_type="nf4")
+
+
+class Linear4bit(torch.nn.Module):
+ def __init__(self, weight, bias, quant_type):
+ super().__init__()
+ self.weight = Params4bit(
+ weight.data,
+ requires_grad=False,
+ compress_statistics=True,
+ quant_type=quant_type,
+ )
+ self.compute_dtype = None
+ self.weight.cuda(weight.device)
+ self.bias = bias
+
+ def forward(self, x: torch.Tensor):
+ # weights are cast automatically as Int8Params, but the bias has to be cast manually
+ if self.bias is not None and self.bias.dtype != x.dtype:
+ self.bias.data = self.bias.data.to(x.dtype)
+
+ if getattr(self.weight, "quant_state", None) is None:
+ print(
+ "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
+ )
+ inp_dtype = x.dtype
+ if self.compute_dtype is not None:
+ x = x.to(self.compute_dtype)
+
+ bias = None if self.bias is None else self.bias.to(self.compute_dtype)
+ out = bnb.matmul_4bit(
+ x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
+ )
+
+ out = out.to(inp_dtype)
+
+ return out
diff --git a/backends/gaudi/server/text_generation_server/layers/conv.py b/backends/gaudi/server/text_generation_server/layers/conv.py
new file mode 100644
index 000000000..7fb18ab3f
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/conv.py
@@ -0,0 +1,41 @@
+from accelerate import init_empty_weights
+import torch
+
+
+@classmethod
+def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
+ weight = weights.get_tensor(f"{prefix}.weight")
+ bias = weights.get_tensor(f"{prefix}.bias")
+ with init_empty_weights():
+ conv2d = cls(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ )
+
+ conv2d.weight = torch.nn.Parameter(weight)
+ conv2d.bias = torch.nn.Parameter(bias)
+ return conv2d
+
+
+@classmethod
+def load_conv2d_no_bias(
+ cls, prefix, weights, in_channels, out_channels, kernel_size, stride
+):
+ weight = weights.get_tensor(f"{prefix}.weight")
+ with init_empty_weights():
+ conv2d = cls(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ )
+
+ conv2d.weight = torch.nn.Parameter(weight)
+ conv2d.bias = None
+ return conv2d
+
+
+torch.nn.Conv2d.load = load_conv2d
+torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias
diff --git a/backends/gaudi/server/text_generation_server/layers/exl2.py b/backends/gaudi/server/text_generation_server/layers/exl2.py
new file mode 100644
index 000000000..a6e07f453
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/exl2.py
@@ -0,0 +1,78 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import torch
+from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
+
+
+@dataclass
+class Exl2Weight(Weight):
+ """
+ Exllama2 exl2 quantized weights.
+ """
+
+ q_weight: torch.Tensor
+ q_scale: torch.Tensor
+ q_invperm: torch.Tensor
+ q_scale_max: torch.Tensor
+ q_groups: torch.Tensor
+
+ def __post_init__(self):
+ self.q_scale_max /= 256
+ self.q_invperm = self.q_invperm.short()
+
+ @property
+ def device(self) -> torch.device:
+ return self.q_weight.device
+
+ def get_linear(self, bias: torch.Tensor):
+ from text_generation_server.layers.gptq import ExllamaQuantLinear
+
+ return ExllamaQuantLinear(self, bias)
+
+
+class Exl2WeightsLoader(WeightsLoader):
+ """Loader for exl2-quantized weights."""
+
+ def get_weights(self, weights: "Weights", prefix: str):
+ """
+ Get weights at the given prefix and apply without tensor paralllism.
+ """
+ try:
+ q_weight = weights.get_tensor(f"{prefix}.q_weight")
+ except RuntimeError:
+ raise RuntimeError(
+ "Cannot load `exl2`-quantized weight, make sure the model is already quantized."
+ )
+
+ q_scale = weights.get_tensor(f"{prefix}.q_scale")
+ q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
+ q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
+ q_groups = weights.get_tensor(f"{prefix}.q_groups")
+
+ return Exl2Weight(
+ q_weight=q_weight,
+ q_scale=q_scale,
+ q_invperm=q_invperm,
+ q_scale_max=q_scale_max,
+ q_groups=q_groups,
+ )
+
+ def get_weights_col_packed(
+ self,
+ weights: Weights,
+ prefix: str,
+ block_sizes: Union[int, List[int]],
+ ):
+ raise RuntimeError("Column-packed weights are not supported for exl")
+
+ def get_weights_col(self, weights: Weights, prefix: str):
+ # Sharding is not yet supported, so we return the weights as-is.
+ return self.get_weights(weights, prefix)
+
+ def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
+ raise ValueError("get_multi_weights_col is not supported for exl2")
+
+ def get_weights_row(self, weights: Weights, prefix: str):
+ # Sharding is not yet supported, so we return the weights as-is.
+ return self.get_weights(weights, prefix)
diff --git a/backends/gaudi/server/text_generation_server/layers/fp8.py b/backends/gaudi/server/text_generation_server/layers/fp8.py
new file mode 100644
index 000000000..0dc5cdafd
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/fp8.py
@@ -0,0 +1,458 @@
+from dataclasses import dataclass
+from typing import Optional, Tuple, Type, Union, List
+
+import torch
+
+from text_generation_server.utils.weights import (
+ Weight,
+ WeightsLoader,
+ UnquantizedWeight,
+ Weights,
+)
+
+from vllm_hpu_extension.ops import scaled_fp8_quant
+from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2
+import habana_frameworks.torch.utils.experimental as htexp
+
+w8a8_block_fp8_matmul = None
+per_token_group_quant_fp8 = None
+quant_dtype: torch.dtype = torch.float8_e4m3fn
+
+
+def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
+ """
+ Return an FP8 linear `Module` that is compatible with the current system.
+ """
+ # On other systems let Torch decide if the hardware supports FP8.
+ return Fp8Linear
+
+
+def normalize_e4m3fn_to_native_float8(
+ weight: torch.Tensor,
+ weight_scale: torch.Tensor,
+ input_scale: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ return weight, weight_scale, input_scale
+
+
+def per_tensor_dequantize(
+ tensor: torch.Tensor,
+ inv_scale: Union[float, torch.Tensor],
+ dtype: torch.dtype = torch.float16,
+) -> torch.Tensor:
+ device = tensor.device
+ dtype = torch.bfloat16
+ if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2:
+ # dequant on cpu to avoid nan on gaudi2
+ tensor = tensor.to("cpu")
+
+ fake_qweight = tensor.to(dtype).to(device)
+ dq_weight = fake_qweight * inv_scale
+ return dq_weight
+
+
+def requantize_with_max_scale(
+ weight: torch.Tensor,
+ weight_scale: torch.Tensor,
+ logical_widths: int,
+ dtype: torch.dtype,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Max scale to be used for requanitzation.
+ max_w_scale = weight_scale.max()
+
+ if is_hpu_gaudi2():
+ max_w_scale = max_w_scale * get_hpu_gaudi2_scale_factor()
+
+ start = 0
+ for idx, logical_width in enumerate(logical_widths):
+ end = start + logical_width
+ weight_dq = per_tensor_dequantize(
+ weight[start:end, :], weight_scale[idx], dtype
+ )
+ weight[start:end, :], max_w_scale_normalized = fp8_quantize(
+ weight_dq, max_w_scale
+ )
+ start = end
+
+ return weight, max_w_scale_normalized
+
+
+def fp8_quantize(
+ weight: torch.Tensor,
+ scale: Optional[torch.Tensor] = None,
+ scale_upper_bound: Optional[torch.Tensor] = None,
+ qdtype: torch.dtype = torch.float8_e4m3fn,
+ scalar: bool = False,
+):
+ """
+ This function returns a reciprocal of the scale, so that a tensor can be unscaled
+ by multiplying it with the returned scale. If a scale is given through the `scale`
+ argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
+ be used without modification).
+ """
+ shape = weight.shape
+ qweight, scale = scaled_fp8_quant(
+ weight.reshape(-1, shape[-1]),
+ scale=scale,
+ scale_ub=scale_upper_bound,
+ # TODO: don't do this when we have to use the Torch kernel.
+ use_per_token_if_dynamic=not scalar,
+ )
+
+ return qweight.reshape(shape), scale
+
+
+class HybridFP8UnquantLoader(WeightsLoader):
+ """Weight loader that loads FP8 and unquantized Torch tensors."""
+
+ def __init__(
+ self,
+ activation_scale_ub: Optional[float],
+ to_fp8: bool,
+ weight_block_size: Optional[List[int]] = None,
+ ):
+ self.activation_scale_ub = activation_scale_ub
+ self.to_fp8 = to_fp8
+ self.weight_block_size = weight_block_size
+
+ def get_weights(self, weights: "Weights", prefix: str):
+ w = weights.get_tensor(f"{prefix}.weight")
+
+ if w.dtype == torch.float8_e4m3fn:
+ if self.weight_block_size is not None:
+ scale = weights.get_tensor(f"{prefix}.weight_scale_inv")
+ return Fp8Weight(
+ weight=w,
+ weight_scale=scale,
+ activation_scale_ub=self.activation_scale_ub,
+ dtype=weights.dtype,
+ weight_block_size=self.weight_block_size,
+ )
+ # FP8 branch
+ scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
+
+ input_scale = None
+ if weights.has_tensor(f"{prefix}.input_scale"):
+ input_scale = (
+ weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
+ .reshape(-1)
+ .max()
+ )
+ logical_widths = [w.shape[0]]
+ w, scale = requantize_with_max_scale(
+ w, scale.unsqueeze(0), logical_widths, weights.dtype
+ )
+
+ return Fp8Weight(
+ weight=w,
+ weight_scale=scale,
+ input_scale=input_scale,
+ activation_scale_ub=self.activation_scale_ub,
+ dtype=weights.dtype,
+ )
+ if self.to_fp8:
+ return Fp8Weight(weight=w, dtype=weights.dtype)
+
+ return UnquantizedWeight(w)
+
+ def get_weights_col_packed(
+ self,
+ weights: Weights,
+ prefix: str,
+ block_sizes: Union[int, List[int]],
+ ):
+ w = weights.get_packed_sharded(
+ f"{prefix}.weight", dim=0, block_sizes=block_sizes
+ )
+
+ if w.dtype == torch.float8_e4m3fn:
+ # FP8 branch
+ scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
+
+ if scale.numel() > 1:
+ scale = weights.get_packed_sharded(
+ f"{prefix}.weight_scale",
+ dim=0,
+ block_sizes=block_sizes,
+ to_dtype=False,
+ )
+
+ input_scale = None
+ if weights.has_tensor(f"{prefix}.input_scale"):
+ input_scale = weights.get_tensor(
+ f"{prefix}.input_scale", to_dtype=False
+ )
+ if input_scale.numel() > 1:
+ input_scale = weights.get_packed_sharded(
+ f"{prefix}.input_scale",
+ dim=0,
+ block_sizes=block_sizes,
+ to_dtype=False,
+ )
+ input_scale = input_scale.reshape(-1).max()
+ logical_widths = [w.shape[0]]
+ w, scale = requantize_with_max_scale(
+ w, scale.unsqueeze(0), logical_widths, weights.dtype
+ )
+
+ return Fp8Weight(
+ weight=w,
+ weight_scale=scale,
+ input_scale=input_scale,
+ activation_scale_ub=self.activation_scale_ub,
+ dtype=weights.dtype,
+ )
+ if self.to_fp8:
+ return Fp8Weight(weight=w, dtype=weights.dtype)
+
+ return UnquantizedWeight(w)
+
+ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
+ # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
+ w = [
+ weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
+ ]
+ shapes = [x.shape for x in w]
+
+ # Concat then send to the device
+ w = torch.cat(w, dim=dim).to(weights.device)
+
+ # FP8 branch
+ if w.dtype == torch.float8_e4m3fn:
+ if self.weight_block_size is not None:
+ scale = [
+ weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False)
+ for p in prefixes
+ ]
+ scale = torch.cat(scale, dim=dim)
+ scale = scale.to(weights.device)
+ return Fp8Weight(
+ weight=w,
+ weight_scale=scale,
+ activation_scale_ub=self.activation_scale_ub,
+ dtype=weights.dtype,
+ weight_block_size=self.weight_block_size,
+ )
+
+ scale = [
+ _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
+ for p, shape in zip(prefixes, shapes)
+ ]
+ scale = torch.cat(scale, dim=0).reshape(-1)
+
+ input_scale = [
+ _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
+ for p, shape in zip(prefixes, shapes)
+ if weights.has_tensor(f"{p}.input_scale")
+ ]
+ assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
+ input_scale = (
+ torch.cat(input_scale, dim=0).reshape(-1).max()
+ if len(input_scale) != 0
+ else None
+ )
+
+ logical_widths = [x[0] for x in shapes]
+ w, scale = requantize_with_max_scale(
+ w, scale.to(weights.device), logical_widths, weights.dtype
+ )
+
+ return Fp8Weight(
+ weight=w,
+ weight_scale=scale,
+ input_scale=input_scale,
+ activation_scale_ub=self.activation_scale_ub,
+ dtype=weights.dtype,
+ )
+ if self.to_fp8:
+ return Fp8Weight(weight=w, dtype=weights.dtype)
+
+ return UnquantizedWeight(w)
+
+ def get_weights_row(self, weights: "Weights", prefix: str):
+ w = weights.get_sharded(f"{prefix}.weight", dim=1)
+ # FP8 branch
+ if w.dtype == torch.float8_e4m3fn:
+ if self.weight_block_size is not None:
+ # XXX: Yes the weights is named scale_inv, but corresponds to scale it seems.
+ scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1)
+
+ return Fp8Weight(
+ weight=w,
+ weight_scale=scale,
+ activation_scale_ub=self.activation_scale_ub,
+ dtype=weights.dtype,
+ weight_block_size=self.weight_block_size,
+ )
+
+ scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
+
+ input_scale = None
+ if weights.has_tensor(f"{prefix}.input_scale"):
+ input_scale = (
+ weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
+ .reshape(-1)
+ .max()
+ )
+ logical_widths = [w.shape[0]]
+ w, scale = requantize_with_max_scale(
+ w, scale.unsqueeze(0), logical_widths, weights.dtype
+ )
+ return Fp8Weight(
+ weight=w,
+ weight_scale=scale,
+ input_scale=input_scale,
+ activation_scale_ub=self.activation_scale_ub,
+ dtype=weights.dtype,
+ )
+ if self.to_fp8:
+ return Fp8Weight(weight=w, dtype=weights.dtype)
+
+ return UnquantizedWeight(w)
+
+
+@dataclass
+class Fp8Weight(Weight):
+ weight: torch.Tensor
+ dtype: torch.dtype
+ weight_scale: Optional[torch.Tensor] = None
+ input_scale: Optional[torch.Tensor] = None
+ activation_scale_ub: Optional[float] = None
+ force_w8a16: bool = False
+ weight_block_size: Optional[List[int]] = None
+
+ def get_linear(self, bias: torch.Tensor):
+ if self.weight_scale is None:
+ return get_fp8_linear(force_w8a16=self.force_w8a16).from_unquant(
+ self.weight, bias, self.dtype
+ )
+ # This is not checked by the fbgemm kernels, but they require contiguous
+ # memory. Can be non-contiguous when we e.g. expand from scalars.
+ self.weight_scale = self.weight_scale.contiguous()
+ return get_fp8_linear(force_w8a16=self.force_w8a16).from_fp8(
+ weight=self.weight,
+ scale=self.weight_scale,
+ dtype=self.dtype,
+ bias=bias,
+ input_scale=self.input_scale,
+ scale_upper_bound=self.activation_scale_ub,
+ weight_block_size=self.weight_block_size,
+ )
+
+
+class Fp8Linear(torch.nn.Module):
+ _device_identity_cache = {}
+
+ def __init__(
+ self,
+ qweight: torch.Tensor,
+ scale: torch.Tensor,
+ dtype: torch.dtype,
+ bias: Optional[torch.Tensor] = None,
+ input_scale: Optional[torch.Tensor] = None,
+ scale_upper_bound: Optional[float] = None,
+ weight_block_size: Optional[List[int]] = None,
+ ) -> None:
+ super().__init__()
+
+ self.dtype = dtype
+ self.qweight = qweight
+ self.scale = scale.float()
+ self.input_scale = input_scale.float() if input_scale is not None else None
+ self.weight_block_size = weight_block_size
+ self.scale_upper_bound = scale_upper_bound
+
+ self.bias = bias if bias is not None else None
+
+ @classmethod
+ def from_unquant(cls, weight, bias, dtype):
+ qweight, scale = fp8_quantize(weight, scalar=True)
+ return cls(
+ qweight=qweight,
+ scale=scale,
+ dtype=dtype,
+ bias=bias,
+ input_scale=None,
+ scale_upper_bound=None,
+ )
+
+ @classmethod
+ def from_fp8(
+ cls,
+ weight: torch.Tensor,
+ scale: torch.Tensor,
+ dtype: torch.dtype,
+ bias: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> "Fp8Linear":
+ input_scale = kwargs.get("input_scale", None)
+ scale_upper_bound = kwargs.get("scale_upper_bound", None)
+ weight_block_size = kwargs.get("weight_block_size", None)
+
+ return cls(
+ qweight=weight,
+ scale=scale,
+ input_scale=input_scale,
+ scale_upper_bound=scale_upper_bound,
+ bias=bias,
+ dtype=dtype,
+ weight_block_size=weight_block_size,
+ )
+
+ @classmethod
+ def get_shared_device_identity(cls, device):
+ # Input scaling factors are no longer optional in _scaled_mm starting
+ # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
+ if device not in cls._device_identity_cache:
+ cls._device_identity_cache[device] = torch.ones(1, device=device)
+ return cls._device_identity_cache[device]
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ if self.weight_block_size is not None:
+ # https://arxiv.org/pdf/2412.19437
+ # At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and
+ # scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we
+ # group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output
+ # channels).
+ qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
+ output = w8a8_block_fp8_matmul(
+ qinput,
+ self.qweight,
+ scale,
+ self.scale,
+ self.weight_block_size,
+ output_dtype=input.dtype,
+ )
+
+ if self.bias is not None:
+ output = output + self.bias
+ return output.to(dtype=input.dtype)
+
+ qinput, scale = fp8_quantize(
+ input,
+ self.input_scale,
+ scale_upper_bound=self.scale_upper_bound,
+ scalar=True,
+ )
+
+ output = torch._scaled_mm(
+ qinput,
+ self.qweight.t(),
+ out_dtype=self.dtype,
+ scale_a=scale,
+ scale_b=self.scale,
+ bias=self.bias,
+ )
+
+ if isinstance(output, tuple) and len(output) == 2:
+ output = output[0]
+
+ return output
+
+
+def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
+ scale = weights.get_tensor(prefix, to_dtype=False)
+
+ if scale.numel() > 1:
+ scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
+ return scale.reshape(-1)
diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py
new file mode 100644
index 000000000..90b8f6923
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py
@@ -0,0 +1,357 @@
+from dataclasses import dataclass
+from typing import List, Optional, Union
+
+import torch
+from loguru import logger
+from text_generation_server.utils.log import log_once
+from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
+
+
+from .hpu import QuantLinear
+
+
+@dataclass
+class GPTQWeight(Weight):
+ qweight: torch.Tensor
+ qzeros: torch.Tensor
+ scales: torch.Tensor
+ g_idx: Optional[torch.Tensor]
+ bits: int
+ groupsize: int
+ use_awq_kernel: bool
+ use_exllama: bool
+
+ def __post_init__(self):
+ if self.scales.dtype == torch.float:
+ self.scales = self.scales.half()
+
+ @property
+ def device(self) -> torch.device:
+ return self.qweight.device
+
+ def get_linear(self, bias: torch.Tensor):
+ if self.use_awq_kernel:
+ try:
+ from text_generation_server.layers.awq.quantize import WQLinear
+
+ return WQLinear(
+ w_bit=self.bits,
+ group_size=self.groupsize,
+ qweight=self.qweight,
+ qzeros=self.qzeros,
+ scales=self.scales,
+ bias=bias,
+ )
+ except ImportError:
+ raise NotImplementedError(
+ "You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
+ )
+ else:
+ return QuantLinear(
+ self.qweight,
+ self.qzeros,
+ self.scales,
+ self.g_idx,
+ bias,
+ self.bits,
+ self.groupsize,
+ )
+
+
+class GPTQWeightsLoader(WeightsLoader):
+ """
+ Loader for GPTQ- and AWQ-quantized weights.
+ """
+
+ def __init__(
+ self,
+ *,
+ bits: int,
+ desc_act: bool,
+ groupsize: int,
+ quant_method: str,
+ quantize: str,
+ sym: bool,
+ ):
+ self.bits = bits
+ self.desc_act = desc_act
+ self.groupsize = groupsize
+ self.quant_method = quant_method
+ self.quantize = quantize
+ self.sym = sym
+
+ def get_weights(self, weights: Weights, prefix: str):
+ self._get_gptq_params(weights)
+
+ use_exllama = True
+ if self.bits != 4:
+ use_exllama = False
+
+ if self.desc_act:
+ log_once(logger.warning, "Disabling exllama because desc_act=True")
+ use_exllama = False
+
+ try:
+ qweight = weights.get_tensor(f"{prefix}.qweight")
+ except RuntimeError:
+ raise RuntimeError(
+ "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
+ )
+
+ if self.quantize == "gptq" and self.quant_method == "gptq":
+ g_idx = weights.get_tensor(f"{prefix}.g_idx")
+ else:
+ g_idx = None
+
+ qzeros = weights.get_tensor(f"{prefix}.qzeros")
+ scales = weights.get_tensor(f"{prefix}.scales")
+
+ if use_exllama and g_idx is not None:
+ g_idx = g_idx - g_idx[0]
+
+ if self.quantize == "gptq" and self.quant_method == "awq":
+ log_once(
+ logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
+ )
+ from text_generation_server.layers.awq.conversion_utils import (
+ fast_awq_to_gptq,
+ )
+
+ qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
+ if use_exllama:
+ g_idx = None
+ else:
+ g_idx = (
+ torch.arange(
+ qweight.shape[0] * (32 // self.bits),
+ device=qweight.device,
+ )
+ // self.groupsize
+ ).to(dtype=torch.int32)
+
+ return GPTQWeight(
+ qweight=qweight,
+ qzeros=qzeros,
+ scales=scales,
+ g_idx=g_idx,
+ bits=self.bits,
+ groupsize=self.groupsize,
+ use_exllama=use_exllama,
+ )
+
+ def get_weights_col_packed(
+ self,
+ weights: Weights,
+ prefix: str,
+ block_sizes: Union[int, List[int]],
+ ):
+ try:
+ qweight = weights.get_packed_sharded(
+ f"{prefix}.qweight", dim=1, block_sizes=block_sizes
+ )
+ except RuntimeError:
+ raise RuntimeError(
+ f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
+ )
+ scales = weights.get_packed_sharded(
+ f"{prefix}.scales", dim=1, block_sizes=block_sizes
+ )
+ scales = scales.to(dtype=weights.dtype)
+
+ self._get_gptq_params(weights)
+
+ qzeros = weights.get_packed_sharded(
+ f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
+ )
+ if self.quantize == "gptq" and self.quant_method == "gptq":
+ g_idx = weights.get_tensor(f"{prefix}.g_idx")
+ elif self.quantize == "gptq" and self.quant_method == "awq":
+ log_once(
+ logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
+ )
+ from text_generation_server.layers.awq.conversion_utils import (
+ fast_awq_to_gptq,
+ )
+
+ qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
+ g_idx = (
+ torch.arange(
+ qweight.shape[0] * (32 // self.bits),
+ device=qweight.device,
+ )
+ // self.groupsize
+ ).to(dtype=torch.int32)
+ else:
+ g_idx = None
+
+ return GPTQWeight(
+ qweight=qweight,
+ qzeros=qzeros,
+ scales=scales,
+ g_idx=g_idx,
+ bits=self.bits,
+ groupsize=self.groupsize,
+ use_awq_kernel=self.quantize == "awq",
+ use_exllama=False,
+ )
+
+ def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
+ try:
+ qweight = torch.cat(
+ [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
+ )
+ except RuntimeError:
+ raise RuntimeError(
+ f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
+ )
+
+ scales = torch.cat(
+ [weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
+ )
+
+ self._get_gptq_params(weights)
+
+ qzeros = torch.cat(
+ [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
+ )
+
+ use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act
+
+ if self.quantize == "gptq" and self.quant_method == "gptq":
+ w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
+ for w2 in w[1:]:
+ torch.testing.assert_close(w2, w[0])
+ g_idx = w[0]
+ elif self.quantize == "gptq" and self.quant_method == "awq":
+ log_once(
+ logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
+ )
+ from text_generation_server.layers.awq.conversion_utils import (
+ fast_awq_to_gptq,
+ )
+
+ qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
+ if use_exllama:
+ g_idx = None
+ else:
+ g_idx = (
+ torch.arange(
+ qweight.shape[0] * (32 // self.bits),
+ device=qweight.device,
+ )
+ // self.groupsize
+ ).to(dtype=torch.int32)
+ else:
+ g_idx = None
+
+ return GPTQWeight(
+ qweight=qweight,
+ qzeros=qzeros,
+ scales=scales,
+ g_idx=g_idx,
+ bits=self.bits,
+ groupsize=self.groupsize,
+ use_awq_kernel=self.quantize == "awq",
+ use_exllama=use_exllama,
+ )
+
+ def get_weights_row(self, weights: Weights, prefix: str):
+ self._get_gptq_params(weights)
+
+ use_exllama = True
+ desc_act = self.desc_act
+ if self.bits != 4:
+ use_exllama = False
+
+ if self.desc_act:
+ log_once(logger.warning, "Disabling exllama because desc_act=True")
+ use_exllama = False
+
+ try:
+ qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
+ except RuntimeError:
+ raise RuntimeError(
+ "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
+ )
+
+ if self.quantize == "gptq" and self.quant_method == "gptq":
+ g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
+ else:
+ g_idx = None
+
+ if weights.process_group.size() > 1:
+ if g_idx is not None:
+ if (
+ not torch.equal(
+ # Remove g_idx[0] to adapt the check with TP>1.
+ (g_idx - g_idx[0]).cpu(),
+ torch.tensor(
+ [i // self.groupsize for i in range(g_idx.shape[0])],
+ dtype=torch.int32,
+ ),
+ )
+ and not (g_idx == 0).all()
+ ):
+ # Exllama implementation does not support row tensor parallelism with act-order, as
+ # it would require to reorder input activations that are split unto several GPUs
+ use_exllama = False
+ desc_act = True
+
+ from text_generation_server.layers.gptq import (
+ GPTQWeight,
+ )
+
+ if not desc_act and self.groupsize != -1:
+ qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
+ scales = weights.get_sharded(f"{prefix}.scales", dim=0)
+ if g_idx is not None:
+ # qzeros, scales sharded, and g_idx must be adjusted accordingly
+ g_idx = g_idx - g_idx[0]
+ else:
+ qzeros = weights.get_tensor(f"{prefix}.qzeros")
+ scales = weights.get_tensor(f"{prefix}.scales")
+
+ if self.quantize == "gptq" and self.quant_method == "awq":
+ log_once(
+ logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
+ )
+ from text_generation_server.layers.awq.conversion_utils import (
+ fast_awq_to_gptq,
+ )
+
+ qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
+ if use_exllama:
+ g_idx = None
+ else:
+ g_idx = (
+ torch.arange(
+ qweight.shape[0] * (32 // self.bits),
+ device=qweight.device,
+ )
+ // self.groupsize
+ ).to(dtype=torch.int32)
+
+ return GPTQWeight(
+ qweight=qweight,
+ qzeros=qzeros,
+ scales=scales,
+ g_idx=g_idx,
+ bits=self.bits,
+ groupsize=self.groupsize,
+ use_awq_kernel=self.quantize == "awq",
+ use_exllama=use_exllama,
+ )
+
+ def _get_gptq_params(self, weights: Weights):
+ if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"):
+ self.bits = weights.get_tensor("gptq_bits").item()
+ self.groupsize = weights.get_tensor("gptq_groupsize").item()
+ self.desc_act = False
+ # `server quantize` used asymmetric quantization unconditionally
+ # before the `gptq_sym` setting tensor was added.
+ self.sym = (
+ weights.get_tensor("gptq_sym").item()
+ if weights.has_tensor("gptq_sym")
+ else False
+ )
+ self.quant_method = "gptq"
diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py b/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py
new file mode 100644
index 000000000..72944fa0e
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/gptq/hpu.py
@@ -0,0 +1,186 @@
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+
+try:
+
+ convert_from_uint4 = torch.ops.hpu.convert_from_uint4
+except Exception as e:
+ hpu_import_exception = e
+
+ def error_raiser_hpu(*args, **kwargs):
+ raise ValueError(
+ f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
+ )
+
+ convert_from_uint4 = error_raiser_hpu
+
+
+def pack_tensor(input, bits=4):
+ normal = input.to(torch.int32)
+ q = torch.zeros((normal.shape[0], normal.shape[1] // 32 * bits), dtype=torch.int32)
+ i = 0
+ col = 0
+ while col < q.shape[1]:
+ for j in range(i, i + (32 // bits)):
+ q[:, col] |= normal[:, j] << (bits * (j - i))
+ i += 32 // bits
+ col += 1
+ q = q.to(torch.int32)
+ return q
+
+
+class QuantLinear(nn.Module):
+ def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
+ super().__init__()
+ self.register_buffer("qweight", qweight)
+ self.register_buffer("qzeros", qzeros)
+ self.register_buffer("scales", scales)
+ self.register_buffer("g_idx", g_idx)
+ if bias is not None:
+ self.register_buffer("bias", bias)
+ else:
+ self.bias = None
+ if bits not in [4]:
+ raise NotImplementedError("Only 4 bits are supported.")
+ self.bits = bits
+ self.maxq = 2**self.bits - 1
+ self.groupsize = groupsize
+
+ self.outfeatures = qweight.shape[1]
+ self.infeatures = qweight.shape[0] * 32 // bits
+ self.wf = torch.tensor(
+ list(range(0, 32, self.bits)), dtype=torch.int32
+ ).unsqueeze(0)
+ self._preprocessing()
+
+ def unpack_zeros_from_cuda_old_format(self):
+ zeros = torch.bitwise_right_shift(
+ torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),
+ self.wf.unsqueeze(0),
+ ).to(torch.int16 if self.bits == 8 else torch.int8)
+
+ zeros = zeros + 1
+ zeros = torch.bitwise_and(zeros, (2**self.bits) - 1).to(
+ self.scales.dtype
+ ) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important.
+ zeros = zeros.reshape(-1, zeros.shape[1] * zeros.shape[2])
+ return zeros
+
+ def unpack_weight_from_cuda_old_format(self):
+ weight = torch.bitwise_right_shift(
+ torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1),
+ self.wf.unsqueeze(-1),
+ ).to(torch.int16 if self.bits == 8 else torch.int8)
+ weight = torch.bitwise_and(weight, (2**self.bits) - 1)
+ weight = weight.reshape((weight.shape[0] * weight.shape[1], weight.shape[2]))
+ return weight
+
+ def _preprocessing(self):
+ orig_device = self.qweight.device
+ self.qweight = self.qweight.cpu()
+ weight = self.unpack_weight_from_cuda_old_format()
+ new_qweight = pack_tensor(weight)
+ self.qweight = new_qweight.to(orig_device)
+ # TODO: Support group indexing and remove the check
+ columns = self.qweight.shape[0]
+ g_idx_trivial = [i // self.groupsize for i in range(columns)]
+ g_idx_trivial = torch.tensor(
+ g_idx_trivial, dtype=torch.int32, device=self.g_idx.device
+ )
+ assert torch.equal(
+ self.g_idx, g_idx_trivial
+ ), "Non-trivial tensor g_idx is not supported"
+ self.qzeros = self.qzeros.cpu()
+ zeros = self.unpack_zeros_from_cuda_old_format()
+ new_qzeros = pack_tensor(zeros)
+ self.qzeros = new_qzeros.to(orig_device)
+
+ @classmethod
+ def new(cls, bits, groupsize, infeatures, outfeatures, bias):
+ if bits not in [4]:
+ raise NotImplementedError("Only 4 bits are supported.")
+
+ qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
+ qzeros = torch.zeros(
+ (math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
+ dtype=torch.int32,
+ )
+ scales = torch.zeros(
+ (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
+ )
+ g_idx = torch.tensor(
+ [i // groupsize for i in range(infeatures)], dtype=torch.int32
+ )
+ if bias:
+ bias = torch.zeros((outfeatures), dtype=torch.float16)
+ else:
+ bias = None
+ return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
+
+ def pack(self, linear, scales, zeros, g_idx=None):
+ self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
+
+ scales = scales.t().contiguous()
+ zeros = zeros.t().contiguous()
+ scale_zeros = zeros * scales
+ self.scales = scales.clone().half()
+ if linear.bias is not None:
+ self.bias = linear.bias.clone().half()
+
+ intweight = []
+ for idx in range(self.infeatures):
+ intweight.append(
+ torch.round(
+ (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
+ / self.scales[self.g_idx[idx]]
+ ).to(torch.int)[:, None]
+ )
+ intweight = torch.cat(intweight, dim=1)
+ intweight = intweight.t().contiguous()
+ intweight = intweight.numpy().astype(np.uint32)
+ qweight = np.zeros(
+ (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
+ )
+ i = 0
+ row = 0
+ while row < qweight.shape[0]:
+ if self.bits in [4]:
+ for j in range(i, i + (32 // self.bits)):
+ qweight[row] |= intweight[j] << (self.bits * (j - i))
+ i += 32 // self.bits
+ row += 1
+ else:
+ raise NotImplementedError("Only 4 bits are supported.")
+
+ qweight = qweight.astype(np.int32)
+ self.qweight = torch.from_numpy(qweight)
+
+ zeros -= 1
+ zeros = zeros.numpy().astype(np.uint32)
+ qzeros = np.zeros(
+ (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
+ )
+ i = 0
+ col = 0
+ while col < qzeros.shape[1]:
+ if self.bits in [4]:
+ for j in range(i, i + (32 // self.bits)):
+ qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
+ i += 32 // self.bits
+ col += 1
+ else:
+ raise NotImplementedError("Only 4 bits are supported.")
+
+ qzeros = qzeros.astype(np.int32)
+ self.qzeros = torch.from_numpy(qzeros)
+
+ def forward(self, x):
+ out_shape = x.shape[:-1] + (self.outfeatures,)
+ x = x.reshape(-1, x.shape[-1])
+ weight = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)
+ out = torch.matmul(x, weight)
+ out = out.reshape(out_shape)
+ out = out + self.bias if self.bias is not None else out
+ return out
diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/quantize.py b/backends/gaudi/server/text_generation_server/layers/gptq/quantize.py
new file mode 100644
index 000000000..aa664ea60
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/gptq/quantize.py
@@ -0,0 +1,1026 @@
+import time
+import torch.nn as nn
+import math
+import json
+import os
+import torch
+import transformers
+
+from texttable import Texttable
+from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
+from huggingface_hub import HfApi
+from accelerate import init_empty_weights
+from text_generation_server.utils import initialize_torch_distributed, Weights
+from text_generation_server.utils.hub import weight_files
+from text_generation_server.layers.gptq import QuantLinear
+from loguru import logger
+from typing import Optional
+from text_generation_server.layers.gptq.utils import torch_snr_error
+
+from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
+
+DEV = torch.device("cuda:0")
+
+
+class Quantizer(nn.Module):
+ def __init__(self, shape=1):
+ super(Quantizer, self).__init__()
+ self.register_buffer("maxq", torch.tensor(0))
+ self.register_buffer("scale", torch.zeros(shape))
+ self.register_buffer("zero", torch.zeros(shape))
+
+ def configure(
+ self,
+ bits,
+ perchannel=False,
+ sym=True,
+ mse=False,
+ norm=2.4,
+ grid=100,
+ maxshrink=0.8,
+ trits=False,
+ ):
+ self.maxq = torch.tensor(2**bits - 1)
+ self.perchannel = perchannel
+ self.sym = sym
+ self.mse = mse
+ self.norm = norm
+ self.grid = grid
+ self.maxshrink = maxshrink
+ if trits:
+ self.maxq = torch.tensor(-1)
+ self.scale = torch.zeros_like(self.scale)
+
+ def _quantize(self, x, scale, zero, maxq):
+ if maxq < 0:
+ return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
+ q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
+ return scale * (q - zero)
+
+ def find_params(self, x, weight=False):
+ dev = x.device
+ self.maxq = self.maxq.to(dev)
+
+ shape = x.shape
+ if self.perchannel:
+ if weight:
+ x = x.flatten(1)
+ else:
+ if len(shape) == 4:
+ x = x.permute([1, 0, 2, 3])
+ x = x.flatten(1)
+ if len(shape) == 3:
+ x = x.reshape((-1, shape[-1])).t()
+ if len(shape) == 2:
+ x = x.t()
+ else:
+ x = x.flatten().unsqueeze(0)
+
+ tmp = torch.zeros(x.shape[0], device=dev)
+ xmin = torch.minimum(x.min(1)[0], tmp)
+ xmax = torch.maximum(x.max(1)[0], tmp)
+
+ if self.sym:
+ xmax = torch.maximum(torch.abs(xmin), xmax)
+ tmp = xmin < 0
+ if torch.any(tmp):
+ xmin[tmp] = -xmax[tmp]
+ tmp = (xmin == 0) & (xmax == 0)
+ xmin[tmp] = -1
+ xmax[tmp] = +1
+
+ if self.maxq < 0:
+ self.scale = xmax
+ self.zero = xmin
+ else:
+ self.scale = (xmax - xmin) / self.maxq
+ if self.sym:
+ self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
+ else:
+ self.zero = torch.round(-xmin / self.scale)
+
+ if self.mse:
+ best = torch.full([x.shape[0]], float("inf"), device=dev)
+ for i in range(int(self.maxshrink * self.grid)):
+ p = 1 - i / self.grid
+ xmin1 = p * xmin
+ xmax1 = p * xmax
+ scale1 = (xmax1 - xmin1) / self.maxq
+ zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
+ q = self._quantize(
+ x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq
+ )
+ q -= x
+ q.abs_()
+ q.pow_(self.norm)
+ err = torch.sum(q, 1)
+ tmp = err < best
+ if torch.any(tmp):
+ best[tmp] = err[tmp]
+ self.scale[tmp] = scale1[tmp]
+ self.zero[tmp] = zero1[tmp]
+ if not self.perchannel:
+ if weight:
+ tmp = shape[0]
+ else:
+ tmp = shape[1] if len(shape) != 3 else shape[2]
+ self.scale = self.scale.repeat(tmp)
+ self.zero = self.zero.repeat(tmp)
+
+ if weight:
+ shape = [-1] + [1] * (len(shape) - 1)
+ self.scale = self.scale.reshape(shape)
+ self.zero = self.zero.reshape(shape)
+ return
+ if len(shape) == 4:
+ self.scale = self.scale.reshape((1, -1, 1, 1))
+ self.zero = self.zero.reshape((1, -1, 1, 1))
+ if len(shape) == 3:
+ self.scale = self.scale.reshape((1, 1, -1))
+ self.zero = self.zero.reshape((1, 1, -1))
+ if len(shape) == 2:
+ self.scale = self.scale.unsqueeze(0)
+ self.zero = self.zero.unsqueeze(0)
+
+ def quantize(self, x):
+ if self.ready():
+ return self._quantize(x, self.scale, self.zero, self.maxq)
+
+ return x
+
+ def enabled(self):
+ return self.maxq > 0
+
+ def ready(self):
+ return torch.all(self.scale != 0)
+
+
+class GPTQ:
+ def __init__(self, layer, observe=False):
+ self.layer = layer
+ self.dev = self.layer.weight.device
+ W = layer.weight.data.clone()
+ if isinstance(self.layer, nn.Conv2d):
+ W = W.flatten(1)
+ if isinstance(self.layer, transformers.Conv1D):
+ W = W.t()
+ self.rows = W.shape[0]
+ self.columns = W.shape[1]
+ self.H = torch.zeros((self.columns, self.columns), device=self.dev)
+ self.nsamples = 0
+ self.quantizer = Quantizer()
+ self.observe = observe
+
+ def add_batch(self, inp, out):
+ # Hessian H = 2 X XT + λ I
+ if self.observe:
+ self.inp1 = inp
+ self.out1 = out
+ else:
+ self.inp1 = None
+ self.out1 = None
+
+ if len(inp.shape) == 2:
+ inp = inp.unsqueeze(0)
+ tmp = inp.shape[0]
+ if isinstance(self.layer, nn.Linear) or isinstance(
+ self.layer, transformers.Conv1D
+ ):
+ if len(inp.shape) == 3:
+ inp = inp.reshape((-1, inp.shape[-1]))
+ inp = inp.t()
+ if isinstance(self.layer, nn.Conv2d):
+ unfold = nn.Unfold(
+ self.layer.kernel_size,
+ dilation=self.layer.dilation,
+ padding=self.layer.padding,
+ stride=self.layer.stride,
+ )
+ inp = unfold(inp)
+ inp = inp.permute([1, 0, 2])
+ inp = inp.flatten(1)
+ self.H *= self.nsamples / (self.nsamples + tmp)
+ self.nsamples += tmp
+ # inp = inp.float()
+ inp = math.sqrt(2 / self.nsamples) * inp.float()
+ # self.H += 2 / self.nsamples * inp.matmul(inp.t())
+ self.H += inp.matmul(inp.t())
+
+ def print_loss(self, name, q_weight, weight_error, timecost):
+ table = Texttable()
+ length = 28
+ name = (
+ (name + " " * (length - len(name)))
+ if len(name) <= length
+ else name[:length]
+ )
+
+ table.header(["name", "weight_error", "fp_inp_SNR", "q_inp_SNR", "time"])
+
+ # assign weight
+ self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(
+ self.layer.weight.data.dtype
+ )
+
+ if self.inp1 is not None:
+ # quantize input to int8
+ quantizer = Quantizer()
+ quantizer.configure(8, perchannel=False, sym=True, mse=False)
+ quantizer.find_params(self.inp1)
+ q_in = quantizer.quantize(self.inp1).type(torch.float16)
+ q_out = self.layer(q_in)
+
+ # get kinds of SNR
+ q_SNR = torch_snr_error(q_out, self.out1).item()
+ fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item()
+ else:
+ q_SNR = "-"
+ fp_SNR = "-"
+
+ table.add_row([name, weight_error, fp_SNR, q_SNR, timecost])
+ print(table.draw().split("\n")[-2])
+
+ def fasterquant(
+ self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False, name=""
+ ):
+ self.layer.to(self.dev)
+
+ W = self.layer.weight.data.clone()
+ if isinstance(self.layer, nn.Conv2d):
+ W = W.flatten(1)
+ if isinstance(self.layer, transformers.Conv1D):
+ W = W.t()
+ W = W.float()
+
+ tick = time.time()
+
+ if not self.quantizer.ready():
+ self.quantizer.find_params(W, weight=True)
+
+ H = self.H
+ if not self.observe:
+ del self.H
+ dead = torch.diag(H) == 0
+ H[dead, dead] = 1
+ W[:, dead] = 0
+
+ if act_order:
+ perm = torch.argsort(torch.diag(H), descending=True)
+ W = W[:, perm]
+ H = H[perm][:, perm]
+
+ Losses = torch.zeros_like(W)
+ Q = torch.zeros_like(W)
+
+ damp = percdamp * torch.mean(torch.diag(H))
+ diag = torch.arange(self.columns, device=self.dev)
+ H[diag, diag] += damp
+ H = torch.linalg.cholesky(H)
+ H = torch.cholesky_inverse(H)
+ try:
+ H = torch.linalg.cholesky(H, upper=True)
+ except Exception:
+ # Addition because Falcon fails on h_to_4h
+ H = torch.linalg.cholesky(
+ H + 1e-5 * torch.eye(H.shape[0]).to(H.device), upper=True
+ )
+ Hinv = H
+
+ g_idx = []
+ scale = []
+ zero = []
+ now_idx = 1
+
+ for i1 in range(0, self.columns, blocksize):
+ i2 = min(i1 + blocksize, self.columns)
+ count = i2 - i1
+
+ W1 = W[:, i1:i2].clone()
+ Q1 = torch.zeros_like(W1)
+ Err1 = torch.zeros_like(W1)
+ Losses1 = torch.zeros_like(W1)
+ Hinv1 = Hinv[i1:i2, i1:i2]
+
+ for i in range(count):
+ w = W1[:, i]
+ d = Hinv1[i, i]
+
+ if groupsize != -1:
+ if (i1 + i) % groupsize == 0:
+ self.quantizer.find_params(
+ W[:, (i1 + i) : (i1 + i + groupsize)], weight=True
+ )
+
+ if ((i1 + i) // groupsize) - now_idx == -1:
+ scale.append(self.quantizer.scale)
+ zero.append(self.quantizer.zero)
+ now_idx += 1
+
+ q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
+ Q1[:, i] = q
+ Losses1[:, i] = (w - q) ** 2 / d**2
+
+ err1 = (w - q) / d
+ W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
+ Err1[:, i] = err1
+
+ Q[:, i1:i2] = Q1
+ Losses[:, i1:i2] = Losses1 / 2
+
+ W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
+
+ torch.cuda.synchronize()
+ error = torch.sum(Losses).item()
+
+ groupsize = groupsize if groupsize != -1 else self.columns
+ g_idx = [i // groupsize for i in range(self.columns)]
+ g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
+ if act_order:
+ invperm = torch.argsort(perm)
+ Q = Q[:, invperm]
+ g_idx = g_idx[invperm]
+
+ if isinstance(self.layer, transformers.Conv1D):
+ Q = Q.t()
+
+ self.print_loss(
+ name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick)
+ )
+
+ if scale == []:
+ scale.append(self.quantizer.scale)
+ zero.append(self.quantizer.zero)
+ scale = torch.cat(scale, dim=1)
+ zero = torch.cat(zero, dim=1)
+ return scale, zero, g_idx, error
+
+ def free(self):
+ self.inp1 = None
+ self.out1 = None
+ self.H = None
+ self.Losses = None
+ self.Trace = None
+ torch.cuda.empty_cache()
+
+
+def get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code):
+ from datasets import load_dataset
+
+ traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
+ testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
+
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id, use_fast=False, trust_remote_code=trust_remote_code
+ )
+ except Exception:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id, use_fast=True, trust_remote_code=trust_remote_code
+ )
+
+ trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
+ testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
+
+ import random
+
+ random.seed(seed)
+ trainloader = []
+ for _ in range(nsamples):
+ i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
+ j = i + seqlen
+ inp = trainenc.input_ids[:, i:j]
+ tar = inp.clone()
+ tar[:, :-1] = -100
+ trainloader.append((inp, tar))
+ return trainloader, testenc
+
+
+def get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code):
+ from datasets import load_dataset
+
+ traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
+ valdata = load_dataset("ptb_text_only", "penn_treebank", split="validation")
+
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id, use_fast=False, trust_remote_code=trust_remote_code
+ )
+ except Exception:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id, use_fast=True, trust_remote_code=trust_remote_code
+ )
+
+ trainenc = tokenizer("\n\n".join(traindata["sentence"]), return_tensors="pt")
+ testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt")
+
+ import random
+
+ random.seed(seed)
+ trainloader = []
+ for _ in range(nsamples):
+ i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
+ j = i + seqlen
+ inp = trainenc.input_ids[:, i:j]
+ tar = inp.clone()
+ tar[:, :-1] = -100
+ trainloader.append((inp, tar))
+ return trainloader, testenc
+
+
+def get_c4(nsamples, seed, seqlen, model_id, trust_remote_code):
+ from datasets import load_dataset
+
+ traindata = load_dataset(
+ "allenai/c4",
+ "allenai--c4",
+ data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
+ split="train",
+ use_auth_token=False,
+ )
+ valdata = load_dataset(
+ "allenai/c4",
+ "allenai--c4",
+ data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
+ split="validation",
+ use_auth_token=False,
+ )
+
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id, use_fast=False, trust_remote_code=trust_remote_code
+ )
+ except Exception:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id, use_fast=True, trust_remote_code=trust_remote_code
+ )
+
+ import random
+
+ random.seed(seed)
+ trainloader = []
+ for _ in range(nsamples):
+ while True:
+ i = random.randint(0, len(traindata) - 1)
+ trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
+ if trainenc.input_ids.shape[1] >= seqlen:
+ break
+ i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
+ j = i + seqlen
+ inp = trainenc.input_ids[:, i:j]
+ tar = inp.clone()
+ tar[:, :-1] = -100
+ trainloader.append((inp, tar))
+
+ import random
+
+ random.seed(0)
+ valenc = []
+ for _ in range(256):
+ while True:
+ i = random.randint(0, len(valdata) - 1)
+ tmp = tokenizer(valdata[i]["text"], return_tensors="pt")
+ if tmp.input_ids.shape[1] >= seqlen:
+ break
+ i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
+ j = i + seqlen
+ valenc.append(tmp.input_ids[:, i:j])
+ valenc = torch.hstack(valenc)
+
+ class TokenizerWrapper:
+ def __init__(self, input_ids):
+ self.input_ids = input_ids
+
+ valenc = TokenizerWrapper(valenc)
+
+ return trainloader, valenc
+
+
+def get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code):
+ from datasets import load_dataset
+
+ traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
+ testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
+
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id, use_fast=False, trust_remote_code=trust_remote_code
+ )
+ except Exception:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id, use_fast=True, trust_remote_code=trust_remote_code
+ )
+
+ trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt")
+ testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt")
+
+ import random
+
+ random.seed(seed)
+ trainloader = []
+ for _ in range(nsamples):
+ i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
+ j = i + seqlen
+ inp = trainenc.input_ids[:, i:j]
+ tar = inp.clone()
+ tar[:, :-1] = -100
+ trainloader.append((inp, tar))
+ return trainloader, testenc
+
+
+def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code):
+ from datasets import load_dataset
+
+ traindata = load_dataset(
+ "allenai/c4",
+ "allenai--c4",
+ data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
+ split="train",
+ )
+ valdata = load_dataset(
+ "allenai/c4",
+ "allenai--c4",
+ data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
+ split="validation",
+ )
+
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id, use_fast=False, trust_remote_code=trust_remote_code
+ )
+ except Exception:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id, use_fast=True, trust_remote_code=trust_remote_code
+ )
+
+ import random
+
+ random.seed(seed)
+ trainloader = []
+ for _ in range(nsamples):
+ while True:
+ i = random.randint(0, len(traindata) - 1)
+ trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
+ if trainenc.input_ids.shape[1] >= seqlen:
+ break
+ i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
+ j = i + seqlen
+ inp = trainenc.input_ids[:, i:j]
+ tar = inp.clone()
+ tar[:, :-1] = -100
+ trainloader.append((inp, tar))
+
+ valenc = tokenizer(" ".join(valdata[:1100]["text"]), return_tensors="pt")
+ valenc = valenc.input_ids[:, : (256 * seqlen)]
+
+ class TokenizerWrapper:
+ def __init__(self, input_ids):
+ self.input_ids = input_ids
+
+ valenc = TokenizerWrapper(valenc)
+
+ return trainloader, valenc
+
+
+def get_loaders(
+ name, nsamples=128, seed=0, seqlen=2048, model_id="", trust_remote_code=False
+):
+ if "wikitext2" in name:
+ return get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code)
+ if "ptb" in name:
+ if "new" in name:
+ return get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code)
+ return get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code)
+ if "c4" in name:
+ if "new" in name:
+ return get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code)
+ return get_c4(nsamples, seed, seqlen, model_id, trust_remote_code)
+
+
+def find_layers(module, layers=(nn.Conv2d, nn.Linear), name=""):
+ # Skip last lm_head linear
+ # Need isintance Falcon is inheriting Linear.
+ if isinstance(module, layers) and "lm_head" not in name:
+ return {name: module}
+ res = {}
+ for name1, child in module.named_children():
+ res.update(
+ find_layers(
+ child, layers=layers, name=name + "." + name1 if name != "" else name1
+ )
+ )
+ return res
+
+
+@torch.no_grad()
+def sequential(
+ model,
+ dataloader,
+ dev,
+ nsamples,
+ bits,
+ groupsize,
+ *,
+ hooks,
+ percdamp=0.01,
+ sym: bool = False,
+ act_order: bool = False,
+):
+ print("Starting ...")
+
+ use_cache = model.config.use_cache
+ model.config.use_cache = False
+ try:
+ layers = model.model.layers
+ prefix = "model.layers"
+ except Exception:
+ layers = model.transformer.h
+ prefix = "transformer.h"
+
+ dtype = next(iter(model.parameters())).dtype
+ inps = torch.zeros(
+ (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
+ )
+
+ cache = {"i": 0}
+ extra = {}
+
+ class Catcher(nn.Module):
+ def __init__(self, module):
+ super().__init__()
+ self.module = module
+
+ def forward(self, inp, **kwargs):
+ inps[cache["i"]] = inp
+ cache["i"] += 1
+ extra.update(kwargs.copy())
+ raise ValueError
+
+ layers[0] = Catcher(layers[0])
+ for batch in dataloader:
+ try:
+ model(batch[0].cuda())
+ except ValueError:
+ pass
+ layers[0] = layers[0].module
+
+ # layers[0] = layers[0].cpu()
+ # model.model.embed_tokens = model.model.embed_tokens.cpu()
+ # model.model.norm = model.model.norm.cpu()
+ torch.cuda.empty_cache()
+ for hook in hooks:
+ hook.remove()
+
+ outs = torch.zeros_like(inps)
+
+ extra = {
+ k: v.to(dev) if isinstance(v, torch.Tensor) else v for k, v in extra.items()
+ }
+
+ print("Ready.")
+
+ quantizers = {}
+ for i in range(len(layers)):
+ print(f"Quantizing layer {i+1}/{len(layers)}..")
+ print("+------------------+--------------+------------+-----------+-------+")
+ print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |")
+ print("+==================+==============+============+===========+=======+")
+
+ layer = layers[i]
+ layer.load()
+ full = find_layers(layer)
+ sequential = [list(full.keys())]
+
+ for names in sequential:
+ subset = {n: full[n] for n in names}
+ gptq = {}
+ for name in subset:
+ gptq[name] = GPTQ(subset[name])
+ gptq[name].quantizer.configure(
+ bits, perchannel=True, sym=sym, mse=False
+ )
+ pass
+
+ def add_batch(name):
+ nonlocal gptq
+
+ def tmp(_, inp, out):
+ gptq[name].add_batch(inp[0].data, out.data)
+
+ return tmp
+
+ handles = []
+ for name in subset:
+ handles.append(subset[name].register_forward_hook(add_batch(name)))
+ for j in range(nsamples):
+ outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
+ for h in handles:
+ h.remove()
+
+ for name in subset:
+ scale, zero, g_idx, error = gptq[name].fasterquant(
+ percdamp=percdamp,
+ groupsize=groupsize,
+ act_order=act_order,
+ name=name,
+ )
+ quantizers[f"{prefix}.{i}.{name}"] = (
+ gptq[name].quantizer.cpu(),
+ scale.cpu(),
+ zero.cpu(),
+ g_idx.cpu(),
+ bits,
+ groupsize,
+ )
+
+ gptq[name].free()
+
+ for j in range(nsamples):
+ outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
+
+ layer.unload()
+ del layer
+ del gptq
+ torch.cuda.empty_cache()
+
+ inps, outs = outs, inps
+ print("+------------------+--------------+------------+-----------+-------+")
+ print("\n")
+
+ model.config.use_cache = use_cache
+
+ return quantizers
+
+
+def make_quant_linear(module, names, bits, groupsize, name=""):
+ if isinstance(module, QuantLinear):
+ return
+ for attr in dir(module):
+ tmp = getattr(module, attr)
+ name1 = name + "." + attr if name != "" else attr
+ if name1 in names:
+ delattr(module, attr)
+ setattr(
+ module,
+ attr,
+ QuantLinear.new(
+ bits,
+ groupsize,
+ tmp.in_features,
+ tmp.out_features,
+ tmp.bias is not None,
+ ),
+ )
+ for name1, child in module.named_children():
+ make_quant_linear(
+ child, names, bits, groupsize, name + "." + name1 if name != "" else name1
+ )
+
+
+# TODO: perform packing on GPU
+def pack(model, quantizers, bits, groupsize):
+ layers = find_layers(model)
+ layers = {n: layers[n] for n in quantizers}
+ make_quant_linear(model, quantizers, bits, groupsize)
+ qlayers = find_layers(model, (QuantLinear,))
+ print("Packing ...")
+ for name in qlayers:
+ print(name)
+ quantizers[name], scale, zero, g_idx, _, _ = quantizers[name]
+ qlayers[name].pack(layers[name], scale, zero, g_idx)
+ print("Done.")
+ return model
+
+
+def setdeepattr(module, full_name, tensor):
+ current = module
+ tokens = full_name.split(".")
+ for token in tokens[:-1]:
+ current = getattr(current, token)
+ setattr(current, tokens[-1], tensor)
+
+
+def getdeepattr(module, full_name):
+ current = module
+ tokens = full_name.split(".")
+ for token in tokens:
+ current = getattr(current, token)
+ return current
+
+
+def load_weights_pre_hook(module_name, weights, recursive=False):
+ def inner(module, args):
+ print(f"Pre hook {module_name}")
+ local_params = {}
+ for k, v in module.named_parameters():
+ if not recursive and k.count(".") != 1:
+ continue
+ local_params[k] = v
+ for k, v in module.named_buffers():
+ if not recursive and k.count(".") != 1:
+ continue
+ local_params[k] = v
+
+ for local_param in local_params:
+ current_tensor = getdeepattr(module, local_param)
+ if current_tensor.device == torch.device("meta"):
+ # print(f"Loading {local_param}")
+ if module_name:
+ tensor_name = f"{module_name}.{local_param}"
+ else:
+ tensor_name = local_param
+ tensor = weights.get_tensor(tensor_name)
+ setdeepattr(module, local_param, nn.Parameter(tensor))
+ else:
+ tensor = current_tensor.to(device=torch.device("cuda:0"))
+ if current_tensor.requires_grad:
+ tensor = nn.Parameter(tensor)
+ setdeepattr(module, local_param, tensor)
+
+ return inner
+
+
+def load_weights_post_hook(module_name, weights, recursive=False):
+ def inner(module, args, output):
+ print(f"Post hook {module_name}")
+ local_params = {}
+ for k, v in module.named_parameters():
+ if not recursive and k.count(".") != 1:
+ continue
+ local_params[k] = v
+ for k, v in module.named_buffers():
+ if not recursive and k.count(".") != 1:
+ continue
+ local_params[k] = v
+ for local_param in local_params:
+ # print(f"Unloading {local_param}")
+ current_tensor = getdeepattr(module, local_param)
+ setdeepattr(
+ module,
+ local_param,
+ nn.Parameter(current_tensor.to(device=torch.device("cpu"))),
+ )
+ return output
+
+ return inner
+
+
+def quantize(
+ model_id: str,
+ bits: int,
+ groupsize: int,
+ output_dir: str,
+ revision: str,
+ trust_remote_code: bool,
+ upload_to_model_id: Optional[str],
+ percdamp: float,
+ act_order: bool,
+ sym: bool,
+):
+ print("loading model")
+ config = AutoConfig.from_pretrained(
+ model_id,
+ trust_remote_code=trust_remote_code,
+ )
+
+ with init_empty_weights():
+ model = AutoModelForCausalLM.from_config(
+ config, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
+ )
+ model = model.eval()
+
+ print("LOADED model")
+ files = weight_files(model_id, revision, extension=".safetensors")
+ process_group, _, _ = initialize_torch_distributed()
+ weights = Weights(
+ files,
+ device=torch.device("cuda:0"),
+ dtype=torch.float16,
+ process_group=process_group,
+ aliases={"embed_tokens.weight": ["lm_head.weight"]},
+ weights_loader=DefaultWeightsLoader(UnquantizedWeight),
+ )
+ hooks = []
+ for name, module in model.named_modules():
+
+ def load(module, name):
+ def _load():
+ load_weights_pre_hook(name, weights, recursive=True)(module, None)
+
+ return _load
+
+ def unload(module, name):
+ def _unload():
+ load_weights_post_hook(name, weights, recursive=True)(
+ module, None, None
+ )
+
+ return _unload
+
+ module.load = load(module, name)
+ module.unload = unload(module, name)
+ hooks.append(
+ module.register_forward_pre_hook(load_weights_pre_hook(name, weights))
+ )
+ hooks.append(
+ module.register_forward_hook(load_weights_post_hook(name, weights))
+ )
+ model.seqlen = 2048
+
+ dataset = "wikitext2"
+ nsamples = 128
+ seed = None
+
+ dataloader, testloader = get_loaders(
+ dataset,
+ nsamples=nsamples,
+ seed=seed,
+ model_id=model_id,
+ seqlen=model.seqlen,
+ trust_remote_code=trust_remote_code,
+ )
+
+ tick = time.time()
+ quantizers = sequential(
+ model,
+ dataloader,
+ DEV,
+ nsamples,
+ bits,
+ groupsize,
+ percdamp=percdamp,
+ act_order=act_order,
+ hooks=hooks,
+ sym=sym,
+ )
+ print(time.time() - tick)
+
+ pack(model, quantizers, bits, groupsize)
+ from safetensors.torch import save_file
+ from huggingface_hub import split_torch_state_dict_into_shards
+
+ state_dict = model.state_dict()
+ state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
+
+ max_shard_size = "10GB"
+ state_dict_split = split_torch_state_dict_into_shards(
+ state_dict,
+ filename_pattern="model.safetensors",
+ max_shard_size=max_shard_size,
+ )
+ index = None
+ if state_dict_split.is_sharded:
+ index = {
+ "metadata": state_dict_split.metadata,
+ "weight_map": state_dict_split.tensor_to_filename,
+ }
+ shards = state_dict_split.filename_to_tensors
+ os.makedirs(output_dir, exist_ok=True)
+ for shard_file, shard in shards.items():
+ save_file(
+ shard,
+ os.path.join(output_dir, shard_file),
+ metadata={
+ "format": "pt",
+ "quantized": "gptq",
+ "origin": "text-generation-inference",
+ },
+ )
+ if index is None:
+ path_to_weights = os.path.join(output_dir, "model.safetensors")
+ logger.info(f"Model weights saved in {path_to_weights}")
+ else:
+ save_index_file = "model.safetensors.index.json"
+ save_index_file = os.path.join(output_dir, save_index_file)
+ with open(save_index_file, "w", encoding="utf-8") as f:
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
+ f.write(content)
+ logger.info(
+ f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
+ f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
+ config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
+ config.quantization_config = {
+ "bits": bits,
+ "group_size": groupsize,
+ "damp_percent": percdamp,
+ "desc_act": act_order,
+ "static_groups": False,
+ "sym": sym,
+ "quant_method": "gptq",
+ }
+ config.save_pretrained(output_dir)
+ logger.info("Saved config")
+ logger.info("Saving tokenizer")
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id, trust_remote_code=trust_remote_code
+ )
+ tokenizer.save_pretrained(output_dir)
+ logger.info("Saved tokenizer")
+
+ if upload_to_model_id:
+ api = HfApi()
+
+ api.upload_folder(
+ folder_path=output_dir, repo_id=upload_to_model_id, repo_type="model"
+ )
diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/utils.py b/backends/gaudi/server/text_generation_server/layers/gptq/utils.py
new file mode 100644
index 000000000..cbc0f391f
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/gptq/utils.py
@@ -0,0 +1,56 @@
+import torch
+
+
+# copied from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py
+def torch_snr_error(
+ y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = "mean"
+) -> torch.Tensor:
+ """
+ Compute SNR between y_pred(tensor) and y_real(tensor)
+
+ SNR can be calcualted as following equation:
+
+ SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2
+
+ if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements.
+
+ SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2)
+
+ Args:
+ y_pred (torch.Tensor): _description_
+ y_real (torch.Tensor): _description_
+ reduction (str, optional): _description_. Defaults to 'mean'.
+
+ Raises:
+ ValueError: _description_
+ ValueError: _description_
+
+ Returns:
+ torch.Tensor: _description_
+ """
+ if y_pred.shape != y_real.shape:
+ raise ValueError(
+ f"Can not compute snr loss for tensors with different shape. "
+ f"({y_pred.shape} and {y_real.shape})"
+ )
+ reduction = str(reduction).lower()
+
+ if y_pred.ndim == 1:
+ y_pred = y_pred.unsqueeze(0)
+ y_real = y_real.unsqueeze(0)
+
+ y_pred = y_pred.flatten(start_dim=1)
+ y_real = y_real.flatten(start_dim=1)
+
+ noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1)
+ signal_power = torch.pow(y_real, 2).sum(dim=-1)
+ snr = (noise_power) / (signal_power + 1e-7)
+
+ if reduction == "mean":
+ return torch.mean(snr)
+ elif reduction == "sum":
+ return torch.sum(snr)
+ elif reduction == "none":
+ return snr
+ else:
+ raise ValueError("Unsupported reduction method.")
diff --git a/backends/gaudi/server/text_generation_server/layers/layernorm.py b/backends/gaudi/server/text_generation_server/layers/layernorm.py
new file mode 100644
index 000000000..848787910
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/layernorm.py
@@ -0,0 +1,67 @@
+import torch
+from torch import nn
+from accelerate import init_empty_weights
+
+
+# Monkey patching
+@classmethod
+def load_layer_norm(cls, prefix, weights, eps):
+ weight = weights.get_tensor(f"{prefix}.weight")
+ bias = weights.get_tensor(f"{prefix}.bias")
+ with init_empty_weights():
+ ln = cls(weight.shape, eps=eps)
+
+ ln.weight = torch.nn.Parameter(weight)
+ ln.bias = torch.nn.Parameter(bias)
+ return ln
+
+
+@classmethod
+def load_layer_norm_no_bias(cls, prefix, weights, eps):
+ weight = weights.get_tensor(f"{prefix}.weight")
+ with init_empty_weights():
+ ln = cls(weight.shape, eps=eps)
+
+ ln.weight = torch.nn.Parameter(weight)
+ ln.bias = None
+ return ln
+
+
+torch.nn.LayerNorm.load = load_layer_norm
+torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
+
+
+class FastLayerNorm(nn.LayerNorm):
+ def forward(self, hidden_states, residual=None):
+ if residual is not None:
+ hidden_states += residual
+ residual = hidden_states
+
+ return super().forward(hidden_states), residual
+
+
+class FastRMSNorm(nn.Module):
+ def __init__(self, weight: torch.Tensor, eps: float):
+ super().__init__()
+
+ self.weight = nn.Parameter(weight)
+ self.variance_epsilon = eps
+
+ @classmethod
+ def load(cls, prefix, weights, eps=1e-6):
+ weight = weights.get_tensor(f"{prefix}.weight")
+ return cls(weight, eps)
+
+ def forward(self, hidden_states, residual=None):
+ from vllm_hpu_extension.kernels import rms_norm
+
+ orig_shape = hidden_states.shape
+ if residual is not None:
+ residual += hidden_states.view(residual.shape)
+ else:
+ residual = hidden_states
+ # Note: HPUFusedRMSNorm requires 3D tensors as inputs
+ if len(orig_shape) == 2:
+ residual = residual.unsqueeze(0)
+ x = rms_norm().apply(residual, self.weight, self.variance_epsilon)
+ return x.view(orig_shape), residual.view(orig_shape)
diff --git a/backends/gaudi/server/text_generation_server/layers/linear.py b/backends/gaudi/server/text_generation_server/layers/linear.py
new file mode 100644
index 000000000..cca80c44e
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/linear.py
@@ -0,0 +1,38 @@
+import torch
+from torch.nn import functional as F
+
+
+class FastLinear(torch.nn.Module):
+ def __init__(
+ self,
+ weight,
+ bias,
+ ) -> None:
+ super().__init__()
+ self.weight = torch.nn.Parameter(weight, requires_grad=False)
+ if bias is not None:
+ self.bias = torch.nn.Parameter(bias, requires_grad=False)
+ else:
+ self.bias = None
+
+ @classmethod
+ def load(cls, config, prefix: str, weights, bias: bool):
+ weight = weights.get_tensor(f"{prefix}.weight")
+ if bias:
+ bias = weights.get_tensor(f"{prefix}.bias")
+ else:
+ bias = None
+ return cls(weight, bias)
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return F.linear(input, self.weight, self.bias)
+
+
+def get_linear(weight, bias):
+ # Weights that are loaded through methods that are not
+ # quantization-aware are still bare tensors. We may want
+ # to change this in the future.
+ if isinstance(weight, torch.Tensor):
+ return FastLinear(weight, bias)
+
+ return weight.get_linear(bias)
diff --git a/backends/gaudi/server/text_generation_server/layers/lora.py b/backends/gaudi/server/text_generation_server/layers/lora.py
new file mode 100644
index 000000000..a4537b55b
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/lora.py
@@ -0,0 +1,279 @@
+from typing import TYPE_CHECKING, Optional, List
+
+import torch
+import torch.distributed
+from torch import nn
+from torch.distributed import ProcessGroup
+
+from text_generation_server.utils.sgmv import (
+ add_lora_a_bgmv,
+ add_lora_b_bgmv,
+ has_sgmv,
+ lora_a_sgmv_cutlass,
+ lora_b_sgmv_cutlass,
+ orient_for_rank,
+)
+
+if TYPE_CHECKING:
+ from text_generation_server.adapters import AdapterBatchData
+ from text_generation_server.adapters.lora import BatchLoraWeights
+
+
+class LoraLinear(nn.Module):
+ def __init__(
+ self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup
+ ):
+ super().__init__()
+ self.base_layer = base_layer
+ self.layer_id = layer_id
+ self.process_group = process_group
+
+ def forward_layer_type(
+ self,
+ result: torch.Tensor,
+ input: torch.Tensor,
+ adapter_data: "AdapterBatchData",
+ layer_type: str,
+ start_idx: int,
+ end_idx: int,
+ ) -> torch.Tensor:
+ if adapter_data is None:
+ return result
+ data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
+
+ if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
+ # In tensor-parallel configurations, each GPU processes a specific segment of the output.
+ # The 'result' tensor represents the full output, which can vary in size based on
+ # the layer type (e.g., attention vs. feed-forward layers). We define the current
+ # segment using start_idx and end_idx. If the segment size doesn't match this GPU's
+ # slice of 'result', we create a zero tensor of the correct size for LoRA computation.
+ # This approach ensures accurate LoRA application across various layer sizes and
+ # configurations, adapting to different model architectures and parallelization strategies.
+ #
+ # Example scenarios where this is necessary:
+ # 1. The adapter's size doesn't evenly divide across GPUs.
+ # 2. We're processing the last segment which might be smaller.
+ # 3. Different projection layers (q, k, v) have different sizes.
+ if end_idx - start_idx != result.shape[1]:
+ proj = torch.zeros_like(result[:, start_idx:end_idx])
+ else:
+ proj = result
+
+ for r, rank_segments in data.rank_data.items():
+ lora_a_ptr = rank_segments.lora_a_ptr
+ lora_b_ptr = rank_segments.lora_b_ptr
+
+ if lora_a_ptr is None or lora_b_ptr is None:
+ raise ValueError("LoRA data is missing")
+
+ if data.use_sgmv:
+ # Use SGMV for prefill
+ v = lora_a_sgmv_cutlass(
+ input,
+ rank_segments.tmp_shrink,
+ lora_a_ptr,
+ rank_segments.segment_starts,
+ rank_segments.segment_ends,
+ self.layer_id,
+ r,
+ )
+
+ if self.process_group.size() > 1:
+ v = self.collect_lora_a(v)
+
+ lora_b_sgmv_cutlass(
+ proj,
+ v,
+ rank_segments.tmp_expand,
+ lora_b_ptr,
+ rank_segments.segment_starts,
+ rank_segments.segment_ends,
+ self.layer_id,
+ )
+ else:
+ # Use BGMV for decode
+ v = torch.zeros(
+ (input.size(0), r), dtype=input.dtype, device=input.device
+ )
+ # TODO: error with [-1, 0], but not [0, -1]
+ add_lora_a_bgmv(
+ v,
+ input,
+ lora_a_ptr,
+ rank_segments.indices,
+ self.layer_id,
+ )
+
+ if self.process_group.size() > 1:
+ v = self.collect_lora_a(v)
+
+ add_lora_b_bgmv(
+ proj,
+ v,
+ lora_b_ptr,
+ rank_segments.indices,
+ self.layer_id,
+ )
+
+ if end_idx - start_idx != result.shape[1]:
+ result[:, start_idx:end_idx] += proj
+ else:
+ for adapter_index in adapter_data.meta.adapter_set:
+ if data is not None and data.has_adapter(adapter_index):
+ adapter_mask = (
+ (adapter_data.meta.adapter_indices == adapter_index)
+ .to(input.dtype)
+ .view(-1, 1)
+ )
+ layer_result = self.forward_lora(
+ input, data, adapter_index, adapter_mask
+ )
+ result[:, start_idx:end_idx] += layer_result
+
+ return result
+
+ def forward_lora(
+ self,
+ input: torch.Tensor,
+ data: "BatchLoraWeights",
+ adapter_index: int,
+ adapter_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
+ lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
+
+ lora_a = orient_for_rank(lora_a, lora_b.size(0))
+
+ a_out = input @ lora_a
+ if self.process_group.size() > 1:
+ a_out = self.collect_lora_a(a_out)
+
+ result = (a_out @ lora_b) * adapter_mask
+ return result
+
+ def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
+ raise NotImplementedError("Implemented in subclasses")
+
+
+class TensorParallelMultiAdapterLinear(LoraLinear):
+ def __init__(
+ self,
+ base_layer: nn.Module,
+ layer_id: int,
+ layer_names: List[str],
+ sizes: List[int],
+ process_group: ProcessGroup,
+ ):
+ super().__init__(base_layer, layer_id, process_group)
+ self.layer_names = layer_names
+ self.sizes = sizes
+
+ @classmethod
+ def load(
+ cls,
+ base_layer: nn.Module,
+ layer_id: int,
+ layer_names: List[str],
+ sizes: List[int],
+ process_group: ProcessGroup,
+ ):
+ return TensorParallelMultiAdapterLinear(
+ base_layer, layer_id, layer_names, sizes, process_group
+ )
+
+ def forward(
+ self, input: torch.Tensor, adapter_data: "AdapterBatchData"
+ ) -> torch.Tensor:
+ result = self.base_layer(input)
+
+ # noop if no layer names are provided (e.g. for models without adapters)
+ if self.layer_names is None:
+ return result
+
+ # handle models like Bloom that have inputs of shape
+ # (batch_size, sequence_length, hidden_size)
+ # we need to reshape them to (batch_size * sequence_length, hidden_size)
+ # for the LoRA computation, then reshape back
+ prev_shape = result.shape
+ is_3d = len(input.shape) >= 3
+ if is_3d:
+ input = input.reshape(-1, input.shape[-1])
+ result = result.reshape(-1, result.shape[-1])
+
+ offset = 0
+ for i, layer_name in enumerate(self.layer_names):
+ start_idx = offset // self.process_group.size()
+ # The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
+ # projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
+ # ensures correct slicing of the result tensor, accommodating variations like grouped-query
+ # attention where k_proj and v_proj differ from q_proj. This allows precise application of
+ # LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
+ # different projection sizes across layers and model architectures.
+ if self.sizes is not None:
+ offset += self.sizes[i]
+ end_idx = offset // self.process_group.size()
+ else:
+ end_idx = result.shape[1]
+
+ result = self.forward_layer_type(
+ result, input, adapter_data, layer_name, start_idx, end_idx
+ )
+
+ if is_3d:
+ result = result.reshape(prev_shape)
+
+ return result
+
+ def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
+ # Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.
+ # We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.
+ #
+ # TODO(travis): this is not very efficient as we do an all-gather for every adapter,
+ # instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
+ # rank, compute `a_out` on each, and then slice them into the buffer as shown here:
+ # https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
+ gathered_tensors = [
+ torch.empty_like(a_out) for _ in range(self.process_group.size())
+ ]
+ torch.distributed.all_gather(gathered_tensors, a_out)
+ return torch.cat(gathered_tensors, dim=1)
+
+
+class TensorParallelAdapterRowLinear(LoraLinear):
+ def __init__(self, base_layer, layer_id, layer_name, process_group):
+ super().__init__(base_layer, layer_id, process_group)
+ self.layer_name = layer_name
+
+ @classmethod
+ def load(cls, base_layer, layer_id, layer_name, process_group):
+ return cls(base_layer, layer_id, layer_name, process_group)
+
+ def forward(
+ self, input: torch.Tensor, adapter_data: "AdapterBatchData"
+ ) -> torch.Tensor:
+ result = self.base_layer(input)
+
+ if self.layer_name is None:
+ return result
+
+ # Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285
+ stride = result.shape[-1] // self.process_group.size()
+ start_idx = self.process_group.rank() * stride
+ end_idx = (self.process_group.rank() + 1) * stride
+
+ self.forward_layer_type(
+ result, input, adapter_data, self.layer_name, start_idx, end_idx
+ )
+
+ return result
+
+ def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
+ # Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise.
+ # We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks.
+ #
+ # TODO(travis): this is not very efficient as we do an all-reduce for every adapter,
+ # instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
+ # rank, compute `a_out` on each, and then slice them into the buffer as shown here:
+ # https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
+ torch.distributed.all_reduce(a_out, group=self.process_group)
+ return a_out
diff --git a/backends/gaudi/server/text_generation_server/layers/medusa.py b/backends/gaudi/server/text_generation_server/layers/medusa.py
new file mode 100644
index 000000000..139c4dc25
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/medusa.py
@@ -0,0 +1,191 @@
+import torch
+from torch import nn
+from typing import Tuple, Optional
+from text_generation_server.utils.speculate import get_speculate
+from text_generation_server.layers.linear import FastLinear
+from text_generation_server.layers.tensor_parallel import (
+ TensorParallelHead,
+ TensorParallelColumnLinear,
+)
+
+
+class ResBlock(torch.nn.Module):
+ def __init__(self, config, prefix, weights):
+ super().__init__()
+ self.linear = FastLinear.load(
+ config, prefix=f"{prefix}.linear", weights=weights, bias=True
+ )
+ self.act = torch.nn.SiLU()
+
+ def forward(self, x):
+ return x + self.act(self.linear(x))
+
+
+class MedusaModel(torch.nn.Module):
+ def __init__(self, config, medusa_config, weights):
+ super().__init__()
+ self.heads = torch.nn.ModuleList(
+ [
+ MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
+ for i in range(get_speculate())
+ ]
+ )
+
+ def forward(self, x):
+ if not self.heads:
+ return None
+ speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
+ return speculative_logits
+
+
+class MedusaHead(torch.nn.Module):
+ def __init__(self, config, medusa_config, prefix, weights):
+ super().__init__()
+ self.blocks = torch.nn.ModuleList(
+ [
+ ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
+ for i in range(medusa_config["medusa_num_layers"])
+ ]
+ )
+ n = len(self.blocks)
+ self.out = FastLinear.load(
+ config, prefix=f"{prefix}.{n}", weights=weights, bias=False
+ )
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x)
+ x = self.out(x)
+ return x
+
+
+class MedusaHeadV1(nn.Module):
+ def __init__(self, lm_head, medusa):
+ super().__init__()
+ self.lm_head = lm_head
+ self.medusa = medusa
+
+ @staticmethod
+ def load(config, prefix: str, weights):
+ from pathlib import Path
+ from safetensors import safe_open
+ import json
+
+ speculator = config.speculator
+
+ path = speculator["path"]
+ medusa_config = str(Path(path) / "config.json")
+
+ for fname in speculator["model_paths"]:
+ filename = str(Path(path) / fname)
+
+ with open(medusa_config, "r") as f:
+ medusa_config = json.load(f)
+ routing = weights.routing
+ with safe_open(filename, framework="pytorch") as f:
+ for k in f.keys():
+ if k in routing and routing[k] != filename:
+ raise RuntimeError(
+ f"Key {k} was found in multiple files: {filename} and {routing[k]}"
+ )
+ routing[k] = filename
+
+ medusa = MedusaModel(config, medusa_config, weights)
+ lm_head = TensorParallelHead.load(config, prefix, weights)
+ return MedusaHeadV1(lm_head, medusa)
+
+ def forward(
+ self, input: torch.Tensor
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ logits = self.lm_head(input)
+ # If we have too many tokens, we skip speculative logits
+ if input.shape[0] > 128:
+ return logits, None
+
+ speculative_logits = self.medusa(input)
+ return logits, speculative_logits
+
+
+class MedusaHeadV2(nn.Module):
+ def __init__(self, config, prefix, weights):
+ super().__init__()
+ from pathlib import Path
+ from safetensors import safe_open
+ import json
+
+ speculator_path = config.speculator["path"]
+
+ medusa_config = str(Path(speculator_path) / "config.json")
+ filename = str(Path(speculator_path) / "medusa_lm_head.safetensors")
+
+ with open(medusa_config, "r") as f:
+ medusa_config = json.load(f)
+ routing = weights.routing
+ with safe_open(filename, framework="pytorch") as f:
+ for k in f.keys():
+ if k in routing and routing[k] != filename:
+ raise RuntimeError(
+ f"Key {k} was found in multiple files: {filename} and {routing[k]}"
+ )
+ routing[k] = filename
+
+ self.n_medusa_heads = get_speculate()
+
+ assert medusa_config["medusa_num_layers"] == 1
+ self.linear = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{i}.0.linear" for i in range(self.n_medusa_heads)],
+ dim=0,
+ weights=weights,
+ bias=True,
+ )
+ self.process_group = weights.process_group
+ self.world_size = self.process_group.size()
+ self.rank = self.process_group.rank()
+
+ self.act = torch.nn.SiLU()
+
+ self.lm_head = TensorParallelHead.load(config, prefix, weights)
+
+ def forward(self, x):
+ # If we have too many tokens, we skip speculative logits
+ if x.shape[0] > 128:
+ logits = self.lm_head(x)
+ return logits, None
+
+ size = x.shape[-1]
+ block_size = (size + self.world_size - 1) // self.world_size
+ start = self.rank * block_size
+ stop = (self.rank + 1) * block_size
+
+ x_block = x[:, start:stop]
+
+ # Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
+ medusa_res = self.act(self.linear(x)).reshape(
+ *x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]
+ )
+
+ # Apply all residual medusa heads
+ output = x[:, start:stop].unsqueeze(-2) + medusa_res
+
+ # Gather medusa heads
+ world_output = [
+ torch.empty_like(output) for _ in range(self.process_group.size())
+ ]
+ torch.distributed.all_gather(world_output, output, group=self.process_group)
+ world_output = torch.cat(world_output, dim=-1)
+
+ # Stack x and medusa residual x
+ stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2)
+
+ # Compute lm head on x + medusa residual x
+ logits = self.lm_head(stacked_x)
+
+ # Finally, split logits from speculative logits
+ logits, speculative_logits = torch.split(
+ logits, [1, self.n_medusa_heads], dim=-2
+ )
+ # Squeeze added dimension
+ logits = logits.squeeze(-2)
+
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/layers/mlp.py b/backends/gaudi/server/text_generation_server/layers/mlp.py
new file mode 100644
index 000000000..d33b41f32
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/mlp.py
@@ -0,0 +1,282 @@
+import torch
+import math
+from torch import nn
+from torch.nn import functional as F
+from typing import Optional, Tuple
+from text_generation_server.layers import TensorParallelEmbedding, FastLinear
+from text_generation_server.layers.tensor_parallel import TensorParallelHead
+from text_generation_server.utils.speculate import get_speculate
+
+
+class MLPSpeculatorLayerNorm(nn.Module):
+ """
+ A L2 normalization implementation
+ ...
+ Args
+ ----
+ normalized_shape : int
+ Dimensionality of input data (size of final tensor axis)
+ elementwise_scale_weight : torch.Tensor
+ learned scaling term after normalization?
+ elementwise_shift_bias : torch.Tensor
+ learned bias term after normalization?
+ eps : float
+ Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8).
+ """
+
+ def __init__(
+ self,
+ prefix,
+ config,
+ weights,
+ eps=1e-06,
+ ):
+ super(MLPSpeculatorLayerNorm, self).__init__()
+ self.weight = weights.get_tensor(f"{prefix}.weight")
+ self.bias = weights.get_tensor(f"{prefix}.bias")
+ self.eps = eps
+
+ def forward(self, x):
+ xf = x
+ xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
+ x = xf.type_as(x)
+ x = self.weight * x
+ x = x + self.bias
+ return x
+
+
+INV_SQRT2 = 2**-0.5
+
+
+def simple_norm(x: torch.Tensor, eps=1e-06):
+ xf = x
+ xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + eps)
+ x = xf.type_as(x)
+ return x * INV_SQRT2
+
+
+class MLPSpeculatorModelTied(torch.nn.Module):
+ def __init__(self, config, prefix, weights):
+ super().__init__()
+ self.config = config
+ self.n_predict = get_speculate()
+ self.hidden_size = config.hidden_size
+
+ self.emb = TensorParallelEmbedding(f"{prefix}.emb.0", weights)
+ self.proj0 = FastLinear.load(
+ config,
+ prefix=f"{prefix}.proj.0",
+ weights=weights,
+ bias=False,
+ )
+ self.proj1 = FastLinear.load(
+ config,
+ prefix=f"{prefix}.proj.1",
+ weights=weights,
+ bias=False,
+ )
+ self.head = FastLinear.load(config, f"{prefix}.head.0", weights, bias=False)
+ self.ln = MLPSpeculatorLayerNorm(
+ prefix=f"{prefix}.ln.0",
+ config=config,
+ weights=weights,
+ )
+
+ # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
+ self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
+ self.activation = nn.GELU()
+ self.vsize = config.vocab_size
+ self.inner_dim = config.speculator_config["inner_dim"]
+ self.top_k_tokens_per_head = [1] * self.n_predict
+ self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(
+ self.inner_dim / 2
+ )
+ self.emb.weight *= self.emb_weight
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ input_ids: torch.Tensor,
+ ):
+ top_k_tokens_per_head = self.top_k_tokens_per_head
+
+ # k indicates # of candidates
+ # h indicates # of generated tokens
+ state = hidden_states
+ b = state.size(0)
+ ind = input_ids.unsqueeze(0)
+ all_probs = torch.empty(
+ b, self.n_predict, self.vsize, device=state.device
+ ) # b k h v
+ assert (
+ len(top_k_tokens_per_head) == self.n_predict
+ ), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
+ for i in range(self.n_predict):
+ # Project and predict
+ z = self.emb(ind)
+ # z = z.mul(self.emb_weight) # b k d
+ if i == 0:
+ state = self.proj0(state) * self.state_weight + z
+ else:
+ state = self.proj1(state) * self.state_weight + z
+ state = self.activation(self.ln(state)) # b k d
+ probs = F.log_softmax(self.head(state), dim=-1) # b k v
+ _probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
+
+ # Update candidate set with new predictions
+
+ # Update distribution set with new logits
+ all_probs[:, i] = probs.exp()
+
+ # Update state, log_probs and ind for new predictions
+ state = state.unsqueeze(2).expand(
+ -1, -1, top_k_tokens_per_head[i], -1
+ ) # b k k' d
+ state = state.reshape(-1, b, state.size(3)) # b kk' d
+ ind = preds.view(-1, b) # b kk'
+
+ speculative_logits = all_probs
+ return speculative_logits
+
+
+class MLPSpeculatorModel(torch.nn.Module):
+ def __init__(self, config, prefix, weights):
+ super().__init__()
+ self.config = config
+ self.n_predict = get_speculate()
+ self.hidden_size = config.hidden_size
+
+ self.emb = nn.ModuleList(
+ [
+ TensorParallelEmbedding(f"{prefix}.emb.{i}", weights)
+ for i in range(self.n_predict)
+ ]
+ )
+ self.proj = [
+ FastLinear.load(
+ config,
+ prefix=f"{prefix}.proj.{i}",
+ weights=weights,
+ bias=False,
+ )
+ for i in range(self.n_predict)
+ ]
+ self.head = nn.ModuleList(
+ [
+ FastLinear.load(config, f"{prefix}.head.{i}", weights, bias=False)
+ for i in range(self.n_predict)
+ ]
+ )
+ self.ln = nn.ModuleList(
+ [
+ MLPSpeculatorLayerNorm(
+ prefix=f"{prefix}.ln.{i}",
+ config=config,
+ weights=weights,
+ )
+ for i in range(self.n_predict)
+ ]
+ )
+
+ # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
+ self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
+ self.activation = nn.GELU()
+ self.vsize = config.vocab_size
+ self.inner_dim = config.speculator_config["inner_dim"]
+ self.top_k_tokens_per_head = [1] * self.n_predict
+ self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(
+ self.inner_dim / 2
+ )
+ self.emb.weight *= self.emb_weight
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ input_ids: torch.Tensor,
+ ):
+ top_k_tokens_per_head = self.top_k_tokens_per_head
+
+ # k indicates # of candidates
+ # h indicates # of generated tokens
+ state = hidden_states
+ b = state.size(0)
+ ind = input_ids.unsqueeze(0)
+ all_probs = torch.empty(
+ b, self.n_predict, self.vsize, device=state.device
+ ) # b k h v
+ assert (
+ len(top_k_tokens_per_head) == self.n_predict
+ ), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
+ for i in range(self.n_predict):
+ # Project and predict
+ z = self.emb[i](ind)
+ # z = z.mul(self.emb_weight) # b k d
+ state = self.proj[i](state) * self.state_weight + z
+ state = self.activation(self.ln[i](state)) # b k d
+ probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
+ _probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
+
+ # Update candidate set with new predictions
+
+ # Update distribution set with new logits
+ all_probs[:, i] = probs.exp()
+
+ # Update state, log_probs and ind for new predictions
+ state = state.unsqueeze(2).expand(
+ -1, -1, top_k_tokens_per_head[i], -1
+ ) # b k k' d
+ state = state.reshape(-1, b, state.size(3)) # b kk' d
+ ind = preds.view(-1, b) # b kk'
+
+ speculative_logits = all_probs
+ return speculative_logits
+
+
+class MLPSpeculatorHead(nn.Module):
+ def __init__(self, lm_head, mlp_speculator, scale_input: bool):
+ super().__init__()
+ self.lm_head = lm_head
+ self.mlp_speculator = mlp_speculator
+ self.scale_input = scale_input
+
+ def forward(
+ self, input: torch.Tensor
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ logits = self.lm_head(input)
+ # If we have too many tokens, we skip speculative logits
+ if input.shape[0] > 128:
+ return logits, None
+
+ input_ids = logits.argmax(dim=-1)
+ if self.scale_input:
+ input = simple_norm(input)
+ speculative_logits = self.mlp_speculator(input, input_ids)
+ return logits, speculative_logits
+
+ @staticmethod
+ def load(config, prefix: str, weights):
+ from pathlib import Path
+ from safetensors import safe_open
+
+ speculator_path = config.speculator["path"]
+
+ for fname in config.speculator["model_paths"]:
+ filename = str(Path(speculator_path) / fname)
+ routing = weights.routing
+ with safe_open(filename, framework="pytorch") as f:
+ for k in f.keys():
+ if k in routing and routing[k] != filename:
+ raise RuntimeError(
+ f"Key {k} was found in multiple files: {filename} and {routing[k]}"
+ )
+ routing[k] = filename
+
+ tie_weights = config.speculator_config.get("tie_weights", False)
+ if tie_weights:
+ mlp_speculator = MLPSpeculatorModelTied(config, "speculator", weights)
+ else:
+ mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
+ # This is used in https://huggingface.co/ibm-fms/llama3-70b-accelerator
+ scale_input = config.speculator_config.get("scale_input", False)
+ lm_head = TensorParallelHead.load(config, prefix, weights)
+ return MLPSpeculatorHead(lm_head, mlp_speculator, scale_input)
diff --git a/backends/gaudi/server/text_generation_server/layers/moe/__init__.py b/backends/gaudi/server/text_generation_server/layers/moe/__init__.py
new file mode 100644
index 000000000..8b9d6fcb0
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/moe/__init__.py
@@ -0,0 +1,250 @@
+from typing import Optional, Protocol, runtime_checkable
+
+import torch
+import torch.nn as nn
+from loguru import logger
+from transformers.activations import ACT2FN
+
+from text_generation_server.layers import (
+ TensorParallelColumnLinear,
+ TensorParallelRowLinear,
+)
+from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
+from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
+from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer
+from text_generation_server.utils.log import log_once
+from text_generation_server.utils.weights import (
+ DefaultWeightsLoader,
+ Weights,
+ UnquantizedWeight,
+)
+
+from .fused_moe import fused_topk, grouped_topk
+
+# NOTE: we are using a protocol here, because multiple inherance is not nice.
+# We need `Module`, and `Module` -> some abstract class -> some concrete
+# class inheritance is whacky.
+
+
+@runtime_checkable
+class MoELayer(Protocol):
+ def __init__(
+ self,
+ *,
+ n_expert_group: Optional[int],
+ n_experts: int,
+ prefix: str,
+ renormalize: bool,
+ topk: int,
+ topk_group: Optional[int],
+ weights: Weights,
+ gate_proj_name: str = "gate_proj",
+ up_proj_name: str = "up_proj",
+ down_proj_name: str = "down_proj",
+ hidden_act: str = "silu",
+ scoring_func: Optional[str] = None,
+ e_score_correction_bias: Optional[float] = None,
+ ): ...
+
+ def forward(
+ self, x: torch.Tensor, *, gating_output: torch.Tensor
+ ) -> torch.Tensor: ...
+
+
+class DenseMoELayer(nn.Module):
+ """
+ Layer for MoE that applies *all* experts to each tokens and then weights
+ their outputs based on the calculated routing. This layer is much slower
+ than `SparseMoELayer` and should only be used when no fused kernels are
+ available (e.g. for unsupported quantizers).
+ """
+
+ def __init__(
+ self,
+ *,
+ n_expert_group: Optional[int],
+ n_experts: int,
+ prefix: str,
+ renormalize: bool,
+ topk: int,
+ topk_group: Optional[int],
+ weights: Weights,
+ gate_proj_name: str = "gate_proj",
+ up_proj_name: str = "up_proj",
+ down_proj_name: str = "down_proj",
+ hidden_act: str = "silu",
+ scoring_func: Optional[str] = None,
+ e_score_correction_bias: Optional[float] = None,
+ ):
+ super().__init__()
+
+ assert scoring_func is None, "scoring func is not handled"
+ assert e_score_correction_bias is None, "scoring correction bias is not handled"
+
+ log_once(
+ logger.info,
+ "No fused layers are available for this model type, using (slower) dense MoE layer",
+ )
+
+ assert (n_expert_group is None) == (
+ topk_group is None
+ ), "n_expert_group and topk_group must both be None or have some value"
+
+ self.n_expert_group = n_expert_group
+ self.n_experts = n_experts
+ self.renormalize = renormalize
+ self.topk = topk
+ self.topk_group = topk_group
+
+ if "gelu" in hidden_act:
+ self.act = lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh"
+ if hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
+ else "none"
+ ),
+ )
+ elif "silu" in hidden_act:
+ self.act = torch.nn.functional.silu
+ else:
+ self.act = ACT2FN[hidden_act]
+
+ self.gate_proj = [
+ TensorParallelColumnLinear.load(
+ None,
+ prefix=f"{prefix}.{i}.{gate_proj_name}",
+ weights=weights,
+ bias=False,
+ )
+ for i in range(self.n_experts)
+ ]
+ self.up_proj = [
+ TensorParallelColumnLinear.load(
+ None,
+ prefix=f"{prefix}.{i}.{up_proj_name}",
+ weights=weights,
+ bias=False,
+ )
+ for i in range(self.n_experts)
+ ]
+ self.down_proj = [
+ TensorParallelRowLinear.load(
+ None,
+ prefix=f"{prefix}.{i}.{down_proj_name}",
+ weights=weights,
+ bias=False,
+ )
+ for i in range(self.n_experts)
+ ]
+
+ self.process_group = weights.process_group
+
+ def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
+ """
+ x: (sequence_length, model_dim)
+ gating_output: (sequence_length, n_experts)
+ """
+ # optional reshape
+ input_shape = x.shape
+ x = x.view(-1, input_shape[-1])
+
+ if self.n_expert_group is not None and self.topk_group is not None:
+ topk_weights, topk_ids = grouped_topk(
+ x,
+ gating_output,
+ self.topk,
+ renormalize=self.renormalize,
+ num_expert_group=self.n_expert_group,
+ topk_group=self.topk_group,
+ )
+ else:
+ topk_weights, topk_ids = fused_topk(
+ x, gating_output, self.topk, self.renormalize
+ )
+ topk_weights = topk_weights.to(x.dtype)
+
+ weights = torch.zeros(
+ topk_ids.shape[0], self.n_experts, dtype=x.dtype, device=x.device
+ )
+
+ weights.scatter_(1, topk_ids.long(), topk_weights.to(weights.dtype))
+
+ out = torch.zeros_like(x)
+ for i in range(self.n_experts):
+ h = self.act(self.gate_proj[i](x)) * self.up_proj[i](x)
+ h = self.down_proj[i](h, reduce=False)
+ out += h * weights[:, i].view(-1, 1)
+
+ return out
+
+
+class SparseMoELayer(nn.Module):
+ """
+ Layer for MoE that uses fused kernels to only apply the active experts
+ for each token (rather than applying all experts and selecting the
+ outputs of active experts).
+ """
+
+ def __init__(
+ self,
+ *,
+ n_expert_group: Optional[int],
+ n_experts: int,
+ prefix: str,
+ renormalize: bool,
+ topk: int,
+ topk_group: Optional[int],
+ weights: Weights,
+ scoring_func: Optional[str] = "softmax",
+ e_score_correction_bias: Optional[float] = None,
+ gate_proj_name: str = "gate_proj",
+ up_proj_name: str = "up_proj",
+ down_proj_name: str = "down_proj",
+ ):
+ super().__init__()
+ if (
+ isinstance(weights.loader, DefaultWeightsLoader)
+ and isinstance(weights.loader.weight_class, UnquantizedWeight)
+ ) or isinstance(weights.loader, HybridFP8UnquantLoader):
+ if (
+ isinstance(weights.loader, HybridFP8UnquantLoader)
+ and weights.loader.to_fp8
+ ):
+ cls = FP8SparseMoELayer
+ else:
+ cls = UnquantizedSparseMoELayer
+ else:
+ raise ValueError(
+ f"Unsupported weights loader: {type(weights.loader)}, sparse MoE is only supported for unquantized, AWQ, and GPTQ weights"
+ )
+
+ log_once(
+ logger.info,
+ "Using MoE layer wih fused gemm",
+ )
+
+ self.moe = cls(
+ n_expert_group=n_expert_group,
+ n_experts=n_experts,
+ prefix=prefix,
+ renormalize=renormalize,
+ topk=topk,
+ topk_group=topk_group,
+ weights=weights,
+ scoring_func=scoring_func,
+ e_score_correction_bias=e_score_correction_bias,
+ gate_proj_name=gate_proj_name,
+ up_proj_name=up_proj_name,
+ down_proj_name=down_proj_name,
+ )
+
+ def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
+ return self.moe(x, gating_output=gating_output)
+
+ @staticmethod
+ def is_supported(weights: Weights) -> bool:
+ return (
+ isinstance(weights.loader, DefaultWeightsLoader)
+ and isinstance(weights.loader.weight_class, UnquantizedWeight)
+ ) or isinstance(weights.loader, HybridFP8UnquantLoader)
diff --git a/backends/gaudi/server/text_generation_server/layers/moe/fp8.py b/backends/gaudi/server/text_generation_server/layers/moe/fp8.py
new file mode 100644
index 000000000..071b2abee
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/moe/fp8.py
@@ -0,0 +1,173 @@
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+from text_generation_server.utils.weights import Weights
+from text_generation_server.layers.fp8 import (
+ Fp8Weight,
+ fp8_quantize,
+ quant_dtype,
+ normalize_e4m3fn_to_native_float8,
+)
+
+try:
+ from .unquantized import fused_moe
+except Exception:
+ fused_moe = None
+
+
+class FP8SparseMoELayer(nn.Module):
+ def __init__(
+ self,
+ *,
+ n_expert_group: Optional[int],
+ n_experts: int,
+ prefix: str,
+ renormalize: bool,
+ topk: int,
+ topk_group: Optional[int],
+ weights: Weights,
+ scoring_func: Optional[str] = "softmax",
+ e_score_correction_bias: Optional[float] = None,
+ gate_proj_name: str = "gate_proj",
+ up_proj_name: str = "up_proj",
+ down_proj_name: str = "down_proj",
+ ):
+ super().__init__()
+
+ assert (n_expert_group is None) == (
+ topk_group is None
+ ), "n_expert_group and topk_group must both be None or have some value"
+
+ self.n_expert_group = n_expert_group
+ self.topk = topk
+ self.topk_group = topk_group
+ self.renormalize = renormalize
+ self.weight_block_size = weights.weights_loader.weight_block_size
+ self.scoring_func = scoring_func
+ self.e_score_correction_bias = e_score_correction_bias
+
+ (
+ self.gate_up_proj,
+ self.gate_up_proj_weight_scale,
+ self.gate_up_proj_input_scale,
+ ) = _load_expert_multi_weights_col(
+ prefix=prefix,
+ n_experts=n_experts,
+ gate_proj_name=gate_proj_name,
+ up_proj_name=up_proj_name,
+ weights=weights,
+ )
+
+ self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (
+ _load_expert_weights_row(
+ prefix=prefix,
+ n_experts=n_experts,
+ name=down_proj_name,
+ weights=weights,
+ )
+ )
+
+ def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
+ return fused_moe(
+ x,
+ w1=self.gate_up_proj,
+ w2=self.down_proj,
+ gating_output=gating_output,
+ topk=self.topk,
+ renormalize=self.renormalize,
+ inplace=True,
+ use_grouped_topk=self.n_expert_group is not None,
+ num_expert_group=self.n_expert_group,
+ topk_group=self.topk_group,
+ scoring_func=self.scoring_func,
+ e_score_correction_bias=self.e_score_correction_bias,
+ use_fp8_w8a8=True,
+ w1_scale=self.gate_up_proj_weight_scale,
+ w2_scale=self.down_proj_weight_scale,
+ a1_scale=self.gate_up_proj_input_scale,
+ a2_scale=self.down_proj_input_scale,
+ )
+
+
+def _load_expert_weights(
+ get_weight_fn,
+ *,
+ prefix: str,
+ n_experts: int,
+ name: str,
+ weights: Weights,
+) -> torch.Tensor:
+ all_weight = None
+ all_weight_scales = None
+ max_input_scale = None
+
+ for i in range(n_experts):
+ weight = get_weight_fn(prefix, i, name, weights)
+
+ assert isinstance(weight, Fp8Weight)
+
+ if all_weight is None:
+ all_weight = torch.empty(
+ (n_experts,) + weight.weight.shape,
+ dtype=quant_dtype,
+ device=weight.weight.device,
+ )
+ if all_weight_scales is None:
+ all_weight_scales = torch.empty(
+ (n_experts,) + weight.weight_scale.shape,
+ dtype=torch.float32,
+ device=weight.weight.device,
+ )
+
+ if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}:
+ all_weight[i], all_weight_scales[i], current_input_scale = (
+ normalize_e4m3fn_to_native_float8(
+ weight.weight, weight.weight_scale, weight.input_scale
+ )
+ )
+ if current_input_scale is not None:
+ if max_input_scale is None or current_input_scale > max_input_scale:
+ max_input_scale = current_input_scale
+ else:
+ all_weight[i], all_weight_scales[i] = fp8_quantize(
+ weight.weight, scalar=True
+ )
+
+ assert all_weight is not None
+
+ return all_weight, all_weight_scales, max_input_scale
+
+
+def _load_expert_multi_weights_col(
+ *,
+ prefix: str,
+ n_experts: int,
+ gate_proj_name: str,
+ up_proj_name: str,
+ weights: Weights,
+) -> torch.Tensor:
+ def get_weight_fn(prefix, i, name, weights):
+ return weights.get_multi_weights_col(
+ [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
+ )
+
+ return _load_expert_weights(
+ get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights
+ )
+
+
+def _load_expert_weights_row(
+ *,
+ prefix: str,
+ n_experts: int,
+ name: str,
+ weights: Weights,
+) -> torch.Tensor:
+ def get_weight_fn(prefix, i, name, weights):
+ return weights.get_weights_row(f"{prefix}.{i}.{name}")
+
+ return _load_expert_weights(
+ get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights
+ )
diff --git a/server/text_generation_server/layers/moe/fused_moe_rocm.py b/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py
similarity index 80%
rename from server/text_generation_server/layers/moe/fused_moe_rocm.py
rename to backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py
index 68accb990..e26ff8770 100644
--- a/server/text_generation_server/layers/moe/fused_moe_rocm.py
+++ b/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py
@@ -16,10 +16,8 @@
from typing import Tuple
import torch
-import torch.distributed
-# TODO: Remove the functions once moe_kernel are built for ROCM
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
@@ -50,3 +48,18 @@ def grouped_topk(
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
+
+
+def fused_topk(
+ hidden_states: torch.Tensor,
+ gating_output: torch.Tensor,
+ topk: int,
+ renormalize: bool,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ topk_weights = torch.nn.functional.softmax(
+ gating_output, dim=1, dtype=torch.float32
+ )
+ topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
+ if renormalize:
+ topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
+ return topk_weights, topk_ids
diff --git a/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py
new file mode 100644
index 000000000..ec1583989
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py
@@ -0,0 +1,121 @@
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+from text_generation_server.utils.weights import UnquantizedWeight, Weights
+from vllm_hpu_extension.ops import DynamicFusedMOE
+
+
+class UnquantizedSparseMoELayer(nn.Module):
+ def __init__(
+ self,
+ *,
+ n_expert_group: Optional[int],
+ n_experts: int,
+ prefix: str,
+ renormalize: bool,
+ topk: int,
+ topk_group: Optional[int],
+ weights: Weights,
+ scoring_func: Optional[str] = "softmax",
+ e_score_correction_bias: Optional[float] = None,
+ gate_proj_name: str = "gate_proj",
+ up_proj_name: str = "up_proj",
+ down_proj_name: str = "down_proj",
+ ):
+ super().__init__()
+
+ assert (n_expert_group is None) == (
+ topk_group is None
+ ), "n_expert_group and topk_group must both be None or have some value"
+
+ self.n_expert_group = n_expert_group
+ self.topk = topk
+ self.topk_group = topk_group
+ self.renormalize = renormalize
+ self.weight_block_size = weights.weights_loader.weight_block_size
+ self.scoring_func = scoring_func
+ self.e_score_correction_bias = e_score_correction_bias
+
+ self.gate_up_proj = _load_expert_multi_weights_col(
+ prefix=prefix,
+ n_experts=n_experts,
+ gate_proj_name=gate_proj_name,
+ up_proj_name=up_proj_name,
+ weights=weights,
+ )
+
+ self.down_proj = _load_expert_weights_row(
+ prefix=prefix,
+ n_experts=n_experts,
+ name=down_proj_name,
+ weights=weights,
+ )
+
+ self.hpu_fused_moe = DynamicFusedMOE(n_experts)
+ for i in range(n_experts):
+ self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
+ self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.down_proj[i])
+
+ def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
+ return self.hpu_fused_moe(x, gating_output, self.topk)
+
+
+def _load_expert_multi_weights_col(
+ *,
+ prefix: str,
+ n_experts: int,
+ gate_proj_name: str,
+ up_proj_name: str,
+ weights: Weights,
+) -> torch.Tensor:
+ all_weight = None
+ for i in range(n_experts):
+ weight = weights.get_multi_weights_col(
+ [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
+ )
+
+ assert isinstance(weight, UnquantizedWeight)
+
+ if all_weight is None:
+ all_weight = torch.empty(
+ (n_experts,) + weight.weight.shape,
+ dtype=weight.weight.dtype,
+ device=weight.weight.device,
+ )
+
+ all_weight[i] = weight.weight
+
+ assert all_weight is not None
+
+ return all_weight
+
+
+def _load_expert_weights_row(
+ *,
+ prefix: str,
+ n_experts: int,
+ name: str,
+ weights: Weights,
+) -> torch.Tensor:
+ all_weight = None
+ for i in range(n_experts):
+ weight = weights.get_weights_row(
+ f"{prefix}.{i}.{name}",
+ )
+
+ assert isinstance(weight, UnquantizedWeight)
+
+ if all_weight is None:
+ all_weight = torch.empty(
+ (n_experts,) + weight.weight.shape,
+ dtype=weight.weight.dtype,
+ device=weight.weight.device,
+ )
+
+ all_weight[i] = weight.weight
+
+ assert all_weight is not None
+
+ return all_weight
diff --git a/backends/gaudi/server/text_generation_server/layers/rotary.py b/backends/gaudi/server/text_generation_server/layers/rotary.py
new file mode 100644
index 000000000..6a83d6a57
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/rotary.py
@@ -0,0 +1,606 @@
+import os
+import math
+import torch
+from torch import nn
+from habana_frameworks.torch.hpex.kernels import (
+ RotaryPosEmbeddingMode,
+ apply_rotary_pos_emb,
+)
+
+
+def _create_inv_freq(dim, base, device):
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
+ )
+ return inv_freq
+
+
+def _get_rope_config(config):
+ if os.getenv("ROPE_SCALING", None) is not None:
+ rope_scaling = {
+ "type": os.environ["ROPE_SCALING"],
+ "factor": float(os.environ["ROPE_FACTOR"]),
+ }
+ return rope_scaling
+ return getattr(config, "rope_scaling", None)
+
+
+class PositionRotaryEmbedding(nn.Module):
+ def __init__(self, inv_freq, scaling_factor, max_position_embeddings):
+ super().__init__()
+ self.inv_freq = inv_freq
+ self._seq_len_cached = 0
+ self._cos_cached = None
+ self._sin_cached = None
+ self._cos_k_cached = None
+ self._sin_k_cached = None
+ self.scaling_factor = scaling_factor
+ self.dynamic_args = None
+ self._update_cos_sin_cache(
+ torch.float32, inv_freq.device, max_position_embeddings
+ )
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ ):
+ num_tokens = query.shape[0]
+ head_size = query.shape[-1]
+ # HPU RoPE kernel requires hidden dimension for cos and sin to be equal
+ # to query hidden dimension, so the original tensors need to be
+ # expanded
+ # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
+ # and expansion of cos/sin tensors via concatenation
+ rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
+ cos = torch.cat((cos, cos), dim=-1)
+ sin = torch.cat((sin, sin), dim=-1)
+ rotary_dim = cos.shape[-1]
+ query_shape = query.shape
+ query = query.view(num_tokens, -1, head_size)
+ query_rot = query[..., :rotary_dim]
+ query_pass = query[..., rotary_dim:]
+ query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
+ query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
+
+ key_shape = key.shape
+ key = key.view(num_tokens, -1, head_size)
+ key_rot = key[..., :rotary_dim]
+ key_pass = key[..., rotary_dim:]
+ key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
+ key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
+
+ @classmethod
+ def static(cls, config, dim, base, device):
+ inv_freq = _create_inv_freq(dim, base, device)
+ scaling_factor = None
+ rope_scaling = _get_rope_config(config)
+ if not hasattr(config, "max_position_embeddings") and hasattr(
+ config, "max_seq_len"
+ ):
+ # handling for dbrx
+ config.max_position_embeddings = config.max_seq_len
+ if rope_scaling is not None:
+ # `rope_type` is now standard in transformers, but some existing models
+ # have `type` instead.
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))
+
+ if rope_type == "linear":
+ pass
+ elif rope_type == "default":
+ pass
+ elif rope_type == "mrope":
+ mrope_section = rope_scaling["mrope_section"]
+ if mrope_section is not None:
+ return RotaryPositionEmbeddingMultimodalSections(
+ inv_freq,
+ scaling_factor,
+ mrope_section,
+ config.max_position_embeddings,
+ )
+ elif rope_type == "dynamic":
+ scaling_factor = rope_scaling["factor"]
+ return DynamicPositionRotaryEmbedding(
+ dim=dim,
+ max_position_embeddings=config.max_position_embeddings,
+ base=base,
+ device=inv_freq.device,
+ scaling_factor=scaling_factor,
+ )
+ elif rope_type == "llama3":
+ inv_freq = apply_llama3_scaling(
+ inv_freq,
+ scaling_factor=rope_scaling["factor"],
+ low_freq_factor=rope_scaling["low_freq_factor"],
+ high_freq_factor=rope_scaling["high_freq_factor"],
+ original_max_position_embeddings=rope_scaling[
+ "original_max_position_embeddings"
+ ],
+ )
+
+ return cls(inv_freq, scaling_factor, config.max_position_embeddings)
+
+ elif rope_type == "yarn":
+ scaling_factor = rope_scaling["factor"]
+ mscale = rope_scaling.get("mscale", 1.0)
+ mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
+ return YarnPositionRotaryEmbedding(
+ dim=2 * inv_freq.shape[0],
+ max_position_embeddings=rope_scaling[
+ "original_max_position_embeddings"
+ ],
+ base=base,
+ device=inv_freq.device,
+ scaling_factor=scaling_factor,
+ extrapolation_factor=1,
+ attn_factor=1,
+ beta_fast=32,
+ beta_slow=1,
+ mscale=mscale,
+ mscale_all_dim=mscale_all_dim,
+ )
+ elif rope_type in ["su", "longrope"]:
+ short_factor = torch.tensor(
+ rope_scaling["short_factor"], dtype=torch.float32, device=device
+ )
+ short_inv_freq = 1.0 / (
+ short_factor
+ * base
+ ** (
+ torch.arange(0, dim, 2, device=device, dtype=torch.float32)
+ / dim
+ )
+ )
+ long_factor = torch.tensor(
+ rope_scaling["long_factor"], dtype=torch.float32, device=device
+ )
+ long_inv_freq = 1.0 / (
+ long_factor
+ * base
+ ** (
+ torch.arange(0, dim, 2, device=device, dtype=torch.float32)
+ / dim
+ )
+ )
+
+ original_max_position_embeddings = (
+ config.original_max_position_embeddings
+ )
+ max_position_embeddings = config.max_position_embeddings
+ if max_position_embeddings <= original_max_position_embeddings:
+ scaling_factor = 1.0
+ else:
+ scale = max_position_embeddings / original_max_position_embeddings
+ scaling_factor = math.sqrt(
+ 1 + math.log(scale) / math.log(original_max_position_embeddings)
+ )
+
+ # if short_mscale and long_mscale are provided we need to scale the freqs
+ # using the Phi3LongRoPEScaledRotaryEmbedding
+ if ("short_mscale" in rope_scaling) and ("long_mscale" in rope_scaling):
+ short_mscale = rope_scaling["short_mscale"]
+ long_mscale = rope_scaling["long_mscale"]
+ return Phi3LongRoPEScaledRotaryEmbedding(
+ short_inv_freq=short_inv_freq,
+ long_inv_freq=long_inv_freq,
+ max_position_embeddings=config.max_position_embeddings,
+ short_mscale=short_mscale,
+ long_mscale=long_mscale,
+ original_max_position_embeddings=original_max_position_embeddings,
+ )
+
+ return SuRotaryEmbedding(
+ short_inv_freq=short_inv_freq,
+ long_inv_freq=long_inv_freq,
+ scaling_factor=scaling_factor,
+ original_max_position_embeddings=original_max_position_embeddings,
+ max_position_embeddings=config.max_position_embeddings,
+ )
+ else:
+ raise NotImplementedError(
+ f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
+ )
+ return cls(inv_freq, scaling_factor, config.max_position_embeddings)
+
+ @classmethod
+ def load(cls, config, prefix, weights):
+ # XXX: Always load this in float32 !
+ dtype = weights.dtype
+ weights.dtype = torch.float32
+ inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
+ weights.dtype = dtype
+
+ scaling_factor = None
+ rope_scaling = _get_rope_config(config)
+ if rope_scaling is not None:
+ scaling_factor = rope_scaling["factor"]
+ if rope_scaling["type"] == "linear":
+ pass
+ elif rope_scaling["type"] == "dynamic":
+ return DynamicPositionRotaryEmbedding(
+ dim=2 * inv_freq.shape[0],
+ max_position_embeddings=config.max_position_embeddings,
+ base=10000.0,
+ device=inv_freq.device,
+ scaling_factor=scaling_factor,
+ )
+ elif rope_scaling["type"] == "yarn":
+ mscale = rope_scaling.get("mscale", 1.0)
+ mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
+ return YarnPositionRotaryEmbedding(
+ dim=2 * inv_freq.shape[0],
+ max_position_embeddings=rope_scaling[
+ "original_max_position_embeddings"
+ ],
+ base=10000.0,
+ device=inv_freq.device,
+ scaling_factor=scaling_factor,
+ extrapolation_factor=1,
+ attn_factor=1,
+ beta_fast=32,
+ beta_slow=1,
+ mscale=mscale,
+ mscale_all_dim=mscale_all_dim,
+ )
+ else:
+ raise NotImplementedError(
+ f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
+ )
+ return cls(inv_freq, scaling_factor, config.max_position_embeddings)
+
+ def _update_cos_sin_cache(self, dtype, device, seqlen):
+ # Reset the tables if the sequence length has changed,
+ # or if we're on a new device (possibly due to tracing for instance)
+ if (
+ seqlen > self._seq_len_cached
+ or self._cos_cached.device != device
+ or self._cos_cached.dtype != dtype
+ ):
+ self._seq_len_cached = seqlen
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
+ if self.scaling_factor is not None:
+ t /= self.scaling_factor
+ # Don't do einsum, it converts fp32 to fp16
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+
+ freqs = torch.outer(t, self.inv_freq.to(device=t.device))
+ self._cos_cached = torch.cos(freqs).to(dtype)
+ self._sin_cached = torch.sin(freqs).to(dtype)
+
+ def get_cos_sin(self, position_ids: torch.Tensor):
+
+ cos = torch.index_select(self._cos_cached, 0, position_ids)
+ sin = torch.index_select(self._sin_cached, 0, position_ids)
+
+ # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
+ return cos.unsqueeze(1), sin.unsqueeze(1)
+
+
+class SuRotaryEmbedding(PositionRotaryEmbedding):
+ def __init__(
+ self,
+ short_inv_freq,
+ long_inv_freq,
+ scaling_factor,
+ original_max_position_embeddings,
+ max_position_embeddings,
+ ):
+ super(PositionRotaryEmbedding, self).__init__()
+ self.short_inv_freq = short_inv_freq
+ self.long_inv_freq = long_inv_freq
+ self.scaling_factor = scaling_factor
+ self.original_max_position_embeddings = original_max_position_embeddings
+ self._seq_len_cached = 0
+ self._cos_cached = None
+ self._sin_cached = None
+ self._cos_k_cached = None
+ self._sin_k_cached = None
+ self.dynamic_args = None
+ self._update_cos_sin_cache(
+ torch.float32, short_inv_freq.device, max_position_embeddings
+ )
+
+ def _update_cos_sin_cache(self, dtype, device, seqlen):
+ # Reset the tables if the sequence length has changed,
+ # or if we're on a new device (possibly due to tracing for instance)
+ if (
+ seqlen > self._seq_len_cached
+ or self._cos_cached is None
+ or self._cos_cached.device != device
+ or self._cos_cached.dtype != dtype
+ ):
+ self._seq_len_cached = seqlen
+
+ t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
+ short_freqs = torch.outer(
+ t[: self.original_max_position_embeddings],
+ self.short_inv_freq.to(device=t.device),
+ )
+ long_freqs = torch.outer(
+ t[self.original_max_position_embeddings :],
+ self.long_inv_freq.to(device=t.device),
+ )
+
+ freqs = torch.cat([short_freqs, long_freqs])
+
+ self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype)
+ self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)
+
+
+class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
+ def __init__(
+ self,
+ short_inv_freq: torch.Tensor,
+ long_inv_freq: torch.Tensor,
+ max_position_embeddings: int,
+ short_mscale: float,
+ long_mscale: float,
+ original_max_position_embeddings: int,
+ ):
+ super(PositionRotaryEmbedding, self).__init__()
+ self.short_inv_freq = short_inv_freq
+ self.long_inv_freq = long_inv_freq
+ self.max_position_embeddings = max_position_embeddings
+ self.short_mscale = short_mscale
+ self.long_mscale = long_mscale
+ self.original_max_position_embeddings = original_max_position_embeddings
+
+ # cache
+ self._seq_len_cached = 0
+ self._cos_cached = None
+ self._sin_cached = None
+ self._cos_k_cached = None
+ self._sin_k_cached = None
+ self.dynamic_args = None
+ self._update_cos_sin_cache(
+ torch.float32, short_inv_freq.device, max_position_embeddings
+ )
+
+ def _update_cos_sin_cache(self, dtype, device, seqlen):
+ if (
+ seqlen > self._seq_len_cached
+ or self._cos_cached is None
+ or self._cos_cached.device != device
+ or self._cos_cached.dtype != dtype
+ ):
+ self._seq_len_cached = seqlen
+ t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
+
+ short_freqs = torch.outer(
+ t[: self.original_max_position_embeddings],
+ self.short_inv_freq.to(device=t.device),
+ )
+
+ long_freqs = torch.outer(
+ t[self.original_max_position_embeddings :],
+ self.long_inv_freq.to(device=t.device),
+ )
+
+ short_freqs = short_freqs * self.short_mscale
+ long_freqs = long_freqs * self.long_mscale
+
+ freqs = torch.empty((seqlen, short_freqs.shape[1]), device=device)
+ freqs[: self.original_max_position_embeddings] = short_freqs
+ freqs[self.original_max_position_embeddings :] = long_freqs
+
+ self._cos_cached = torch.cos(freqs).to(dtype)
+ self._sin_cached = torch.sin(freqs).to(dtype)
+
+
+class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
+ def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
+ inv_freq = _create_inv_freq(dim, base, device)
+ super().__init__(inv_freq, scaling_factor, max_position_embeddings)
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+
+ def _update_cos_sin_cache(self, dtype, device, seqlen):
+ # Reset the tables if the sequence length has changed,
+ # or if we're on a new device (possibly due to tracing for instance)
+ if (
+ seqlen > self._seq_len_cached
+ or self._cos_cached.device != device
+ or self._cos_cached.dtype != dtype
+ ):
+ if seqlen > self.max_position_embeddings:
+ newbase = self.base * (
+ (self.scaling_factor * seqlen / self.max_position_embeddings)
+ - (self.scaling_factor - 1)
+ ) ** (self.dim / (self.dim - 2))
+ self.inv_freq = _create_inv_freq(
+ self.dim, newbase, self.inv_freq.device
+ )
+ self._seq_len_cached = seqlen
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
+ # Don't do einsum, it converts fp32 to fp16
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+
+ freqs = torch.outer(t, self.inv_freq.to(device=t.device))
+ self._cos_cached = torch.cos(freqs).to(dtype)
+ self._sin_cached = torch.sin(freqs).to(dtype)
+
+
+def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
+ 2 * math.log(base)
+ )
+
+
+# Find dim range bounds based on rotations
+def find_correction_range(
+ low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
+):
+ low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
+ high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
+ return max(low, 0), min(high, dim - 1) # Clamp values just in case
+
+
+def linear_ramp_mask(min, max, dim):
+ if min == max:
+ max += 0.001 # Prevent singularity
+
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
+ ramp_func = torch.clamp(linear_func, 0, 1)
+ return ramp_func
+
+
+def get_mscale(scale: float = 1.0, mscale: float = 1.0):
+ if scale <= 1:
+ return 1.0
+ return 0.1 * mscale * math.log(scale) + 1.0
+
+
+class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
+ def __init__(
+ self,
+ dim,
+ max_position_embeddings,
+ base,
+ device,
+ scaling_factor,
+ *,
+ extrapolation_factor,
+ attn_factor,
+ beta_fast,
+ beta_slow,
+ mscale: float,
+ mscale_all_dim: float,
+ ):
+ inv_freq = _create_inv_freq(dim, base, device)
+ super().__init__(
+ inv_freq, scaling_factor, max_position_embeddings * self.scaling_factor
+ )
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ self.extrapolation_factor = extrapolation_factor
+ self.attn_factor = attn_factor
+ self.beta_fast = beta_fast
+ self.beta_slow = beta_slow
+ self.mscale_all_dim = mscale_all_dim
+ self.scaling_factor = scaling_factor
+ self.mscale = float(
+ get_mscale(self.scaling_factor, mscale)
+ / get_mscale(self.scaling_factor, mscale_all_dim)
+ * self.attn_factor
+ ) # Get n-d magnitude scaling corrected for interpolation
+
+ def _update_cos_sin_cache(self, dtype, device, seqlen):
+ # Reset the tables if the sequence length has changed,
+ # or if we're on a new device (possibly due to tracing for instance)
+ if (
+ seqlen > self._seq_len_cached
+ or self._cos_cached.device != device
+ or self._cos_cached.dtype != dtype
+ ):
+ if seqlen > self.max_position_embeddings or True:
+ inv_freq_extrapolation = _create_inv_freq(
+ self.dim, self.base, self.inv_freq.device
+ )
+ freqs = 1.0 / inv_freq_extrapolation
+ inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
+ low, high = find_correction_range(
+ self.beta_fast,
+ self.beta_slow,
+ self.dim,
+ self.base,
+ self.max_position_embeddings,
+ )
+
+ inv_freq_mask = (
+ 1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)
+ ) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
+ inv_freq = (
+ inv_freq_interpolation * (1 - inv_freq_mask)
+ + inv_freq_extrapolation * inv_freq_mask
+ )
+
+ self.inv_freq = inv_freq
+
+ self._seq_len_cached = seqlen
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
+ # Don't do einsum, it converts fp32 to fp16
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+
+ freqs = torch.outer(t, self.inv_freq.to(device=t.device))
+ self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
+ self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
+
+
+def apply_llama3_scaling(
+ freqs: torch.Tensor,
+ *,
+ scaling_factor: int,
+ low_freq_factor: int,
+ high_freq_factor: int,
+ original_max_position_embeddings: int,
+):
+ low_freq_wavelen = original_max_position_embeddings / low_freq_factor
+ high_freq_wavelen = original_max_position_embeddings / high_freq_factor
+ new_freqs = []
+
+ for freq in freqs:
+ wavelen = 2 * math.pi / freq
+
+ if wavelen < high_freq_wavelen:
+ new_freqs.append(freq)
+ elif wavelen > low_freq_wavelen:
+ new_freqs.append(freq / scaling_factor)
+ else:
+ assert low_freq_wavelen != high_freq_wavelen
+ smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
+ high_freq_factor - low_freq_factor
+ )
+ new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)
+
+ return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
+
+
+class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
+ def __init__(
+ self,
+ inv_freq: torch.Tensor,
+ scaling_factor: float,
+ sections: list,
+ max_position_embeddings,
+ ):
+ self.sections = sections
+ self._cos_cached = None
+ self._sin_cached = None
+ self.section_indices = (
+ torch.arange(len(self.sections))
+ .repeat_interleave(torch.tensor(self.sections))
+ .view(1, 1, -1)
+ .to(inv_freq.device)
+ )
+ super().__init__(inv_freq, scaling_factor, max_position_embeddings)
+
+ def _update_cos_sin_cache(
+ self, dtype: torch.dtype, device: torch.device, seqlen: int
+ ):
+ # always cache the cos/sin for the full sequence length to avoid
+ # recomputing if the sequence length is smaller than the cached one
+ if (
+ seqlen > self._seq_len_cached
+ or self._cos_cached.device != device
+ or self._cos_cached.dtype != dtype
+ ):
+ self._seq_len_cached = seqlen
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
+ freqs = torch.outer(t, self.inv_freq.to(device=t.device))
+ self._cos_cached = torch.cos(freqs).to(dtype)
+ self._sin_cached = torch.sin(freqs).to(dtype)
+ self._sections = self.section_indices.expand(seqlen, -1, -1)
+
+ def get_cos_sin(
+ self,
+ position_ids: torch.Tensor,
+ ):
+ slen = position_ids.shape[0]
+
+ cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
+ sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])
+ return cos, sin
diff --git a/backends/gaudi/server/text_generation_server/layers/speculative.py b/backends/gaudi/server/text_generation_server/layers/speculative.py
new file mode 100644
index 000000000..cf8469b53
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/speculative.py
@@ -0,0 +1,52 @@
+import torch
+import json
+from typing import Tuple, Optional
+from text_generation_server.layers.tensor_parallel import TensorParallelHead
+from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2
+from text_generation_server.layers.mlp import MLPSpeculatorHead
+
+
+class SpeculativeHead(torch.nn.Module):
+ def __init__(self, lm_head, speculator):
+ super().__init__()
+ self.head = lm_head
+ self.speculator = speculator
+
+ @staticmethod
+ def load(config, prefix: str, weights):
+ speculator = config.speculator
+ if speculator:
+ speculator_path = config.speculator["path"]
+ speculator_config = str(speculator_path / "config.json")
+
+ with open(speculator_config, "r") as f:
+ speculator_config = json.load(f)
+
+ config.speculator_config = speculator_config
+ try:
+ architecture = speculator_config["architectures"][0]
+
+ if architecture == "MLPSpeculatorPreTrainedModel":
+ speculator = MLPSpeculatorHead.load(config, prefix, weights)
+ else:
+ speculator = None
+ except KeyError:
+ try:
+ speculator = MedusaHeadV1.load(config, prefix, weights)
+ except Exception:
+ speculator = MedusaHeadV2(config, prefix, weights)
+ lm_head = None
+ else:
+ lm_head = TensorParallelHead.load(config, prefix, weights)
+ speculator = None
+ return SpeculativeHead(lm_head, speculator)
+
+ def forward(
+ self, input: torch.Tensor
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ if self.speculator is not None:
+ return self.speculator(input)
+
+ assert self.head is not None
+ logits = self.head(input)
+ return logits, None
diff --git a/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py b/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py
new file mode 100644
index 000000000..8f19174f8
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py
@@ -0,0 +1,244 @@
+import torch
+from torch.nn import functional as F
+from typing import Iterable, List
+from text_generation_server.layers.linear import get_linear, FastLinear
+
+import habana_frameworks.torch as htorch
+
+
+class LayerConcat(torch.nn.Module):
+ """
+ Apply multiple layers to the input and concatenate their
+ outputs.
+ """
+
+ def __init__(self, layers: Iterable[torch.nn.Module], dim: int = -1):
+ """
+ `dim` is the dimension along which layer outputs are concatenated.
+ """
+ super().__init__()
+ self.layers = layers
+ self.dim = dim
+
+ def forward(self, x: torch.Tensor):
+ outputs = [layer(x) for layer in self.layers]
+ return torch.cat(outputs, self.dim)
+
+
+class SuperLayer(torch.nn.Module):
+ def __init__(self, linear):
+ super().__init__()
+ self.linear = linear
+
+ def forward(self, x):
+ return self.linear.forward(x)
+
+
+class TensorParallelHead(SuperLayer):
+ def __init__(self, linear, process_group, should_gather: bool):
+ super().__init__(linear)
+ self.process_group = process_group
+ self.should_gather = should_gather
+
+ @staticmethod
+ def load(config, prefix: str, weights):
+ if config.quantize == "exl2":
+ try:
+ # If the piece and LM head embeddings are shared, we have
+ # non-quantized weights...
+ weight = weights.get_tensor(f"{prefix}.weight")
+ except Exception:
+ # ...otherwise they are quantized.
+ weight = weights.get_weights_col(prefix)
+ should_gather = weights.process_group.size() > 1
+ elif weights.process_group.size() > 1:
+ try:
+ weight = weights.get_sharded(f"{prefix}.weight", dim=0)
+ should_gather = True
+ except AssertionError:
+ # If the vocab size is not divisible by number of shards
+ # just load the entire thing.
+ weight = weights.get_tensor(f"{prefix}.weight")
+ should_gather = False
+ else:
+ weight = weights.get_tensor(f"{prefix}.weight")
+ should_gather = False
+
+ return TensorParallelHead(
+ get_linear(weight, bias=None),
+ process_group=weights.process_group,
+ should_gather=should_gather,
+ )
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ if not self.should_gather:
+ return super().forward(input)
+
+ world_size = self.process_group.size()
+ if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
+ out_dim = self.linear.weight.shape[0]
+
+ if input.shape[0] == 1:
+ world_out = input.new_empty(1, out_dim * world_size)
+ local_out = input.new_empty(1, out_dim)
+ gather_input = local_out
+ else:
+ world_out = input.new_empty(out_dim * world_size, input.shape[0])
+ gather_input = input.new_empty(out_dim, input.shape[0])
+ local_out = gather_input.T
+
+ torch.mm(input, self.linear.weight.T, out=local_out)
+ htorch.core.mark_step()
+ torch.distributed.all_gather_into_tensor(
+ world_out, gather_input, group=self.process_group
+ )
+
+ if input.shape[0] == 1:
+ return world_out
+ return world_out.T
+
+ output = super().forward(input)
+ world_output = [
+ torch.empty_like(output) for _ in range(self.process_group.size())
+ ]
+
+ htorch.core.mark_step()
+ torch.distributed.all_gather(world_output, output, group=self.process_group)
+ world_output = torch.cat(world_output, dim=-1)
+ return world_output
+
+
+class TensorParallelColumnLinear(SuperLayer):
+ @classmethod
+ def load_gate_up(cls, config, prefix: str, weights, bias: bool):
+ """Specific method when the QKV was joined after the fact"""
+ weight = weights.get_weights_col_packed_gate_up(prefix)
+ if bias:
+ raise NotImplementedError("packed_gate_up only implemented without bias")
+ else:
+ bias = None
+ linear = get_linear(weight, bias)
+ return cls(linear)
+
+ @classmethod
+ def load_qkv(
+ cls,
+ config,
+ prefix: str,
+ weights,
+ bias: bool,
+ num_heads: int,
+ num_key_value_heads: int,
+ ):
+ """Specific method when the QKV was joined after the fact"""
+ weight = weights.get_weights_col_packed_qkv(
+ prefix,
+ num_heads=num_heads,
+ num_key_value_heads=num_key_value_heads,
+ )
+ if bias:
+ raise NotImplementedError("packed_qkv only implemented for baichuan")
+ else:
+ bias = None
+ linear = get_linear(weight, bias)
+ return cls(linear)
+
+ @classmethod
+ def load(cls, config, prefix: str, weights, bias: bool):
+ weight = weights.get_weights_col(prefix)
+ if bias:
+ bias = weights.get_sharded(f"{prefix}.bias", dim=0)
+ else:
+ bias = None
+ linear = get_linear(weight, bias)
+ return cls(linear)
+
+ @classmethod
+ def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
+ if config.quantize == "exl2":
+ linears = []
+ for prefix in prefixes:
+ weight = weights.get_weights_col(prefix)
+ b = weights.get_tensor(f"{prefix}.bias") if bias else None
+ linears.append(get_linear(weight, b))
+ linear = LayerConcat(linears)
+ else:
+ weight = weights.get_multi_weights_col(prefixes, dim=dim)
+ if bias:
+ b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
+ bias = torch.cat(b, dim=dim)
+ else:
+ bias = None
+ linear = get_linear(weight, bias)
+ return cls(linear)
+
+
+class TensorParallelRowLinear(SuperLayer):
+ def __init__(self, linear, process_group):
+ super().__init__(linear)
+ self.process_group = process_group
+
+ @classmethod
+ def load(cls, config, prefix: str, weights, bias: bool):
+ weight = weights.get_weights_row(prefix)
+
+ if bias and weights.process_group.rank() == 0:
+ # Rank is only on the first rank process
+ bias = weights.get_tensor(f"{prefix}.bias")
+ else:
+ bias = None
+ return cls(
+ get_linear(weight, bias),
+ process_group=weights.process_group,
+ )
+
+ def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
+ out = super().forward(input)
+ if self.process_group.size() > 1 and reduce:
+ # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
+ # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
+ # (which is required for tensor parallel HPUGraph inference)
+ htorch.core.mark_step()
+ torch.distributed.all_reduce(out, group=self.process_group)
+ return out
+
+
+class TensorParallelEmbedding(torch.nn.Module):
+ def __init__(self, prefix: str, weights, reduce=True):
+ super().__init__()
+ weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
+ num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
+
+ process_group = weights.process_group
+
+ world_size = process_group.size()
+ rank = process_group.rank()
+
+ block_size = (num_embeddings + world_size - 1) // world_size
+ self.min_id = rank * block_size
+ self.max_id = min(num_embeddings, (rank + 1) * block_size)
+ self.null_idx = weight.shape[
+ 0
+ ] # Usually block_size, might be less in non even vocab_size.
+ self.process_group = weights.process_group
+ self.reduce = reduce
+
+ """Additional 0 entry used for masking"""
+ self.weight = torch.nn.Parameter(F.pad(weight, (0, 0, 0, 1)))
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ # default all out of bounds values to `self.null_idx` that will then be mapped to 0
+ # translate for [0, self.max_id - self.min_id[
+ input = torch.where(
+ (self.min_id > input) | (input >= self.max_id),
+ self.null_idx,
+ input - self.min_id,
+ )
+ out = torch.nn.functional.embedding(input, self.weight)
+ if self.reduce and self.process_group.size() > 1:
+ # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
+ # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
+ # (which is required for tensor parallel HPUGraph inference)
+ htorch.core.mark_step()
+ torch.distributed.all_reduce(out, group=self.process_group)
+ return out
diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py
new file mode 100644
index 000000000..778b14a1b
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/__init__.py
@@ -0,0 +1,994 @@
+# ruff: noqa: F821
+# the above line disables the `undefined-name` rule for the model type variables
+import torch
+import os
+
+from loguru import logger
+from transformers.configuration_utils import PretrainedConfig
+from transformers.models.auto import modeling_auto
+from huggingface_hub import hf_hub_download, HfApi
+from typing import Optional
+from pathlib import Path
+from typing import List, Dict
+import enum
+
+# Needed to properly setup habana_frameworks
+
+from text_generation_server.utils.speculate import get_speculate, set_speculate
+from text_generation_server.models.model import Model
+from text_generation_server.models.causal_lm import CausalLM
+from text_generation_server.models.bloom import BLOOM
+from text_generation_server.models.starcoder import StarCoder
+from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
+ PhiMoEConfig,
+)
+
+from text_generation_server.utils.adapter import (
+ AdapterParameters,
+ build_layer_weight_lookup,
+ load_and_merge_adapters,
+ AdapterInfo,
+)
+from text_generation_server.adapters.lora import LoraWeights
+
+from text_generation_server.utils.log import log_master
+from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
+
+__all__ = [
+ "Model",
+ "CausalLM",
+ "Seq2SeqLM",
+ "get_model_with_lora_adapters",
+]
+from text_generation_server.models.globals import ATTENTION
+
+FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
+
+FLASH_ATTENTION = False
+if ATTENTION == "paged":
+ FLASH_ATTENTION = True
+
+try:
+ from text_generation_server.models.flash_causal_lm import FlashCausalLM
+ from text_generation_server.models.flash_vlm_causal_lm import FlashVlmCausalLM
+ from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLM
+ from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
+ FlashDeepseekV2ForCausalLM,
+ DeepseekV2Config,
+ )
+ from text_generation_server.models.custom_modeling.flash_deepseek_v3_modeling import (
+ FlashDeepseekV3ForCausalLM,
+ DeepseekV3Config,
+ )
+ from text_generation_server.models.custom_modeling.flash_llama_modeling import (
+ FlashLlamaForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
+ FlashCohereForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
+ FlashGemmaForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
+ FlashGemma2ForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
+ FlashDbrxForCausalLM,
+ DbrxConfig,
+ )
+ from text_generation_server.models.custom_modeling.flash_rw_modeling import (
+ RWConfig,
+ FlashRWForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_neox_modeling import (
+ FlashGPTNeoXForCausalLM,
+ )
+ from text_generation_server.models.pali_gemma import (
+ PaliGemmaBatch,
+ )
+ from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
+ PaliGemmaForConditionalGeneration,
+ )
+ from text_generation_server.models.custom_modeling.flash_phi_modeling import (
+ FlashPhiForCausalLM,
+ )
+ from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
+ from text_generation_server.models.custom_modeling.flash_mllama import (
+ FlashMllamaForConditionalGeneration,
+ )
+ from text_generation_server.models.custom_modeling.flash_llava_next import (
+ FlashLlavaNextForConditionalGeneration,
+ )
+
+ from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
+ FlashSantacoderForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
+ FlashStarcoder2ForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
+ Qwen2ForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
+ FlashMistralForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
+ FlashMixtralForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
+ FlashGPT2ForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.flash_gptj_modeling import (
+ FlashGPTJForCausalLM,
+ )
+ from text_generation_server.models.custom_modeling.idefics2 import (
+ Idefics2ForConditionalGeneration,
+ )
+ from text_generation_server.models.custom_modeling.idefics3 import (
+ Idefics3ForConditionalGeneration,
+ )
+ from text_generation_server.models.custom_modeling.qwen2_vl import (
+ Qwen2VLForConditionalGeneration,
+ )
+ from text_generation_server.models.custom_modeling.qwen2_5_vl import (
+ Qwen2_5VLForConditionalGeneration,
+ Qwen2_5_VLConfig,
+ Qwen2_5_VLProcessor,
+ )
+ from text_generation_server.layers.attention import SUPPORTS_WINDOWING
+except ImportError as e:
+ log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
+ SUPPORTS_WINDOWING = False
+ FLASH_ATTENTION = False
+
+if FLASH_ATTENTION:
+ __all__.append(FlashCausalLM)
+
+
+class ModelType(enum.Enum):
+ DEEPSEEK_V2 = {
+ "type": "deepseek_v2",
+ "name": "Deepseek V2",
+ "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
+ }
+ DEEPSEEK_V3 = {
+ "type": "deepseek_v3",
+ "name": "Deepseek V3",
+ "url": "https://huggingface.co/deepseek-ai/DeepSeek-V3",
+ }
+ IDEFICS2 = {
+ "type": "idefics2",
+ "name": "Idefics 2",
+ "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
+ "multimodal": True,
+ }
+ IDEFICS3 = {
+ "type": "idefics3",
+ "name": "Idefics 3",
+ "url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3",
+ "multimodal": True,
+ }
+ LLAVA_NEXT = {
+ "type": "llava_next",
+ "name": "Llava Next (1.6)",
+ "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
+ "multimodal": True,
+ }
+ LLAMA = {
+ "type": "llama",
+ "name": "Llama",
+ "url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
+ }
+ PHI3 = {
+ "type": "phi3",
+ "name": "Phi 3",
+ "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
+ }
+ GRANITE = {
+ "type": "granite",
+ "name": "Granite",
+ "url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct",
+ }
+ GEMMA = {
+ "type": "gemma",
+ "name": "Gemma",
+ "url": "https://huggingface.co/google/gemma-7b",
+ }
+ PALIGEMMA = {
+ "type": "paligemma",
+ "name": "PaliGemma",
+ "url": "https://huggingface.co/google/paligemma-3b-pt-224",
+ }
+ GEMMA2 = {
+ "type": "gemma2",
+ "name": "Gemma2",
+ "url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
+ }
+ COHERE = {
+ "type": "cohere",
+ "name": "Cohere",
+ "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
+ }
+ DBRX = {
+ "type": "dbrx",
+ "name": "Dbrx",
+ "url": "https://huggingface.co/databricks/dbrx-instruct",
+ }
+ MAMBA = {
+ "type": "mamba",
+ "name": "Mamba",
+ "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
+ }
+ MISTRAL = {
+ "type": "mistral",
+ "name": "Mistral",
+ "url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407",
+ }
+ MIXTRAL = {
+ "type": "mixtral",
+ "name": "Mixtral",
+ "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
+ }
+ GPT_BIGCODE = {
+ "type": "gpt_bigcode",
+ "name": "Gpt Bigcode",
+ "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
+ }
+ PHI = {
+ "type": "phi",
+ "name": "Phi",
+ "url": "https://huggingface.co/microsoft/phi-1_5",
+ }
+ PHI_MOE = {
+ "type": "phimoe",
+ "name": "PhiMoe",
+ "url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct",
+ }
+ BAICHUAN = {
+ "type": "baichuan",
+ "name": "Baichuan",
+ "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
+ }
+ FALCON = {
+ "type": "falcon",
+ "name": "Falcon",
+ "url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
+ }
+ STARCODER2 = {
+ "type": "starcoder2",
+ "name": "StarCoder 2",
+ "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
+ }
+ QWEN2 = {
+ "type": "qwen2",
+ "name": "Qwen 2",
+ "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
+ }
+ QWEN2_VL = {
+ "type": "qwen2_vl",
+ "name": "Qwen 2 VL",
+ "url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d",
+ }
+ QWEN2_5_VL = {
+ "type": "qwen2_5_vl",
+ "name": "Qwen 2.5 VL",
+ "url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e",
+ }
+ GALACTICA = {
+ "type": "galactica",
+ "name": "Galactica",
+ "url": "https://huggingface.co/facebook/galactica-120b",
+ }
+ SANTACODER = {
+ "type": "santacoder",
+ "name": "SantaCoder",
+ "url": "https://huggingface.co/bigcode/santacoder",
+ }
+ GPT2 = {
+ "type": "gpt2",
+ "name": "Gpt2",
+ "url": "https://huggingface.co/openai-community/gpt2",
+ }
+ GPT_NEOX = {
+ "type": "gpt_neox",
+ "name": "Gpt Neox",
+ "url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
+ }
+ GPTJ = {
+ "type": "gptj",
+ "name": "Gptj",
+ "url": "https://huggingface.co/EleutherAI/gpt-j-6b",
+ }
+ MLLAMA = {
+ "type": "mllama",
+ "name": "Mllama",
+ "url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct",
+ "multimodal": True,
+ }
+
+
+__GLOBALS = locals()
+for data in ModelType:
+ __GLOBALS[data.name] = data.value["type"]
+
+SDP_ON_BF16 = int(os.environ.get("SDP_ON_BF16", 0))
+# Disable gradients
+torch.set_grad_enabled(False)
+
+
+def get_model(
+ model_id: str,
+ lora_adapter_ids: Optional[List[str]],
+ revision: Optional[str],
+ sharded: bool,
+ quantize: Optional[str],
+ speculate: Optional[int],
+ dtype: Optional[torch.dtype],
+ trust_remote_code: bool,
+ max_input_tokens: int,
+) -> Model:
+ global FLASH_ATTENTION
+
+ if speculate is not None:
+ set_speculate(speculate)
+ else:
+ set_speculate(0)
+
+ config_dict, _ = PretrainedConfig.get_config_dict(
+ model_id, revision=revision, trust_remote_code=trust_remote_code
+ )
+ model_type = config_dict.get("model_type", None)
+
+ speculator = None
+ if "medusa_num_heads" in config_dict:
+ medusa_model_id = model_id
+ medusa_revision = revision
+ model_id = config_dict["base_model_name_or_path"]
+ revision = "main"
+ speculate_medusa = config_dict["medusa_num_heads"]
+ if speculate is not None:
+ if speculate > speculate_medusa:
+ raise RuntimeError(
+ f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
+ )
+ else:
+ set_speculate(speculate)
+ else:
+ set_speculate(speculate_medusa)
+
+ config_dict, _ = PretrainedConfig.get_config_dict(
+ model_id, revision=revision, trust_remote_code=trust_remote_code
+ )
+ # Reload model type from parent.
+ model_type = config_dict.get("model_type", None)
+ is_local = Path(medusa_model_id).exists()
+ if not is_local:
+ medusa_config = hf_hub_download(
+ medusa_model_id, revision=medusa_revision, filename="config.json"
+ )
+ hf_hub_download(
+ medusa_model_id,
+ revision=medusa_revision,
+ filename="medusa_lm_head.safetensors",
+ )
+ speculator = {
+ "path": Path(medusa_config).parent,
+ "model_paths": ["medusa_lm_head.safetensors"],
+ }
+ else:
+ speculator = {
+ "path": Path(medusa_model_id),
+ "model_paths": ["medusa_lm_head.safetensors"],
+ }
+
+ method = "medusa"
+ elif model_type == "mlp_speculator":
+ mlp_model_id = model_id
+ mlp_revision = revision
+ model_id = config_dict["base_model_name_or_path"]
+ revision = "main"
+ speculate_mlp = config_dict["n_predict"]
+ if speculate is not None:
+ if speculate > speculate_mlp:
+ raise RuntimeError(
+ f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
+ )
+ else:
+ set_speculate(speculate)
+ else:
+ set_speculate(speculate_mlp)
+
+ config_dict, _ = PretrainedConfig.get_config_dict(
+ model_id, revision=revision, trust_remote_code=trust_remote_code
+ )
+ # Reload model type from parent.
+ model_type = config_dict.get("model_type", None)
+ is_local = Path(mlp_model_id).exists()
+ extension = ".safetensors"
+ if not is_local:
+ mlp_speculator_config = hf_hub_download(
+ mlp_model_id, revision=mlp_revision, filename="config.json"
+ )
+ api = HfApi()
+ info = api.model_info(mlp_model_id, revision=mlp_revision)
+ filenames = [
+ s.rfilename
+ for s in info.siblings
+ if s.rfilename.endswith(extension)
+ and len(s.rfilename.split("/")) == 1
+ and "arguments" not in s.rfilename
+ and "args" not in s.rfilename
+ and "training" not in s.rfilename
+ ]
+ for filename in filenames:
+ hf_hub_download(
+ mlp_model_id,
+ revision=mlp_revision,
+ filename=filename,
+ )
+ speculator_dir_path = Path(mlp_speculator_config).parent
+ # if these are downloaded, they get converted to safetensors
+ filenames.extend(
+ [p for p in os.listdir(speculator_dir_path) if p.endswith(extension)]
+ )
+ speculator = {
+ "path": Path(mlp_speculator_config).parent,
+ "model_paths": filenames,
+ }
+ else:
+ speculator = Path(mlp_model_id)
+ filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
+ speculator = {"path": speculator, "model_paths": filenames}
+ method = "mlp_speculator"
+ else:
+ method = "n-gram"
+
+ speculate = get_speculate()
+ if speculate > 0:
+ logger.info(f"Using speculation {method} with {speculate} input ids.")
+
+ model_type = config_dict["model_type"]
+
+ kv_cache_dtype = dtype
+
+ if FLASH_ATTENTION:
+ if model_type == DEEPSEEK_V2:
+ head_size = max(
+ config_dict.get("qk_nope_dim", 128)
+ + config_dict.get("qk_rope_dim", 64),
+ config_dict.get("v_head_dim", 128),
+ )
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashDeepseekV2ForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ default_dtype=torch.bfloat16,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ config_class=DeepseekV2Config,
+ head_size=head_size,
+ )
+ elif model_type == DEEPSEEK_V3:
+ head_size = max(
+ config_dict.get("qk_nope_dim", 128)
+ + config_dict.get("qk_rope_dim", 64),
+ config_dict.get("v_head_dim", 128),
+ )
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashDeepseekV3ForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ default_dtype=torch.bfloat16,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ config_class=DeepseekV3Config,
+ head_size=head_size,
+ )
+
+ elif (
+ model_type == GPT_BIGCODE
+ or model_type == GPT2
+ and model_id.startswith("bigcode/")
+ ):
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashSantacoderForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ aliases={"transformer.wte.weight": ["lm_head.weight"]},
+ num_kv_heads=1,
+ )
+ elif model_type == GPT2:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashGPT2ForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == GPTJ:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashGPTJForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == GPT_NEOX:
+ from text_generation_server.models.custom_modeling.flash_neox_modeling import (
+ GPTNeoXConfig,
+ )
+
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashGPTNeoXForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ config_class=GPTNeoXConfig,
+ )
+ elif model_type == PHI:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashPhiForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == PHI_MOE:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashLlamaForCausalLM,
+ config_class=PhiMoEConfig,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashLlamaForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == BAICHUAN:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashLlamaForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == GEMMA:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashGemmaForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ # Works better for these models
+ default_dtype=torch.bfloat16,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == GEMMA2:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashGemma2ForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ # Works better for these models
+ default_dtype=torch.bfloat16,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == COHERE:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashCohereForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == DBRX:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashDbrxForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ # Dbrx works better in bfloat16.
+ default_dtype=torch.bfloat16,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ config_class=DbrxConfig,
+ )
+ elif (
+ model_type in ["RefinedWeb", "RefinedWebModel", FALCON]
+ and not sharded
+ and not config_dict.get("alibi", False)
+ ):
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashRWForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ aliases={
+ "lm_head.weight": ["transformer.word_embeddings.weight"],
+ "transformer.word_embeddings.weight": ["lm_head.weight"],
+ },
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ config_class=RWConfig,
+ )
+ elif model_type == MISTRAL:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashMistralForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == MIXTRAL:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashMixtralForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == STARCODER2:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=FlashStarcoder2ForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == QWEN2:
+ return FlashCausalLM(
+ model_id=model_id,
+ model_class=Qwen2ForCausalLM,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == QWEN2_VL:
+ return FlashVlmCausalLM(
+ model_id=model_id,
+ model_class=Qwen2VLForConditionalGeneration,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ default_dtype=torch.bfloat16,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == QWEN2_5_VL:
+ return FlashVlmCausalLM(
+ model_id=model_id,
+ model_class=Qwen2_5VLForConditionalGeneration,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ default_dtype=torch.bfloat16,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ config_class=Qwen2_5_VLConfig,
+ processor_class=Qwen2_5_VLProcessor,
+ )
+ elif model_type == MLLAMA:
+ return FlashMllamaCausalLM(
+ model_id=model_id,
+ model_class=FlashMllamaForConditionalGeneration,
+ batch_class=FlashMllamaCausalLMBatch,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ default_dtype=torch.bfloat16,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ )
+ elif model_type == IDEFICS2:
+ return FlashVlmCausalLM(
+ model_id=model_id,
+ model_class=Idefics2ForConditionalGeneration,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ # XXX: Extremely important to cap resolution in order to limit
+ # VRAM usage.
+ processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
+ )
+ elif model_type == IDEFICS3:
+ return FlashVlmCausalLM(
+ model_id=model_id,
+ model_class=Idefics3ForConditionalGeneration,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ default_dtype=torch.bfloat16,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ # XXX: Extremely important to cap resolution in order to limit
+ # VRAM usage.
+ processor_kwargs={"size": {"longest_edge": 1456}},
+ )
+ elif model_type == PALIGEMMA:
+ return FlashVlmCausalLM(
+ model_id=model_id,
+ model_class=PaliGemmaForConditionalGeneration,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ # Works better for these models
+ default_dtype=torch.bfloat16,
+ trust_remote_code=trust_remote_code,
+ lora_adapter_ids=lora_adapter_ids,
+ batch_class=PaliGemmaBatch,
+ )
+ elif model_type == LLAVA_NEXT:
+ return FlashVlmCausalLM(
+ model_class=FlashLlavaNextForConditionalGeneration,
+ model_id=model_id,
+ revision=revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
+ trust_remote_code=trust_remote_code,
+ )
+
+ from text_generation_server.models.vlm_causal_lm import VlmCausalLM
+ from text_generation_server.models.custom_modeling.mllama import (
+ MllamaForConditionalGeneration,
+ )
+ from text_generation_server.models.custom_modeling.llava_next import (
+ LlavaNextForConditionalGeneration,
+ )
+
+ adapt_transformers_to_gaudi()
+ if SDP_ON_BF16 == 1:
+ torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
+ if model_type == "gpt_bigcode":
+ return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
+ if model_type == "bloom":
+ return BLOOM(
+ model_id=model_id,
+ revision=revision,
+ speculator=speculator,
+ dtype=dtype,
+ trust_remote_code=trust_remote_code,
+ )
+
+ if model_type == "llava_next":
+ return VlmCausalLM(
+ model_class=LlavaNextForConditionalGeneration,
+ model_id=model_id,
+ revision=revision,
+ quantize=None,
+ speculator=speculator,
+ dtype=dtype,
+ trust_remote_code=trust_remote_code,
+ )
+
+ if model_type == "mllama":
+ return VlmCausalLM(
+ model_class=MllamaForConditionalGeneration,
+ model_id=model_id,
+ revision=revision,
+ quantize=None,
+ speculator=speculator,
+ dtype=dtype,
+ trust_remote_code=trust_remote_code,
+ )
+
+ if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
+ return CausalLM(
+ model_id,
+ revision,
+ quantize=quantize,
+ speculator=speculator,
+ dtype=dtype,
+ trust_remote_code=trust_remote_code,
+ )
+
+ raise ValueError(f"Unsupported model type {model_type}")
+
+
+# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters
+# this provides a post model loading hook to load adapters into the model after the model has been loaded
+def get_model_with_lora_adapters(
+ model_id: str,
+ lora_adapters: Optional[List[AdapterInfo]],
+ revision: Optional[str],
+ sharded: bool,
+ quantize: Optional[str],
+ speculate: Optional[int],
+ dtype: Optional[torch.dtype],
+ trust_remote_code: bool,
+ max_input_tokens: int,
+ adapter_to_index: Dict[str, int],
+):
+ lora_adapter_ids = [adapter.id for adapter in lora_adapters]
+ model = get_model(
+ model_id,
+ lora_adapter_ids,
+ revision,
+ sharded,
+ quantize,
+ speculate,
+ dtype,
+ trust_remote_code,
+ max_input_tokens,
+ )
+
+ if len(lora_adapters) > 0:
+ target_to_layer = build_layer_weight_lookup(model.model)
+
+ for index, adapter in enumerate(lora_adapters):
+ # The AdapterParameters object allows for merging multiple adapters into a single adapter.
+ # At the moment, we only support loading a single adapter into the model, but we keep the
+ # AdapterParameters object for easier extension in the future.
+ adapter_parameters = AdapterParameters(
+ adapter_info=[adapter],
+ # when merging multiple adapters we can weight them differently
+ # if this is not set, all adapters will be weighted equally
+ # see: text_generation_server.utils.merges.strategies for impl
+ weights=None,
+ merge_strategy=0,
+ density=1.0,
+ majority_sign_method=0,
+ )
+
+ adapter_index = index + 1
+ adapter_to_index[adapter.id] = adapter_index
+
+ logger.info(
+ f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
+ )
+ weight_names = tuple([v[0] for v in target_to_layer.values()])
+ (
+ module_map,
+ adapter_config,
+ adapter_weight_names,
+ adapter_tokenizer,
+ ) = load_and_merge_adapters(
+ model.model_id,
+ adapter_parameters,
+ adapter_index,
+ weight_names,
+ False,
+ )
+
+ unused_weight_names = adapter_weight_names.copy()
+
+ adapter_layers = [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ "o_proj",
+ "gate_proj",
+ "up_proj",
+ "down_proj",
+ "qkv_proj",
+ ]
+
+ for layer_name in adapter_layers:
+ nlayers = (
+ 1 if layer_name == "lm_head" else len(model.model.model.layers)
+ )
+ adapter_weights = LoraWeights.prepare_weights(
+ config=adapter_config,
+ module_map=module_map,
+ layer_type=layer_name,
+ unused_weight_names=unused_weight_names,
+ nlayers=nlayers,
+ dtype=model.dtype,
+ world_size=model.world_size,
+ process_group=model.process_group,
+ target_to_layer=target_to_layer,
+ )
+
+ if adapter_weights is None:
+ continue
+
+ model.layer_to_adapter_weights[layer_name].add_adapter(
+ adapter_index, adapter_weights
+ )
+
+ if len(unused_weight_names) > 0:
+ logger.warning(
+ f"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}"
+ )
+
+ if adapter_tokenizer is not None:
+ model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
+
+ model.loaded_adapters.add(adapter_index)
+
+ return model
diff --git a/backends/gaudi/server/text_generation_server/models/bloom.py b/backends/gaudi/server/text_generation_server/models/bloom.py
new file mode 100644
index 000000000..6fe643748
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/bloom.py
@@ -0,0 +1,52 @@
+# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
+
+import torch
+
+from typing import Optional, Type
+
+from transformers import PreTrainedTokenizerBase
+
+from text_generation_server.models import CausalLM
+from text_generation_server.models.causal_lm import CausalLMBatch
+from text_generation_server.pb import generate_pb2
+
+
+class BloomCausalLMBatch(CausalLMBatch):
+ @classmethod
+ def from_pb(
+ cls,
+ pb: generate_pb2.Batch,
+ tokenizer: PreTrainedTokenizerBase,
+ dtype: torch.dtype,
+ device: torch.device,
+ ) -> "CausalLMBatch":
+ batch = super().from_pb(
+ pb=pb,
+ tokenizer=tokenizer,
+ dtype=dtype,
+ device=device,
+ )
+ batch.keys_head_dim_last = False
+ return batch
+
+
+class BLOOM(CausalLM):
+ def __init__(
+ self,
+ model_id: str,
+ revision: Optional[str] = None,
+ speculator: Optional[str] = None,
+ dtype: Optional[torch.dtype] = None,
+ trust_remote_code: bool = False,
+ ):
+ super(BLOOM, self).__init__(
+ model_id=model_id,
+ revision=revision,
+ speculator=speculator,
+ dtype=dtype,
+ trust_remote_code=trust_remote_code,
+ )
+
+ @property
+ def batch_type(self) -> Type[CausalLMBatch]:
+ return BloomCausalLMBatch
diff --git a/backends/gaudi/server/text_generation_server/models/causal_lm.py b/backends/gaudi/server/text_generation_server/models/causal_lm.py
new file mode 100644
index 000000000..c1ce3335f
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/causal_lm.py
@@ -0,0 +1,1426 @@
+# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
+
+import bisect
+from dataclasses import dataclass
+from functools import wraps
+import itertools
+import math
+import os
+import tempfile
+import time
+import copy
+from typing import Dict, List, Optional, Tuple, Type
+
+import torch
+import torch._dynamo
+from loguru import logger
+from opentelemetry import trace
+
+import text_generation_server.habana_quantization_env as hq_env
+import habana_frameworks.torch as htorch
+from optimum.habana.utils import HabanaProfile
+from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
+from text_generation_server.utils.chunks import concat_text_chunks
+from optimum.habana.checkpoint_utils import (
+ get_repo_root,
+ model_on_meta,
+ write_checkpoints_json,
+)
+from transformers import (
+ AutoTokenizer,
+ AutoModelForCausalLM,
+ PreTrainedTokenizerBase,
+ AutoConfig,
+)
+
+from text_generation_server.utils.tokens import batch_top_tokens
+from text_generation_server.models import Model
+from text_generation_server.models.types import (
+ Batch,
+ Tokens,
+ Generation,
+ GeneratedText,
+)
+from text_generation_server.pb import generate_pb2
+from text_generation_server.utils import (
+ HeterogeneousNextTokenChooser,
+ StoppingCriteria,
+ is_tokenizer_transparent,
+ pad_next_token_chooser_parameters,
+)
+from optimum.habana.utils import get_hpu_memory_stats
+from text_generation_server.utils.debug import dbg_trace
+from text_generation_server.utils.speculate import get_speculate
+
+tracer = trace.get_tracer(__name__)
+MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 2048))
+PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 256))
+CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
+LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
+BATCH_SIZE_EXPONENT_BASE = int(os.environ.get("BATCH_SIZE_EXPONENT_BASE", 2))
+MAX_BATCH_SIZE = (
+ int(os.environ.get("MAX_BATCH_SIZE"))
+ if os.environ.get("MAX_BATCH_SIZE") is not None
+ else None
+)
+
+
+def torch_compile_for_eager(func):
+ if LAZY_MODE == 1:
+ return func
+ return torch.compile(
+ func, backend="hpu_backend", options={"keep_input_mutations": True}
+ )
+
+
+def round_up_seq(number, k):
+ return (number + k - 1) // k * k
+
+
+def round_up_batch(number):
+ return BATCH_SIZE_EXPONENT_BASE ** (
+ math.ceil(math.log(number, BATCH_SIZE_EXPONENT_BASE))
+ )
+
+
+def to_tensor_indices(indices, device):
+ return torch.tensor(indices, dtype=torch.long, device=device)
+
+
+def calculate_chunks(offset):
+ result = []
+ while offset != 0:
+ sign = 1 if offset > 0 else -1
+ best_chunk = min((abs(offset - sign * c), sign * c) for c in CHUNK_SIZES)[1]
+ result.append(best_chunk)
+ offset = offset - best_chunk
+ return result
+
+
+def biggest_single_chunk(offset):
+ if offset != 0:
+ idx = bisect.bisect(CHUNK_SIZES, abs(offset))
+ return int(math.copysign(CHUNK_SIZES[idx - 1], offset))
+ else:
+ return 0
+
+
+@torch_compile_for_eager
+def grouped_pad(tensor_groups, dims, values):
+ grouped_result = []
+ for tensors, dim, value in zip(tensor_groups, dims, values):
+ padding = MAX_TOTAL_TOKENS - tensors[0].size(dim) if dim is not None else 0
+ if padding > 0:
+ assert dim in [-1, -2], f"Only dims -1 and -2 are supported! {dim}"
+ pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding)
+ result = [
+ torch.nn.functional.pad(t, pad_shape, value=value) for t in tensors
+ ]
+ else:
+ result = [t for t in tensors]
+ grouped_result.append(result)
+ htorch.core.mark_step()
+ return grouped_result
+
+
+@torch_compile_for_eager
+def roll(tensor, chunk, dim, merge_graphs):
+ if dim is None:
+ return tensor
+ tensor = torch.roll(tensor, chunk, dim)
+ if not merge_graphs:
+ htorch.core.mark_step()
+ return tensor
+
+
+def grouped_roll(tensor_groups, chunk, dims, merge_graphs):
+ tensor_groups = [
+ [roll(t, chunk, dim, merge_graphs) for t in tensors]
+ for tensors, dim in zip(tensor_groups, dims)
+ ]
+ if merge_graphs:
+ htorch.core.mark_step()
+ return tensor_groups
+
+
+@torch_compile_for_eager
+def grouped_shift(tensor_groups, dims, offset, merge_graphs):
+ chunks = calculate_chunks(offset)
+ for c in chunks:
+ tensor_groups = grouped_roll(tensor_groups, c, dims, merge_graphs)
+ return tensor_groups
+
+
+def move(dst_tensors, dst_indices, src_tensors):
+ bs_dim = 0
+ num_indices = dst_indices.size(0)
+ for i, (dst_t, src_t) in enumerate(zip(dst_tensors, src_tensors)):
+ if src_t.size(bs_dim) != num_indices:
+ src_t = torch.narrow(src_t, bs_dim, 0, num_indices)
+ dst_t.index_copy_(bs_dim, dst_indices, src_t)
+ htorch.core.mark_step()
+
+
+def grouped_move(dst_tensor_groups, dst_indices, src_tensor_groups):
+ for dst_tensors, src_tensors in zip(dst_tensor_groups, src_tensor_groups):
+ move(dst_tensors, dst_indices, src_tensors)
+
+
+@torch_compile_for_eager
+def extend_tensor(tensor, padding, dim):
+ result = torch.cat([tensor, padding], dim=dim)
+ htorch.core.mark_step()
+ return result
+
+
+@torch_compile_for_eager
+def extend_batch(tensors, target_bs, dim):
+ diff = target_bs - tensors[0].size(dim)
+ # TODO: add support for shrinking bs
+ if diff <= 0:
+ return tensors
+ shape = list(tensors[0].shape)
+ shape[dim] = diff
+ padding = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype)
+ tensors = [extend_tensor(t, padding, dim) for t in tensors]
+ return tensors
+
+
+def grouped_extend_batch(tensor_groups, target_bs, bs_dims):
+ tensor_groups = [
+ extend_batch(tensors, target_bs, dim)
+ for tensors, dim in zip(tensor_groups, bs_dims)
+ ]
+ return tensor_groups
+
+
+@torch_compile_for_eager
+def merge(tensor_group):
+ tensor_group = [torch.stack(tensor_group)]
+ htorch.core.mark_step()
+ return tensor_group
+
+
+@torch_compile_for_eager
+def split(tensor_group, clone_data):
+ tensor_group = [t.squeeze(0) for t in torch.split(tensor_group[0], 1)]
+ if clone_data:
+ tensor_group = [t.clone() for t in tensor_group]
+ htorch.core.mark_step()
+ return tensor_group
+
+
+def remove_kv_cache_from_output(module):
+ orig_fwd = module.forward
+
+ @wraps(orig_fwd)
+ def forward(*args, **kwargs):
+ if kwargs["past_key_values"] is not None:
+ kwargs["return_dict"] = False
+ output = orig_fwd(*args, **kwargs)
+ first_value, second_value, *_ = output
+ if first_value.nelement() < 2:
+ return second_value
+ else:
+ return first_value
+ else:
+ kwargs["return_dict"] = True
+ return orig_fwd(*args, **kwargs)
+
+ module.forward = forward
+ return module
+
+
+@dataclass
+class CausalLMRequest:
+ idx: int
+ data: generate_pb2.Request
+ input_length: int
+ prefix_offset: int
+ read_offset: int
+ stopping_criteria: StoppingCriteria
+
+ all_input_ids: torch.Tensor
+
+ @classmethod
+ def from_pb(
+ cls, idx: int, data: generate_pb2.Request, tokenizer: PreTrainedTokenizerBase
+ ):
+ return cls(
+ idx=idx,
+ data=data,
+ input_length=None,
+ prefix_offset=None,
+ read_offset=None,
+ stopping_criteria=StoppingCriteria.from_pb(
+ data.stopping_parameters, tokenizer
+ ),
+ all_input_ids=None,
+ )
+
+ def update_idx(self, new_idx):
+ prev = self.idx
+ self.idx = new_idx
+ return (new_idx, prev)
+
+
+@dataclass
+class CausalLMBatch(Batch):
+ batch_id: int
+ requests: List[CausalLMRequest]
+
+ # Decoder values
+ input_ids: torch.Tensor
+ attention_mask: torch.Tensor
+ position_ids: torch.Tensor
+ past_key_values: Optional[List[Tuple]]
+ merged_kv_cache: bool
+
+ # Lengths of all generations present in the batch
+ input_length: int
+
+ # Generation helpers
+ next_token_chooser: HeterogeneousNextTokenChooser
+ top_n_tokens: List[int]
+ top_n_tokens_tensor: torch.Tensor
+
+ input_length: int
+
+ # Past metadata
+ logits = None
+ past = None
+
+ keys_head_dim_last: bool = True
+
+ def to_pb(self) -> generate_pb2.CachedBatch:
+ return generate_pb2.CachedBatch(
+ id=self.batch_id,
+ request_ids=[r.data.id for r in self.requests],
+ size=len(self),
+ max_tokens=self.max_tokens,
+ )
+
+ def detach_kv_cache(self):
+ past_keys = [past[0] for past in self.past_key_values]
+ past_values = [past[1] for past in self.past_key_values]
+ del self.past_key_values
+ return past_keys, past_values
+
+ def attach_kv_cache(self, past_keys, past_values):
+ # TODO: Add support for models that don't store kv_cache in a list
+ self.past_key_values = list(zip(past_keys, past_values))
+
+ def merge_kv_cache_if_needed(self, target_bs, offset):
+ pad_needed = self.seq_length < MAX_TOTAL_TOKENS
+ shift_needed = offset != 0
+ expand_needed = target_bs > self.batch_size
+ # Very simple heuristic to determine whether we should merge tensors
+ # this needs tuning for other models/scenarios
+ small_bs = len(self.past_key_values) > self.batch_size
+ if (
+ not self.merged_kv_cache
+ and small_bs
+ and (pad_needed or shift_needed or expand_needed)
+ ):
+ past_keys, past_values = self.detach_kv_cache()
+ past_keys = merge(past_keys)
+ past_values = merge(past_values)
+ self.attach_kv_cache(past_keys, past_values)
+ self.merged_kv_cache = True
+
+ def split_kv_cache_if_needed(self, clone_data):
+ if self.merged_kv_cache:
+ past_keys, past_values = self.detach_kv_cache()
+ past_keys = split(past_keys, clone_data)
+ past_values = split(past_values, clone_data)
+ self.attach_kv_cache(past_keys, past_values)
+ self.merged_kv_cache = False
+
+ def get_tensor_groups(self):
+ past_keys, past_values = self.detach_kv_cache()
+ seq_dim = -1
+ key_dim = -2 if self.keys_head_dim_last else -1
+ value_dim = -2
+ tensors = [
+ [self.input_ids],
+ [self.attention_mask],
+ [self.position_ids],
+ past_keys,
+ past_values,
+ ]
+ # We don't need to align position_ids
+ seq_dims = [seq_dim, seq_dim, None, key_dim, value_dim]
+ bs_dims = [0, 0, 0] + ([1, 1] if self.merged_kv_cache else [0, 0])
+ return tensors, seq_dims, bs_dims
+
+ def set_tensor_groups(self, tensors):
+ self.input_ids = tensors.pop(0)[0]
+ self.attention_mask = tensors.pop(0)[0]
+ self.position_ids = tensors.pop(0)[0]
+ past_keys = tensors.pop(0)
+ past_values = tensors.pop(0)
+ self.attach_kv_cache(past_keys, past_values)
+
+ def realign(self, target_bs, offset, pad_token_id):
+ tensors, seq_dims, _ = self.get_tensor_groups()
+ tensors = grouped_pad(tensors, seq_dims, [pad_token_id, 0, 0, 0, 0])
+ tensors = grouped_shift(tensors, seq_dims, offset, self.merged_kv_cache)
+ self.set_tensor_groups(tensors)
+
+ def expand_bs(self, target_bs):
+ tensors, _, bs_dims = self.get_tensor_groups()
+ tensors = grouped_extend_batch(tensors, target_bs, bs_dims)
+ self.set_tensor_groups(tensors)
+
+ def used_indices(self):
+ return [req.idx for req in self.requests]
+
+ def update_indices(self, new_indices):
+ for req, new_idx in zip(self.requests, new_indices):
+ req.idx = new_idx
+ return self.used_indices()
+
+ def free_indices_generator(self):
+ used = set(req.idx for req in self.requests)
+ return (i for i in range(self.batch_size) if i not in used)
+
+ def move_data(self, src_batches):
+ dst_tensors, _, dst_dims = self.get_tensor_groups()
+ free_indices_gen = self.free_indices_generator()
+ for src_b in src_batches:
+ dst_indices = to_tensor_indices(
+ src_b.update_indices(free_indices_gen), self.input_ids.device
+ )
+ src_tensors, _, src_dims = src_b.get_tensor_groups()
+ grouped_move(dst_tensors, dst_indices, src_tensors)
+ self.set_tensor_groups(dst_tensors)
+
+ @classmethod
+ def recombine(
+ cls, batches: List["CausalLMBatch"], pad_token_id: int
+ ) -> "CausalLMBatch":
+ if not all(b.past_key_values is not None for b in batches):
+ raise ValueError("KV cache not allocated! Cannot recombine before prefill!")
+
+ total_requests = sum(len(b) for b in batches)
+ new_bs = total_requests
+ new_bs = round_up_batch(total_requests)
+
+ batch_id = batches[0].batch_id
+ device = batches[0].input_ids.device
+
+ input_lengths = [b.input_length for b in batches]
+ max_input_length = max(input_lengths)
+ offsets = [max_input_length - b.input_length for b in batches]
+
+ cur_padding = [b.right_padding for b in batches]
+ # For prefill there is a space allocated only for first token
+ # Need to add padding to the max total tokens before first decode
+
+ moves_needed = [
+ total_requests - len(b) if b.batch_size == new_bs else total_requests
+ for b in batches
+ ]
+ dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0]
+ reshape = batches[dst_batch_idx].batch_size < new_bs
+
+ # TODO: Add support for changing max seq len, i.e. due to output length bucketing
+ # FIXME: max_seq_len for non optimized code
+ if len(batches) > 1:
+ scenario = "CONCAT"
+ elif reshape:
+ scenario = "RESHAPE"
+ elif cur_padding[dst_batch_idx] <= 0:
+ scenario = "SHIFT"
+ offsets = [
+ biggest_single_chunk(b.max_input_length - max_input_length)
+ for b in batches
+ ]
+ max_input_length = max_input_length + offsets[dst_batch_idx]
+ else:
+ # Nothing to do
+ return batches[0]
+
+ dbg_trace(
+ scenario,
+ f"bs:{[b.batch_size for b in batches]}->{new_bs}"
+ f" reqs:{[len(b) for b in batches]}"
+ f" offsets:{offsets}"
+ f" input_lengths:{input_lengths}"
+ f" cur_padding:{cur_padding}"
+ f" dst_batch:{dst_batch_idx}",
+ )
+
+ grouped_requests = [[req for req in batch.requests] for batch in batches]
+ flat_requests = list(itertools.chain(*grouped_requests))
+
+ for i in range(len(batches)):
+ target_bs = new_bs if i == dst_batch_idx else batches[i].batch_size
+ batches[i].merge_kv_cache_if_needed(target_bs, offsets[i])
+ batches[i].realign(target_bs, offsets[i], pad_token_id)
+ batches[i].split_kv_cache_if_needed(i == dst_batch_idx)
+ batches[dst_batch_idx].expand_bs(new_bs)
+ batches[dst_batch_idx].move_data(
+ [batches[i] for i in range(len(batches)) if i != dst_batch_idx]
+ )
+
+ top_n_tokens = [r.data.top_n_tokens for r in flat_requests]
+ top_n_tokens.extend([-1] * (new_bs - total_requests))
+ top_n_tokens_tensor = torch.tensor(
+ top_n_tokens, device=device, dtype=torch.int64
+ )
+
+ parameters = [r.data.parameters for r in flat_requests]
+ # append the dummy parameters for dummy requests
+ batch_size = batches[dst_batch_idx].batch_size
+ parameters = pad_next_token_chooser_parameters(parameters, batch_size)
+
+ # update past grammar states
+ fsm_grammar_states = [0] * batch_size
+ for batch in batches:
+ for i, req in enumerate(batch.requests):
+ fsm_grammar_states[req.idx] = (
+ batch.next_token_chooser.fsm_grammar_states[i]
+ )
+
+ next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
+ parameters,
+ batches[dst_batch_idx].next_token_chooser.dtype,
+ batches[dst_batch_idx].next_token_chooser.device,
+ batches[dst_batch_idx].next_token_chooser.tokenizer,
+ fsm_grammar_states,
+ quantization_enabled=hq_env.is_quantization_enabled,
+ )
+
+ input_ids = batches[dst_batch_idx].input_ids
+ attention_mask = batches[dst_batch_idx].attention_mask
+ position_ids = batches[dst_batch_idx].position_ids
+ past_key_values = batches[dst_batch_idx].past_key_values
+ input_length = max_input_length
+
+ htorch.core.mark_step()
+
+ return cls(
+ batch_id=batch_id,
+ requests=flat_requests,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ merged_kv_cache=False,
+ next_token_chooser=next_token_chooser,
+ top_n_tokens=top_n_tokens,
+ top_n_tokens_tensor=top_n_tokens_tensor,
+ input_length=input_length,
+ )
+
+ @classmethod
+ def from_pb(
+ cls,
+ pb: generate_pb2.Batch,
+ tokenizer: PreTrainedTokenizerBase,
+ dtype: torch.dtype,
+ device: torch.device,
+ ) -> "CausalLMBatch":
+ dbg_trace("FROM_PB", f"num_reqs:{len(pb.requests)}")
+ requests = [
+ CausalLMRequest.from_pb(idx, req, tokenizer)
+ for idx, req in enumerate(pb.requests)
+ ]
+ inputs = []
+ top_n_tokens = []
+
+ # Parse batch
+ max_truncation = 0
+ for i, r in enumerate(pb.requests):
+ inputs.append(concat_text_chunks(r.input_chunks.chunks))
+ top_n_tokens.append(r.top_n_tokens)
+ max_truncation = max(max_truncation, r.truncate)
+
+ max_input_length = max_truncation
+ if max_input_length < PAD_SEQUENCE_TO_MULTIPLE_OF:
+ max_input_length = PAD_SEQUENCE_TO_MULTIPLE_OF
+ max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests)
+
+ # TODO: by tokenizing all inputs at once we loose information on actual input lengths
+ # this means that we cannot shift inputs to the left after a long input sequence
+ # was filtered out
+ new_bs = round_up_batch(len(requests))
+ missing_inputs = new_bs - len(inputs)
+ dummy_inputs = ["?"] * missing_inputs
+ parameters = [r.parameters for r in pb.requests]
+ # append the dummy parameters for dummy request
+ parameters = pad_next_token_chooser_parameters(parameters, new_bs)
+
+ next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
+ pb=parameters,
+ dtype=dtype,
+ device=device,
+ tokenizer=tokenizer,
+ quantization_enabled=hq_env.is_quantization_enabled,
+ )
+
+ tokenized_inputs = tokenizer(
+ inputs + dummy_inputs,
+ return_tensors="pt",
+ padding="longest",
+ return_token_type_ids=False,
+ truncation=True,
+ max_length=max_truncation,
+ )
+
+ input_len = tokenized_inputs["input_ids"].shape[1]
+ # Round up sequence length
+ bucket_size = max_input_length
+ left_padding = max_input_length - input_len
+ if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0:
+ assert (
+ PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length
+ ), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
+ rounded_seq_len = round_up_seq(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
+ if rounded_seq_len <= max_input_length:
+ bucket_size = rounded_seq_len - 1
+ else:
+ bucket_size = max_input_length - 1
+ left_padding = bucket_size - input_len
+
+ input_ids = tokenized_inputs["input_ids"]
+ attention_mask = tokenized_inputs["attention_mask"]
+
+ # Allocate space for first token
+ input_ids = torch.nn.functional.pad(
+ input_ids, (left_padding, 1), value=tokenizer.pad_token_id
+ )
+ attention_mask = torch.nn.functional.pad(
+ attention_mask, (left_padding, 1), value=0
+ )
+ all_input_ids = torch.nn.functional.pad(
+ input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id
+ ).T.split(1, dim=1)
+ input_len = bucket_size
+ for r in requests:
+ r.input_length = input_len
+ r.prefix_offset = input_len - 5
+ r.read_offset = input_len
+ r.all_input_ids = all_input_ids[r.idx]
+
+ input_ids = input_ids.to(device)
+ attention_mask = attention_mask.to(device)
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+
+ old_bs = len(requests)
+ top_n_tokens.extend([-1] * (new_bs - old_bs))
+ top_n_tokens_tensor = torch.tensor(
+ top_n_tokens, device=device, dtype=torch.int64
+ )
+ htorch.core.mark_step()
+ return cls(
+ batch_id=pb.id,
+ requests=requests,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=None,
+ merged_kv_cache=False,
+ next_token_chooser=next_token_chooser,
+ top_n_tokens=top_n_tokens,
+ top_n_tokens_tensor=top_n_tokens_tensor,
+ input_length=input_len,
+ )
+
+ @tracer.start_as_current_span("filter")
+ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
+ dbg_trace("FILTER", f"num_reqs:{len(self.requests)} -> {len(request_ids)}")
+ request_ids = set(request_ids)
+ self.requests = [req for req in self.requests if req.data.id in request_ids]
+ return self
+
+ @classmethod
+ @tracer.start_as_current_span("concatenate")
+ def concatenate(
+ cls, batches: List["CausalLMBatch"], pad_token_id: int = 0
+ ) -> "CausalLMBatch":
+ return cls.recombine(batches, pad_token_id)
+
+ def __len__(self):
+ return len(self.requests)
+
+ @property
+ def max_input_length(self):
+ return max(req.input_length for req in self.requests)
+
+ @property
+ def batch_size(self):
+ return self.attention_mask.size(0)
+
+ @property
+ def seq_length(self):
+ return self.attention_mask.size(1)
+
+ @property
+ def right_padding(self):
+ return self.seq_length - self.input_length
+
+ # Maximum number of tokens this batch will grow to
+ @property
+ def max_tokens(self):
+ max_total_tokens = self.attention_mask.size(1)
+ return len(self.requests) * max_total_tokens
+
+
+class CausalLM(Model):
+ def __init__(
+ self,
+ model_id: str,
+ model_class: Optional[Type[torch.nn.Module]] = None,
+ revision: Optional[str] = None,
+ quantize: Optional[str] = None,
+ speculator: Optional[str] = None,
+ dtype: Optional[torch.dtype] = None,
+ default_dtype=torch.float16,
+ trust_remote_code: bool = False,
+ tokenizer_class=AutoTokenizer,
+ config_class=AutoConfig,
+ batch_class=CausalLMBatch,
+ ):
+ if speculator:
+ raise RuntimeError("Speculator decoding is not enabled for AutoModel")
+
+ self.prev_bs = 0
+ self.quantize = quantize
+
+ # Create tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id,
+ revision=revision,
+ padding_side="left",
+ truncation_side="left",
+ trust_remote_code=trust_remote_code,
+ )
+
+ # Create model
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
+ rank = int(os.getenv("RANK", "0"))
+ dtype = torch.bfloat16 if dtype is None else dtype
+ device = torch.device("hpu")
+
+ if hq_env.is_quantization_enabled:
+ htorch.core.hpu_set_env()
+
+ if world_size > 1:
+ os.environ.setdefault(
+ "DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1"
+ )
+ model = self.get_deepspeed_model(model_id, dtype, revision)
+ model = hq_env.prepare_model_for_quantization(model)
+ else:
+ get_repo_root(model_id)
+
+ # Check support for rope scaling
+ model_kwargs = {}
+ config = AutoConfig.from_pretrained(model_id)
+ if hasattr(config, "rope_scaling"):
+ model_kwargs["rope_scaling"] = self.get_rope_scaling()
+
+ model = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ revision=revision,
+ torch_dtype=dtype,
+ trust_remote_code=trust_remote_code,
+ **model_kwargs,
+ )
+ model = hq_env.prepare_model_for_quantization(model)
+ model = model.eval().to(device)
+
+ self.enable_hpu_graph = (
+ os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
+ )
+ self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "true").lower() == "true"
+
+ if model.config.model_type not in [
+ "gpt_bigcode"
+ ]: # gpt_bigcode/starcoderbase-3b skips remove_kv_cache_from_output()
+ model = remove_kv_cache_from_output(model)
+
+ if self.enable_hpu_graph:
+ from habana_frameworks.torch.hpu import wrap_in_hpu_graph
+
+ model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
+ else:
+ if LAZY_MODE == 0:
+ # It is said that "keep_input_mutations" is safe for inference to be done
+ dbg_trace("TORCH COMPILE", "Torch compiling of model")
+ model.model = torch.compile(
+ model.model,
+ backend="hpu_backend",
+ options={"keep_input_mutations": True},
+ )
+
+ model = hq_env.setup_quantization(model)
+
+ if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
+ raise ValueError(f"Model type {model.config.model_type} is not supported!")
+
+ if tokenizer.pad_token_id is None:
+ if model.config.pad_token_id is not None:
+ tokenizer.pad_token_id = model.config.pad_token_id
+ elif model.config.eos_token_id is not None:
+ if isinstance(model.config.eos_token_id, int):
+ tokenizer.pad_token_id = model.config.eos_token_id
+ elif isinstance(model.config.eos_token_id, list):
+ tokenizer.pad_token_id = model.config.eos_token_id[0]
+ else:
+ raise ValueError(
+ f"{type(model.config.eos_token_id)} type of eos_token_id in the model's config is not supported for tokenizer.pad_token_id"
+ )
+ elif tokenizer.eos_token_id is not None:
+ tokenizer.pad_token_id = tokenizer.eos_token_id
+ else:
+ tokenizer.add_special_tokens({"pad_token": "[PAD]"})
+
+ self.kwargs = {
+ "use_cache": True,
+ "return_dict": True,
+ }
+
+ if model.config.model_type in [
+ "llama",
+ "mistral",
+ "starcoder2",
+ "qwen2",
+ "falcon",
+ "gpt_bigcode",
+ ]:
+ if model.config.model_type not in ["falcon", "gpt_bigcode"]:
+ self.kwargs["attn_softmax_bf16"] = True
+
+ if model.config.model_type not in ["gpt_bigcode"]:
+ self.kwargs["trim_logits"] = True
+
+ if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "true":
+ self.kwargs["use_flash_attention"] = True
+ if os.getenv("FLASH_ATTENTION_RECOMPUTE", "true").lower() == "true":
+ self.kwargs["flash_attention_recompute"] = True
+
+ self.speculate = get_speculate()
+
+ super(CausalLM, self).__init__(
+ model_id=model_id,
+ model=model,
+ tokenizer=tokenizer,
+ requires_padding=True,
+ dtype=dtype,
+ device=device,
+ rank=rank,
+ )
+
+ # Create profiler
+ ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(",")]
+ record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true"
+ output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile")
+ self.profiling_warmup_steps = (
+ int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0
+ )
+ self.profiling_steps = (
+ int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0
+ )
+ self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0"))
+ if self.profiling_steps > 0:
+ self.hb_profiler = HabanaProfile(
+ wait=self.profiling_wait_steps,
+ warmup=self.profiling_warmup_steps,
+ active=self.profiling_steps,
+ output_dir=output_dir,
+ record_shapes=record_shapes,
+ )
+ self.hb_profiler.start()
+ else:
+ self.hb_profiler = None
+ self.step = 0
+
+ def get_deepspeed_model(
+ self, model_id: str, dtype: torch.dtype, revision: Optional[str] = None
+ ) -> torch.nn.Module:
+ import deepspeed
+ from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
+
+ world_size, rank, local_rank = initialize_distributed_hpu()
+ model_kwargs = {"revision": revision}
+
+ # Initialize process(es) for DeepSpeed
+ deepspeed.init_distributed(dist_backend="hccl")
+ logger.info(
+ "DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(
+ world_size, rank, local_rank
+ )
+ )
+ config = AutoConfig.from_pretrained(model_id, **model_kwargs)
+ load_to_meta = model_on_meta(config)
+
+ # Check support for rope scaling
+ if hasattr(config, "rope_scaling"):
+ config.rope_scaling = self.get_rope_scaling()
+ model_kwargs["rope_scaling"] = self.get_rope_scaling()
+
+ if load_to_meta:
+ # Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load
+ with deepspeed.OnDevice(dtype=dtype, device="meta"):
+ model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype)
+ else:
+ get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK"))
+ # TODO: revisit placement on CPU when auto-injection is possible
+ with deepspeed.OnDevice(dtype=dtype, device="cpu"):
+ model = AutoModelForCausalLM.from_pretrained(
+ model_id, torch_dtype=dtype, **model_kwargs
+ )
+ model = model.eval()
+
+ # Initialize the model
+ ds_inference_kwargs = {"dtype": dtype}
+ ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size}
+ ds_inference_kwargs["enable_cuda_graph"] = False
+
+ if load_to_meta:
+ # model loaded to meta is managed differently
+ checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
+ write_checkpoints_json(model_id, local_rank, checkpoints_json)
+ ds_inference_kwargs["checkpoint"] = checkpoints_json.name
+ model = deepspeed.init_inference(model, **ds_inference_kwargs)
+
+ return model.module
+
+ def get_rope_scaling(self) -> Optional[Dict]:
+ rope_scaling = os.getenv("ROPE_SCALING", None)
+ if rope_scaling is None:
+ return None
+
+ rope_factor = float(os.getenv("ROPE_FACTOR", 1.0))
+ return {"type": rope_scaling, "factor": float(rope_factor)}
+
+ @property
+ def batch_type(self) -> Type[CausalLMBatch]:
+ return CausalLMBatch
+
+ def decode(self, generated_ids: List[int]) -> str:
+ return self.tokenizer.decode(
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
+ )
+
+ def decode_token(
+ self,
+ all_input_ids: List[int],
+ prefix_offset: int = 0,
+ read_offset: int = 0,
+ ) -> Tuple[str, int, int]:
+ if is_tokenizer_transparent(self.tokenizer):
+ new_text = self.tokenizer.decode(
+ all_input_ids[read_offset:], skip_special_tokens=False
+ )
+ return new_text, read_offset, len(all_input_ids)
+ else:
+ return super().decode_token(all_input_ids, prefix_offset, read_offset)
+
+ def forward(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ token_idx,
+ past_key_values: Optional[List[Tuple]] = None,
+ bypass_hpu_graph: Optional[bool] = None,
+ ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
+ # Model Forward
+ kwargs = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "past_key_values": past_key_values,
+ "token_idx": token_idx,
+ }
+
+ # Optimum Habana got "lazy_mode" key-val only supported for llama type of models
+ if self.model.config.model_type == "llama":
+ kwargs["lazy_mode"] = LAZY_MODE == 1
+
+ if self.has_position_ids:
+ kwargs["position_ids"] = position_ids
+
+ if bypass_hpu_graph is not None:
+ kwargs["bypass_hpu_graphs"] = bypass_hpu_graph
+
+ kwargs.update(self.kwargs)
+
+ if past_key_values is not None and self.model.config.model_type not in [
+ "gpt_bigcode"
+ ]:
+ return self.model.forward(**kwargs)
+ else:
+ outputs = self.model.forward(**kwargs)
+ return outputs.logits, outputs.past_key_values
+
+ @tracer.start_as_current_span("generate_token")
+ def generate_token(
+ self, batches: List[CausalLMBatch]
+ ) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]:
+ start = time.time_ns()
+ # Results
+ generations: List[Generation] = []
+ prev_batches = []
+ requests_to_generate = []
+ # In order to pipeline any actions on CPU we perform the operation in 3 main stages:
+ # Stage 1. Collect next token ids of any previously started generations
+ for batch_id, batch in enumerate(batches):
+ if batch.logits is not None:
+ logits = batch.logits
+ past = batch.past
+ prefill = batch.past_key_values is None
+ if prefill:
+ # no right padding for prefill
+ token_idx_scalar = batch.attention_mask.shape[-1] - 1
+ token_idx = torch.tensor(token_idx_scalar).to(self.device)
+ else:
+ token_idx_scalar = (
+ batch.attention_mask.shape[-1] - batch.right_padding
+ )
+ token_idx = torch.tensor(token_idx_scalar).to(self.device)
+
+ # Select next token
+ input_length = batch.input_length
+ if logits.shape[-2] > 1:
+ next_token_ids, next_token_logprobs, logprobs, _, _ = (
+ batch.next_token_chooser(
+ batch.input_ids,
+ logits[:, input_length - 1 : input_length, :].squeeze(-2),
+ self.speculate,
+ )
+ )
+ else:
+ next_token_ids, next_token_logprobs, logprobs, _, _ = (
+ batch.next_token_chooser(
+ batch.input_ids, logits.squeeze(-2), self.speculate
+ )
+ )
+ # Speculation is not active for causal
+ accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
+ batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
+ batch.top_n_tokens,
+ batch.top_n_tokens_tensor,
+ logprobs,
+ accepted_ids,
+ )
+
+ prev_batches.append(
+ {
+ "next_token_ids": next_token_ids,
+ "next_token_logprobs": next_token_logprobs,
+ }
+ )
+
+ for req_idx, req in enumerate(batch.requests):
+ requests_to_generate.append(
+ {
+ "req": req,
+ "prev_req_idx": req.idx,
+ "batch_id": batch_id,
+ "seed": batch.next_token_chooser.seeds[req_idx],
+ "do_sample": batch.next_token_chooser.do_sample[req_idx],
+ "top_n_tokens": batch.top_n_tokens[req_idx],
+ "top_token_ids": batch_top_token_ids[req_idx],
+ "top_token_logprobs": batch_top_token_logprobs[req_idx],
+ "grammar_state": batch.next_token_chooser.fsm_grammar_states[
+ req.idx
+ ],
+ }
+ )
+
+ htorch.core.mark_step()
+
+ # Add new token into input_ids
+ batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1))
+
+ # Update attention_mask as we added a new token to input_ids
+ batch.attention_mask.index_fill_(1, token_idx, 1)
+
+ # Adjust lengths
+ batch.input_length += 1
+
+ # Update position_ids
+ if prefill:
+ batch.position_ids = (
+ torch.index_select(batch.position_ids, 1, token_idx - 1) + 1
+ )
+ else:
+ batch.position_ids += 1
+ # Update past key values
+ if prefill or self.model.config.model_type in ["gpt_bigcode"]:
+ batch.past_key_values = past
+
+ htorch.core.mark_step()
+
+ # Stage 2. Prepare new batch for speculative scheduling
+ if len(batches) > 1:
+ batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id)
+ else:
+ batch = batches[0]
+
+ prefill = batch.past_key_values is None
+
+ # Check if we need to do any bookkeeping first
+ if not prefill:
+ batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id)
+
+ scenario = "PREFILL" if prefill else "GENERATE"
+ if (
+ self.enable_hpu_graph
+ and self.limit_hpu_graph
+ and round_up_batch(batch.batch_size) != self.prev_bs
+ ):
+ self.model.clear_cache()
+ self.prev_bs = round_up_batch(batch.batch_size)
+ dbg_trace(
+ scenario,
+ f"bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}",
+ )
+ assert batch.right_padding > 0, "No more room for next token!"
+
+ # Execute batch
+ if prefill:
+ # no right padding for prefill
+ token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device)
+ batch.logits, batch.past = self.forward(
+ batch.input_ids,
+ batch.attention_mask,
+ batch.position_ids,
+ token_idx,
+ batch.past_key_values,
+ bypass_hpu_graph=(
+ prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
+ ),
+ )
+ elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]):
+ # Don't schedule next forward if max_new_tokens for all requests equals 1
+ # - we've already generated the first and only needed token in the prefill phase
+ pass
+ else:
+ token_idx = torch.tensor(
+ batch.attention_mask.shape[-1] - batch.right_padding
+ ).to(self.device)
+ input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
+ logits = self.forward(
+ input_ids,
+ batch.attention_mask,
+ batch.position_ids,
+ token_idx,
+ batch.past_key_values,
+ bypass_hpu_graph=(
+ prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
+ ),
+ )
+ if self.model.config.model_type in ["gpt_bigcode"]:
+ batch.logits, batch.past = logits
+ else:
+ batch.logits = logits
+
+ htorch.core.mark_step()
+
+ start_decode = time.time_ns()
+
+ # Stage 3. Finish and return previous generations
+ stopped = len(requests_to_generate) > 0
+ for prev_batch in prev_batches:
+ prev_batch["next_token_logprobs"] = prev_batch[
+ "next_token_logprobs"
+ ].tolist()
+ prev_batch["next_token_ids_cpu"] = prev_batch["next_token_ids"].cpu()
+ htorch.core.mark_step()
+
+ for req_data in requests_to_generate:
+ req = req_data["req"]
+ i = req_data["prev_req_idx"]
+ prev_batch_id = req_data["batch_id"]
+ assert len(prev_batches) > prev_batch_id
+ next_token_ids_cpu = prev_batches[prev_batch_id]["next_token_ids_cpu"]
+ next_token_logprobs = prev_batches[prev_batch_id]["next_token_logprobs"]
+
+ request = req.data
+ input_length = req.input_length
+ prefix_offset = req.prefix_offset
+ read_offset = req.read_offset
+ do_sample = req_data["do_sample"]
+ seed = req_data["seed"]
+ stopping_criteria = req.stopping_criteria
+ all_input_ids = req.all_input_ids
+ next_token_id = next_token_ids_cpu[i]
+ next_token_logprob = next_token_logprobs[i]
+ top_n_tokens = req_data["top_n_tokens"]
+ top_token_ids = req_data["top_token_ids"]
+ top_token_logprobs = req_data["top_token_logprobs"]
+ grammar_state = req_data["grammar_state"]
+
+ # Append next token to all tokens
+ all_input_ids[input_length] = next_token_id
+ new_input_length = input_length + 1
+
+ # Generated token
+ if (
+ is_tokenizer_transparent(self.tokenizer)
+ and len(stopping_criteria.stop_sequence_criterias) == 0
+ ):
+ next_token_text = ""
+ else:
+ next_token_text, prefix_offset, read_offset = self.decode_token(
+ all_input_ids[0:new_input_length, 0], prefix_offset, read_offset
+ )
+
+ # Evaluate stopping criteria
+ stop, reason = stopping_criteria(
+ next_token_id,
+ next_token_text,
+ )
+
+ if not stop:
+ stopped = False
+
+ # Shard generations
+ # All generations will be appended in the rust sharded client
+ if i % self.world_size == self.rank:
+ if stop:
+ # Decode generated tokens
+ if is_tokenizer_transparent(self.tokenizer):
+ output_text = None
+ else:
+ output_text = self.decode(
+ all_input_ids[
+ new_input_length
+ - stopping_criteria.current_tokens : new_input_length,
+ 0,
+ ]
+ )
+ generated_text = GeneratedText(
+ output_text,
+ stopping_criteria.current_tokens,
+ reason,
+ seed if do_sample else None,
+ )
+ else:
+ generated_text = None
+
+ # Prefill
+ if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
+ # Remove generated token to only have prefill and add nan for first prompt token
+ prefill_logprobs = [float("nan")] + next_token_logprobs
+ prefill_token_ids = all_input_ids[0 : new_input_length - 1]
+ prefill_texts = self.tokenizer.batch_decode(
+ prefill_token_ids,
+ clean_up_tokenization_spaces=False,
+ skip_special_tokens=False,
+ )
+ prefill_tokens = Tokens(
+ prefill_token_ids,
+ prefill_logprobs,
+ prefill_texts,
+ is_special=[],
+ )
+ else:
+ prefill_tokens = None
+
+ if top_n_tokens > 0:
+ all_top_tokens = []
+ for top_token_ids, top_token_logprobs in zip(
+ top_token_ids, top_token_logprobs
+ ):
+ toptoken_texts = self.tokenizer.batch_decode(
+ top_token_ids,
+ clean_up_tokenization_spaces=False,
+ skip_special_tokens=False,
+ )
+ special_toptokens = [
+ token_id in self.all_special_ids
+ for token_id in top_token_ids
+ ]
+ top_tokens = Tokens(
+ top_token_ids,
+ top_token_logprobs,
+ toptoken_texts,
+ special_toptokens,
+ )
+ all_top_tokens.append(top_tokens)
+ top_tokens = all_top_tokens
+ else:
+ top_tokens = None
+
+ generation = Generation(
+ request.id,
+ prefill_tokens,
+ Tokens(
+ [next_token_id],
+ [next_token_logprob],
+ [next_token_text],
+ [next_token_id in self.all_special_ids],
+ ),
+ generated_text,
+ top_tokens,
+ )
+
+ generations.append(generation)
+
+ batch.next_token_chooser = (
+ batch.next_token_chooser.advance_grammar_single_with_past_state(
+ req.idx, next_token_id, grammar_state
+ )
+ )
+
+ req.all_input_ids = all_input_ids
+ req.input_length = new_input_length
+ req.prefix_offset = prefix_offset
+ req.read_offset = read_offset
+
+ htorch.core.mark_step()
+ self.step = self.step + 1
+ if self.hb_profiler is not None:
+ if (
+ self.step
+ > self.profiling_wait_steps
+ + self.profiling_warmup_steps
+ + self.profiling_steps
+ ):
+ self.hb_profiler.stop()
+ else:
+ self.hb_profiler.step()
+
+ forward_ns = start_decode - start
+ decode_ns = time.time_ns() - start_decode
+ return generations, batch if not stopped else None, (forward_ns, decode_ns)
+
+ def generate_warmup_batch(self, request, seq_len, batch_size):
+ batch = copy.deepcopy(request.batch)
+ for req in batch.requests:
+ req.truncate = seq_len
+
+ for i in range(len(batch.requests) - batch_size):
+ batch.requests.pop()
+
+ return self.batch_type.from_pb(batch, self.tokenizer, self.dtype, self.device)
+
+ def warmup(
+ self, request: generate_pb2.WarmupRequest
+ ) -> Tuple[Optional[int], Optional[int], Optional[int]]:
+ assert (
+ MAX_BATCH_SIZE is not None
+ ), "MAX_BATCH_SIZE is not set, it should be set in the launcher"
+ MAX_BATCH_TOTAL_TOKENS = MAX_BATCH_SIZE * request.max_total_tokens
+ logger.info(f"MAX_BATCH_SIZE: {MAX_BATCH_SIZE}")
+ logger.info(f"MAX_BATCH_TOTAL_TOKENS: {MAX_BATCH_TOTAL_TOKENS}")
+ MAX_TOTAL_TOKENS = request.max_total_tokens
+
+ batch = self.batch_type.from_pb(
+ request.batch, self.tokenizer, self.dtype, self.device
+ )
+ max_prefill_batch_size = batch.input_ids.shape[0]
+ try:
+ # max prefill batch size warmup
+ _, prefill_batch, _ = self.generate_token([batch])
+ except Exception:
+ raise RuntimeError(
+ f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
+ f"You need to decrease `--max-batch-prefill-tokens`"
+ )
+
+ del prefill_batch
+
+ # Warmup prefill batch_size
+ max_input_tokens = request.max_input_tokens
+ max_exp = math.ceil(math.log(max_prefill_batch_size, BATCH_SIZE_EXPONENT_BASE))
+ prefill_batch_size_list = [
+ BATCH_SIZE_EXPONENT_BASE**exp
+ for exp in range(
+ 0,
+ max_exp + 1,
+ )
+ ]
+ prefill_seqlen_list = [
+ seq
+ for seq in range(
+ PAD_SEQUENCE_TO_MULTIPLE_OF,
+ max_input_tokens,
+ PAD_SEQUENCE_TO_MULTIPLE_OF,
+ )
+ ]
+ prefill_seqlen_list.append(max_input_tokens)
+ prefill_batch_size_list.sort(reverse=True)
+ prefill_seqlen_list.sort(reverse=True)
+ try:
+ for batch_size in prefill_batch_size_list:
+ for seq_len in prefill_seqlen_list:
+ batch = self.generate_warmup_batch(request, seq_len - 1, batch_size)
+ _, prefill_batch, _ = self.generate_token([batch])
+ except Exception:
+ prefill_batch_size_list.sort()
+ prefill_seqlen_list.sort()
+ raise RuntimeError(
+ f"Not enough memory to run following prefill batch_size."
+ f"Prefill batch size list:{prefill_batch_size_list}"
+ f"Prefill sequence length list:{prefill_seqlen_list}"
+ f"You need to decrease `--max-batch-prefill-tokens`"
+ )
+ prefill_seqlen_list.sort()
+ prefill_batch_size_list.sort()
+ mem_stats = get_hpu_memory_stats(self.device)
+ logger.info(
+ f"\nFollowing prefill warmup successfully.\n"
+ f"Prefill batch size list:{prefill_batch_size_list}\n"
+ f"Prefill sequence length list:{prefill_seqlen_list}\n"
+ f"Memory stats: {mem_stats} "
+ )
+
+ max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
+ max_exp = math.ceil(math.log(max_decode_batch_size, BATCH_SIZE_EXPONENT_BASE))
+ decode_batch_size_list = [
+ BATCH_SIZE_EXPONENT_BASE**exp for exp in range(0, max_exp + 1)
+ ]
+ decode_batch_size_list.sort(reverse=True)
+
+ try:
+ for batch_size in decode_batch_size_list:
+ batches = []
+ iters = math.floor(batch_size / max_prefill_batch_size)
+ for i in range(iters):
+ batch = self.generate_warmup_batch(
+ request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size
+ )
+ _, prefill_batch, _ = self.generate_token([batch])
+ batches.append(prefill_batch)
+
+ if batch_size % max_prefill_batch_size != 0:
+ batch = self.generate_warmup_batch(
+ request,
+ PAD_SEQUENCE_TO_MULTIPLE_OF - 1,
+ batch_size % max_prefill_batch_size,
+ )
+ _, prefill_batch, _ = self.generate_token([batch])
+ batches.append(prefill_batch)
+
+ _, decode_batch, _ = self.generate_token(batches)
+ _, decode_batch, _ = self.generate_token([decode_batch])
+ del decode_batch
+ batches.clear()
+
+ except Exception:
+ raise RuntimeError(
+ f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})."
+ f"You need to decrease `--max-batch-total-tokens`"
+ )
+
+ decode_batch_size_list.sort()
+ max_supported_total_tokens = MAX_TOTAL_TOKENS * decode_batch_size_list[-1]
+ mem_stats = get_hpu_memory_stats(self.device)
+ logger.info(
+ f"\nFollowing decode warmup successfully.\n"
+ f"Decode batch size list:{decode_batch_size_list}\n"
+ f"Memory stats: {mem_stats} "
+ )
+
+ max_input_tokens = max_input_tokens
+ max_total_tokens = MAX_TOTAL_TOKENS
+
+ return max_supported_total_tokens, max_input_tokens, max_total_tokens
diff --git a/.devcontainer/devcontainer.json b/backends/gaudi/server/text_generation_server/models/custom_modeling/__init__.py
similarity index 100%
rename from .devcontainer/devcontainer.json
rename to backends/gaudi/server/text_generation_server/models/custom_modeling/__init__.py
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/bloom_modeling.py
new file mode 100644
index 000000000..84835ab89
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/bloom_modeling.py
@@ -0,0 +1,923 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch BLOOM model."""
+
+import math
+import os
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.distributed
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import LayerNorm
+from torch.nn import functional as F
+
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+)
+from transformers import BloomConfig, PreTrainedModel
+
+from text_generation_server.layers import (
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ TensorParallelRowLinear,
+ SpeculativeHead,
+)
+
+CUSTOM_KERNELS_ENABLED = False
+if (
+ torch.cuda.is_available()
+ and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True"
+):
+ try:
+ from custom_kernels import fused_bloom_attention_cuda
+
+ CUSTOM_KERNELS_ENABLED = True
+ except ImportError:
+ pass
+
+_CHECKPOINT_FOR_DOC = "bigscience/bloom-560m"
+_CONFIG_FOR_DOC = "BloomConfig"
+
+BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "bigscience/bigscience-small-testing",
+ "bigscience/bloom-560m",
+ "bigscience/bloom-1b1",
+ "bigscience/bloom-1b7",
+ "bigscience/bloom-3b",
+ "bigscience/bloom-7b1",
+ "bigscience/bloom",
+]
+
+
+def _make_causal_mask(
+ input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
+) -> torch.BoolTensor:
+ """
+ Make causal mask used for self-attention.
+ """
+ batch_size, target_length = input_ids_shape
+ mask = torch.ones(
+ (target_length, target_length + past_key_values_length),
+ dtype=torch.bool,
+ device=device,
+ )
+ mask = mask.triu(1 + past_key_values_length)
+
+ expanded_mask = mask.unsqueeze(0).expand(
+ batch_size, target_length, target_length + past_key_values_length
+ )
+ return expanded_mask
+
+
+def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
+ """
+ Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
+ """
+ batch_size, src_length = mask.shape
+ tgt_length = tgt_length if tgt_length is not None else src_length
+
+ expanded_mask = ~(mask[:, None, :].to(torch.bool))
+ return expanded_mask.expand(batch_size, tgt_length, src_length)
+
+
+def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int) -> torch.Tensor:
+ """
+ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
+ relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
+ `softmax(l+a) = softmax(l)`. Based on
+ https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
+ TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
+
+ Args:
+ Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
+ attention_mask (`torch.Tensor`):
+ Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
+ num_heads (`int`, *required*):
+ number of heads
+ dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
+ dtype of the output tensor
+ """
+ batch_size, seq_length = attention_mask.shape
+ closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
+ base = torch.tensor(
+ 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
+ device=attention_mask.device,
+ dtype=torch.float32,
+ )
+ powers = torch.arange(
+ 1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32
+ )
+ slopes = torch.pow(base, powers)
+
+ if closest_power_of_2 != num_heads:
+ extra_base = torch.tensor(
+ 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
+ device=attention_mask.device,
+ dtype=torch.float32,
+ )
+ num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
+ extra_powers = torch.arange(
+ 1,
+ 1 + 2 * num_remaining_heads,
+ 2,
+ device=attention_mask.device,
+ dtype=torch.int32,
+ )
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
+
+ # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
+ # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
+ # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
+ # => the query_length dimension will then be broadcasted correctly
+ # This is more or less identical to T5's relative position bias:
+ # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
+ arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
+ alibi = slopes[..., None] * arange_tensor
+ return alibi
+
+
+# @torch.jit.script
+def dropout_add(
+ x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool
+) -> torch.Tensor:
+ """
+ Dropout add function
+
+ Args:
+ x (`torch.tensor`, *required*):
+ input tensor
+ residual (`torch.tensor`, *required*):
+ esidual tensor
+ prob (`float`, *required*):
+ dropout probability
+ training (`bool`, *required*):
+ training mode
+ """
+ out = F.dropout(x, p=prob, training=training)
+ out = residual + out
+ return out
+
+
+# @torch.jit.script # this is shit for unknow reasons.
+def _split_heads(
+ fused_qkv: torch.Tensor, num_heads: int, head_dim: int
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
+ storage as `fused_qkv`
+
+ Args:
+ fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
+
+ Returns:
+ query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
+ value: [batch_size, seq_length, num_heads, head_dim]
+ """
+ batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
+ fused_qkv = fused_qkv.view(batch_size, seq_length, num_heads, 3 * head_dim)
+ query_layer, key_layer, value_layer = fused_qkv.split(head_dim, dim=-1)
+
+ query_layer = query_layer.transpose(1, 2).reshape(
+ batch_size * num_heads, seq_length, head_dim
+ )
+ key_layer = key_layer.permute(0, 2, 3, 1).reshape(
+ batch_size * num_heads, head_dim, seq_length
+ )
+ value_layer = value_layer.transpose(1, 2).reshape(
+ batch_size * num_heads, seq_length, head_dim
+ )
+
+ return query_layer, key_layer, value_layer
+
+
+# @torch.jit.script
+def _merge_heads(x: torch.Tensor, num_heads: int, head_dim: int) -> torch.Tensor:
+ """
+ Merge heads together over the last dimenstion
+
+ Args:
+ x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
+
+ Returns:
+ torch.tensor: [batch_size, seq_length, num_heads * head_dim]
+ """
+ # What we want to achieve is:
+ # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
+ batch_size_and_num_heads, seq_length, _ = x.shape
+ batch_size = batch_size_and_num_heads // num_heads
+
+ # First view to decompose the batch size
+ # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
+ x = x.view(batch_size, num_heads, seq_length, head_dim)
+
+ # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
+ x = x.permute(0, 2, 1, 3)
+
+ # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
+ return x.reshape(batch_size, seq_length, num_heads * head_dim)
+
+
+class BloomAttention(nn.Module):
+ def __init__(self, prefix, config: BloomConfig, weights):
+ super().__init__()
+
+ self.pretraining_tp = config.pretraining_tp
+ self.slow_but_exact = config.slow_but_exact
+
+ self.process_group = weights.process_group
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.n_head
+ self.head_dim = self.hidden_size // self.num_heads
+ self.split_size = self.hidden_size
+ self.hidden_dropout = config.hidden_dropout
+
+ if self.head_dim * self.num_heads != self.hidden_size:
+ raise ValueError(
+ f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+
+ # Layer-wise attention scaling
+ self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
+ self.beta = 1.0
+
+ process_group = weights.process_group
+ if self.num_heads % process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {process_group.size()}"
+ )
+ self.num_heads = self.num_heads // process_group.size()
+ self.query_key_value = TensorParallelColumnLinear.load(
+ config=config,
+ prefix=f"{prefix}.query_key_value",
+ weights=weights,
+ bias=True,
+ )
+ self.dense = TensorParallelRowLinear.load(
+ config=config, prefix=f"{prefix}.dense", weights=weights, bias=True
+ )
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
+
+ @staticmethod
+ def compute_attention(
+ fused_qkv: torch.Tensor,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]],
+ alibi: torch.Tensor,
+ attention_mask: torch.Tensor,
+ head_mask: Optional[torch.Tensor],
+ beta: float,
+ inv_norm_factor: float,
+ num_heads: int,
+ use_cache: bool,
+ ):
+ batch_size, q_length, three_times_hidden_size = fused_qkv.shape
+ head_dim = three_times_hidden_size // (3 * num_heads)
+ batch_size * num_heads
+
+ ### TODO @thomasw21: this takes quite a bit of time, how do I accelerate that?
+ # 3 x [batch_size, seq_length, num_heads, head_dim]
+ (query_layer, key_layer, value_layer) = _split_heads(
+ fused_qkv, num_heads=num_heads, head_dim=head_dim
+ )
+
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ # concatenate along seq_length dimension:
+ # - key: [batch_size * self.num_heads, head_dim, kv_length]
+ # - value: [batch_size * self.num_heads, kv_length, head_dim]
+ past_key = past_key.view(-1, *past_key.shape[-2:])
+ key_layer = torch.cat((past_key, key_layer), dim=2)
+ past_value = past_value.view(-1, *past_value.shape[-2:])
+ value_layer = torch.cat((past_value, value_layer), dim=1)
+
+ _, _, kv_length = key_layer.shape
+
+ if use_cache is True:
+ present = (key_layer, value_layer)
+ else:
+ present = None
+ ###
+
+ # [batch_size * num_heads, q_length, kv_length]
+ # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
+ attention_scores = alibi.baddbmm(
+ batch1=query_layer,
+ batch2=key_layer,
+ beta=beta,
+ alpha=inv_norm_factor,
+ )
+
+ # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
+ input_dtype = attention_scores.dtype
+ # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
+ if input_dtype == torch.float16:
+ attention_scores = attention_scores.to(torch.float)
+ # torch.finfo not supported by torch.jit, we temporarily remplace with `-1e34`
+ attn_weights = attention_scores.masked_fill_(
+ attention_mask, torch.finfo(attention_scores.dtype).min
+ )
+ attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
+ input_dtype
+ )
+
+ # # [batch_size, num_heads, q_length, kv_length]
+ # attention_probs = self.attention_dropout(attention_probs)
+
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ # matmul: [batch_size * num_heads, q_length, head_dim]
+ context_layer = torch.bmm(attention_probs, value_layer, out=query_layer)
+
+ # change view [batch_size, num_heads, q_length, head_dim]
+ context_layer = _merge_heads(
+ context_layer, num_heads=num_heads, head_dim=head_dim
+ )
+
+ return context_layer, present, attention_probs
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor,
+ alibi: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ ):
+ fused_qkv = self.query_key_value(
+ hidden_states
+ ) # [batch_size, seq_length, 3 x hidden_size]
+ batch_size, q_length, _ = fused_qkv.shape
+
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ layer_past = (
+ past_key.view(-1, *past_key.shape[-2:]),
+ past_value.view(-1, *past_value.shape[-2:]),
+ )
+
+ if CUSTOM_KERNELS_ENABLED and attention_mask.shape[-1] < 4096:
+ assert self.training is False, "Only foward pass was implemented"
+ assert (
+ attention_mask.shape[-1] < 4096
+ ), "Custom kernel support only up to 4096 tokens"
+ (
+ context_layer,
+ present,
+ attention_probs,
+ ) = fused_bloom_attention_cuda.forward(
+ fused_qkv,
+ layer_past,
+ alibi,
+ attention_mask,
+ head_mask,
+ self.beta,
+ self.inv_norm_factor,
+ self.num_heads,
+ use_cache,
+ )
+ else:
+ context_layer, present, attention_probs = self.compute_attention(
+ fused_qkv=fused_qkv,
+ layer_past=layer_past,
+ alibi=alibi,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ beta=self.beta,
+ inv_norm_factor=self.inv_norm_factor,
+ num_heads=self.num_heads,
+ use_cache=use_cache,
+ )
+
+ # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
+ if self.pretraining_tp > 1 and self.slow_but_exact:
+ slices = self.hidden_size / self.pretraining_tp
+ output_tensor = torch.zeros_like(context_layer)
+ for i in range(self.pretraining_tp):
+ output_tensor = output_tensor + F.linear(
+ context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
+ self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
+ )
+ else:
+ output_tensor = self.dense(context_layer)
+
+ # output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
+ output_tensor += residual
+
+ outputs = (output_tensor, present)
+ if output_attentions:
+ outputs += (attention_probs,)
+
+ return outputs
+
+
+class BloomMLP(nn.Module):
+ def __init__(self, prefix, config: BloomConfig, weights):
+ super().__init__()
+
+ self.pretraining_tp = config.pretraining_tp
+ self.slow_but_exact = config.slow_but_exact
+ self.dense_h_to_4h = TensorParallelColumnLinear.load(
+ config=config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
+ )
+ self.dense_4h_to_h = TensorParallelRowLinear.load(
+ config=config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
+ )
+ self.gelu_impl = torch.nn.GELU(approximate="tanh")
+ self.hidden_dropout = config.hidden_dropout
+
+ def forward(
+ self, hidden_states: torch.Tensor, residual: torch.Tensor
+ ) -> torch.Tensor:
+ hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
+
+ if self.pretraining_tp > 1 and self.slow_but_exact:
+ intermediate_output = torch.zeros_like(residual)
+ slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
+ for i in range(self.pretraining_tp):
+ intermediate_output = intermediate_output + F.linear(
+ hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
+ self.dense_4h_to_h.weight[
+ :, int(i * slices) : int((i + 1) * slices)
+ ],
+ )
+ else:
+ intermediate_output = self.dense_4h_to_h(hidden_states)
+
+ # output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
+ intermediate_output += residual
+
+ return intermediate_output
+
+
+class BloomBlock(nn.Module):
+ def __init__(self, layer_id: int, config: BloomConfig, weights):
+ super().__init__()
+
+ prefix = f"h.{layer_id}"
+ self.input_layernorm = LayerNorm.load(
+ prefix=f"{prefix}.input_layernorm",
+ weights=weights,
+ eps=config.layer_norm_epsilon,
+ )
+ self.num_heads = config.n_head
+ self.self_attention = BloomAttention(
+ prefix=f"{prefix}.self_attention", config=config, weights=weights
+ )
+ self.post_attention_layernorm = LayerNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.layer_norm_epsilon,
+ )
+
+ self.mlp = BloomMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
+ self.apply_residual_connection_post_layernorm = (
+ config.apply_residual_connection_post_layernorm
+ )
+ self.hidden_dropout = config.hidden_dropout
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ alibi: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ ):
+ # hidden_states: [batch_size, seq_length, hidden_size]
+
+ # Layer norm at the beginning of the transformer layer.
+ layernorm_output = self.input_layernorm(hidden_states)
+
+ # Layer norm post the self attention.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = hidden_states
+
+ # Self attention.
+ attn_outputs = self.self_attention(
+ layernorm_output,
+ residual,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ alibi=alibi,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ attention_output = attn_outputs[0]
+
+ outputs = attn_outputs[1:]
+
+ layernorm_output = self.post_attention_layernorm(attention_output)
+
+ # Get residual
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = attention_output
+
+ # MLP.
+ output = self.mlp(layernorm_output, residual)
+
+ if use_cache:
+ outputs = (output,) + outputs
+ else:
+ outputs = (output,) + outputs[1:]
+
+ return outputs # hidden_states, present, attentions
+
+
+class BloomPreTrainedModel(PreTrainedModel):
+ config_class = BloomConfig
+ base_model_prefix = "transformer"
+ _no_split_modules = ["BloomBlock"]
+
+ @staticmethod
+ def _convert_to_standard_cache(
+ past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
+ num_heads, ...]))
+ """
+ batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
+ num_heads = batch_size_times_num_heads // batch_size
+ # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
+ # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
+ return tuple(
+ (
+ layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
+ layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
+ )
+ for layer_past in past_key_value
+ )
+
+ @staticmethod
+ def _convert_to_bloom_cache(
+ past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]],
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
+ """
+ batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
+ batch_size_times_num_heads = batch_size * num_heads
+ # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
+ # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
+ return tuple(
+ (
+ layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
+ layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
+ )
+ for layer_past in past_key_value
+ )
+
+
+class BloomModel(BloomPreTrainedModel):
+ def __init__(self, config: BloomConfig, weights):
+ super().__init__(config)
+
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.n_head
+
+ process_group = weights.process_group
+ self.tp_rank = process_group.rank()
+ self.tp_world_size = process_group.size()
+
+ self.word_embeddings = TensorParallelEmbedding(
+ prefix="word_embeddings", weights=weights
+ )
+
+ self.word_embeddings_layernorm = LayerNorm.load(
+ prefix="word_embeddings_layernorm",
+ weights=weights,
+ eps=config.layer_norm_epsilon,
+ )
+
+ # Transformer blocks
+ self.h = nn.ModuleList(
+ [
+ BloomBlock(layer_id=layer_id, config=config, weights=weights)
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+
+ # Final Layer Norm
+ self.ln_f = LayerNorm.load(
+ prefix="ln_f", weights=weights, eps=config.layer_norm_epsilon
+ )
+
+ def _prepare_attn_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_shape: Tuple[int, int],
+ past_key_values_length: int,
+ ) -> torch.BoolTensor:
+ # create causal mask
+ # [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
+ combined_attention_mask = None
+ device = attention_mask.device
+ _, src_length = input_shape
+
+ if src_length > 1:
+ combined_attention_mask = _make_causal_mask(
+ input_shape,
+ device=device,
+ past_key_values_length=past_key_values_length,
+ )
+
+ # [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
+ expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
+ combined_attention_mask = (
+ expanded_attn_mask
+ if combined_attention_mask is None
+ else expanded_attn_mask | combined_attention_mask
+ )
+
+ return combined_attention_mask
+
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
+ self.word_embeddings = new_embeddings
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **deprecated_arguments,
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
+ if deprecated_arguments.pop("position_ids", False) is not False:
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
+ warnings.warn(
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
+ " passing `position_ids`.",
+ FutureWarning,
+ )
+ if len(deprecated_arguments) > 0:
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
+
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if past_key_values is None:
+ past_key_values = tuple([None] * len(self.h))
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape batch_size x num_heads x N x N
+ # head_mask has shape n_layer x batch x num_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ # Compute alibi tensor: check build_alibi_tensor documentation
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+ if past_key_values[0] is not None:
+ past_key_values_length = past_key_values[0][0].shape[-1]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past), device=hidden_states.device
+ )
+ else:
+ attention_mask = attention_mask.to(hidden_states.device)
+
+ alibi = build_alibi_tensor(attention_mask, self.num_heads)
+
+ causal_mask = self._prepare_attn_mask(
+ attention_mask,
+ input_shape=(batch_size, seq_length),
+ past_key_values_length=past_key_values_length,
+ )
+
+ if hasattr(self, "tp_rank"):
+ assert self.num_heads % self.tp_world_size == 0
+ block_size = self.num_heads // self.tp_world_size
+ alibi = alibi[
+ :, self.tp_rank * block_size : (self.tp_rank + 1) * block_size
+ ]
+ alibi = alibi.reshape(batch_size * block_size, 1, seq_length_with_past)
+ causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0)
+ else:
+ alibi = alibi.reshape(batch_size * self.num_heads, 1, seq_length_with_past)
+ causal_mask = torch.repeat_interleave(causal_mask, self.num_heads, dim=0)
+
+ alibi = alibi.to(hidden_states.dtype)
+
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ outputs = block(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=causal_mask,
+ head_mask=head_mask[i],
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ alibi=alibi,
+ )
+
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (
+ outputs[2 if use_cache else 1],
+ )
+
+ # Add last hidden state
+ hidden_states = self.ln_f(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ presents,
+ all_hidden_states,
+ all_self_attentions,
+ ]
+ if v is not None
+ )
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class BloomForCausalLM(BloomPreTrainedModel):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__(config)
+ self.transformer = BloomModel(config, weights)
+
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix="word_embeddings",
+ weights=weights,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids: torch.LongTensor,
+ past_key_values: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> dict:
+ # only last token for input_ids if past is not None
+ if past_key_values:
+ input_ids = input_ids[:, -1].unsqueeze(-1)
+
+ # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
+ if past_key_values[0][0].shape[0] == input_ids.shape[0]:
+ past_key_values = self._convert_to_bloom_cache(past_key_values)
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **deprecated_arguments,
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+ if deprecated_arguments.pop("position_ids", False) is not False:
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
+ warnings.warn(
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
+ " passing `position_ids`.",
+ FutureWarning,
+ )
+ if len(deprecated_arguments) > 0:
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
+
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+
+ logits, speculative_logits = self.lm_head(hidden_states)
+ loss = None
+
+ if not return_dict:
+ output = (logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return (
+ CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ ),
+ speculative_logits,
+ )
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/clip.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/clip.py
new file mode 100644
index 000000000..ab824da5b
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/clip.py
@@ -0,0 +1,817 @@
+from typing import Optional, Tuple
+
+import torch
+from torch import nn
+
+from transformers.activations import ACT2FN
+from transformers.modeling_attn_mask_utils import (
+ _create_4d_causal_attention_mask,
+ _prepare_4d_attention_mask,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPooling,
+)
+from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
+
+from text_generation_server.layers import (
+ TensorParallelEmbedding,
+ TensorParallelColumnLinear,
+ TensorParallelRowLinear,
+)
+
+
+class CLIPVisionEmbeddings(nn.Module):
+ def __init__(self, prefix, config: CLIPVisionConfig, weights):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ # TODO Should we TP this ?
+ self.class_embedding = weights.get_tensor(f"{prefix}.class_embedding")
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ bias=False,
+ )
+ self.patch_embedding.weight = nn.Parameter(
+ weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
+ )
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches + 1
+ self.position_embedding = TensorParallelEmbedding(
+ prefix=f"{prefix}.position_embedding", weights=weights
+ )
+ self.register_buffer(
+ "position_ids",
+ torch.arange(self.num_positions, device=weights.device).expand((1, -1)),
+ persistent=False,
+ )
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ batch_size = pixel_values.shape[0]
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(
+ pixel_values.to(dtype=target_dtype)
+ ) # shape = [*, width, grid, grid]
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+ embeddings = embeddings + self.position_embedding(self.position_ids)
+ return embeddings
+
+
+class CLIPTextEmbeddings(nn.Module):
+ def __init__(self, config: CLIPTextConfig):
+ super().__init__()
+ embed_dim = config.hidden_size
+
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
+ self.position_embedding = nn.Embedding(
+ config.max_position_embeddings, embed_dim
+ )
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer(
+ "position_ids",
+ torch.arange(config.max_position_embeddings).expand((1, -1)),
+ persistent=False,
+ )
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ seq_length = (
+ input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
+ )
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, :seq_length]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.token_embedding(input_ids)
+
+ position_embeddings = self.position_embedding(position_ids)
+ embeddings = inputs_embeds + position_embeddings
+
+ return embeddings
+
+
+class CLIPAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_size = self.embed_dim // self.num_heads
+ if self.head_size * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.embed_dim = self.embed_dim // weights.process_group.size()
+ self.scale = self.head_size**-0.5
+ self.dropout = config.attention_dropout
+
+ self.qkv = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=True,
+ )
+ self.out_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.out_proj",
+ weights=weights,
+ bias=True,
+ )
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return (
+ tensor.view(bsz, seq_len, self.num_heads, self.head_size)
+ .transpose(1, 2)
+ .contiguous()
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+
+ qkv = self.qkv(hidden_states)
+ query_states, key_states, value_states = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ ]
+ * 3,
+ dim=2,
+ )
+ query_states = query_states * self.scale
+ key_states = self._shape(key_states, -1, bsz)
+ value_states = self._shape(value_states, -1, bsz)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_size)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ # apply the causal_attention_mask first
+ if causal_attention_mask is not None:
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {causal_attention_mask.size()}"
+ )
+ attn_weights = (
+ attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ + causal_attention_mask
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = (
+ attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ + attention_mask
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ attn_probs = nn.functional.dropout(
+ attn_weights, p=self.dropout, training=self.training
+ )
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, None
+
+
+class CLIPMLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = TensorParallelColumnLinear.load(
+ prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
+ )
+ self.fc2 = TensorParallelRowLinear.load(
+ prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class CLIPEncoderLayer(nn.Module):
+ def __init__(self, prefix, config: CLIPConfig, weights):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = CLIPAttention(
+ prefix=f"{prefix}.self_attn", config=config, weights=weights
+ )
+ self.layer_norm1 = nn.LayerNorm.load(
+ prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
+ )
+ self.mlp = CLIPMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
+ self.layer_norm2 = nn.LayerNorm.load(
+ prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ causal_attention_mask: torch.Tensor,
+ ):
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ `(config.encoder_attention_heads,)`.
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+class CLIPPreTrainedModel(nn.Module):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = CLIPConfig
+ base_model_prefix = "clip"
+ supports_gradient_checkpointing = True
+
+
+CLIP_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+CLIP_TEXT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+"""
+
+CLIP_VISION_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
+"""
+
+CLIP_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
+ return_loss (`bool`, *optional*):
+ Whether or not to return the contrastive loss.
+"""
+
+
+class CLIPEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`CLIPEncoderLayer`].
+
+ Args:
+ config: CLIPConfig
+ """
+
+ def __init__(self, prefix, config: CLIPConfig, weights):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList(
+ [
+ CLIPEncoderLayer(
+ prefix=f"{prefix}.layers.{i}", config=config, weights=weights
+ )
+ for i in range(config.num_hidden_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ """
+
+ hidden_states = inputs_embeds
+ for idx, encoder_layer in enumerate(self.layers):
+ hidden_states = encoder_layer(
+ hidden_states,
+ attention_mask,
+ causal_attention_mask,
+ )
+
+ return hidden_states
+
+
+class CLIPTextTransformer(nn.Module):
+ def __init__(self, prefix: str, config: CLIPTextConfig, weights=None):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+ self.embeddings = CLIPTextEmbeddings(config)
+ # Initialize weights and apply final processing with `self.post_init()`
+ self.encoder = CLIPEncoder(
+ prefix=f"{prefix}.encoder", config=config, weights=weights
+ )
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+
+ # For `pooled_output` computation
+ self.eos_token_id = config.eos_token_id
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ ):
+ r"""
+ Returns:
+
+ """
+ if input_ids is None:
+ raise ValueError("You have to specify input_ids")
+
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
+
+ # CLIP's text model uses causal mask, prepare it here.
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
+ causal_attention_mask = _create_4d_causal_attention_mask(
+ input_shape, hidden_states.dtype, device=hidden_states.device
+ )
+ # expand attention_mask
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(
+ attention_mask, hidden_states.dtype
+ )
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
+
+ if self.eos_token_id == 2:
+ # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
+ # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
+ # ------------------------------------------------------------
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
+ last_hidden_state[
+ torch.arange(
+ last_hidden_state.shape[0], device=last_hidden_state.device
+ ),
+ input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(
+ dim=-1
+ ),
+ ]
+ else:
+ # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
+ last_hidden_state[
+ torch.arange(
+ last_hidden_state.shape[0], device=last_hidden_state.device
+ ),
+ # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
+ (
+ input_ids.to(dtype=torch.int, device=last_hidden_state.device)
+ == self.eos_token_id
+ )
+ .int()
+ .argmax(dim=-1),
+ ]
+
+ return last_hidden_state
+
+
+class CLIPTextModel(CLIPPreTrainedModel):
+ config_class = CLIPTextConfig
+
+ _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
+
+ def __init__(self, prefix, config: CLIPTextConfig):
+ super().__init__(config)
+ self.text_model = CLIPTextTransformer(prefix, config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ ):
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, CLIPTextModel
+
+ >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
+
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
+ ```"""
+
+ return self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ )
+
+
+class CLIPVisionTransformer(nn.Module):
+ def __init__(self, prefix, config: CLIPVisionConfig, weights):
+ super().__init__()
+ self.config = config
+
+ self.embeddings = CLIPVisionEmbeddings(
+ prefix=f"{prefix}.embeddings", config=config, weights=weights
+ )
+ self.pre_layrnorm = nn.LayerNorm.load(
+ prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps
+ )
+ self.encoder = CLIPEncoder(
+ prefix=f"{prefix}.encoder", config=config, weights=weights
+ )
+ # self.post_layernorm = nn.LayerNorm.load(prefix=f"{prefix}.post_layernorm", weights=weights, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+ Returns:
+
+ """
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ hidden_states = self.embeddings(pixel_values)
+ hidden_states = self.pre_layrnorm(hidden_states)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ )
+ last_hidden_state = encoder_outputs
+ # pooled_output = last_hidden_state[:, 0, :]
+ # pooled_output = self.post_layernorm(pooled_output)
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ # pooler_output=pooled_output,
+ # hidden_states=encoder_outputs,
+ )
+
+
+class CLIPVisionModel(CLIPPreTrainedModel):
+ config_class = CLIPVisionConfig
+ main_input_name = "pixel_values"
+ _no_split_modules = ["CLIPEncoderLayer"]
+
+ def __init__(self, config: CLIPVisionConfig):
+ super().__init__(config)
+ self.vision_model = CLIPVisionTransformer(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.vision_model.embeddings.patch_embedding
+
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, CLIPVisionModel
+
+ >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
+ ```"""
+
+ return self.vision_model(
+ pixel_values=pixel_values,
+ )
+
+
+class CLIPModel(nn.Module):
+ def __init__(self, prefix, config: CLIPConfig, weights):
+ super().__init__()
+ text_config = config.text_config
+ vision_config = config.vision_config
+
+ self.projection_dim = config.projection_dim
+ self.text_embed_dim = text_config.hidden_size
+ self.vision_embed_dim = vision_config.hidden_size
+
+ self.text_model = CLIPTextTransformer(text_config)
+ self.vision_model = CLIPVisionTransformer(vision_config)
+
+ self.visual_projection = nn.Linear(
+ self.vision_embed_dim, self.projection_dim, bias=False
+ )
+ self.text_projection = nn.Linear(
+ self.text_embed_dim, self.projection_dim, bias=False
+ )
+ self.logit_scale = nn.Parameter(
+ torch.tensor(self.config.logit_scale_init_value)
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_text_features(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
+ applying the projection layer to the pooled output of [`CLIPTextModel`].
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, CLIPModel
+
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
+
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
+ >>> text_features = model.get_text_features(**inputs)
+ ```"""
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ )
+
+ pooled_output = text_outputs[1]
+ text_features = self.text_projection(pooled_output)
+
+ return text_features
+
+ def get_image_features(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
+ applying the projection layer to the pooled output of [`CLIPVisionModel`].
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, CLIPModel
+
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> image_features = model.get_image_features(**inputs)
+ ```"""
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ )
+
+ pooled_output = vision_outputs[1] # pooled_output
+ image_features = self.visual_projection(pooled_output)
+
+ return image_features
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ ):
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, CLIPModel
+
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(
+ ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
+ ... )
+
+ >>> outputs = model(**inputs)
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
+ ```"""
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ )
+
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ )
+
+ image_embeds = vision_outputs[1]
+ image_embeds = self.visual_projection(image_embeds)
+
+ text_embeds = text_outputs[1]
+ text_embeds = self.text_projection(text_embeds)
+
+ # normalized features
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
+
+ # cosine similarity as logits
+ logit_scale = self.logit_scale.exp()
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
+ logits_per_image = logits_per_text.t()
+
+ return logits_per_image, logits_per_text
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py
new file mode 100644
index 000000000..3bcc689d2
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py
@@ -0,0 +1,493 @@
+# coding=utf-8
+# Copyright 2024 Cohere team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from typing import Optional, List, Tuple
+
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers import (
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+ get_linear,
+)
+from text_generation_server.layers.layernorm import (
+ FastLayerNorm,
+)
+from text_generation_server.layers.rotary import (
+ PositionRotaryEmbedding,
+)
+from text_generation_server.utils.weights import UnquantizedWeight
+from habana_frameworks.torch.hpex.kernels import (
+ RotaryPosEmbeddingMode,
+ apply_rotary_pos_emb,
+)
+
+
+class CohereRotary(PositionRotaryEmbedding):
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ ):
+ # Such controlflows may add some overhead.
+ num_tokens = query.shape[0]
+ head_size = query.shape[-1]
+ rope_mode = RotaryPosEmbeddingMode.PAIRWISE
+ sin = torch.repeat_interleave(sin, 2, dim=-1)
+ cos = torch.repeat_interleave(cos, 2, dim=-1)
+ rotary_dim = cos.shape[-1]
+ query_shape = query.shape
+ query = query.view(num_tokens, -1, head_size)
+ query_rot = query[..., :rotary_dim]
+ query_pass = query[..., rotary_dim:]
+ query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
+ query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
+
+ key_shape = key.shape
+ key = key.view(num_tokens, -1, head_size)
+ key_rot = key[..., :rotary_dim]
+ key_pass = key[..., rotary_dim:]
+ key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
+ key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
+
+
+class CohereLayerNorm(nn.Module):
+ def __init__(self, prefix, weights, eps):
+ super().__init__()
+ weight = weights.get_sharded(f"{prefix}.weight", dim=0)
+ self.weight = nn.Parameter(weight)
+ # Fake weights
+ self.ones = weight.new_ones(weight.shape[1])
+ self.eps = eps
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.reshape(
+ -1, self.weight.shape[0], self.weight.shape[1]
+ )
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ mean = hidden_states.mean(-1, keepdim=True)
+ hidden_states_minus_mean = hidden_states - mean
+ variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps)
+ hidden_states = self.weight.to(torch.float32) * hidden_states
+ hidden_states = hidden_states.view(-1, self.weight.shape[1])
+ return hidden_states.to(input_dtype)
+
+
+def load_attention(config, prefix, weights):
+ if config.num_attention_heads != config.num_key_value_heads:
+ return _load_gqa(config, prefix, weights)
+ else:
+ return TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=config.attention_bias,
+ )
+
+
+def _load_gqa(config, prefix: str, weights):
+ assert config.hidden_size % config.num_attention_heads == 0
+ assert config.num_attention_heads % weights.process_group.size() == 0
+
+ weight = weights.get_multi_weights_col(
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ )
+
+ if isinstance(weight, UnquantizedWeight):
+ weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
+
+ head_size = config.hidden_size // config.num_attention_heads
+ num_heads = config.num_attention_heads // weights.process_group.size()
+ num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
+ assert list(weight.weight.shape) == [
+ (num_heads + 2 * num_key_value_heads) * head_size,
+ config.hidden_size,
+ ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
+
+ if config.attention_bias:
+ w = [
+ weights.get_sharded(f"{p}.bias", dim=0)
+ for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
+ ]
+ bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device)
+ else:
+ bias = None
+
+ return TensorParallelColumnLinear(get_linear(weight, bias=bias))
+
+
+class FlashCohereAttention(torch.nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights,
+ ):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ self.head_size = self.hidden_size // self.num_heads
+
+ self.rotary_emb = CohereRotary.static(
+ config=config,
+ dim=self.head_size,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
+ self.softmax_scale = self.head_size**-0.5
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+
+ self.query_key_value = load_attention(config, prefix, weights)
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ self.use_qk_norm = config.use_qk_norm
+ if self.use_qk_norm:
+ self.q_norm = CohereLayerNorm(
+ prefix=f"{prefix}.q_norm",
+ weights=weights,
+ eps=config.layer_norm_eps,
+ )
+ self.k_norm = CohereLayerNorm(
+ prefix=f"{prefix}.k_norm",
+ weights=weights,
+ eps=config.layer_norm_eps,
+ )
+ else:
+ self.q_norm = None
+ self.k_norm = None
+
+ self.o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=config.attention_bias,
+ )
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ qkv = self.query_key_value(hidden_states)
+ query, key, value = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ self.head_size * self.num_key_value_heads,
+ self.head_size * self.num_key_value_heads,
+ ],
+ dim=1,
+ )
+
+ if self.use_qk_norm:
+ query = query.reshape(-1, self.head_size)
+ key = key.reshape(-1, self.head_size)
+ query = self.q_norm(query.contiguous())
+ key = self.k_norm(key.contiguous())
+
+ query = query.view(-1, self.num_heads, self.head_size)
+ key = key.view(-1, self.num_key_value_heads, self.head_size)
+ value = value.view(-1, self.num_key_value_heads, self.head_size)
+
+ self.rotary_emb(query, key, cos, sin)
+
+ kv_cache.store(
+ key=key,
+ value=value,
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=key,
+ value=value,
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.o_proj(
+ attn_output.view(-1, self.num_heads * self.head_size), reduce=False
+ )
+
+
+class CohereMLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ act = config.hidden_act
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+ # Fuse gate and up proj
+ self.gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=False,
+ )
+ self.down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.intermediate_size = (
+ config.intermediate_size // weights.process_group.size()
+ )
+
+ def forward(self, hidden_states):
+ gate_up_states = self.gate_up_proj(hidden_states)
+ gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
+ return self.down_proj(
+ self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=False
+ )
+
+
+class FlashCohereLayer(nn.Module):
+ def __init__(self, prefix: str, layer_id, config, weights):
+ super().__init__()
+ prefix = f"{prefix}.layers.{layer_id}"
+ self.self_attn = FlashCohereAttention(
+ prefix=f"{prefix}.self_attn", config=config, weights=weights
+ )
+ self.mlp = CohereMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
+
+ self.input_layernorm = FastLayerNorm.load_no_bias(
+ prefix=f"{prefix}.input_layernorm",
+ weights=weights,
+ eps=config.layer_norm_eps,
+ )
+ self.process_group = weights.process_group
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ normed_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ mlp_output = self.mlp(normed_hidden_states)
+ output = attn_output + mlp_output
+
+ if self.process_group.size() > 1:
+ torch.distributed.all_reduce(output, group=self.process_group)
+
+ return output, res
+
+
+class FlashCohereModel(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ process_group = weights.process_group
+ self.tp_rank = process_group.rank()
+ self.tp_world_size = process_group.size()
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=f"{prefix}.embed_tokens", weights=weights
+ )
+ self.layers = nn.ModuleList(
+ [
+ FlashCohereLayer(
+ prefix,
+ layer_id,
+ config,
+ weights,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = FastLayerNorm.load_no_bias(
+ prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_eps
+ )
+
+ self.gradient_checkpointing = False
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: torch.Tensor,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ hidden_states = self.embed_tokens(input_ids)
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashCohereForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ if not prefix:
+ prefix = "model"
+ else:
+ prefix = f"{prefix}.model"
+
+ self.model = FlashCohereModel(prefix, config, weights)
+ try:
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix="lm_head",
+ weights=weights,
+ )
+ except RuntimeError:
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix=f"{prefix}.embed_tokens",
+ weights=weights,
+ )
+ self.logit_scale = config.logit_scale
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ hidden_states = self.model(
+ input_ids,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.lm_head(hidden_states)
+ logits *= self.logit_scale
+ if speculative_logits is not None:
+ speculative_logits *= self.logit_scale
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py
new file mode 100644
index 000000000..15c243c97
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py
@@ -0,0 +1,745 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+from typing import Optional, List, Tuple, Any
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+
+
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers import (
+ FastLinear,
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+ get_linear,
+)
+from text_generation_server.layers.rotary import (
+ PositionRotaryEmbedding,
+)
+from text_generation_server.layers.layernorm import (
+ FastLayerNorm,
+)
+from vllm_hpu_extension.ops import DynamicFusedMOE
+
+
+class DbrxAttentionConfig(PretrainedConfig):
+ def __init__(
+ self,
+ attn_pdrop: float = 0,
+ clip_qkv: Optional[float] = None,
+ kv_n_heads: int = 1,
+ rope_theta: float = 10000.0,
+ **kwargs: Any,
+ ):
+ super().__init__(**kwargs)
+ self.attn_pdrop = attn_pdrop
+ self.clip_qkv = clip_qkv
+ self.kv_n_heads = kv_n_heads
+ self.rope_theta = rope_theta
+
+ for k in ["model_type"]:
+ if k in kwargs:
+ kwargs.pop(k)
+ if len(kwargs) != 0:
+ raise ValueError(f"Found unknown {kwargs=}")
+
+
+class DbrxFFNConfig(PretrainedConfig):
+ def __init__(
+ self,
+ ffn_act_fn: Optional[dict] = None,
+ ffn_hidden_size: int = 3584,
+ moe_num_experts: int = 4,
+ moe_top_k: int = 1,
+ moe_jitter_eps: Optional[float] = None,
+ moe_loss_weight: float = 0.01,
+ moe_normalize_expert_weights: Optional[float] = 1,
+ uniform_expert_assignment: bool = False,
+ **kwargs: Any,
+ ):
+ super().__init__()
+ if ffn_act_fn is None:
+ ffn_act_fn = {"name": "silu"}
+ self.ffn_act_fn = ffn_act_fn
+ self.ffn_hidden_size = ffn_hidden_size
+ self.moe_num_experts = moe_num_experts
+ self.moe_top_k = moe_top_k
+ self.moe_jitter_eps = moe_jitter_eps
+ self.moe_loss_weight = moe_loss_weight
+ self.moe_normalize_expert_weights = moe_normalize_expert_weights
+ self.uniform_expert_assignment = uniform_expert_assignment
+
+ if uniform_expert_assignment:
+ raise ValueError("`uniform_expert_assignment = True` is not supported")
+
+ for k in ["model_type"]:
+ if k in kwargs:
+ kwargs.pop(k)
+ if len(kwargs) != 0:
+ raise ValueError(f"Found unknown {kwargs=}")
+
+
+class DbrxConfig(PretrainedConfig):
+ attribute_map = {
+ "hidden_size": "d_model",
+ "num_attention_heads": "n_heads",
+ "num_hidden_layers": "n_layers",
+ }
+
+ def __init__(
+ self,
+ d_model: int = 2048,
+ n_heads: int = 16,
+ n_layers: int = 24,
+ max_seq_len: int = 2048,
+ vocab_size: int = 32000,
+ resid_pdrop: float = 0.0,
+ emb_pdrop: float = 0.0,
+ attn_config: Optional[DbrxAttentionConfig] = None,
+ ffn_config: Optional[DbrxFFNConfig] = None,
+ use_cache: bool = True,
+ initializer_range: float = 0.02,
+ output_router_logits: bool = False,
+ router_aux_loss_coef: float = 0.05,
+ **kwargs: Any,
+ ):
+ if attn_config is None:
+ self.attn_config = DbrxAttentionConfig()
+ elif isinstance(attn_config, dict):
+ self.attn_config = DbrxAttentionConfig(**attn_config)
+ else:
+ self.attn_config = attn_config
+
+ if ffn_config is None:
+ self.ffn_config = DbrxFFNConfig()
+ elif isinstance(ffn_config, dict):
+ self.ffn_config = DbrxFFNConfig(**ffn_config)
+ else:
+ self.ffn_config = ffn_config
+
+ self.d_model = d_model
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.max_seq_len = max_seq_len
+ self.vocab_size = vocab_size
+ self.resid_pdrop = resid_pdrop
+ self.emb_pdrop = emb_pdrop
+ self.use_cache = use_cache
+ self.initializer_range = initializer_range
+ self.output_router_logits = output_router_logits
+ self.router_aux_loss_coef = router_aux_loss_coef
+
+ tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
+ if tie_word_embeddings:
+ raise ValueError("tie_word_embeddings is not supported for Dbrx models.")
+
+ super().__init__(
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ @property
+ def num_key_value_heads(self):
+ # We can't use the attribute map, since this the number of KV
+ # heads is not top-level.
+ return self.attn_config.kv_n_heads
+
+
+def promote_scalar(x: torch.Tensor) -> torch.Tensor:
+ return x.view(1) if len(x.size()) == 0 else x
+
+
+def load_attention(config, prefix, weights):
+ return TensorParallelColumnLinear.load_qkv(
+ config,
+ prefix=f"{prefix}.Wqkv",
+ weights=weights,
+ bias=False,
+ num_heads=config.n_heads,
+ num_key_value_heads=config.attn_config.kv_n_heads,
+ )
+
+
+def _load_experts(config, prefix, weights):
+ world_size = weights.process_group.size()
+ rank = weights.process_group.rank()
+
+ assert (
+ config.ffn_config.ffn_hidden_size % world_size == 0
+ ), f"The chosen size {config.ffn_config.ffn_hidden_size} is not compatible with sharding on {world_size} shards"
+
+ expert_size = config.ffn_config.ffn_hidden_size
+ block_size = expert_size // world_size
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+
+ tensor = torch.empty(
+ (config.ffn_config.moe_num_experts * block_size, config.d_model),
+ dtype=weights.dtype,
+ device=weights.device,
+ )
+
+ slice_ = weights._get_slice(f"{prefix}")
+
+ for i in range(config.ffn_config.moe_num_experts):
+ offset = i * expert_size
+ expert_slice = slice_[start + offset : stop + offset]
+
+ tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
+ dtype=weights.dtype
+ ).to(device=weights.device)
+ return tensor
+
+
+def _load_experts_quantized(config, prefix, weights, cls):
+ world_size = weights.process_group.size()
+ rank = weights.process_group.rank()
+
+ assert (
+ config.ffn_config.ffn_hidden_size % world_size == 0
+ ), f"The chosen size {config.ffn_config.ffn_hidden_size} is not compatible with sharding on {world_size} shards"
+
+ expert_size = config.ffn_config.ffn_hidden_size
+ block_size = expert_size // world_size
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+
+ slice_ = weights._get_slice(f"{prefix}")
+
+ experts = []
+ for i in range(config.ffn_config.moe_num_experts):
+ if config.quantize in ["gptq", "awq"]:
+ raise NotImplementedError(
+ "Dbrx does not support gptq/awq quantization yet."
+ )
+ else:
+ offset = i * expert_size
+ expert_slice = (
+ slice_[start + offset : stop + offset]
+ .to(dtype=weights.dtype)
+ .to(device=weights.device)
+ )
+
+ if cls == TensorParallelRowLinear:
+ expert_slice = expert_slice.t().contiguous()
+ linear = get_linear(expert_slice, None)
+ experts.append(cls(linear, weights.process_group))
+ else:
+ linear = get_linear(expert_slice, None)
+ experts.append(cls(linear))
+
+ return experts
+
+
+class DbrxAttention(torch.nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights,
+ ):
+ super().__init__()
+ self.clip_qkv = config.attn_config.clip_qkv
+ self.num_heads = config.n_heads
+ self.hidden_size = config.d_model
+ self.head_size = self.hidden_size // self.num_heads
+
+ self.rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=self.head_size,
+ base=config.attn_config.rope_theta,
+ device=weights.device,
+ )
+
+ self.softmax_scale = self.head_size**-0.5
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.attn_config.kv_n_heads // weights.process_group.size()
+ )
+
+ self.query_key_value = load_attention(config, prefix, weights)
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ self.o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.out_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ qkv = self.query_key_value(hidden_states)
+ if self.clip_qkv is not None:
+ qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
+
+ query, kv = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ 2 * self.head_size * self.num_key_value_heads,
+ ],
+ dim=1,
+ )
+ query = query.view(-1, self.num_heads, self.head_size)
+ kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
+
+ self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
+
+ kv_cache.store(
+ key=kv[:, 0],
+ value=kv[:, 1],
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=kv[:, 0],
+ value=kv[:, 1],
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
+
+
+class DbrxNormAttentionNorm(nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights,
+ ):
+ super().__init__()
+ self.norm_1 = FastLayerNorm.load_no_bias(
+ prefix=f"{prefix}.norm_1", weights=weights, eps=1e-5
+ )
+ self.self_attn = DbrxAttention(
+ prefix=f"{prefix}.attn", config=config, weights=weights
+ )
+ self.norm_2 = FastLayerNorm.load_no_bias(
+ prefix=f"{prefix}.norm_2",
+ weights=weights,
+ eps=1e-5,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ normed_hidden_states, res = self.norm_1(hidden_states, residual)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ normed_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ # faster post attention rms norm
+ normed_attn_res_output, attn_res = self.norm_2(attn_output, res)
+
+ return normed_attn_res_output, attn_res
+
+
+@torch.jit.script
+def select_experts(
+ gate_logits: torch.Tensor, top_k: int, moe_normalize_expert_weights: int
+):
+ # all_probs: (sequence_length, n_experts) and upcast for softmax
+ all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
+ # weights, selected_experts: (sequence_length, top-k)
+ weights, selected_experts = torch.topk(all_probs, top_k, dim=-1)
+ if moe_normalize_expert_weights:
+ weights = weights / torch.norm(
+ weights, p=moe_normalize_expert_weights, dim=-1, keepdim=True
+ )
+ weights = weights.view(-1)
+ selected_experts = selected_experts.view(-1)
+
+ return selected_experts, weights
+
+
+@torch.jit.script
+def round_up(x: torch.Tensor, value: int):
+ return torch.div(x + (value - 1), value, rounding_mode="trunc") * value
+
+
+class BlockSparseMoE(nn.Module):
+ def __init__(self, prefix, config: DbrxConfig, weights):
+ super().__init__()
+ self.moe_normalize_expert_weights = (
+ config.ffn_config.moe_normalize_expert_weights
+ )
+ self.hidden_dim = config.d_model
+ self.ffn_dim = config.ffn_config.ffn_hidden_size // weights.process_group.size()
+ self.num_experts = config.ffn_config.moe_num_experts
+ self.top_k = config.ffn_config.moe_top_k
+
+ act = config.ffn_config.ffn_act_fn["name"]
+ if "gelu" in act:
+ self.act = lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ elif "silu" in act:
+ self.act = torch.nn.functional.silu
+ else:
+ self.act = ACT2FN[act]
+
+ # gating
+ self.gate = FastLinear.load(
+ config, f"{prefix}.router.layer", weights, bias=False
+ )
+
+ # merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
+ w1 = _load_experts(config, f"{prefix}.experts.mlp.w1", weights).view(
+ self.num_experts, self.ffn_dim, self.hidden_dim
+ )
+ v1 = _load_experts(config, f"{prefix}.experts.mlp.v1", weights).view(
+ self.num_experts, self.ffn_dim, self.hidden_dim
+ )
+ self.wv1 = torch.cat([w1, v1], dim=1)
+ self.w2 = (
+ _load_experts(config, f"{prefix}.experts.mlp.w2", weights)
+ .view(self.num_experts, self.ffn_dim, self.hidden_dim)
+ .transpose(1, 2)
+ .contiguous()
+ )
+
+ self.process_group = weights.process_group
+
+ self.hpu_fused_moe = DynamicFusedMOE(self.num_experts)
+ for i in range(self.num_experts):
+ self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.wv1[i])
+ self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.w2[i])
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # router_logits: (num_tokens, n_experts)
+ router_logits = self.gate(x)
+ out = self.hpu_fused_moe(x, router_logits, self.top_k)
+
+ # Reduce sum
+ if self.process_group.size() > 1:
+ torch.distributed.all_reduce(out, group=self.process_group)
+
+ return out.view(*x.shape)
+
+
+class DenseMoE(nn.Module):
+ def __init__(self, prefix, config: DbrxConfig, weights):
+ super().__init__()
+
+ self.moe_normalize_expert_weights = (
+ config.ffn_config.moe_normalize_expert_weights
+ )
+ self.hidden_dim = config.d_model
+ self.ffn_dim = config.ffn_config.ffn_hidden_size // weights.process_group.size()
+ self.num_experts = config.ffn_config.moe_num_experts
+ self.top_k = config.ffn_config.moe_top_k
+
+ act = config.ffn_config.ffn_act_fn["name"]
+ if "gelu" in act:
+ self.act = lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ elif "silu" in act:
+ self.act = torch.nn.functional.silu
+ else:
+ self.act = ACT2FN[act]
+
+ # gating
+ self.gate = FastLinear.load(
+ config, f"{prefix}.router.layer", weights, bias=False
+ )
+
+ self.w1 = _load_experts_quantized(
+ config,
+ prefix=f"{prefix}.experts.mlp.w1",
+ weights=weights,
+ cls=TensorParallelColumnLinear,
+ )
+ self.w2 = _load_experts_quantized(
+ config,
+ prefix=f"{prefix}.experts.mlp.w2",
+ weights=weights,
+ cls=TensorParallelRowLinear,
+ )
+ self.v1 = _load_experts_quantized(
+ config,
+ prefix=f"{prefix}.experts.mlp.v1",
+ weights=weights,
+ cls=TensorParallelColumnLinear,
+ )
+
+ self.process_group = weights.process_group
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ x: (sequence_length, model_dim)
+ gate_logits: (sequence_length, n_experts)
+ """
+ # optional reshape
+ input_shape = x.shape
+ x = x.view(-1, input_shape[-1])
+
+ # gate_logits: (sequence_length, n_experts)
+ gate_logits = self.gate(x)
+ # all_probs: (sequence_length, n_experts) and upcast for softmax
+ weights = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
+
+ if self.top_k < self.num_experts:
+ _, not_selected_experts = torch.topk(
+ weights,
+ self.num_experts - self.top_k,
+ largest=False,
+ sorted=False,
+ dim=1,
+ )
+ # Mask not selected experts
+ weights.scatter_(1, not_selected_experts, 0)
+
+ # Re-normalize
+ if self.moe_normalize_expert_weights:
+ weights = weights / torch.norm(
+ weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
+ )
+ weights = weights.to(x.dtype)
+
+ # Final output tensor
+ out = x.new_zeros(x.shape[0], self.hidden_dim)
+ for i in range(self.num_experts):
+ h = self.act(self.w1[i](x)) * self.v1[i](x)
+ h = self.w2[i](h, reduce=False)
+ # Add expert output to out with masking
+ out += h * weights[:, i].view(-1, 1)
+
+ # Reduce sum
+ if self.process_group.size() > 1:
+ torch.distributed.all_reduce(out, group=self.process_group)
+
+ return out
+
+
+class DbrxLayer(nn.Module):
+ def __init__(self, prefix: str, layer_id, config, weights):
+ super().__init__()
+ prefix = f"{prefix}.blocks.{layer_id}"
+
+ self.attn = DbrxNormAttentionNorm(
+ prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights
+ )
+
+ moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE
+ self.moe = moe_cls(f"{prefix}.ffn", config, weights)
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ # Self Attention
+ attn_output, attn_res = self.attn(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ moe_output = self.moe(attn_output)
+
+ return moe_output, attn_res
+
+
+class DbrxModel(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=f"{prefix}.wte", weights=weights
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ DbrxLayer(
+ prefix,
+ layer_id,
+ config,
+ weights,
+ )
+ for layer_id in range(config.n_layers)
+ ]
+ )
+ self.norm = FastLayerNorm.load_no_bias(
+ prefix=f"{prefix}.norm_f", weights=weights, eps=1e-5
+ )
+
+ self.head_size = self.layers[0].attn.self_attn.head_size
+ self.num_heads = self.layers[0].attn.self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].attn.self_attn.num_key_value_heads
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ hidden_states = self.embed_tokens(input_ids)
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashDbrxForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ if not prefix:
+ prefix = "transformer"
+ else:
+ prefix = f"{prefix}.transformer"
+
+ self.model = DbrxModel(prefix, config, weights)
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix="lm_head",
+ weights=weights,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ hidden_states = self.model(
+ input_ids,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.lm_head(hidden_states)
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py
new file mode 100644
index 000000000..9d61c6941
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py
@@ -0,0 +1,633 @@
+# coding=utf-8
+# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+import torch.distributed
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+
+from text_generation_server.layers import (
+ FastLinear,
+ SpeculativeHead,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ TensorParallelRowLinear,
+ get_linear,
+)
+from text_generation_server.layers.attention import (
+ Seqlen,
+ attention,
+ paged_attention,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
+from text_generation_server.layers.layernorm import FastRMSNorm
+from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
+from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
+from text_generation_server.utils.weights import Weights
+
+
+class DeepseekV2Config(PretrainedConfig):
+ def __init__(
+ self,
+ vocab_size=102400,
+ hidden_size=4096,
+ intermediate_size=11008,
+ moe_intermediate_size=1407,
+ num_hidden_layers=30,
+ num_attention_heads=32,
+ num_key_value_heads=32,
+ n_shared_experts=2,
+ n_routed_experts=160,
+ ep_size=1,
+ routed_scaling_factor=1.0,
+ kv_lora_rank=512,
+ q_lora_rank=1536,
+ qk_rope_head_dim=64,
+ v_head_dim=128,
+ qk_nope_head_dim=128,
+ topk_method="gready",
+ n_group=8,
+ topk_group=3,
+ num_experts_per_tok=6,
+ moe_layer_freq=1,
+ first_k_dense_replace=0,
+ norm_topk_prob=False,
+ scoring_func="softmax",
+ aux_loss_alpha=0.001,
+ seq_aux=True,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=100000,
+ eos_token_id=100001,
+ pretraining_tp=1,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.moe_intermediate_size = moe_intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.n_shared_experts = n_shared_experts
+ self.n_routed_experts = n_routed_experts
+ self.ep_size = ep_size
+ self.routed_scaling_factor = routed_scaling_factor
+ self.kv_lora_rank = kv_lora_rank
+ self.q_lora_rank = q_lora_rank
+ self.qk_rope_head_dim = qk_rope_head_dim
+ self.v_head_dim = v_head_dim
+ self.qk_nope_head_dim = qk_nope_head_dim
+ self.topk_method = topk_method
+ self.n_group = n_group
+ self.topk_group = topk_group
+ self.num_experts_per_tok = num_experts_per_tok
+ self.moe_layer_freq = moe_layer_freq
+ self.first_k_dense_replace = first_k_dense_replace
+ self.norm_topk_prob = norm_topk_prob
+ self.scoring_func = scoring_func
+ self.aux_loss_alpha = aux_loss_alpha
+ self.seq_aux = seq_aux
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+ tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
+ if tie_word_embeddings:
+ raise ValueError(
+ "tie_word_embeddings is not supported for Deepseek V2 models."
+ )
+
+ if ep_size != 1:
+ raise ValueError(
+ f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}"
+ )
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+class DeepseekV2Attention(torch.nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights: Weights,
+ ):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ self.kv_lora_rank = config.kv_lora_rank
+ self.q_lora_rank = config.q_lora_rank
+ self.qk_nope_head_dim = config.qk_nope_head_dim
+ self.qk_rope_head_dim = config.qk_rope_head_dim
+ self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
+ self.value_head_size = config.v_head_dim
+ self.head_pad_size = max(self.head_size, self.value_head_size)
+
+ self.rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=self.qk_rope_head_dim,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
+ mscale = get_mscale(
+ self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
+ )
+ self.softmax_scale = self.head_size**-0.5 * mscale * mscale
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+
+ if self.q_lora_rank is None:
+ self.q_proj = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.q_proj",
+ weights=weights,
+ bias=config.attention_bias,
+ )
+ else:
+ self.q_a_proj = get_linear(
+ weight=weights.get_weights(f"{prefix}.q_a_proj"),
+ bias=(
+ weights.get_tensor(f"{prefix}.q_a_proj.bias")
+ if config.attention_bias
+ else None
+ ),
+ )
+ self.q_a_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.q_a_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ self.q_b_proj = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.q_b_proj",
+ weights=weights,
+ bias=config.attention_bias,
+ )
+
+ self.kv_a_proj_with_mqa = get_linear(
+ weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"),
+ bias=(
+ weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias")
+ if config.attention_bias
+ else None
+ ),
+ )
+
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ self.kv_a_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+
+ self.kv_b_proj = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.kv_b_proj",
+ weights=weights,
+ bias=config.attention_bias,
+ )
+
+ self.o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ cu_seqlen_prefill: torch.Tensor,
+ kv_cache: KVCache,
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ):
+ if self.q_lora_rank is None:
+ query = self.q_proj(hidden_states)
+ else:
+ query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
+ query = query.view(-1, self.num_heads, self.head_size)
+
+ _, query_pe = torch.split(
+ query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
+ )
+
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
+ compressed_kv, key_pe = torch.split(
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
+ )
+
+ key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
+ kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
+ -1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
+ )
+
+ key_nope, value = torch.split(
+ kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
+ )
+
+ batch_size, heads, head_dim = query_pe.shape
+ query_pe = (
+ query_pe.view(batch_size, heads, head_dim // 2, 2)
+ .transpose(2, 3)
+ .reshape(batch_size, heads, head_dim)
+ )
+ batch_size, heads, head_dim = key_pe.shape
+ key_pe = (
+ key_pe.view(batch_size, heads, head_dim // 2, 2)
+ .transpose(2, 3)
+ .reshape(batch_size, heads, head_dim)
+ )
+ self.rotary_emb(query_pe, key_pe, cos, sin)
+
+ query[..., self.qk_nope_head_dim :] = query_pe
+ key = torch.empty_like(query)
+ key[..., : self.qk_nope_head_dim] = key_nope
+ key[..., self.qk_nope_head_dim :] = key_pe
+
+ # We need to pad the heads because Flash Attention does not support
+ # qk and v with different head sizes.
+ query = torch.nn.functional.pad(
+ query, (0, self.head_pad_size - self.head_size), value=0
+ )
+ key = torch.nn.functional.pad(
+ key, (0, self.head_pad_size - self.head_size), value=0
+ )
+ value = torch.nn.functional.pad(
+ value, (0, self.head_pad_size - self.value_head_size), value=0
+ )
+
+ kv_cache.store(
+ key=key,
+ value=value,
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # flash attention
+ attn_output = attention(
+ query=query,
+ key=key,
+ value=value,
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ # Remove padding.
+ attn_output = attn_output[..., : self.value_head_size]
+
+ return self.o_proj(
+ attn_output.reshape(-1, self.num_heads * self.value_head_size)
+ )
+
+
+class DeepseekV2MLP(nn.Module):
+ def __init__(self, prefix: str, config, weights, intermediate_size: int):
+ super().__init__()
+ self.hidden_act = config.hidden_act
+ if self.hidden_act != "silu":
+ # Bail out because MoE only supports silu.
+ raise NotImplementedError(
+ "Currently only `silu` is supported as an activation for Deepseek V2."
+ )
+ self.act = ACT2FN[self.hidden_act]
+
+ self.gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=False,
+ )
+
+ self.down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=False,
+ )
+
+ self.intermediate_size = intermediate_size // weights.process_group.size()
+
+ # TODO: This is a hotfix to be removed & properly refactored.
+ self.quantize = config.quantize
+
+ def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
+ gate_up_states = self.gate_up_proj(hidden_states)
+ gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
+ return self.down_proj(
+ self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
+ )
+
+
+class DeepseekV2MoE(nn.Module):
+ def __init__(
+ self,
+ prefix,
+ config: DeepseekV2Config,
+ moe_layer_cls: Type[MoELayer],
+ weights,
+ ):
+ super().__init__()
+
+ self.hidden_dim = config.hidden_size
+ self.moe_intermediate_size = (
+ config.moe_intermediate_size // weights.process_group.size()
+ )
+ self.routed_scaling_factor = config.routed_scaling_factor
+
+ # Gating
+ self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
+
+ self.moe_layer = moe_layer_cls(
+ prefix=f"{prefix}.experts",
+ n_experts=config.n_routed_experts,
+ n_expert_group=config.n_group,
+ renormalize=config.norm_topk_prob,
+ topk=config.num_experts_per_tok,
+ topk_group=config.topk_group,
+ weights=weights,
+ )
+ assert isinstance(self.moe_layer, MoELayer)
+
+ if config.n_shared_experts is not None:
+ self.shared_experts = DeepseekV2MLP(
+ prefix=f"{prefix}.shared_experts",
+ config=config,
+ weights=weights,
+ intermediate_size=config.moe_intermediate_size
+ * config.n_shared_experts,
+ )
+ else:
+ self.shared_experts = None
+
+ self.process_group = weights.process_group
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.shared_experts is not None:
+ shared_output = self.shared_experts(x, reduce=False)
+ else:
+ shared_output = None
+
+ router_logits = self.gate(x)
+
+ out = self.moe_layer(x, gating_output=router_logits)
+
+ if shared_output is not None:
+ out = out + shared_output
+
+ # Reduce sum
+ if self.process_group.size() > 1:
+ torch.distributed.all_reduce(out, group=self.process_group)
+
+ return out.view(*x.shape)
+
+
+class DeepseekV2Layer(nn.Module):
+ def __init__(self, prefix, layer_id, config, weights):
+ super().__init__()
+ prefix = f"{prefix}.layers.{layer_id}"
+
+ self.self_attn = DeepseekV2Attention(
+ prefix=f"{prefix}.self_attn",
+ config=config,
+ weights=weights,
+ )
+
+ if (
+ config.n_routed_experts is not None
+ and layer_id >= config.first_k_dense_replace
+ and layer_id % config.moe_layer_freq == 0
+ ):
+ moe_layer_cls = (
+ SparseMoELayer
+ if SparseMoELayer.is_supported(weights)
+ else DenseMoELayer
+ )
+ self.mlp = DeepseekV2MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
+ else:
+ self.mlp = DeepseekV2MLP(
+ prefix=f"{prefix}.mlp",
+ config=config,
+ weights=weights,
+ intermediate_size=config.intermediate_size,
+ )
+
+ self.input_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ cu_seqlen_prefill: torch.Tensor,
+ kv_cache,
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ):
+ normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ normed_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ # faster post attention rms norm
+ normed_attn_res_output, residual = self.post_attention_layernorm(
+ attn_output, residual
+ )
+
+ output = self.mlp(normed_attn_res_output)
+
+ return output, residual
+
+
+class DeepseekV2Model(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights: Weights):
+ super().__init__()
+
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=f"{prefix}.embed_tokens", weights=weights
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ DeepseekV2Layer(
+ prefix,
+ layer_id,
+ config,
+ weights,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = FastRMSNorm.load(
+ prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
+ )
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ hidden_states = self.embed_tokens(input_ids)
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashDeepseekV2ForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights: Weights):
+ super().__init__()
+
+ self.model = DeepseekV2Model(
+ "model" if not prefix else f"{prefix}.model", config, weights
+ )
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix="lm_head" if not prefix else f"{prefix}.lm_head",
+ weights=weights,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ hidden_states = self.model(
+ input_ids,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.lm_head(hidden_states)
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py
new file mode 100644
index 000000000..1a7ce5cf5
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py
@@ -0,0 +1,642 @@
+# coding=utf-8
+# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+import torch.distributed
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+
+from text_generation_server.layers import (
+ FastLinear,
+ SpeculativeHead,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ TensorParallelRowLinear,
+ get_linear,
+)
+from text_generation_server.layers.attention import (
+ Seqlen,
+ attention,
+ paged_attention,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
+from text_generation_server.layers.layernorm import FastRMSNorm
+from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
+from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
+from text_generation_server.utils.weights import Weights
+
+
+class DeepseekV3Config(PretrainedConfig):
+ def __init__(
+ self,
+ vocab_size=102400,
+ hidden_size=4096,
+ intermediate_size=11008,
+ moe_intermediate_size=1407,
+ num_hidden_layers=30,
+ num_attention_heads=32,
+ num_key_value_heads=32,
+ n_shared_experts=2,
+ n_routed_experts=160,
+ ep_size=1,
+ routed_scaling_factor=1.0,
+ kv_lora_rank=512,
+ q_lora_rank=1536,
+ qk_rope_head_dim=64,
+ v_head_dim=128,
+ qk_nope_head_dim=128,
+ topk_method="gready",
+ n_group=8,
+ topk_group=3,
+ num_experts_per_tok=6,
+ moe_layer_freq=1,
+ first_k_dense_replace=0,
+ norm_topk_prob=False,
+ scoring_func="softmax",
+ aux_loss_alpha=0.001,
+ seq_aux=True,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=100000,
+ eos_token_id=100001,
+ pretraining_tp=1,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.moe_intermediate_size = moe_intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.n_shared_experts = n_shared_experts
+ self.n_routed_experts = n_routed_experts
+ self.ep_size = ep_size
+ self.routed_scaling_factor = routed_scaling_factor
+ self.kv_lora_rank = kv_lora_rank
+ self.q_lora_rank = q_lora_rank
+ self.qk_rope_head_dim = qk_rope_head_dim
+ self.v_head_dim = v_head_dim
+ self.qk_nope_head_dim = qk_nope_head_dim
+ self.topk_method = topk_method
+ self.n_group = n_group
+ self.topk_group = topk_group
+ self.num_experts_per_tok = num_experts_per_tok
+ self.moe_layer_freq = moe_layer_freq
+ self.first_k_dense_replace = first_k_dense_replace
+ self.norm_topk_prob = norm_topk_prob
+ self.scoring_func = scoring_func
+ self.aux_loss_alpha = aux_loss_alpha
+ self.seq_aux = seq_aux
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+ tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
+ if tie_word_embeddings:
+ raise ValueError(
+ "tie_word_embeddings is not supported for Deepseek V2 models."
+ )
+
+ if ep_size != 1:
+ raise ValueError(
+ f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}"
+ )
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+class DeepseekV3Attention(torch.nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights: Weights,
+ ):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ self.kv_lora_rank = config.kv_lora_rank
+ self.q_lora_rank = config.q_lora_rank
+ self.qk_nope_head_dim = config.qk_nope_head_dim
+ self.qk_rope_head_dim = config.qk_rope_head_dim
+ self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
+ self.value_head_size = config.v_head_dim
+ self.head_pad_size = max(self.head_size, self.value_head_size)
+
+ self.rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=self.qk_rope_head_dim,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
+ mscale = get_mscale(
+ self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
+ )
+ self.softmax_scale = self.head_size**-0.5 * mscale * mscale
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+
+ if self.q_lora_rank is None:
+ self.q_proj = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.q_proj",
+ weights=weights,
+ bias=config.attention_bias,
+ )
+ else:
+ self.q_a_proj = get_linear(
+ weight=weights.get_weights(f"{prefix}.q_a_proj"),
+ bias=(
+ weights.get_tensor(f"{prefix}.q_a_proj.bias")
+ if config.attention_bias
+ else None
+ ),
+ )
+ self.q_a_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.q_a_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ self.q_b_proj = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.q_b_proj",
+ weights=weights,
+ bias=config.attention_bias,
+ )
+
+ self.kv_a_proj_with_mqa = get_linear(
+ weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"),
+ bias=(
+ weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias")
+ if config.attention_bias
+ else None
+ ),
+ )
+
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ self.kv_a_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+
+ self.kv_b_proj = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.kv_b_proj",
+ weights=weights,
+ bias=config.attention_bias,
+ )
+
+ self.o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ cu_seqlen_prefill: torch.Tensor,
+ kv_cache: KVCache,
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ):
+ if self.q_lora_rank is None:
+ query = self.q_proj(hidden_states)
+ else:
+ query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
+ query = query.view(-1, self.num_heads, self.head_size)
+
+ _, query_pe = torch.split(
+ query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
+ )
+
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
+ compressed_kv, key_pe = torch.split(
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
+ )
+
+ key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
+ kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
+ -1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
+ )
+
+ key_nope, value = torch.split(
+ kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
+ )
+
+ batch_size, heads, head_dim = query_pe.shape
+ query_pe = (
+ query_pe.view(batch_size, heads, head_dim // 2, 2)
+ .transpose(2, 3)
+ .reshape(batch_size, heads, head_dim)
+ )
+ batch_size, heads, head_dim = key_pe.shape
+ key_pe = (
+ key_pe.view(batch_size, heads, head_dim // 2, 2)
+ .transpose(2, 3)
+ .reshape(batch_size, heads, head_dim)
+ )
+ self.rotary_emb(query_pe, key_pe, cos, sin)
+
+ query[..., self.qk_nope_head_dim :] = query_pe
+ key = torch.empty_like(query)
+ key[..., : self.qk_nope_head_dim] = key_nope
+ key[..., self.qk_nope_head_dim :] = key_pe
+
+ # We need to pad the heads because Flash Attention does not support
+ # qk and v with different head sizes.
+ query = torch.nn.functional.pad(
+ query, (0, self.head_pad_size - self.head_size), value=0
+ )
+ key = torch.nn.functional.pad(
+ key, (0, self.head_pad_size - self.head_size), value=0
+ )
+ value = torch.nn.functional.pad(
+ value, (0, self.head_pad_size - self.value_head_size), value=0
+ )
+
+ kv_cache.store(
+ key=key,
+ value=value,
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # flash attention
+ attn_output = attention(
+ query=query,
+ key=key,
+ value=value,
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ # Remove padding.
+ attn_output = attn_output[..., : self.value_head_size]
+
+ return self.o_proj(
+ attn_output.reshape(-1, self.num_heads * self.value_head_size)
+ )
+
+
+class DeepseekV3MLP(nn.Module):
+ def __init__(self, prefix: str, config, weights, intermediate_size: int):
+ super().__init__()
+ self.hidden_act = config.hidden_act
+ if self.hidden_act != "silu":
+ # Bail out because MoE only supports silu.
+ raise NotImplementedError(
+ "Currently only `silu` is supported as an activation for Deepseek V2."
+ )
+ self.act = ACT2FN[self.hidden_act]
+
+ self.gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=False,
+ )
+
+ self.down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=False,
+ )
+
+ self.intermediate_size = intermediate_size // weights.process_group.size()
+
+ # TODO: This is a hotfix to be removed & properly refactored.
+ self.quantize = config.quantize
+
+ def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
+ gate_up_states = self.gate_up_proj(hidden_states)
+ gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
+ return self.down_proj(
+ self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
+ )
+
+
+class DeepseekV3MoE(nn.Module):
+ def __init__(
+ self,
+ prefix,
+ config: DeepseekV3Config,
+ moe_layer_cls: Type[MoELayer],
+ weights,
+ ):
+ super().__init__()
+
+ self.hidden_dim = config.hidden_size
+ self.moe_intermediate_size = (
+ config.moe_intermediate_size // weights.process_group.size()
+ )
+ self.routed_scaling_factor = config.routed_scaling_factor
+
+ # Gating
+ self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
+
+ if config.topk_method == "noaux_tc":
+ self.gate.e_score_correction_bias = torch.zeros(
+ config.n_routed_experts, device=weights.device
+ )
+ else:
+ self.gate.e_score_correction_bias = None
+
+ self.moe_layer = moe_layer_cls(
+ prefix=f"{prefix}.experts",
+ n_experts=config.n_routed_experts,
+ n_expert_group=config.n_group,
+ renormalize=config.norm_topk_prob,
+ topk=config.num_experts_per_tok,
+ topk_group=config.topk_group,
+ weights=weights,
+ scoring_func=config.scoring_func,
+ e_score_correction_bias=self.gate.e_score_correction_bias,
+ )
+ assert isinstance(self.moe_layer, MoELayer)
+
+ if config.n_shared_experts is not None:
+ self.shared_experts = DeepseekV3MLP(
+ prefix=f"{prefix}.shared_experts",
+ config=config,
+ weights=weights,
+ intermediate_size=config.moe_intermediate_size
+ * config.n_shared_experts,
+ )
+ else:
+ self.shared_experts = None
+
+ self.process_group = weights.process_group
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.shared_experts is not None:
+ shared_output = self.shared_experts(x, reduce=False)
+ else:
+ shared_output = None
+
+ router_logits = self.gate(x)
+
+ out = self.moe_layer(x, gating_output=router_logits)
+
+ if shared_output is not None:
+ out = out + shared_output
+
+ # Reduce sum
+ if self.process_group.size() > 1:
+ torch.distributed.all_reduce(out, group=self.process_group)
+
+ return out.view(*x.shape)
+
+
+class DeepseekV3Layer(nn.Module):
+ def __init__(self, prefix, layer_id, config, weights):
+ super().__init__()
+ prefix = f"{prefix}.layers.{layer_id}"
+
+ self.self_attn = DeepseekV3Attention(
+ prefix=f"{prefix}.self_attn",
+ config=config,
+ weights=weights,
+ )
+
+ if (
+ config.n_routed_experts is not None
+ and layer_id >= config.first_k_dense_replace
+ and layer_id % config.moe_layer_freq == 0
+ ):
+ moe_layer_cls = (
+ SparseMoELayer
+ if SparseMoELayer.is_supported(weights)
+ else DenseMoELayer
+ )
+ self.mlp = DeepseekV3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
+ else:
+ self.mlp = DeepseekV3MLP(
+ prefix=f"{prefix}.mlp",
+ config=config,
+ weights=weights,
+ intermediate_size=config.intermediate_size,
+ )
+
+ self.input_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ cu_seqlen_prefill: torch.Tensor,
+ kv_cache,
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ):
+ normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ normed_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ # faster post attention rms norm
+ normed_attn_res_output, residual = self.post_attention_layernorm(
+ attn_output, residual
+ )
+
+ output = self.mlp(normed_attn_res_output)
+
+ return output, residual
+
+
+class DeepseekV3Model(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights: Weights):
+ super().__init__()
+
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=f"{prefix}.embed_tokens", weights=weights
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ DeepseekV3Layer(
+ prefix,
+ layer_id,
+ config,
+ weights,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = FastRMSNorm.load(
+ prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
+ )
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ hidden_states = self.embed_tokens(input_ids)
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashDeepseekV3ForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights: Weights):
+ super().__init__()
+
+ self.model = DeepseekV3Model(
+ "model" if not prefix else f"{prefix}.model", config, weights
+ )
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix="lm_head" if not prefix else f"{prefix}.lm_head",
+ weights=weights,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ hidden_states = self.model(
+ input_ids,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.lm_head(hidden_states)
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py
new file mode 100644
index 000000000..79f21b0f3
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py
@@ -0,0 +1,555 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+from typing import Optional, List, Tuple
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers import (
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+ get_linear,
+ TensorParallelMultiAdapterLinear,
+ TensorParallelAdapterRowLinear,
+)
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers.rotary import PositionRotaryEmbedding
+from text_generation_server.layers.layernorm import (
+ FastRMSNorm,
+)
+from text_generation_server.utils.weights import UnquantizedWeight
+
+
+class Gemma2Config(PretrainedConfig):
+ def __init__(
+ self,
+ vocab_size=256128,
+ hidden_size=3072,
+ intermediate_size=24576,
+ num_hidden_layers=28,
+ num_attention_heads=16,
+ num_key_value_heads=16,
+ head_dim=256,
+ hidden_act="gelu_pytorch_tanh",
+ max_position_embeddings=8192,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=True,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.head_dim = head_dim
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+class Gemma2FastRMSNorm(FastRMSNorm):
+ @classmethod
+ def load(cls, prefix: str, weights, eps=1e-6):
+ dtype = weights.dtype
+ weights.dtype = torch.float32
+ weight = weights.get_tensor(f"{prefix}.weight") + 1
+ weights.dtype = dtype
+ new = cls(weight, eps)
+ new.dtype = dtype
+ return new
+
+ # perform the multiplication in full precision and downcast after
+ def forward(self, hidden_states, residual=None):
+ if residual is not None:
+ hidden_states += residual
+ residual = hidden_states
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ hidden_states = hidden_states * self.weight
+ return hidden_states.to(self.dtype), residual
+
+
+def load_attention(config, prefix: str, weights):
+ if config.num_attention_heads != config.num_key_value_heads:
+ return _load_gqa(config, prefix, weights)
+ else:
+ return TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=False,
+ )
+
+
+def _load_gqa(config, prefix: str, weights):
+ assert config.num_attention_heads % weights.process_group.size() == 0
+
+ weight = weights.get_multi_weights_col(
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ )
+
+ if isinstance(weight, UnquantizedWeight):
+ weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
+
+ head_size = config.head_dim
+ num_heads = config.num_attention_heads // weights.process_group.size()
+ num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
+ assert list(weight.weight.shape) == [
+ (num_heads + 2 * num_key_value_heads) * head_size,
+ config.hidden_size,
+ ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
+
+ return TensorParallelColumnLinear(get_linear(weight, bias=None))
+
+
+class FlashGemma2Attention(torch.nn.Module):
+ def __init__(
+ self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
+ ):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.head_size = config.head_dim
+ self.causal = causal
+ if is_sliding:
+ self.window_size = config.sliding_window
+ else:
+ self.window_size = -1
+
+ self.rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=self.head_size,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
+ # self.softmax_scale = self.head_size**-0.5
+ self.softmax_scale = config.query_pre_attn_scalar**-0.5
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+ self.softcap = config.attn_logit_softcapping
+
+ query_key_value = load_attention(config, prefix, weights)
+ self.query_key_value = TensorParallelMultiAdapterLinear.load(
+ query_key_value,
+ layer_id,
+ ["q_proj", "k_proj", "v_proj"],
+ sizes=[
+ self.head_size * config.num_attention_heads,
+ self.head_size * config.num_key_value_heads,
+ self.head_size * config.num_key_value_heads,
+ ],
+ process_group=weights.process_group,
+ )
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.o_proj = TensorParallelAdapterRowLinear.load(
+ o_proj,
+ layer_id,
+ "o_proj",
+ process_group=weights.process_group,
+ )
+
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ ):
+ qkv = self.query_key_value(hidden_states, adapter_data)
+ query, kv = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ 2 * self.head_size * self.num_key_value_heads,
+ ],
+ dim=1,
+ )
+ query = query.view(-1, self.num_heads, self.head_size)
+ kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
+
+ self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
+
+ kv_cache.store(
+ key=kv[:, 0],
+ value=kv[:, 1],
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=kv[:, 0],
+ value=kv[:, 1],
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ window_size_left=self.window_size,
+ softcap=self.softcap,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ softcap=self.softcap,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.o_proj(
+ attn_output.view(-1, self.num_heads * self.head_size), adapter_data
+ )
+
+
+class Gemma2MLP(nn.Module):
+ def __init__(self, prefix, config, weights, layer_id):
+ super().__init__()
+ act = config.hidden_activation
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+ # Fuse gate and up proj
+ gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=False,
+ )
+ self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
+ gate_up_proj,
+ layer_id,
+ ["gate_proj", "up_proj"],
+ sizes=[
+ config.intermediate_size,
+ config.intermediate_size,
+ ],
+ process_group=weights.process_group,
+ )
+
+ down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.down_proj = TensorParallelAdapterRowLinear.load(
+ down_proj,
+ layer_id,
+ "down_proj",
+ process_group=weights.process_group,
+ )
+
+ self.intermediate_size = (
+ config.intermediate_size // weights.process_group.size()
+ )
+
+ def forward(self, hidden_states, adapter_data):
+ gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
+ gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
+ return self.down_proj(
+ self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
+ )
+
+
+class FlashGemma2Layer(nn.Module):
+ def __init__(
+ self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
+ ):
+ super().__init__()
+ self.self_attn = FlashGemma2Attention(
+ prefix=f"{prefix}.self_attn",
+ config=config,
+ weights=weights,
+ layer_id=layer_id,
+ causal=causal,
+ is_sliding=is_sliding,
+ )
+ self.mlp = Gemma2MLP(
+ prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
+ )
+
+ self.input_layernorm = Gemma2FastRMSNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = Gemma2FastRMSNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ self.pre_feedforward_layernorm = Gemma2FastRMSNorm.load(
+ prefix=f"{prefix}.pre_feedforward_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ self.post_feedforward_layernorm = Gemma2FastRMSNorm.load(
+ prefix=f"{prefix}.post_feedforward_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ ):
+ normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ normed_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ )
+
+ # faster post attention rms norm
+ normed_attn_res_output, _ = self.post_attention_layernorm(attn_output)
+ normed_attn_res_output = normed_attn_res_output + res
+ res = normed_attn_res_output
+
+ pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output)
+ mlp_output = self.mlp(pre_normed, adapter_data)
+ post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output)
+
+ return post_hidden_states, normed_attn_res_output
+
+
+class FlashGemma2Model(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights, causal: bool):
+ super().__init__()
+
+ process_group = weights.process_group
+ self.tp_rank = process_group.rank()
+ self.tp_world_size = process_group.size()
+ self.layers = nn.ModuleList(
+ [
+ FlashGemma2Layer(
+ prefix=f"{prefix}.layers.{layer_id}",
+ config=config,
+ weights=weights,
+ layer_id=layer_id,
+ causal=causal,
+ is_sliding=layer_id % 2 == 0,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = Gemma2FastRMSNorm.load(
+ prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
+ )
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ adapter_data: Optional[torch.Tensor],
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ hidden_states = inputs_embeds
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ )
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashGemma2ForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights, *, causal: bool = True):
+ super().__init__()
+
+ embed_norm = config.hidden_size**0.5
+ if not prefix:
+ prefix = "model"
+ else:
+ prefix = f"{prefix}.model"
+
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=f"{prefix}.embed_tokens", weights=weights
+ )
+ self.embed_tokens.weight *= embed_norm
+
+ self.model = FlashGemma2Model(
+ prefix=prefix, config=config, weights=weights, causal=causal
+ )
+ self.lm_head = SpeculativeHead.load(
+ prefix=(
+ f"{prefix}.embed_tokens"
+ if config.tie_word_embeddings
+ else f"{prefix}.lm_head"
+ ),
+ config=config,
+ weights=weights,
+ )
+ self.softcap = config.final_logit_softcapping
+ assert isinstance(self.softcap, float)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_embeds = self.embed_tokens(input_ids)
+ hidden_states = self.model(
+ input_embeds,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.lm_head(hidden_states)
+
+ logits /= self.softcap
+ logits = torch.tanh(logits)
+ logits *= self.softcap
+
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py
new file mode 100644
index 000000000..609f03acc
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py
@@ -0,0 +1,469 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+from typing import Optional, List, Tuple
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers import (
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+ get_linear,
+)
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers.rotary import PositionRotaryEmbedding
+from text_generation_server.layers.layernorm import (
+ FastRMSNorm,
+)
+from text_generation_server.utils.weights import UnquantizedWeight
+
+
+class GemmaConfig(PretrainedConfig):
+ def __init__(
+ self,
+ vocab_size=256128,
+ hidden_size=3072,
+ intermediate_size=24576,
+ num_hidden_layers=28,
+ num_attention_heads=16,
+ num_key_value_heads=16,
+ head_dim=256,
+ hidden_act="gelu_pytorch_tanh",
+ max_position_embeddings=8192,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=True,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.head_dim = head_dim
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+class GemmaFastRMSNorm(FastRMSNorm):
+ @classmethod
+ def load(cls, prefix: str, weights, eps=1e-6):
+ dtype = weights.dtype
+ weights.dtype = torch.float32
+ weight = weights.get_tensor(f"{prefix}.weight") + 1
+ weights.dtype = dtype
+ new = cls(weight, eps)
+ new.dtype = dtype
+ return new
+
+ # perform the multiplication in full precision and downcast after
+ def forward(self, hidden_states, residual=None):
+ if residual is not None:
+ hidden_states += residual
+ residual = hidden_states
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ hidden_states = hidden_states * self.weight
+ return hidden_states.to(self.dtype), residual
+
+
+def load_attention(config, prefix: str, weights):
+ if config.num_attention_heads != config.num_key_value_heads:
+ return _load_gqa(config, prefix, weights)
+ else:
+ return TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=False,
+ )
+
+
+def _load_gqa(config, prefix: str, weights):
+ assert config.num_attention_heads % weights.process_group.size() == 0
+
+ weight = weights.get_multi_weights_col(
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ )
+
+ if isinstance(weight, UnquantizedWeight):
+ weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
+
+ head_size = config.head_dim
+ num_heads = config.num_attention_heads // weights.process_group.size()
+ num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
+ assert list(weight.weight.shape) == [
+ (num_heads + 2 * num_key_value_heads) * head_size,
+ config.hidden_size,
+ ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
+
+ return TensorParallelColumnLinear(get_linear(weight, bias=None))
+
+
+class FlashGemmaAttention(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights, causal: bool):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.head_size = config.head_dim
+ self.causal = causal
+
+ self.rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=self.head_size,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
+ self.softmax_scale = self.head_size**-0.5
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+
+ self.query_key_value = load_attention(config, prefix, weights)
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ self.o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ qkv = self.query_key_value(hidden_states)
+ query, kv = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ 2 * self.head_size * self.num_key_value_heads,
+ ],
+ dim=1,
+ )
+ query = query.view(-1, self.num_heads, self.head_size)
+ kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
+
+ self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
+
+ kv_cache.store(
+ key=kv[:, 0],
+ value=kv[:, 1],
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=kv[:, 0],
+ value=kv[:, 1],
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ causal=self.causal,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
+
+
+class GemmaMLP(nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+ act = config.hidden_act
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+ # Fuse gate and up proj
+ self.gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=False,
+ )
+ self.down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.intermediate_size = (
+ config.intermediate_size // weights.process_group.size()
+ )
+
+ def forward(self, hidden_states):
+ gate_up_states = self.gate_up_proj(hidden_states)
+ gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
+ return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
+
+
+class FlashGemmaLayer(nn.Module):
+ def __init__(self, prefix: str, config, weights, causal: bool):
+ super().__init__()
+ self.self_attn = FlashGemmaAttention(
+ prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal
+ )
+ self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
+
+ self.input_layernorm = GemmaFastRMSNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = GemmaFastRMSNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ normed_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ # faster post attention rms norm
+ normed_attn_res_output, attn_res = self.post_attention_layernorm(
+ attn_output, res
+ )
+
+ mlp_output = self.mlp(normed_attn_res_output)
+
+ return mlp_output, attn_res
+
+
+class FlashGemmaModel(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights, causal: bool):
+ super().__init__()
+
+ process_group = weights.process_group
+ self.tp_rank = process_group.rank()
+ self.tp_world_size = process_group.size()
+ self.layers = nn.ModuleList(
+ [
+ FlashGemmaLayer(
+ prefix=f"{prefix}.layers.{layer_id}",
+ config=config,
+ weights=weights,
+ causal=causal,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = GemmaFastRMSNorm.load(
+ prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
+ )
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ adapter_data: Optional[torch.Tensor],
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ hidden_states = inputs_embeds
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashGemmaForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights, *, causal: bool = True):
+ super().__init__()
+
+ embed_norm = config.hidden_size**0.5
+ if not prefix:
+ prefix = "model"
+ else:
+ prefix = f"{prefix}.model"
+
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=f"{prefix}.embed_tokens", weights=weights
+ )
+ self.embed_tokens.weight *= embed_norm
+
+ self.model = FlashGemmaModel(
+ prefix=prefix, config=config, weights=weights, causal=causal
+ )
+ self.lm_head = SpeculativeHead.load(
+ prefix=(
+ f"{prefix}.embed_tokens"
+ if config.tie_word_embeddings
+ else f"{prefix}.lm_head"
+ ),
+ config=config,
+ weights=weights,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_embeds = self.embed_tokens(input_ids)
+ hidden_states = self.model(
+ input_embeds,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.lm_head(hidden_states)
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py
new file mode 100644
index 000000000..10024a6de
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py
@@ -0,0 +1,451 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from typing import Optional, List, Tuple
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers import (
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+ get_linear,
+)
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+
+
+def load_qkv(config, prefix: str, weights, head_size, num_heads):
+ if config.quantize == "gptq":
+ return _load_qkv_gptq(
+ config,
+ prefix,
+ weights,
+ )
+ else:
+ return _load_qkv(config, prefix, weights, head_size, num_heads)
+
+
+def _load_qkv_gptq(config, prefix: str, weights):
+ world_size = weights.process_group.size()
+ rank = weights.process_group.rank()
+
+ # Weights
+ weight = weights.get_weights_col_packed_qkv(
+ f"{prefix}.c_attn",
+ config.num_attention_heads,
+ config.num_attention_heads,
+ )
+
+ # Bias
+ slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
+ shape = slice_.get_shape()
+ total_size = shape[0]
+ assert total_size % 3 == 0, f"Prepacked is not divisible by {3}"
+ single_size = total_size // 3
+ assert single_size % world_size == 0
+ block_size = single_size // world_size
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+ tensors = []
+ for i in range(3):
+ tensor = slice_[start + i * single_size : stop + i * single_size]
+ tensors.append(tensor)
+ bias = torch.cat(tensors, dim=0)
+ bias = bias.to(device=weights.device)
+
+ return TensorParallelColumnLinear(get_linear(weight, bias))
+
+
+def _load_qkv(config, prefix: str, weights, head_size, num_heads):
+ """Load QKV from a single, transposed matrix."""
+
+ slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
+ shape = slice_.get_shape()
+ total_size = shape[1]
+ assert total_size % 3 == 0, f"Prepacked is not divisible by {3}"
+ world_size = weights.process_group.size()
+ single_size = total_size // 3
+ assert single_size % world_size == 0
+ rank = weights.process_group.rank()
+
+ # Weights
+ block_size = single_size // world_size
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+ tensors = []
+ for i in range(3):
+ tensor = slice_[:, start + i * single_size : stop + i * single_size]
+ tensors.append(tensor)
+ weight = torch.cat(tensors, dim=1).T
+ weight = weight.to(dtype=weights.dtype)
+ weight = weight.to(device=weights.device)
+
+ # Bias
+ slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
+ shape = slice_.get_shape()
+ total_size = shape[0]
+ single_size = total_size // 3
+ block_size = single_size // world_size
+ assert single_size % world_size == 0
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+ b = []
+ for i in range(3):
+ tensor = slice_[start + i * single_size : stop + i * single_size]
+ b.append(tensor)
+ bias = torch.cat(b, dim=0)
+ bias = bias.to(dtype=weights.dtype)
+ bias = bias.to(device=weights.device)
+ assert list(bias.shape) == [
+ 3 * num_heads * head_size
+ ], f"{weight.shape} != {[3 * num_heads * head_size]}"
+
+ return TensorParallelColumnLinear(get_linear(weight, bias))
+
+
+def load_row(config, prefix: str, weights, bias: bool):
+ """load_row, but with transposed weight matrices."""
+
+ if config.quantize == "gptq":
+ weight = weights.get_weights_row(prefix)
+ else:
+ weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
+
+ if bias and weights.process_group.rank() == 0:
+ # Rank is only on the first rank process
+ bias = weights.get_tensor(f"{prefix}.bias")
+ else:
+ bias = None
+
+ return TensorParallelRowLinear(
+ get_linear(weight, bias), process_group=weights.process_group
+ )
+
+
+def load_col(config, prefix: str, weights, bias: bool):
+ """load_col, but with transposed weight matrices."""
+ if config.quantize == "gptq":
+ weight = weights.get_multi_weights_col([prefix], dim=1)
+ else:
+ weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
+
+ if bias:
+ bias = weights.get_sharded(f"{prefix}.bias", dim=0)
+ else:
+ bias = None
+
+ return TensorParallelColumnLinear(get_linear(weight, bias))
+
+
+class FlashGPT2Attention(torch.nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights,
+ ):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+
+ self.head_size = self.hidden_size // self.num_heads
+ self.softmax_scale = self.head_size**-0.5
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+
+ self.query_key_value = load_qkv(
+ config,
+ prefix=prefix,
+ weights=weights,
+ head_size=self.head_size,
+ num_heads=self.num_heads,
+ )
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ self.o_proj = load_row(
+ config,
+ prefix=f"{prefix}.c_proj",
+ weights=weights,
+ bias=True,
+ )
+
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_heads, dtype=torch.int32, device=weights.device
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ query, key, value = self.query_key_value(hidden_states).split(
+ self.head_size * self.num_heads, dim=1
+ )
+ query = query.view(-1, self.num_heads, self.head_size)
+ key = key.view(-1, self.num_heads, self.head_size)
+ value = value.view(-1, self.num_heads, self.head_size)
+
+ kv_cache.store(
+ key=key,
+ value=value,
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=key,
+ value=value,
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
+
+
+class GPT2MLP(nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+ act = config.activation_function
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+
+ self.c_fc = load_col(
+ config, prefix=f"{prefix}.c_fc", weights=weights, bias=True
+ )
+ self.c_proj = load_row(
+ config,
+ prefix=f"{prefix}.c_proj",
+ weights=weights,
+ bias=True,
+ )
+
+ intermediate_size = (
+ config.n_inner if config.n_inner is not None else 4 * config.hidden_size
+ )
+
+ self.intermediate_size = intermediate_size // weights.process_group.size()
+
+ def forward(self, hidden_states):
+ hidden_states = self.c_fc(hidden_states)
+ hidden_states = self.act(hidden_states)
+ return self.c_proj(hidden_states)
+
+
+class FlashGPT2Layer(nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+ self.self_attn = FlashGPT2Attention(
+ prefix=f"{prefix}.attn", config=config, weights=weights
+ )
+ self.mlp = GPT2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
+
+ self.input_layernorm = nn.LayerNorm.load(
+ prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
+ )
+ self.post_attention_layernorm = nn.LayerNorm.load(
+ prefix=f"{prefix}.ln_2",
+ weights=weights,
+ eps=config.layer_norm_epsilon,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ hidden_states,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ hidden_states = attn_output + residual
+ residual = hidden_states
+
+ hidden_states = self.post_attention_layernorm(hidden_states)
+
+ mlp_output = self.mlp(hidden_states)
+
+ return residual + mlp_output, residual
+
+
+class FlashGPT2Model(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ process_group = weights.process_group
+ self.tp_rank = process_group.rank()
+ self.tp_world_size = process_group.size()
+ self.layers = nn.ModuleList(
+ [
+ FlashGPT2Layer(
+ prefix=(
+ f"h.{layer_id}" if not prefix else f"{prefix}.h.{layer_id}"
+ ),
+ config=config,
+ weights=weights,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+
+ self.norm = nn.LayerNorm.load(
+ prefix="ln_f" if not prefix else f"{prefix}.ln_f",
+ weights=weights,
+ eps=config.layer_norm_epsilon,
+ )
+
+ self.gradient_checkpointing = False
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ hidden_states = inputs_embeds
+
+ residual = None
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ hidden_states = self.norm(hidden_states)
+
+ return hidden_states
+
+
+class FlashGPT2ForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=("wte" if not prefix else f"{prefix}.wte"),
+ weights=weights,
+ )
+ self.embed_positions = TensorParallelEmbedding(
+ prefix=("wpe" if not prefix else f"{prefix}.wpe"),
+ weights=weights,
+ )
+
+ self.model = FlashGPT2Model(prefix, config, weights)
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix="wte" if not prefix else f"{prefix}.wte",
+ weights=weights,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ token_embeds = self.embed_tokens(input_ids)
+ position_embeds = self.embed_positions(position_ids)
+ inputs_embeds = token_embeds + position_embeds
+ hidden_states = self.model(
+ inputs_embeds,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.lm_head(hidden_states)
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py
new file mode 100644
index 000000000..41eeab78c
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py
@@ -0,0 +1,389 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from typing import Optional, List, Tuple
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers import (
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+ get_linear,
+)
+from text_generation_server.layers.rotary import (
+ PositionRotaryEmbedding,
+)
+from text_generation_server.layers.layernorm import (
+ FastLayerNorm,
+)
+from habana_frameworks.torch.hpex.kernels import (
+ RotaryPosEmbeddingMode,
+ apply_rotary_pos_emb,
+)
+
+
+def load_attention(config, prefix: str, weights):
+ return TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=False,
+ )
+
+
+def load_row(config, prefix: str, weights, bias: bool):
+ weight = weights.get_weights_row(prefix)
+
+ if bias and weights.process_group.rank() == 0:
+ # Rank is only on the first rank process
+ bias = weights.get_tensor(f"{prefix}.bias")
+ else:
+ bias = None
+
+ linear = get_linear(weight, bias)
+ return TensorParallelRowLinear(linear, process_group=weights.process_group)
+
+
+class GPTJRotary(PositionRotaryEmbedding):
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ ):
+ num_tokens = query.shape[0]
+ head_size = query.shape[-1]
+ rope_mode = RotaryPosEmbeddingMode.PAIRWISE
+ sin = torch.repeat_interleave(sin, 2, dim=-1)
+ cos = torch.repeat_interleave(cos, 2, dim=-1)
+ rotary_dim = cos.shape[-1]
+ query_shape = query.shape
+ query = query.view(num_tokens, -1, head_size)
+ query_rot = query[..., :rotary_dim]
+ query_pass = query[..., rotary_dim:]
+ query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
+ query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
+
+ key_shape = key.shape
+ key = key.view(num_tokens, -1, head_size)
+ key_rot = key[..., :rotary_dim]
+ key_pass = key[..., rotary_dim:]
+ key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
+ key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
+
+
+class FlashGPTJAttention(torch.nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights,
+ ):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+
+ self.head_size = self.hidden_size // self.num_heads
+ self.softmax_scale = self.head_size**-0.5
+ self.rotary_dim = config.rotary_dim
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+
+ self.query_key_value = load_attention(
+ config,
+ prefix=prefix,
+ weights=weights,
+ )
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ self.o_proj = load_row(
+ config,
+ prefix=f"{prefix}.out_proj",
+ weights=weights,
+ bias=False,
+ )
+
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_heads, dtype=torch.int32, device=weights.device
+ )
+
+ self.rotary_emb = GPTJRotary.static(
+ config=config,
+ dim=self.rotary_dim,
+ base=10000,
+ device=weights.device,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ query, key, value = self.query_key_value(hidden_states).split(
+ self.head_size * self.num_heads, dim=1
+ )
+ query = query.view(-1, self.num_heads, self.head_size)
+ key = key.view(-1, self.num_heads, self.head_size)
+ value = value.view(-1, self.num_heads, self.head_size)
+
+ # Compute rotary embeddings on rotary_ndims
+ if self.rotary_dim is not None:
+ self.rotary_emb(
+ query[..., : self.rotary_dim], key[..., : self.rotary_dim], cos, sin
+ )
+ else:
+ self.rotary_emb(query, key, cos, sin)
+
+ kv_cache.store(
+ key=key,
+ value=value,
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=key,
+ value=value,
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
+
+
+class GPTJMLP(nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+ act = config.activation_function
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+
+ self.fc_in = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.fc_in", weights=weights, bias=True
+ )
+
+ self.fc_out = load_row(
+ config,
+ prefix=f"{prefix}.fc_out",
+ weights=weights,
+ bias=True,
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = self.fc_in(hidden_states)
+ hidden_states = self.act(hidden_states)
+ return self.fc_out(hidden_states)
+
+
+class FlashGPTJLayer(nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+ self.self_attn = FlashGPTJAttention(
+ prefix=f"{prefix}.attn", config=config, weights=weights
+ )
+ self.mlp = GPTJMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
+
+ self.input_layernorm = FastLayerNorm.load(
+ prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
+ # Self Attention
+ attn_output = self.self_attn(
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ feed_forward_hidden_states = self.mlp(hidden_states)
+
+ return attn_output + feed_forward_hidden_states, residual
+
+
+class FlashGPTJModel(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+ self.config = config
+
+ self.wte = TensorParallelEmbedding(prefix=f"{prefix}.wte", weights=weights)
+ self.layers = nn.ModuleList(
+ [
+ FlashGPTJLayer(
+ prefix=(
+ f"h.{layer_id}" if not prefix else f"{prefix}.h.{layer_id}"
+ ),
+ config=config,
+ weights=weights,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+
+ self.ln_f = FastLayerNorm.load(
+ prefix="ln_f" if not prefix else f"{prefix}.ln_f",
+ weights=weights,
+ eps=config.layer_norm_epsilon,
+ )
+
+ self.gradient_checkpointing = False
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor],
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ hidden_states = self.wte(input_ids)
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ hidden_states, _ = self.ln_f(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashGPTJForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+ if not prefix:
+ prefix = "transformer"
+ else:
+ prefix = f"{prefix}.transformer"
+ self.model = FlashGPTJModel(prefix, config, weights)
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix="lm_head",
+ weights=weights,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ hidden_states = self.model(
+ input_ids,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.lm_head(hidden_states)
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
new file mode 100644
index 000000000..81af55603
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
@@ -0,0 +1,658 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from contextlib import contextmanager
+from typing import List, Optional, Tuple, Type
+
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+
+from text_generation_server.layers.attention import (
+ KVCache,
+ get_kv_scales,
+)
+from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers import (
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+ TensorParallelMultiAdapterLinear,
+ TensorParallelAdapterRowLinear,
+)
+from text_generation_server.layers.rotary import PositionRotaryEmbedding
+from text_generation_server.layers.layernorm import (
+ FastRMSNorm,
+ FastLayerNorm,
+)
+from text_generation_server.layers import (
+ FastLinear,
+)
+from text_generation_server.utils.weights import (
+ Weights,
+)
+from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
+
+
+def load_attention(config, prefix: str, weights, layer_id):
+ # Only defined in granite.
+ bias = getattr(config, "attention_bias", False)
+ head_size = config.hidden_size // config.num_attention_heads
+ sizes = None
+ prefixes = None
+
+ if config.model_type == "phi3":
+ base_layer = TensorParallelColumnLinear.load_qkv(
+ config,
+ prefix=f"{prefix}.qkv_proj",
+ weights=weights,
+ bias=bias,
+ num_heads=config.num_attention_heads,
+ num_key_value_heads=config.num_key_value_heads,
+ )
+ prefixes = ["qkv_proj"]
+ elif config.model_type == "baichuan":
+ prefix = f"{prefix}.W_pack"
+ base_layer = TensorParallelColumnLinear.load_qkv(
+ config,
+ prefix=prefix,
+ weights=weights,
+ bias=bias,
+ num_heads=config.num_attention_heads,
+ num_key_value_heads=config.num_key_value_heads,
+ )
+ prefixes = [prefix]
+ else:
+ prefixes = ["q_proj", "k_proj", "v_proj"]
+ sizes = [
+ head_size * config.num_attention_heads,
+ head_size * config.num_key_value_heads,
+ head_size * config.num_key_value_heads,
+ ]
+ base_layer = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=bias,
+ )
+
+ return TensorParallelMultiAdapterLinear.load(
+ base_layer=base_layer,
+ layer_id=layer_id,
+ layer_names=prefixes,
+ sizes=sizes,
+ process_group=weights.process_group,
+ )
+
+
+@contextmanager
+def no_fp8(weights: Weights):
+ """De-activate fp8 auto conversion for the duration of this context manager"""
+ weights_loader = weights.weights_loader
+ if isinstance(weights_loader, HybridFP8UnquantLoader) and weights_loader.to_fp8:
+ weights_loader = HybridFP8UnquantLoader(
+ weights_loader.activation_scale_ub, to_fp8=False
+ )
+
+ with weights.use_loader(weights_loader):
+ yield
+
+
+class FlashLlamaAttention(torch.nn.Module):
+ def __init__(
+ self,
+ index: int,
+ prefix: str,
+ config,
+ weights,
+ ):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ self.head_size = self.hidden_size // self.num_heads
+
+ # Setting defaults for baichuan custom config which doesn't apply them.
+ config.rope_theta = getattr(config, "rope_theta", 10000)
+ config.num_key_value_heads = getattr(
+ config, "num_key_value_heads", config.num_attention_heads
+ )
+ self.rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=self.head_size,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
+ # `config.attention_multiplier` is used in Granite
+ self.softmax_scale = getattr(
+ config, "attention_multiplier", self.head_size**-0.5
+ )
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ if config.num_key_value_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_key_value_heads` must be divisible by `num_shards` (got `num_key_value_heads`: {config.num_key_value_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+
+ self.query_key_value = load_attention(config, prefix, weights, index)
+ self.index = index
+
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=getattr(config, "attention_bias", False),
+ )
+
+ self.o_proj = TensorParallelAdapterRowLinear.load(
+ o_proj,
+ index,
+ "o_proj",
+ process_group=weights.process_group,
+ )
+
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache: KVCache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ):
+ qkv = self.query_key_value(hidden_states, adapter_data)
+ query, kv = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ 2 * self.head_size * self.num_key_value_heads,
+ ],
+ dim=1,
+ )
+ query = query.view(-1, self.num_heads, self.head_size)
+ kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
+
+ self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
+
+ kv_cache.store(
+ key=kv[:, 0],
+ value=kv[:, 1],
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=kv[:, 0],
+ value=kv[:, 1],
+ kv_scales=self.kv_scales,
+ kv_cache=kv_cache,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.o_proj(
+ attn_output.view(-1, self.num_heads * self.head_size), adapter_data
+ )
+
+
+class Phi3MoE(nn.Module):
+ def __init__(
+ self, prefix: str, config, moe_layer_cls: Type[MoELayer], weights: Weights
+ ):
+ super().__init__()
+
+ # gating
+ self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
+
+ self.moe = moe_layer_cls(
+ prefix=f"{prefix}.experts",
+ n_experts=config.num_local_experts,
+ n_expert_group=None,
+ renormalize=True,
+ topk=config.num_experts_per_tok,
+ topk_group=None,
+ weights=weights,
+ gate_proj_name="w1",
+ up_proj_name="w3",
+ down_proj_name="w2",
+ )
+
+ self.process_group = weights.process_group
+
+ def forward(self, x, adapter_data) -> torch.Tensor:
+ # router_logits: (num_tokens, n_experts)
+ router_logits = self.gate(x)
+ out = self.moe(x, gating_output=router_logits)
+
+ # Reduce sum
+ if self.process_group.size() > 1:
+ torch.distributed.all_reduce(out, group=self.process_group)
+
+ return out.view(*x.shape)
+
+
+class LlamaMLP(nn.Module):
+ def __init__(self, prefix, config, weights, index):
+ super().__init__()
+ self.hidden_act = config.hidden_act
+ self.act = (
+ ACT2FN[self.hidden_act]
+ if "gelu" not in self.hidden_act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh"
+ if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
+ else "none"
+ ),
+ )
+ )
+ prefixes = None
+ sizes = None
+
+ # Fuse gate and up proj
+ bias = getattr(config, "mlp_bias", False)
+ if config.model_type == "phi3":
+ gate_up_proj = TensorParallelColumnLinear.load_gate_up(
+ config,
+ prefix=f"{prefix}.gate_up_proj",
+ weights=weights,
+ bias=bias,
+ )
+ else:
+ prefixes = ["gate_proj", "up_proj"]
+ sizes = [
+ config.intermediate_size,
+ config.intermediate_size,
+ ]
+ gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=bias,
+ )
+
+ self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
+ gate_up_proj,
+ index,
+ layer_names=prefixes,
+ sizes=sizes,
+ process_group=weights.process_group,
+ )
+
+ down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=bias,
+ )
+
+ self.down_proj = TensorParallelAdapterRowLinear.load(
+ down_proj,
+ index,
+ "down_proj",
+ process_group=weights.process_group,
+ )
+
+ self.intermediate_size = (
+ config.intermediate_size // weights.process_group.size()
+ )
+
+ # TODO: This is a hotfix to be removed & properly refactored.
+ self.quantize = config.quantize
+
+ self.hidden_size = config.hidden_size
+
+ def forward(self, hidden_states, adapter_data):
+ gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
+ gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
+ return self.down_proj(
+ self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
+ )
+
+
+class FlashLlamaLayer(nn.Module):
+ def __init__(self, index, prefix, config, weights):
+ super().__init__()
+
+ with no_fp8(weights):
+ self.self_attn = FlashLlamaAttention(
+ index=index,
+ prefix=f"{prefix}.self_attn",
+ config=config,
+ weights=weights,
+ )
+
+ if config.model_type == "phimoe":
+ moe_layer_cls = (
+ SparseMoELayer
+ if SparseMoELayer.is_supported(weights)
+ else DenseMoELayer
+ )
+ self.mlp = Phi3MoE(
+ f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
+ )
+ # with moe the layernorms are are not rmsnorms and they have bias
+ self.input_layernorm = FastLayerNorm.load(
+ prefix=f"{prefix}.input_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ self.post_attention_layernorm = FastLayerNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ else:
+ self.mlp = LlamaMLP(
+ prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
+ )
+ self.input_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.input_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ self.post_attention_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ # Used in Granite
+ # This could eventually be baked into the weights like we do for the embeddings/lm_head
+ # but this would mean modifying the lora code
+ self.residual_multiplier = getattr(config, "residual_multiplier", None)
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ cross_attention_states,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ):
+ normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ normed_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+ if self.residual_multiplier is not None:
+ attn_output *= self.residual_multiplier
+
+ normed_attn_res_output, attn_res = self.post_attention_layernorm(
+ attn_output, res
+ )
+
+ mlp_output = self.mlp(normed_attn_res_output, adapter_data)
+ if self.residual_multiplier is not None:
+ mlp_output *= self.residual_multiplier
+
+ return mlp_output, attn_res
+
+
+class FlashLlamaModel(torch.nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+
+ process_group = weights.process_group
+ self.tp_rank = process_group.rank()
+ self.tp_world_size = process_group.size()
+
+ # Skip fp8 quant for first and last layers
+ self.layers = nn.ModuleList()
+ self.cross_attention_layers = getattr(config, "cross_attention_layers", [])
+ with no_fp8(weights):
+ self.layers.append(
+ FlashLlamaLayer(
+ index=0,
+ prefix=f"{prefix}.layers.0",
+ config=config,
+ weights=weights,
+ )
+ )
+
+ # Skip first and last layers
+ for layer_id in range(1, config.num_hidden_layers - 1):
+ if layer_id in self.cross_attention_layers:
+ from text_generation_server.models.custom_modeling.flash_mllama import (
+ FlashLlamaCrossLayer,
+ )
+
+ self.layers.append(
+ FlashLlamaCrossLayer(
+ index=layer_id,
+ prefix=(f"{prefix}.layers.{layer_id}"),
+ config=config,
+ weights=weights,
+ )
+ )
+ else:
+ self.layers.append(
+ FlashLlamaLayer(
+ index=layer_id,
+ prefix=(f"{prefix}.layers.{layer_id}"),
+ config=config,
+ weights=weights,
+ )
+ )
+
+ with no_fp8(weights):
+ last_layer_id = config.num_hidden_layers - 1
+ self.layers.append(
+ FlashLlamaLayer(
+ index=last_layer_id,
+ prefix=(f"{prefix}.layers.{last_layer_id}"),
+ config=config,
+ weights=weights,
+ )
+ )
+
+ self.norm = FastRMSNorm.load(
+ prefix=f"{prefix}.norm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ self.gradient_checkpointing = False
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ adapter_data,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ cross_attention_states=None,
+ ) -> torch.Tensor:
+ hidden_states = inputs_embeds
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ adapter_data,
+ cross_attention_states,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashLlamaForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights, name=None):
+ if name is None:
+ name = "model"
+ super().__init__()
+ with no_fp8(weights):
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=(
+ f"{name}.embed_tokens"
+ if not prefix
+ else f"{prefix}.{name}.embed_tokens"
+ ),
+ weights=weights,
+ )
+ self.model = FlashLlamaModel(
+ prefix=name if not prefix else f"{prefix}.{name}",
+ config=config,
+ weights=weights,
+ )
+ if config.tie_word_embeddings:
+ suffix = "model.embed_tokens"
+ else:
+ suffix = "lm_head"
+
+ # Used in Granite
+ embedding_multiplier = getattr(config, "embedding_multiplier", None)
+ if embedding_multiplier is not None:
+ self.embed_tokens.weight.data *= embedding_multiplier
+ prefix = suffix if not prefix or name != "model" else f"{prefix}.{suffix}"
+ with no_fp8(weights):
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix,
+ weights,
+ )
+
+ # Used in Granite
+ self.logits_scaling = getattr(config, "logits_scaling", None)
+ if self.logits_scaling is not None and self.lm_head.head is not None:
+ try:
+ # Scale the weights directly
+ self.lm_head.head.linear.weight.data /= self.logits_scaling
+ self.logits_scaled = True
+ except Exception:
+ self.logits_scaled = False
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ cross_attention_states=None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ inputs_embeds = self.embed_tokens(input_ids)
+ hidden_states = self.model(
+ inputs_embeds,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data=adapter_data,
+ cross_attention_states=cross_attention_states,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.lm_head(hidden_states)
+
+ # Used in Granite
+ if self.logits_scaling is not None and not self.logits_scaled:
+ logits /= self.logits_scaling
+ if speculative_logits is not None:
+ speculative_logits /= self.logits_scaling
+
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py
new file mode 100644
index 000000000..88548042d
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py
@@ -0,0 +1,285 @@
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch Llava-NeXT model."""
+
+from typing import List, Optional, Tuple
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from transformers.activations import ACT2FN
+from transformers.image_processing_utils import select_best_resolution
+
+from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
+from text_generation_server.models.custom_modeling.vlm import (
+ load_text_model,
+ load_vision_model,
+)
+from text_generation_server.layers import (
+ TensorParallelColumnLinear,
+ TensorParallelRowLinear,
+)
+
+
+def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
+ """
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
+
+ Args:
+ image_size (`tuple`):
+ The size of the input image in the format (height, width).
+ grid_pinpoints (`List`):
+ A list containing possible resolutions. Each item in the list should be a tuple or list
+ of the form `(height, width)`.
+ patch_size (`int`):
+ The size of each image patch.
+
+ Returns:
+ tuple: The shape of the image patch grid in the format (height, width).
+ """
+ if not isinstance(grid_pinpoints, list):
+ raise ValueError("grid_pinpoints should be a list of tuples or lists")
+
+ height, width = select_best_resolution(image_size, grid_pinpoints)
+ return height // patch_size, width // patch_size
+
+
+def unpad_image(tensor, original_size):
+ """
+ Unpads a PyTorch tensor of a padded and resized image.
+
+ Args:
+ tensor (`torch.Tensor`):
+ The image tensor, assumed to be of shape (num_channels, height, width).
+ original_size (`tuple`):
+ The original size of the image (height, width).
+
+ Returns:
+ `torch.Tensor`: The unpadded image tensor.
+ """
+ original_height, original_width = original_size
+ current_height, current_width = tensor.shape[1:]
+
+ original_aspect_ratio = original_width / original_height
+ current_aspect_ratio = current_width / current_height
+
+ if original_aspect_ratio > current_aspect_ratio:
+ scale_factor = current_width / original_width
+ new_height = int(original_height * scale_factor)
+ padding = (current_height - new_height) // 2
+ unpadded_tensor = tensor[:, padding : current_height - padding, :]
+ else:
+ scale_factor = current_height / original_height
+ new_width = int(original_width * scale_factor)
+ padding = (current_width - new_width) // 2
+ unpadded_tensor = tensor[:, :, padding : current_width - padding]
+
+ return unpadded_tensor
+
+
+# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
+class LlavaNextMultiModalProjector(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+
+ self.linear_1 = TensorParallelColumnLinear.load(
+ prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
+ )
+ self.act = ACT2FN[config.projector_hidden_act]
+ self.linear_2 = TensorParallelRowLinear.load(
+ prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
+ )
+
+ def forward(self, image_features):
+ hidden_states = self.linear_1(image_features)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
+class FlashLlavaNextForConditionalGeneration(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ config.vision_config.quantize = config.quantize
+ vision_config = config.vision_config
+ # Instead of selecting in hidden_states[-2].
+ # Instead compute only the n -2 + 1 layers and don't pool
+ if config.vision_feature_layer < 0:
+ vision_config.num_hidden_layers += config.vision_feature_layer + 1
+ else:
+ vision_config.num_hidden_layers = config.vision_feature_layer + 1
+ self.vision_tower = load_vision_model(
+ prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
+ config=config.vision_config,
+ weights=weights,
+ )
+
+ self.multi_modal_projector = LlavaNextMultiModalProjector(
+ prefix="multi_modal_projector", config=config, weights=weights
+ )
+
+ self.image_newline = weights.get_tensor("image_newline")
+
+ self.vocab_size = config.text_config.vocab_size
+ self.config = config
+ config.text_config.quantize = config.quantize
+ config.text_config.speculator = config.speculator
+ self.text_model = load_text_model(
+ prefix="language_model" if not prefix else f"{prefix}.language_model",
+ config=config.text_config,
+ weights=weights,
+ )
+ self.pad_token_id = (
+ config.pad_token_id if config.pad_token_id is not None else -1
+ )
+
+ def _merge_input_ids_with_image_features(
+ self,
+ input_ids: torch.Tensor,
+ inputs_embeds: torch.Tensor,
+ image_features: torch.Tensor,
+ ):
+ """In place merges in vision_embeddings with inputs_embeds."""
+ mask = torch.where(input_ids == self.config.image_token_index)
+ # Let's pray we have enabled enough slots !
+ try:
+ inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
+ except Exception as e:
+ raise RuntimeError(
+ f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
+ )
+ return inputs_embeds
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ pixel_values: torch.FloatTensor = None,
+ # Unused for this model
+ pixel_attention_mask=None,
+ image_sizes: Optional[torch.LongTensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ ):
+ inputs_embeds = self.text_model.embed_tokens(input_ids)
+ if pixel_values is not None and len(pixel_values) > 0:
+ # num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
+ # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
+ # 1. Extract the input embeddings
+
+ # 2. Merge text and images
+ num_images, num_patches, channels, height, width = pixel_values.shape
+ pixel_values = pixel_values.view(
+ num_images * num_patches, channels, height, width
+ )
+ image_features = self.vision_tower(pixel_values)
+
+ # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
+ # Already done within the clip model
+ selected_image_feature = image_features.last_hidden_state
+
+ if self.config.vision_feature_select_strategy == "default":
+ selected_image_feature = selected_image_feature[:, 1:]
+ elif self.config.vision_feature_select_strategy == "full":
+ selected_image_feature = selected_image_feature
+ else:
+ raise RuntimeError(
+ f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
+ )
+
+ image_features = self.multi_modal_projector(selected_image_feature)
+
+ # split up image_features for each of the individual images
+ # hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
+ # if we assume each image has 5 image features (base image + 4 patches)
+ split_sizes = [num_patches] * num_images
+ image_features = torch.split(image_features, split_sizes, dim=0)
+
+ # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
+ height = width = (
+ self.config.vision_config.image_size
+ // self.config.vision_config.patch_size
+ )
+
+ new_image_features = []
+ for image_idx, image_feature in enumerate(image_features):
+ if image_feature.shape[0] > 1:
+ base_image_feature = image_feature[0]
+ image_feature = image_feature[1:]
+
+ if height * width != base_image_feature.shape[0]:
+ raise ValueError(
+ "The number of patches is not consistent with the image size."
+ )
+
+ # Dimensions are intentionally swapped to be bug-compatible with
+ # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(
+ image_sizes[image_idx],
+ self.config.image_grid_pinpoints,
+ self.config.vision_config.image_size,
+ )
+ image_feature = image_feature.view(
+ num_patch_height, num_patch_width, height, width, -1
+ )
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
+ image_feature = torch.cat(
+ (
+ image_feature,
+ self.image_newline[:, None, None].expand(
+ *image_feature.shape[:-1], 1
+ ),
+ ),
+ dim=-1,
+ )
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
+ image_feature = torch.cat(
+ (base_image_feature, image_feature), dim=0
+ )
+ else:
+ image_feature = image_feature[0]
+ image_feature = torch.cat(
+ (image_feature, self.image_newline[None]), dim=0
+ )
+ new_image_features.append(image_feature)
+ image_features = torch.stack(new_image_features, dim=0)
+
+ inputs_embeds = self._merge_input_ids_with_image_features(
+ input_ids, inputs_embeds, image_features
+ )
+
+ hidden_states = self.text_model.model(
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ cu_seqlen_prefill=cu_seqlen_prefill,
+ kv_cache=kv_cache,
+ slots=slots,
+ seqlen=seqlen,
+ hpu_attention_meta=hpu_attention_meta,
+ adapter_data=adapter_data,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.text_model.lm_head(hidden_states)
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py
new file mode 100644
index 000000000..d23d4f679
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py
@@ -0,0 +1,481 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+from typing import Optional, List, Tuple
+
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers import (
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+ TensorParallelMultiAdapterLinear,
+ TensorParallelAdapterRowLinear,
+)
+from text_generation_server.layers.rotary import PositionRotaryEmbedding
+from text_generation_server.layers.layernorm import (
+ FastRMSNorm,
+)
+
+
+class MistralConfig(PretrainedConfig):
+ model_type = "mistral"
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ hidden_size=4096,
+ intermediate_size=14336,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=8,
+ hidden_act="silu",
+ max_position_embeddings=4096 * 32,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ pretraining_tp=1,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ sliding_window=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.sliding_window = sliding_window
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+class MistralAttention(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights, layer_id):
+ super().__init__()
+ self.max_past = (
+ config.sliding_window if config.sliding_window is not None else -1
+ )
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ if hasattr(config, "head_dim"):
+ self.head_size = config.head_dim
+ else:
+ self.head_size = self.hidden_size // self.num_heads
+
+ self.rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=self.head_size,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
+ self.softmax_scale = self.head_size**-0.5
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+
+ query_key_value = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=False,
+ )
+
+ self.query_key_value = TensorParallelMultiAdapterLinear.load(
+ query_key_value,
+ layer_id,
+ ["q_proj", "k_proj", "v_proj"],
+ sizes=[
+ self.head_size * config.num_attention_heads,
+ self.head_size * config.num_key_value_heads,
+ self.head_size * config.num_key_value_heads,
+ ],
+ process_group=weights.process_group,
+ )
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.o_proj = TensorParallelAdapterRowLinear.load(
+ o_proj,
+ layer_id,
+ "o_proj",
+ process_group=weights.process_group,
+ )
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ ):
+ qkv = self.query_key_value(hidden_states, adapter_data)
+ query, kv = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ 2 * self.head_size * self.num_key_value_heads,
+ ],
+ dim=1,
+ )
+ query = query.view(-1, self.num_heads, self.head_size)
+ kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
+
+ self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
+
+ kv_cache.store(
+ key=kv[:, 0],
+ value=kv[:, 1],
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=kv[:, 0],
+ value=kv[:, 1],
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ window_size_left=self.max_past,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.o_proj(
+ attn_output.view(-1, self.num_heads * self.head_size), adapter_data
+ )
+
+
+class MistralMLP(nn.Module):
+ def __init__(self, prefix: str, config, weights, layer_id):
+ super().__init__()
+ self.hidden_act = config.hidden_act
+ self.act = (
+ ACT2FN[self.hidden_act]
+ if "gelu" not in self.hidden_act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh"
+ if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
+ else "none"
+ ),
+ )
+ )
+ # Fuse gate and up proj
+ gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=False,
+ )
+ self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
+ gate_up_proj,
+ layer_id,
+ ["gate_proj", "up_proj"],
+ sizes=[
+ config.intermediate_size,
+ config.intermediate_size,
+ ],
+ process_group=weights.process_group,
+ )
+
+ down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=False,
+ )
+
+ self.down_proj = TensorParallelAdapterRowLinear.load(
+ down_proj,
+ layer_id,
+ "down_proj",
+ process_group=weights.process_group,
+ )
+ self.intermediate_size = (
+ config.intermediate_size // weights.process_group.size()
+ )
+
+ # TODO: This is a hotfix to be removed & properly refactored.
+ self.quantize = config.quantize
+
+ def forward(self, hidden_states, adapter_data):
+ gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
+ gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
+ return self.down_proj(
+ self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
+ )
+
+
+class MistralLayer(nn.Module):
+ def __init__(self, prefix: str, config, weights, layer_id):
+ super().__init__()
+ self.self_attn = MistralAttention(
+ prefix=f"{prefix}.self_attn",
+ config=config,
+ weights=weights,
+ layer_id=layer_id,
+ )
+ self.mlp = MistralMLP(
+ prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
+ )
+
+ self.input_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ ):
+ normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ normed_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ )
+
+ # faster post attention rms norm
+ normed_attn_res_output, attn_res = self.post_attention_layernorm(
+ attn_output, res
+ )
+
+ mlp_output = self.mlp(normed_attn_res_output, adapter_data)
+
+ return mlp_output, attn_res
+
+
+class MistralModel(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ process_group = weights.process_group
+ self.tp_rank = process_group.rank()
+ self.tp_world_size = process_group.size()
+ self.layers = nn.ModuleList(
+ [
+ MistralLayer(
+ prefix=f"{prefix}.layers.{layer_id}",
+ config=config,
+ weights=weights,
+ layer_id=layer_id,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = FastRMSNorm.load(
+ prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
+ )
+
+ self.gradient_checkpointing = False
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ adapter_data: Optional[torch.Tensor] = None,
+ ):
+ hidden_states = inputs_embeds
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ )
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+
+class FlashMistralForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights, name=None):
+ if name is None:
+ name = "model"
+ super().__init__()
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=(
+ f"{name}.embed_tokens"
+ if not prefix
+ else f"{prefix}.{name}.embed_tokens"
+ ),
+ weights=weights,
+ )
+ self.model = MistralModel(
+ prefix=name if not prefix else f"{prefix}.{name}",
+ config=config,
+ weights=weights,
+ )
+ self.lm_head = SpeculativeHead.load(
+ config,
+ # TODO dirty hack for idefics2.
+ prefix=(
+ "lm_head" if not prefix or name != "model" else f"{prefix}.lm_head"
+ ),
+ weights=weights,
+ )
+ self.max_past = config.sliding_window
+ self.max_past_tensor = (
+ torch.tensor(config.sliding_window, device=weights.device)
+ if self.max_past is not None
+ else None
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ inputs_embeds = self.embed_tokens(input_ids)
+ hidden_states = self.model(
+ inputs_embeds,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ adapter_data,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits = self.lm_head(hidden_states)
+ return logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py
new file mode 100644
index 000000000..1ef6be481
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py
@@ -0,0 +1,515 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+import torch.distributed
+from torch import nn
+from transformers.configuration_utils import PretrainedConfig
+
+from text_generation_server.layers import (
+ FastLinear,
+ SpeculativeHead,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ TensorParallelRowLinear,
+ get_linear,
+)
+from text_generation_server.layers.attention import (
+ Seqlen,
+ attention,
+ paged_attention,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers.layernorm import FastRMSNorm
+from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
+from text_generation_server.layers.rotary import PositionRotaryEmbedding
+from text_generation_server.utils.weights import UnquantizedWeight
+
+
+class MixtralConfig(PretrainedConfig):
+ model_type = "mixtral"
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ hidden_size=4096,
+ intermediate_size=14336,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=8,
+ hidden_act="silu",
+ max_position_embeddings=4096 * 32,
+ initializer_range=0.02,
+ rms_norm_eps=1e-05,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ pretraining_tp=1,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ sliding_window=None,
+ num_experts_per_tok=2,
+ num_local_experts=8,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.sliding_window = sliding_window
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.num_experts_per_tok = num_experts_per_tok
+ self.num_local_experts = num_local_experts
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+def promote_scalar(x: torch.Tensor) -> torch.Tensor:
+ return x.view(1) if len(x.size()) == 0 else x
+
+
+def load_attention(config, prefix: str, weights):
+ if config.num_attention_heads != config.num_key_value_heads:
+ return _load_gqa(config, prefix, weights)
+ else:
+ return TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=False,
+ )
+
+
+def _load_gqa(config, prefix: str, weights):
+ assert config.hidden_size % config.num_attention_heads == 0
+ assert config.num_attention_heads % weights.process_group.size() == 0
+
+ weight = weights.get_multi_weights_col(
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ )
+
+ if isinstance(weight, UnquantizedWeight):
+ weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
+
+ head_size = config.hidden_size // config.num_attention_heads
+ num_heads = config.num_attention_heads // weights.process_group.size()
+ num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
+ assert list(weight.weight.shape) == [
+ (num_heads + 2 * num_key_value_heads) * head_size,
+ config.hidden_size,
+ ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
+
+ return TensorParallelColumnLinear(get_linear(weight, bias=None))
+
+
+def _load_experts(config, prefix: str, mat, weights):
+ if config.quantize is not None:
+ raise NotImplementedError("Mixtral does not support weight quantization yet.")
+
+ assert mat in ["w1", "w2", "w3"]
+
+ world_size = weights.process_group.size()
+ rank = weights.process_group.rank()
+
+ assert (
+ config.intermediate_size % world_size == 0
+ ), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
+
+ block_size = config.intermediate_size // world_size
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+
+ tensor = torch.empty(
+ (config.num_local_experts * block_size, config.hidden_size),
+ dtype=weights.dtype,
+ device=weights.device,
+ )
+
+ for i in range(config.num_local_experts):
+ slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
+
+ if mat == "w2":
+ expert_slice = slice_[:, start:stop].t().contiguous()
+ else:
+ expert_slice = slice_[start:stop]
+ tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
+ dtype=weights.dtype
+ ).to(device=weights.device)
+ return tensor
+
+
+class MixtralAttention(torch.nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights,
+ ):
+ super().__init__()
+ self.max_past = (
+ config.sliding_window if config.sliding_window is not None else -1
+ )
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ self.head_size = self.hidden_size // self.num_heads
+
+ self.rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=self.head_size,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
+ self.softmax_scale = self.head_size**-0.5
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+
+ self.query_key_value = load_attention(config, prefix, weights)
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ self.o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ qkv = self.query_key_value(hidden_states)
+ query, kv = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ 2 * self.head_size * self.num_key_value_heads,
+ ],
+ dim=1,
+ )
+ query = query.view(-1, self.num_heads, self.head_size)
+ kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
+
+ self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
+
+ kv_cache.store(
+ key=kv[:, 0],
+ value=kv[:, 1],
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=kv[:, 0],
+ value=kv[:, 1],
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ window_size_left=self.max_past,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
+
+
+@torch.jit.script
+def select_experts(gate_logits: torch.Tensor, top_k: int):
+ # all_probs: (sequence_length, n_experts) and upcast for softmax
+ all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
+ # weights, selected_experts: (sequence_length, top-k)
+ weights, selected_experts = torch.topk(all_probs, top_k, dim=-1)
+ weights /= weights.sum(dim=-1, keepdim=True)
+ weights = weights.view(-1)
+ selected_experts = selected_experts.view(-1)
+
+ return selected_experts, weights
+
+
+@torch.jit.script
+def round_up(x: torch.Tensor, value: int):
+ return torch.div(x + (value - 1), value, rounding_mode="trunc") * value
+
+
+class MixtralMoE(nn.Module):
+ def __init__(
+ self, prefix, config: MixtralConfig, moe_layer_cls: Type[MoELayer], weights
+ ):
+ super().__init__()
+
+ # gating
+ self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
+
+ self.moe = moe_layer_cls(
+ n_expert_group=None,
+ n_experts=config.num_local_experts,
+ prefix=f"{prefix}.experts",
+ renormalize=True,
+ topk=config.num_experts_per_tok,
+ topk_group=None,
+ weights=weights,
+ gate_proj_name="w1",
+ up_proj_name="w3",
+ down_proj_name="w2",
+ )
+ assert isinstance(self.moe, MoELayer)
+
+ self.process_group = weights.process_group
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # router_logits: (num_tokens, n_experts)
+ router_logits = self.gate(x)
+ out = self.moe(x, gating_output=router_logits)
+
+ # Reduce sum
+ if self.process_group.size() > 1:
+ torch.distributed.all_reduce(out, group=self.process_group)
+
+ return out.view(*x.shape)
+
+
+class MixtralLayer(nn.Module):
+ def __init__(self, prefix: str, layer_id, config, weights):
+ super().__init__()
+ prefix = f"{prefix}.layers.{layer_id}"
+
+ self.self_attn = MixtralAttention(
+ prefix=f"{prefix}.self_attn", config=config, weights=weights
+ )
+
+ moe_layer_cls = (
+ SparseMoELayer if SparseMoELayer.is_supported(weights) else DenseMoELayer
+ )
+ self.moe = MixtralMoE(
+ f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
+ )
+
+ self.input_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ normed_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ # faster post attention rms norm
+ normed_attn_res_output, attn_res = self.post_attention_layernorm(
+ attn_output, res
+ )
+
+ moe_output = self.moe(normed_attn_res_output)
+
+ return moe_output, attn_res
+
+
+class MixtralModel(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=(
+ "model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
+ ),
+ weights=weights,
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ MixtralLayer(
+ "model" if not prefix else f"{prefix}.model",
+ layer_id,
+ config,
+ weights,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = FastRMSNorm.load(
+ prefix="model.norm" if not prefix else f"{prefix}.model.norm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ hidden_states = self.embed_tokens(input_ids)
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashMixtralForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ self.model = MixtralModel(prefix, config, weights)
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix="lm_head" if not prefix else f"{prefix}.lm_head",
+ weights=weights,
+ )
+ self.max_past = config.sliding_window
+ self.max_past_tensor = (
+ torch.tensor(config.sliding_window, device=weights.device)
+ if self.max_past is not None
+ else None
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+
+ hidden_states = self.model(
+ input_ids,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits = self.lm_head(hidden_states)
+ return logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py
new file mode 100644
index 000000000..216642e08
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py
@@ -0,0 +1,986 @@
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Mllama model."""
+
+from typing import Optional, Tuple, List
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from transformers.activations import ACT2FN
+import torch.nn.functional as F
+
+from text_generation_server.layers import (
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ TensorParallelRowLinear,
+ FastLinear,
+)
+from text_generation_server.layers.attention import (
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.models.custom_modeling.flash_llama_modeling import (
+ FlashLlamaForCausalLM,
+)
+from habana_frameworks.torch.hpex.kernels import FusedSDPA
+from vllm_hpu_extension.utils import ModuleFusedSDPA
+
+
+def _prepare_aspect_ratio_attention_mask(
+ aspect_ratio_mask: torch.Tensor,
+ num_patches: int,
+ target_length: int,
+ dtype: torch.dtype,
+) -> torch.Tensor:
+ # Expand aspect ratio mask to target_length
+ batch_size, max_num_tiles = aspect_ratio_mask.shape
+ attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype)
+ attention_mask = attention_mask.repeat(1, 1, target_length, 1)
+
+ # Mask padding patches
+ pad_patches = target_length - num_patches
+ attention_mask[:, :, -pad_patches:] = 0
+
+ # Invert the mask (0 -> 1, 1 -> 0)
+ attention_mask = 1 - attention_mask
+
+ # Reshape to 2D and create 4D attention mask
+ # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
+ attention_mask = attention_mask.reshape(
+ batch_size, max_num_tiles * target_length, 1
+ )
+ attention_mask = (
+ attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min
+ )
+ attention_mask = attention_mask.unsqueeze(1)
+
+ return attention_mask
+
+
+# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
+def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ min_dtype: float,
+ cache_position: torch.Tensor,
+ batch_size: int,
+):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ device (`torch.device`):
+ The device to plcae the 4D attention mask on.
+ min_dtype (`float`):
+ The minimum value representable with the dtype `dtype`.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ causal_mask = torch.full(
+ (sequence_length, target_length),
+ fill_value=min_dtype,
+ dtype=dtype,
+ device=device,
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(
+ target_length, device=device
+ ) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = (
+ causal_mask.clone()
+ ) # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = (
+ causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[
+ :, :, :, :mask_length
+ ].masked_fill(padding_mask, min_dtype)
+
+ return causal_mask
+
+
+def _prepare_cross_attention_mask(
+ cross_attention_mask: torch.Tensor,
+ num_vision_tokens: int,
+ dtype: str,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ # reshape so it can be used by attn module
+ batch_size, text_total_length, *_ = cross_attention_mask.shape
+ cross_attention_mask = cross_attention_mask.repeat_interleave(
+ num_vision_tokens, dim=3
+ )
+ cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1)
+ cross_attention_mask = cross_attention_mask.unsqueeze(1)
+
+ # invert the mask
+ inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype)
+ cross_attention_mask = inverted_cross_attn_mask.masked_fill(
+ inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min
+ )
+
+ # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's
+ # last dimension contains negative infinity values, otherwise it's 1
+ negative_inf_value = torch.finfo(dtype).min
+ full_text_row_masked_out_mask = (
+ (cross_attention_mask != negative_inf_value)
+ .any(dim=-1)
+ .type_as(cross_attention_mask)[..., None]
+ )
+ cross_attention_mask *= full_text_row_masked_out_mask
+
+ return cross_attention_mask, full_text_row_masked_out_mask
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision
+class MllamaVisionMLP(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = TensorParallelColumnLinear.load(
+ prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True
+ )
+ self.fc2 = TensorParallelRowLinear.load(
+ prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class MllamaVisionSdpaAttention(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+
+ self.embed_dim = config.hidden_size
+ self.head_dim = config.hidden_size // config.attention_heads
+ self.num_heads = config.attention_heads // weights.process_group.size()
+
+ self.qkv_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=False,
+ )
+ self.o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=False,
+ )
+
+ def forward(
+ self,
+ hidden_state: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ qkv = self.qkv_proj(hidden_state)
+ query, key, value = qkv.split(
+ [
+ self.head_dim * self.num_heads,
+ self.head_dim * self.num_heads,
+ self.head_dim * self.num_heads,
+ ],
+ dim=2,
+ )
+
+ batch_size, q_seq_len, _ = query.shape
+ _, kv_seq_len, _ = key.shape
+
+ query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim)
+ key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
+ value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
+
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ attn_output = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
+
+ output = self.o_proj(attn_output)
+ return output
+
+
+class MllamaVisionEncoderLayer(nn.Module):
+ def __init__(self, *, prefix, config, weights, is_gated: bool):
+ super().__init__()
+
+ self.hidden_size = config.hidden_size
+ self.num_attention_heads = config.attention_heads
+ self.is_gated = is_gated
+ self.intermediate_size = config.intermediate_size
+
+ self.self_attn = MllamaVisionSdpaAttention(
+ prefix=f"{prefix}.self_attn", config=config, weights=weights
+ )
+ self.mlp = MllamaVisionMLP(
+ prefix=f"{prefix}.mlp", config=config, weights=weights
+ )
+
+ self.input_layernorm = nn.LayerNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=1e-05
+ )
+ self.post_attention_layernorm = nn.LayerNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=1e-05
+ )
+
+ # there used to be an if else here, no code path
+ if is_gated:
+ self.gate_attn = nn.Parameter(
+ weights.get_tensor(f"{prefix}.gate_attn"), requires_grad=False
+ )
+ self.gate_ffn = nn.Parameter(
+ weights.get_tensor(f"{prefix}.gate_ffn"), requires_grad=False
+ )
+
+ def forward(
+ self,
+ hidden_state: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ # Self Attention
+ residual = hidden_state
+ hidden_state = self.input_layernorm(hidden_state)
+ hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask)
+ gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()
+ hidden_state = residual + gate_attn * hidden_state
+
+ # Feed forward
+ residual = hidden_state
+ hidden_state = self.post_attention_layernorm(hidden_state)
+ hidden_state = self.mlp(hidden_state)
+ gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()
+ hidden_state = residual + gate_ffn * hidden_state
+ return hidden_state
+
+
+class MllamaVisionEncoder(nn.Module):
+ def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int):
+ super().__init__()
+ self.config = config
+ self.layers = [
+ MllamaVisionEncoderLayer(
+ prefix=f"{prefix}.layers.{i}",
+ config=config,
+ weights=weights,
+ is_gated=is_gated,
+ )
+ for i in range(num_layers)
+ ]
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ encoder_states = [hidden_states]
+ for encoder_layer in self.layers:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ )
+
+ hidden_states = layer_outputs
+ encoder_states.append(hidden_states)
+
+ return hidden_states, encoder_states
+
+
+class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+ self.max_num_tiles = config.max_num_tiles
+ self.hidden_size = config.hidden_size
+ self.max_aspect_ratio_id = config.max_aspect_ratio_id
+
+ self.embedding = TensorParallelEmbedding(
+ prefix=f"{prefix}.embedding", weights=weights
+ )
+ self.gate = nn.Parameter(
+ weights.get_tensor(f"{prefix}.gate"), requires_grad=False
+ )
+
+ def forward(
+ self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
+ ) -> torch.Tensor:
+ embeddings = self.embedding(aspect_ratio_ids)
+ embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)
+
+ # Always gated.
+ embeddings = embeddings * self.gate.tanh()
+
+ hidden_state = hidden_state + embeddings
+ return hidden_state
+
+
+class MllamaPrecomputedPositionEmbedding(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+ self.max_num_tiles = config.max_num_tiles
+ self.max_aspect_ratio_id = config.max_aspect_ratio_id
+ self.num_patches = (config.image_size // config.patch_size) ** 2 + 1
+ self.hidden_size = config.hidden_size
+ self.scale = config.hidden_size**-0.5
+
+ self.gate = nn.Parameter(
+ weights.get_tensor(f"{prefix}.gate"), requires_grad=False
+ )
+
+ # position embedding
+ embedding = nn.Parameter(
+ weights.get_tensor(f"{prefix}.embedding"), requires_grad=False
+ )
+ self.gated_position_embedding = (1 - self.gate.tanh()) * embedding
+ self.tile_embedding = TensorParallelEmbedding(
+ prefix=f"{prefix}.tile_embedding", weights=weights
+ )
+
+ def forward(
+ self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
+ ) -> torch.Tensor:
+ # position embeddings
+ hidden_state = hidden_state + self.gated_position_embedding.view(
+ 1, 1, self.num_patches, self.hidden_size
+ )
+
+ # precomputed tile position embeddings
+ tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
+ batch_size = hidden_state.shape[0]
+ tile_position_embedding = tile_position_embedding.reshape(
+ batch_size, self.max_num_tiles, self.num_patches, self.hidden_size
+ )
+ gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding
+ hidden_state = hidden_state + gated_tile_position_embedding
+
+ return hidden_state
+
+
+class MllamaVisionModel(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+ self.max_num_tiles = config.max_num_tiles
+ self.hidden_size = config.hidden_size
+ self.num_channels = config.num_channels
+ self.intermediate_layers_indices = config.intermediate_layers_indices
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
+ self.scale = config.hidden_size**-0.5
+ self.dtype = weights.dtype
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.hidden_size,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding="valid",
+ bias=False,
+ )
+ self.patch_embedding.weight = nn.Parameter(
+ weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
+ )
+
+ self.class_embedding = nn.Parameter(
+ weights.get_tensor(f"{prefix}.class_embedding"), requires_grad=False
+ )
+
+ self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(
+ prefix=f"{prefix}.gated_positional_embedding",
+ config=config,
+ weights=weights,
+ )
+
+ self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
+ prefix=f"{prefix}.pre_tile_positional_embedding",
+ config=config,
+ weights=weights,
+ )
+ self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
+ prefix=f"{prefix}.post_tile_positional_embedding",
+ config=config,
+ weights=weights,
+ )
+
+ ## layer norms
+ self.layernorm_pre = nn.LayerNorm.load(
+ prefix=f"{prefix}.layernorm_pre",
+ weights=weights,
+ # torch default
+ eps=1e-05,
+ )
+ self.layernorm_post = nn.LayerNorm.load(
+ prefix=f"{prefix}.layernorm_post",
+ weights=weights,
+ # torch default
+ eps=1e-05,
+ )
+
+ ## encoders
+ self.transformer = MllamaVisionEncoder(
+ prefix=f"{prefix}.transformer",
+ config=config,
+ weights=weights,
+ is_gated=False,
+ num_layers=config.num_hidden_layers,
+ )
+ self.global_transformer = MllamaVisionEncoder(
+ prefix=f"{prefix}.global_transformer",
+ config=config,
+ weights=weights,
+ is_gated=True,
+ num_layers=config.num_global_layers,
+ )
+
+ def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ batch_size, _, hidden_size = hidden_state.shape
+ class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
+ hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
+ return hidden_state
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ aspect_ratio_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ (
+ batch_size,
+ num_concurrent_media,
+ num_tiles,
+ num_channels,
+ height,
+ width,
+ ) = pixel_values.shape
+
+ pixel_values = pixel_values.reshape(
+ batch_size * num_concurrent_media * num_tiles, num_channels, height, width
+ )
+ aspect_ratio_ids = aspect_ratio_ids.reshape(
+ batch_size * num_concurrent_media, -1
+ )
+
+ # patch embedding
+ patch_embeds = self.patch_embedding(pixel_values)
+ hidden_state = patch_embeds.flatten(2).transpose(1, 2)
+
+ # tile embeddings
+ _, num_patches, dim = hidden_state.shape
+ hidden_state = hidden_state.reshape(
+ batch_size * num_concurrent_media, num_tiles, -1, dim
+ )
+ hidden_state = self.pre_tile_positional_embedding(
+ hidden_state, aspect_ratio_ids
+ )
+
+ # apply cls token
+ hidden_state = hidden_state.reshape(
+ batch_size * num_concurrent_media * num_tiles, num_patches, dim
+ )
+ hidden_state = self.apply_class_embedding(hidden_state)
+ num_patches += 1
+
+ # apply position embeddings
+ hidden_state = hidden_state.reshape(
+ batch_size * num_concurrent_media, num_tiles, num_patches, dim
+ )
+ hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids)
+
+ # apply encoder
+ hidden_state = self.layernorm_pre(hidden_state)
+
+ # Compute the number of tokens to pad
+ num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
+ # Compute padding tuple for pad function
+ padding = (
+ 0,
+ 0,
+ 0,
+ num_padding_patches,
+ ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
+ # Pad the tensor
+ hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
+ slice_index = -num_padding_patches if num_padding_patches > 0 else None
+
+ if attention_mask is not None:
+ attention_mask = attention_mask.reshape(
+ batch_size * num_concurrent_media, -1
+ )
+ attention_mask = _prepare_aspect_ratio_attention_mask(
+ aspect_ratio_mask=attention_mask,
+ num_patches=self.num_patches,
+ target_length=hidden_state.shape[2],
+ dtype=self.dtype,
+ )
+
+ hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim)
+ hidden_state, all_intermediate_hidden_states = self.transformer(
+ hidden_state,
+ attention_mask=attention_mask,
+ )
+ intermediate_hidden_states = [
+ hidden_state
+ for idx, hidden_state in enumerate(all_intermediate_hidden_states)
+ if idx in self.intermediate_layers_indices
+ ]
+ intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1)
+
+ # apply global encoder
+ hidden_state = self.layernorm_post(hidden_state)
+ hidden_state = hidden_state.reshape(
+ batch_size * num_concurrent_media,
+ num_tiles,
+ num_patches + num_padding_patches,
+ dim,
+ )
+ hidden_state = self.post_tile_positional_embedding(
+ hidden_state, aspect_ratio_ids
+ )
+ hidden_state = hidden_state.reshape(
+ batch_size * num_concurrent_media,
+ num_tiles * (num_patches + num_padding_patches),
+ dim,
+ )
+ hidden_state, _ = self.global_transformer(
+ hidden_state, attention_mask=attention_mask
+ )
+ hidden_state = hidden_state.reshape(
+ batch_size * num_concurrent_media,
+ num_tiles,
+ num_patches + num_padding_patches,
+ dim,
+ )
+ hidden_state = hidden_state[:, :, :slice_index]
+
+ # adding intermediate layer outputs
+ hidden_state = hidden_state.reshape(
+ batch_size, num_concurrent_media, num_tiles, num_patches, dim
+ )
+ intermediate_hidden_states = intermediate_hidden_states.reshape(
+ batch_size * num_concurrent_media,
+ num_tiles,
+ num_patches + num_padding_patches,
+ -1,
+ )
+ intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]
+ intermediate_hidden_states = intermediate_hidden_states.reshape(
+ batch_size, num_concurrent_media, num_tiles, num_patches, -1
+ )
+ hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)
+ return hidden_state
+
+
+class MllamaTextCrossAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, *, prefix, config, weights, layer_idx):
+ super().__init__()
+ self.config = config
+ self.num_heads = self.config.num_attention_heads
+ self.num_key_value_heads = self.config.num_key_value_heads
+ self.dropout = config.dropout
+ self.hidden_size = config.hidden_size
+ self.head_size = config.hidden_size // self.num_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.layer_idx = layer_idx
+
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ self.num_key_value_heads // weights.process_group.size()
+ )
+
+ self.q_proj = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.q_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.k_proj = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.k_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.v_proj = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.v_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=False,
+ )
+
+ self.q_norm = MllamaTextRMSNorm.load(
+ prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.k_norm = MllamaTextRMSNorm.load(
+ prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.softmax_scale = self.head_size**-0.5
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cross_attention_states: Optional[torch.Tensor] = None,
+ # past_key_value=None,
+ # attention_mask: Optional[torch.Tensor] = None,
+ # cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+ # hidden_states = hidden_states.unsqueeze(0)
+ # bsz, q_len, _ = hidden_states.size()
+ (
+ cross_attention_states,
+ cu_seqlen_q,
+ cu_seqlen_k,
+ indices,
+ ) = cross_attention_states
+ bs = cu_seqlen_q.size(0) - 1
+ query_states = self.q_proj(hidden_states)
+ query_states = query_states.view(bs, -1, self.num_heads, self.head_size)
+ query_states = self.q_norm(query_states)
+
+ key_states = self.k_proj(cross_attention_states)
+ value_states = self.v_proj(cross_attention_states)
+ key_states = key_states.view(bs, -1, self.num_key_value_heads, self.head_size)
+ value_states = value_states.view(
+ bs, -1, self.num_key_value_heads, self.head_size
+ )
+ key_states = self.k_norm(key_states)
+
+ # key_states = key_states.repeat(1, self.num_key_value_groups, 1)
+ # value_states = value_states.repeat(1, self.num_key_value_groups, 1)
+
+ causal = False
+ # logger.info(
+ # f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
+ # )
+ # execute sdpa
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+ fsdpa_op = ModuleFusedSDPA(FusedSDPA)
+ attn_output = fsdpa_op(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=causal,
+ scale=None,
+ softmax_mode="None",
+ recompute_mode=None,
+ valid_sequence_lengths=None,
+ )
+ attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
+ attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
+
+ return attn_output
+
+
+# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText
+class MllamaTextMLP(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = (
+ config.intermediate_size // weights.process_group.size()
+ )
+ self.gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=False,
+ )
+ self.down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ shape = x.shape
+ gate_up_states = self.gate_up_proj(x)
+ gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size)
+ result = self.down_proj(
+ self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1]
+ )
+ return result
+
+
+class FlashLlamaCrossLayer(torch.nn.Module):
+ """Cross-attention transformer block with tanh-gated attention and feedforward."""
+
+ def __init__(self, *, prefix, config, weights, index) -> None:
+ layer_idx = index
+ super().__init__()
+ self.cross_attn = MllamaTextCrossAttention(
+ prefix=f"{prefix}.cross_attn",
+ config=config,
+ weights=weights,
+ layer_idx=layer_idx,
+ )
+
+ self.input_layernorm = MllamaTextRMSNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.cross_attn_attn_gate = torch.nn.Parameter(
+ weights.get_tensor(f"{prefix}.cross_attn_attn_gate"), requires_grad=False
+ )
+
+ self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
+ self.post_attention_layernorm = MllamaTextRMSNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ self.cross_attn_mlp_gate = torch.nn.Parameter(
+ weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False
+ )
+ self.layer_idx = layer_idx
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ cross_attention_states, # [ IB, ...]
+ hpu_attention_meta,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ if cross_attention_states is None:
+ return hidden_states, residual
+ if residual is not None:
+ hidden_states += residual
+
+ indices = cross_attention_states[-1]
+ out_hidden_states = hidden_states[:]
+ if len(indices) > 0:
+ assert max(indices) < hidden_states.shape[0]
+ hidden_states = hidden_states[indices]
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+
+ hidden_states = self.cross_attn(
+ hidden_states=hidden_states,
+ # attention_mask=cross_attention_mask,
+ cross_attention_states=cross_attention_states,
+ )
+ hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
+
+ out_hidden_states[indices] = hidden_states
+ hidden_states = out_hidden_states
+
+ return hidden_states, None
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText
+class MllamaTextRMSNorm(nn.Module):
+ def __init__(self, weight, eps):
+ super().__init__()
+ self.weight = weight
+ self.variance_epsilon = eps
+
+ @classmethod
+ def load(cls, *, prefix, weights, eps):
+ weight = nn.Parameter(
+ weights.get_tensor(f"{prefix}.weight"), requires_grad=False
+ )
+ return cls(weight=weight, eps=eps)
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class FlashMllamaForConditionalGeneration(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ config.vision_config.quantize = None
+ config.vision_config.speculator = config.speculator
+ config.text_config.quantize = config.quantize
+ config.text_config.speculator = config.speculator
+ config.text_config._attn_implementation = "sdpa"
+ self.hidden_size = config.text_config.hidden_size
+ self.vision_model = MllamaVisionModel(
+ prefix="vision_model", config=config.vision_config, weights=weights
+ )
+ self.multi_modal_projector = FastLinear.load(
+ prefix="multi_modal_projector", config=config, weights=weights, bias=True
+ )
+ self.text_model = FlashLlamaForCausalLM(
+ prefix="language_model", config=config.text_config, weights=weights
+ )
+ self.config = config
+ self.dtype = weights.dtype
+ self.device = weights.device
+
+ def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask):
+ if aspect_ratio_ids is None:
+ raise ValueError(
+ "`aspect_ratio_ids` must be provided if `pixel_values` is provided"
+ )
+ # logger.info(f"PIxel values {pixel_values.shape}")
+ batch_size = pixel_values.shape[0]
+ vision_states = self.vision_model(
+ pixel_values, aspect_ratio_ids, aspect_ratio_mask
+ )
+ cross_attention_states = self.multi_modal_projector(vision_states).reshape(
+ -1, vision_states.shape[-2], self.hidden_size
+ )
+ _, _, h = cross_attention_states.shape
+ cross_attention_states = cross_attention_states.view(batch_size, -1, h)
+ # logger.info(f"cross {cross_attention_states.shape}")
+ return cross_attention_states
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor],
+ adapter_data: Optional[torch.Tensor] = None,
+ # XXX: Putting these as optional so that the cuda warmup calls can go through.
+ cross_attention_states: Optional[torch.Tensor] = None,
+ image_indices=None,
+ ):
+ if cross_attention_states is not None:
+ seqlen_q = len(image_indices)
+ n_images = cross_attention_states.shape[0]
+ seqlen_k = cross_attention_states.shape[1]
+ device = cross_attention_states.device
+ if cu_seqlen_prefill is not None:
+ offset = 0
+ cu_q = []
+ indices = []
+ for index in image_indices:
+ cu_q.append(offset)
+ length = seqlen.input_lengths[index].item()
+ assert index < seqlen.cu_seqlen_q.shape[0]
+ input_ids_offset = seqlen.cu_seqlen_q[index]
+ indices.extend(range(input_ids_offset, input_ids_offset + length))
+ offset += length
+ cu_q.append(offset)
+ cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32)
+
+ assert max(indices) < input_ids.shape[0]
+
+ cu_seqlen_k = (
+ torch.arange(
+ n_images + 1,
+ device=device,
+ dtype=torch.int32,
+ )
+ * seqlen_k
+ )
+ else:
+ cu_seqlen_q = torch.arange(
+ seqlen_q + 1, device=device, dtype=torch.int32
+ )
+ seqlen_k = cross_attention_states.shape[1]
+ n_images = cross_attention_states.shape[0]
+ cu_seqlen_k = (
+ torch.arange(
+ n_images + 1,
+ device=device,
+ dtype=torch.int32,
+ )
+ * seqlen_k
+ )
+ indices = image_indices[:]
+
+ cross_attention_states = (
+ cross_attention_states,
+ cu_seqlen_q,
+ cu_seqlen_k,
+ indices,
+ )
+
+ outputs = self.text_model(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ cu_seqlen_prefill=cu_seqlen_prefill,
+ kv_cache=kv_cache,
+ slots=slots,
+ seqlen=seqlen,
+ hpu_attention_meta=hpu_attention_meta,
+ lm_head_indices=lm_head_indices,
+ adapter_data=adapter_data,
+ cross_attention_states=cross_attention_states,
+ )
+
+ return outputs
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
new file mode 100644
index 000000000..33f63333a
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
@@ -0,0 +1,420 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
+from typing import Optional, List, Tuple
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers import (
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+ get_linear,
+)
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers.layernorm import (
+ FastLayerNorm,
+)
+from text_generation_server.layers.rotary import (
+ PositionRotaryEmbedding,
+)
+from text_generation_server.utils.weights import UnquantizedWeight
+
+
+class GPTNeoXConfig(TransformersGPTNeoXConfig):
+ attribute_map = {
+ "num_key_value_heads": "num_attention_heads",
+ }
+
+
+def load_row(config, prefix: str, weights, bias: bool):
+ weight = weights.get_weights_row(prefix)
+
+ if bias and weights.process_group.rank() == 0:
+ # Rank is only on the first rank process
+ bias = weights.get_tensor(f"{prefix}.bias")
+ else:
+ bias = None
+
+ linear = get_linear(weight, bias)
+ if config.use_parallel_residual:
+ return linear
+ else:
+ return TensorParallelRowLinear(linear, process_group=weights.process_group)
+
+
+def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
+ weight = weights.get_multi_weights_col([prefix], dim=0)
+ if isinstance(weight, UnquantizedWeight):
+ # Only on non quantized versions
+ weight.weight = (
+ weight.weight.view(
+ num_heads,
+ 3,
+ head_size,
+ hidden_size,
+ )
+ .permute(1, 0, 2, 3)
+ .reshape(-1, hidden_size)
+ )
+
+ bias = weights.get_sharded(f"{prefix}.bias", dim=0)
+ bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)
+
+ linear = get_linear(weight, bias)
+ if config.use_parallel_residual:
+ return linear
+ else:
+ return TensorParallelColumnLinear(linear)
+
+
+class FlashNeoxAttention(torch.nn.Module):
+ def __init__(self, config, prefix, weights):
+ super().__init__()
+ num_heads = config.num_attention_heads
+ hidden_size = config.hidden_size
+
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.head_size = hidden_size // num_heads
+
+ self.rotary_dim = int(config.rotary_pct * self.head_size)
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+
+ self.rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=self.rotary_dim,
+ base=config.rotary_emb_base,
+ device=weights.device,
+ )
+
+ self.softmax_scale = self.head_size ** (-0.5)
+
+ self.query_key_value = load_qkv(
+ config,
+ prefix=f"{prefix}.query_key_value",
+ weights=weights,
+ num_heads=self.num_heads,
+ head_size=self.head_size,
+ hidden_size=self.hidden_size,
+ )
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+ self.dense = load_row(
+ config, prefix=f"{prefix}.dense", weights=weights, bias=True
+ )
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_heads, dtype=torch.int32, device=weights.device
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ qkv = self.query_key_value(hidden_states)
+ qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
+
+ # Compute rotary embeddings on rotary_ndims
+ query_rot = qkv[:, 0][..., : self.rotary_dim]
+ query_pass = qkv[:, 0][..., self.rotary_dim :]
+ key_rot = qkv[:, 1][..., : self.rotary_dim]
+ key_pass = qkv[:, 1][..., self.rotary_dim :]
+
+ # Inplace rotary
+ self.rotary_emb(query_rot, key_rot, cos, sin)
+ qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
+ qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
+
+ kv_cache.store(
+ key=qkv[:, 1],
+ value=qkv[:, 2],
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=qkv[:, 0],
+ key=qkv[:, 1],
+ value=qkv[:, 2],
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ qkv[:, 0],
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
+
+
+class FlashMLP(nn.Module):
+ def __init__(self, config, prefix, weights):
+ super().__init__()
+ act = config.hidden_act
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+
+ self.dense_h_to_4h = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
+ )
+ self.dense_4h_to_h = load_row(
+ config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense_h_to_4h(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.dense_4h_to_h(hidden_states)
+ return hidden_states
+
+
+class FlashNeoXLayer(nn.Module):
+ def __init__(self, layer_id, config, weights):
+ super().__init__()
+
+ layer_norm_eps = config.layer_norm_eps
+
+ prefix = f"gpt_neox.layers.{layer_id}"
+
+ self.use_parallel_residual = config.use_parallel_residual
+ self.input_layernorm = FastLayerNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=layer_norm_eps
+ )
+ self.post_attention_layernorm = FastLayerNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=layer_norm_eps,
+ )
+ self.attention = FlashNeoxAttention(
+ config, prefix=f"{prefix}.attention", weights=weights
+ )
+
+ self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights)
+ self.process_group = weights.process_group
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ if self.use_parallel_residual:
+ ln1_hidden_states, _ = self.input_layernorm(hidden_states)
+
+ attn_output = self.attention(
+ ln1_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
+
+ mlp_output = self.mlp(ln2_hidden_states)
+ intermediate = mlp_output + attn_output
+
+ if self.process_group.size() > 1:
+ torch.distributed.all_reduce(intermediate, group=self.process_group)
+
+ return intermediate + hidden_states, None
+ else:
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
+
+ hidden_states = self.attention(
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ hidden_states, residual = self.post_attention_layernorm(
+ hidden_states, residual
+ )
+
+ mlp_output = self.mlp(hidden_states)
+
+ return mlp_output, residual
+
+
+class FlashGPTNeoXPreTrainedModel(PreTrainedModel):
+ config_class = GPTNeoXConfig
+ base_model_prefix = "gpt_neox"
+ supports_gradient_checkpointing = False
+ _no_split_modules = None
+
+
+class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__(config)
+ self.config = config
+
+ self.embed_in = TensorParallelEmbedding(
+ prefix=f"{prefix}.embed_in", weights=weights
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ FlashNeoXLayer(layer_id, config, weights)
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.final_layer_norm = FastLayerNorm.load(
+ prefix=f"{prefix}.final_layer_norm",
+ weights=weights,
+ eps=config.layer_norm_eps,
+ )
+
+ self.gradient_checkpointing = False
+
+ self.head_size = self.layers[0].attention.head_size
+ self.num_heads = self.layers[0].attention.num_heads
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ hidden_states = self.embed_in(input_ids)
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ hidden_states, _ = self.final_layer_norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
+ def __init__(self, prefix, config, weights):
+ super().__init__(config)
+
+ if not prefix:
+ prefix = "gpt_neox"
+ else:
+ prefix = f"{prefix}.gpt_neox"
+
+ self.gpt_neox = FlashGPTNeoXModel(prefix, config, weights)
+
+ self.embed_out = SpeculativeHead.load(
+ config, prefix="embed_out", weights=weights
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ hidden_states = self.gpt_neox(
+ input_ids,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits = self.embed_out(hidden_states)
+ return logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py
new file mode 100644
index 000000000..4d31d5ddf
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py
@@ -0,0 +1,117 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.distributed
+from torch import nn
+from typing import Optional, List, Tuple
+
+from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
+from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
+from text_generation_server.models.custom_modeling.vlm import (
+ load_text_model,
+ load_vision_model,
+)
+
+
+class PaliGemmaForConditionalGeneration(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ config.vision_config.quantize = config.quantize
+ self.vision_tower = load_vision_model(
+ prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
+ config=config.vision_config,
+ weights=weights,
+ )
+ self.post_vision_tower_layernorm = nn.LayerNorm.load(
+ prefix="vision_tower.vision_model.post_layernorm",
+ weights=weights,
+ eps=config.vision_config.layer_norm_eps,
+ )
+
+ self.multi_modal_projector = TensorParallelColumnLinear.load(
+ config,
+ prefix="multi_modal_projector.linear",
+ weights=weights,
+ bias=True,
+ )
+
+ self.vocab_size = config.text_config.vocab_size
+ self.config = config
+
+ text_config = config.text_config
+ text_config.speculator = config.speculator
+ text_config.quantize = config.quantize
+ self.text_model = load_text_model(
+ prefix="language_model" if not prefix else f"{prefix}.language_model",
+ config=config.text_config,
+ weights=weights,
+ )
+ self.pad_token_id = (
+ config.pad_token_id if config.pad_token_id is not None else -1
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ pixel_values: torch.FloatTensor = None,
+ # Unused here
+ pixel_attention_mask: Optional[torch.BoolTensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ inputs_embeds = self.text_model.embed_tokens(input_ids)
+ # TODO This is odd but apparently pali gemma position ids start at 1.
+ if cu_seqlen_prefill is not None:
+ position_ids += 1
+
+ if pixel_values is not None:
+ pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
+ image_outputs = self.vision_tower(pixel_values)
+ last_hidden_state = self.post_vision_tower_layernorm(
+ image_outputs.last_hidden_state
+ )
+ image_features = self.multi_modal_projector(last_hidden_state)
+
+ # mask where image or padding tokens
+ mask = input_ids == self.config.image_token_index
+
+ # insert image features into input embeddings
+ inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
+
+ hidden_states = self.text_model.model(
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ cu_seqlen_prefill=cu_seqlen_prefill,
+ kv_cache=kv_cache,
+ slots=slots,
+ seqlen=seqlen,
+ hpu_attention_meta=hpu_attention_meta,
+ adapter_data=adapter_data,
+ )
+
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.text_model.lm_head(hidden_states)
+
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py
new file mode 100644
index 000000000..0c7779124
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py
@@ -0,0 +1,414 @@
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+from typing import Optional, List, Tuple
+
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers import (
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+ get_linear,
+)
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers.layernorm import (
+ FastLayerNorm,
+)
+from text_generation_server.layers.rotary import (
+ PositionRotaryEmbedding,
+)
+
+
+class PhiConfig(PretrainedConfig):
+ def __init__(
+ self,
+ vocab_size=51200,
+ hidden_size=2560,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=32,
+ hidden_act="gelu_fast", # llama uses silu
+ layer_norm_eps=1e-05, # rms in llama,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ resid_pdrop=0.1, # llama doesn't have this
+ partial_rotary_factor=0.5, # important difference between llama and phi
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.layer_norm_eps = layer_norm_eps
+ self.rope_theta = rope_theta
+ self.resid_pdrop = resid_pdrop
+ self.partial_rotary_factor = partial_rotary_factor
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+# this is the same as llama except for Phi uses bias=True
+def load_attention(config, prefix, weights):
+ if config.num_attention_heads != config.num_key_value_heads:
+ return _load_gqa(config, prefix, weights)
+ else:
+ return TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=True,
+ )
+
+
+def _load_gqa(config, prefix: str, weights):
+ assert config.hidden_size % config.num_attention_heads == 0
+ assert config.num_attention_heads % weights.process_group.size() == 0
+
+ weight = weights.get_multi_weights_col(
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ )
+
+ if config.quantize not in ["gptq", "awq"]:
+ weight = weight.to(dtype=weights.dtype).to(device=weights.device)
+
+ head_size = config.hidden_size // config.num_attention_heads
+ num_heads = config.num_attention_heads // weights.process_group.size()
+ num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
+ assert list(weight.shape) == [
+ (num_heads + 2 * num_key_value_heads) * head_size,
+ config.hidden_size,
+ ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
+
+ # this is the same as llama except for Phi uses bias=True
+ return TensorParallelColumnLinear(get_linear(weight, bias=True))
+
+
+class FlashPhiAttention(torch.nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights,
+ ):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ self.head_size = self.hidden_size // self.num_heads
+
+ self.softmax_scale = self.head_size**-0.5
+ self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
+
+ self.rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=self.rotary_dim,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+
+ self.query_key_value = load_attention(config, prefix, weights)
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ # in llama the dense layer is called "o_proj" and has bias=False
+ self.dense = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.dense",
+ weights=weights,
+ bias=True,
+ )
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ # Compute query, key, value and split
+ qkv = self.query_key_value(hidden_states)
+ query, kv = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ 2 * self.head_size * self.num_key_value_heads,
+ ],
+ dim=1,
+ )
+
+ # Reshape query and key for rotary embeddings
+ query = query.view(-1, self.num_heads, self.head_size)
+ kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
+
+ # NOTE: this is the main difference between Llama and Phi
+ # in llama the rotary embeddings are applied to the whole query and key.
+ # Phi uses PARTIAL rotary embeddings, which are applied to the first 32 dimensions
+ #
+ # Apply partial positional embeddings in place
+ self.rotary_emb(
+ query[:, :, : self.rotary_dim], kv[:, 0, :, : self.rotary_dim], cos, sin
+ )
+
+ # Reshape key and value and cache
+ kv_cache.store(
+ key=kv[:, 0],
+ value=kv[:, 1],
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ attn_output = attention(
+ query=query,
+ key=kv[:, 0],
+ value=kv[:, 1],
+ kv_scales=self.kv_scales,
+ kv_cache=kv_cache,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
+
+
+class PhiMLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ act = config.hidden_act
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+
+ # llama weights are up_proj and down_proj and bias=False
+ self.up_proj = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.fc1",
+ weights=weights,
+ bias=True,
+ )
+ self.down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.fc2",
+ weights=weights,
+ bias=True,
+ )
+
+ def forward(self, hidden_states):
+ # NOTE: Llama requires the gate up states to an intermediate size
+ # Phi does not and we can avoid the `view` operation
+ return self.down_proj(self.act(self.up_proj(hidden_states)))
+
+
+class FlashPhiLayer(nn.Module):
+ def __init__(self, prefix: str, layer_id, config, weights):
+ super().__init__()
+ prefix = f"{prefix}.layers.{layer_id}"
+ self.self_attn = FlashPhiAttention(
+ prefix=f"{prefix}.self_attn", config=config, weights=weights
+ )
+ self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
+ self.input_layernorm = FastLayerNorm.load(
+ prefix=f"{prefix}.input_layernorm",
+ weights=weights,
+ eps=config.layer_norm_eps,
+ )
+ self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ hidden_states, res = self.input_layernorm(hidden_states, residual)
+ # Self Attention
+ attn_output = self.self_attn(
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ hidden_states = self.resid_dropout(attn_output).add(
+ self.resid_dropout(self.mlp(hidden_states))
+ )
+
+ return hidden_states, res
+
+
+class FlashPhiModel(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ process_group = weights.process_group
+ self.tp_rank = process_group.rank()
+ self.tp_world_size = process_group.size()
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=f"{prefix}.embed_tokens", weights=weights
+ )
+ self.layers = nn.ModuleList(
+ [
+ FlashPhiLayer(
+ prefix,
+ layer_id,
+ config,
+ weights,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.gradient_checkpointing = False
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
+
+ self.norm = FastLayerNorm.load(
+ prefix="model.final_layernorm",
+ weights=weights,
+ eps=config.layer_norm_eps,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ hidden_states = self.embed_tokens(input_ids)
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashPhiForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ if not prefix:
+ prefix = "model"
+ else:
+ prefix = f"{prefix}.model"
+
+ self.model = FlashPhiModel(prefix, config, weights)
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix="lm_head",
+ weights=weights,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ hidden_states = self.model(
+ input_ids,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+
+ return self.lm_head(hidden_states)
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py
new file mode 100644
index 000000000..bb585cc4b
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py
@@ -0,0 +1,254 @@
+# coding=utf-8
+# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""PyTorch Phi-MoE model."""
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+PHIMOE_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "microsoft/Phi-3.5-MoE-instruct": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct/resolve/main/config.json",
+}
+
+
+class PhiMoEConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`PhiMoEModel`]. It is used to instantiate a Phi-MoE
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the
+ [microsoft/Phi-3.5-MoE-instruct](https://huggingface.co/microsoft/Phi-3.5-MoE-instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32064):
+ Vocabulary size of the PhiMoE model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`PhiMoEModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 6400):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 8):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
+ The maximum sequence length that this model might ever be used with. Mixtral's sliding window attention
+ allows sequence of up to 4096*32 tokens.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*):
+ The id of the padding token.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ The id of the "beginning-of-sequence" token.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ The id of the "end-of-sequence" token.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`dict`, *optional*):
+ The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
+ contain the following keys: `type`, `short_factor`, `long_factor`, `short_mscale`, `long_mscale` and
+ `original_max_position_embeddings`. The `type` must be `longrope`, the `short_mscale` and `long_scale` must
+ be numbers, the `short_factor` and `long_factor` must be lists of numbers with the same length as half of
+ the attention head size and the `original_max_position_embeddings` must be an integer.
+ sliding_window (`int`, *optional*):
+ Sliding window attention window size. If not specified, will default to `262144`.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ num_experts_per_tok (`int`, *optional*, defaults to 2):
+ The number of experts to root per-token, can be also interpreted as the `top-p` routing
+ parameter
+ num_local_experts (`int`, *optional*, defaults to 16):
+ Number of experts per Sparse MLP layer.
+ output_router_logits (`bool`, *optional*, defaults to `False`):
+ Whether or not the router logits should be returned by the model. Enabeling this will also
+ allow the model to output the auxiliary loss. See [here]() for more details
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.0):
+ The aux loss factor for the total loss.
+ router_jitter_noise (`float`, *optional*, defaults to 0.01):
+ Amount of noise to add to the router.
+
+ ```python
+ >>> from transformers import PhiMoEModel, PhiMoEConfig
+
+ >>> # Initializing a Phi-3 style configuration
+ >>> configuration = PhiMoEConfig.from_pretrained("microsoft/Phi-3.5-MoE-instruct")
+
+ >>> # Initializing a model from the configuration
+ >>> model = PhiMoEModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "phimoe"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=32064,
+ hidden_size=4096,
+ intermediate_size=6400,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=8,
+ hidden_act="silu",
+ max_position_embeddings=4096 * 32,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=False,
+ rope_theta=1e6,
+ rope_scaling=None,
+ sliding_window=None,
+ attention_dropout=0.0,
+ num_experts_per_tok=2,
+ num_local_experts=16,
+ output_router_logits=False,
+ router_aux_loss_coef=0.001,
+ router_jitter_noise=0.01,
+ input_jitter_noise=0.0,
+ attention_bias=False,
+ lm_head_bias=False,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.sliding_window = sliding_window
+ self.attention_bias = attention_bias
+ self.lm_head_bias = lm_head_bias
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_dropout = attention_dropout
+
+ self.num_experts_per_tok = num_experts_per_tok
+ self.num_local_experts = num_local_experts
+ self.output_router_logits = output_router_logits
+ self.router_aux_loss_coef = router_aux_loss_coef
+ self.router_jitter_noise = router_jitter_noise
+ self.input_jitter_noise = input_jitter_noise
+
+ self.rope_scaling = rope_scaling
+ self._rope_scaling_validation()
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ def _rope_scaling_validation(self):
+ """
+ Validate the `rope_scaling` configuration.
+ """
+ if self.rope_scaling is None:
+ return
+
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 6:
+ raise ValueError(
+ "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor`, `long_factor`, "
+ f"`short_mscale`, `long_mscale` and `original_max_position_embeddings`, got {self.rope_scaling}"
+ )
+ rope_scaling_type = self.rope_scaling.get("type", None)
+ rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
+ rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
+ rope_scaling_short_mscale = self.rope_scaling.get("short_mscale", None)
+ rope_scaling_long_mscale = self.rope_scaling.get("long_mscale", None)
+ original_max_position_embeddings = self.rope_scaling.get(
+ "original_max_position_embeddings", None
+ )
+ if rope_scaling_type is None or rope_scaling_type not in ["longrope"]:
+ raise ValueError(
+ f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}"
+ )
+ if not (
+ isinstance(rope_scaling_short_factor, list)
+ and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
+ ):
+ raise ValueError(
+ f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
+ )
+ if (
+ not len(rope_scaling_short_factor)
+ == self.hidden_size // self.num_attention_heads // 2
+ ):
+ raise ValueError(
+ f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}"
+ )
+ if not (
+ isinstance(rope_scaling_long_factor, list)
+ and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
+ ):
+ raise ValueError(
+ f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
+ )
+ if (
+ not len(rope_scaling_long_factor)
+ == self.hidden_size // self.num_attention_heads // 2
+ ):
+ raise ValueError(
+ f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
+ )
+ if not isinstance(rope_scaling_short_mscale, (int, float)):
+ raise ValueError(
+ f"`rope_scaling`'s short_mscale field must be a number, got {rope_scaling_short_mscale}"
+ )
+ if not isinstance(rope_scaling_long_mscale, (int, float)):
+ raise ValueError(
+ f"`rope_scaling`'s long_mscale field must be a number, got {rope_scaling_long_mscale}"
+ )
+ if not isinstance(original_max_position_embeddings, int):
+ raise ValueError(
+ f"`rope_scaling`'s original_max_position_embeddings field must be an integer, got {original_max_position_embeddings}"
+ )
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py
new file mode 100644
index 000000000..af4b404d0
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py
@@ -0,0 +1,371 @@
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from typing import Optional, List, Tuple
+
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers import (
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+)
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers.rotary import PositionRotaryEmbedding
+from text_generation_server.layers.layernorm import (
+ FastRMSNorm,
+)
+
+
+def load_attention(config, prefix, weights):
+ if config.num_attention_heads != config.num_key_value_heads:
+ return _load_gqa(config, prefix, weights)
+ else:
+ return TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=True,
+ )
+
+
+def _load_gqa(config, prefix: str, weights):
+ assert config.hidden_size % config.num_attention_heads == 0
+ assert config.num_attention_heads % weights.process_group.size() == 0
+
+ return TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=True,
+ )
+
+
+class Qwen2Attention(torch.nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights,
+ ):
+ super().__init__()
+ self.max_past = (
+ config.sliding_window if config.sliding_window is not None else -1
+ )
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ self.head_size = self.hidden_size // self.num_heads
+
+ self.rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=self.head_size,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
+ self.softmax_scale = self.head_size**-0.5
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+
+ self.query_key_value = load_attention(config, prefix, weights)
+
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ self.o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ qkv = self.query_key_value(hidden_states)
+ query, kv = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ 2 * self.head_size * self.num_key_value_heads,
+ ],
+ dim=1,
+ )
+ query = query.view(-1, self.num_heads, self.head_size)
+ kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
+
+ self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
+
+ kv_cache.store(
+ key=kv[:, 0],
+ value=kv[:, 1],
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=kv[:, 0],
+ value=kv[:, 1],
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ window_size_left=self.max_past,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
+
+
+class Qwen2MLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ act = config.hidden_act
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+ # Fuse gate and up proj
+ self.gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=False,
+ )
+ self.down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.intermediate_size = (
+ config.intermediate_size // weights.process_group.size()
+ )
+
+ def forward(self, hidden_states):
+ gate_up_states = self.gate_up_proj(hidden_states)
+ gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
+ return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
+
+
+class Qwen2Layer(nn.Module):
+ def __init__(self, prefix, layer_id, config, weights):
+ super().__init__()
+ prefix = f"{prefix}.layers.{layer_id}"
+ self.self_attn = Qwen2Attention(
+ prefix=f"{prefix}.self_attn", config=config, weights=weights
+ )
+ self.mlp = Qwen2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
+ self.input_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ normed_hidden_states, residual = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ normed_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ hidden_states = attn_output + residual
+
+ # faster post attention rms norm
+ hidden_states, residual = self.post_attention_layernorm(hidden_states)
+ mlp_output = self.mlp(hidden_states)
+ hidden_states = mlp_output + residual
+ return hidden_states
+
+
+class Qwen2Model(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ prefix = f"{prefix}.model" if prefix else "model"
+
+ process_group = weights.process_group
+ self.tp_rank = process_group.rank()
+ self.tp_world_size = process_group.size()
+ self.layers = nn.ModuleList(
+ [
+ Qwen2Layer(
+ prefix,
+ layer_id,
+ config,
+ weights,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = FastRMSNorm.load(
+ prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
+ )
+
+ self.gradient_checkpointing = False
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ hidden_states = inputs_embeds
+
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
+ position_ids,
+ )
+
+ residual = None
+ for i, layer in enumerate(self.layers):
+ hidden_states = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ hidden_states, _ = self.norm(hidden_states)
+
+ return hidden_states
+
+
+class Qwen2ForCausalLM(torch.nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+
+ self.model = Qwen2Model(prefix, config, weights)
+
+ if config.tie_word_embeddings:
+ suffix = "model.embed_tokens"
+ else:
+ suffix = "lm_head"
+
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix=f"{prefix}.{suffix}" if prefix else suffix,
+ weights=weights,
+ )
+
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens",
+ weights=weights,
+ )
+
+ self.max_past = config.sliding_window
+ self.max_past_tensor = (
+ torch.tensor(config.sliding_window, device=weights.device)
+ if self.max_past is not None
+ else None
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ hidden_states = self.model(
+ inputs_embeds,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits = self.lm_head(hidden_states)
+ return logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
new file mode 100644
index 000000000..141e13a63
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
@@ -0,0 +1,692 @@
+from typing import List, Optional, Tuple
+
+import torch
+import torch.distributed
+from torch import nn
+from transformers.configuration_utils import PretrainedConfig
+from transformers.modeling_utils import PreTrainedModel
+from text_generation_server.layers import (
+ SpeculativeHead,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ TensorParallelRowLinear,
+ get_linear,
+)
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers.layernorm import FastLayerNorm
+from text_generation_server.layers.rotary import PositionRotaryEmbedding
+from text_generation_server.layers.attention import (
+ attention,
+ paged_attention,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+
+
+def load_row(config, prefix: str, weights, bias: bool):
+ weight = weights.get_weights_row(prefix)
+
+ if bias and weights.process_group.rank() == 0:
+ # Rank is only on the first rank process
+ bias = weights.get_tensor(f"{prefix}.bias")
+ else:
+ bias = None
+
+ linear = get_linear(weight, bias)
+ if config.parallel_attn:
+ return linear
+ else:
+ return TensorParallelRowLinear(linear, process_group=weights.process_group)
+
+
+class RWConfig(PretrainedConfig):
+ attribute_map = {
+ "num_hidden_layers": "n_layer",
+ "num_attention_heads": "n_head",
+ "num_key_value_heads": "n_head_kv",
+ }
+
+ def __init__(
+ self,
+ model_type="RefinedWeb",
+ vocab_size=250880,
+ hidden_size=64,
+ num_hidden_layers=None,
+ num_attention_heads=None,
+ num_ln_in_prallel_attention=None,
+ layer_norm_epsilon=1e-5,
+ initializer_range=0.02,
+ use_cache=True,
+ bos_token_id=1,
+ eos_token_id=2,
+ hidden_dropout=0.0,
+ attention_dropout=0.0,
+ num_kv_heads=None,
+ multi_query=False,
+ alibi=False,
+ new_decoder_architecture=None,
+ bias=False,
+ parallel_attn=False,
+ rope_theta=10_000.0,
+ **kwargs,
+ ):
+ if alibi:
+ raise NotImplementedError(
+ "alibi is not supported by this version of the model"
+ )
+
+ self.model_type = model_type
+ self.alibi = False
+ self.rotary = True
+ self.rope_theta = rope_theta
+ self.max_position_embeddings = 2048
+
+ self.vocab_size = vocab_size
+ # Backward compatibility with n_embed kwarg
+ n_embed = kwargs.pop("n_embed", None)
+ self.hidden_size = hidden_size if n_embed is None else n_embed
+ self.n_layer = (
+ num_hidden_layers
+ if num_hidden_layers is not None
+ else kwargs.pop("n_layer", 2)
+ )
+ self.n_head = (
+ num_attention_heads
+ if num_attention_heads is not None
+ else kwargs.pop("n_head", 8)
+ )
+ self.layer_norm_epsilon = layer_norm_epsilon
+ self.num_ln_in_parallel_attn = num_ln_in_prallel_attention
+ self.initializer_range = initializer_range
+ self.use_cache = use_cache
+ self.hidden_dropout = hidden_dropout
+ self.attention_dropout = attention_dropout
+ self.bias = bias
+ self.parallel_attn = parallel_attn
+
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+
+ if num_kv_heads is not None:
+ self.n_head_kv = num_kv_heads
+ else:
+ old_n_head_kv = kwargs.pop("n_head_kv", None)
+ if old_n_head_kv is not None:
+ self.n_head_kv = old_n_head_kv
+ else:
+ self.n_head_kv = 1 if multi_query else self.n_head
+
+ if new_decoder_architecture is not None:
+ self.new_decoder_architecture = new_decoder_architecture
+ elif model_type == "RefinedWeb":
+ self.new_decoder_architecture = True
+ else:
+ self.new_decoder_architecture = False
+
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+
+class FlashRWAttention(torch.nn.Module):
+ def __init__(
+ self,
+ config,
+ prefix: str,
+ weights,
+ ):
+ super().__init__()
+ self.num_heads = config.n_head
+ self.num_heads_kv = config.n_head_kv
+ self.hidden_size = config.hidden_size
+ self.head_size = self.hidden_size // self.num_heads
+ self.rope_theta = config.rope_theta
+
+ self.rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=self.head_size,
+ base=self.rope_theta,
+ device=weights.device,
+ )
+ self.softmax_scale = self.head_size ** (-0.5)
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+
+ self.query_key_value = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.query_key_value",
+ weights=weights,
+ bias=config.bias,
+ )
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+ self.dense = load_row(
+ config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
+ )
+
+ if self.num_heads_kv == 1:
+ self.kv_head_mapping = torch.zeros(
+ self.num_heads, dtype=torch.int32, device=weights.device
+ )
+ else:
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_heads, dtype=torch.int32, device=weights.device
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ qkv = self.query_key_value(hidden_states)
+
+ # Split query from key_value
+ query, kv = qkv.split(
+ [self.head_size * self.num_heads, 2 * self.head_size * self.num_heads_kv],
+ dim=1,
+ )
+
+ # Prepare query and key_value for indexing
+ query = query.view(-1, self.num_heads, self.head_size)
+ kv = kv.view(-1, 2, self.num_heads_kv, self.head_size)
+
+ # Inplace rotary
+ self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
+
+ kv_cache.store(
+ key=kv[:, 0],
+ value=kv[:, 1],
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=kv[:, 0],
+ value=kv[:, 1],
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
+
+
+class FlashRWLargeAttention(torch.nn.Module):
+ def __init__(
+ self,
+ config,
+ prefix: str,
+ weights,
+ ):
+ super().__init__()
+
+ hidden_size = config.hidden_size
+ num_heads = config.n_head
+ # num_heads_kv = config.n_head_kv
+ num_groups = config.n_head_kv
+
+ self.hidden_size = hidden_size
+ self.head_size = hidden_size // num_heads
+ self.num_groups = num_groups
+ self.rope_theta = config.rope_theta
+
+ self.rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=self.head_size,
+ base=self.rope_theta,
+ device=weights.device,
+ )
+ self.softmax_scale = self.head_size ** (-0.5)
+
+ # self.num_groups = num_heads // (num_heads_kv * 2)
+ self.num_heads = num_heads // self.num_groups
+ # self.num_heads_kv = num_heads_kv // self.num_groups
+ process_group = weights.process_group
+
+ if process_group.size() > self.num_groups:
+ raise NotImplementedError(
+ "Tensor Parallelism is not implemented for world_size > n groups"
+ )
+ if self.num_groups % process_group.size() != 0:
+ raise NotImplementedError(
+ f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}"
+ )
+
+ self.num_groups = self.num_groups // process_group.size()
+
+ self.query_key_value = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.query_key_value",
+ weights=weights,
+ bias=config.bias,
+ )
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+ self.dense = load_row(
+ config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
+ )
+
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_groups, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_heads)
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ qkv = self.query_key_value(hidden_states)
+ qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
+
+ # Split on group dimension
+ query, kv = qkv.split(
+ [self.num_heads, 2],
+ dim=2,
+ )
+ # Merge groups and heads
+ query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
+
+ # Inplace rotary
+ self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
+
+ kv_cache.store(
+ key=kv[:, :, 0].contiguous(),
+ value=kv[:, :, 1].contiguous(),
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # flash attention
+ attn_output = attention(
+ query=query,
+ key=kv[:, :, 0],
+ value=kv[:, :, 1],
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.dense(
+ attn_output.view(-1, self.num_groups * self.num_heads * self.head_size)
+ )
+
+
+class FlashMLP(nn.Module):
+ def __init__(self, config, prefix: str, weights):
+ super().__init__()
+ self.act = torch.nn.functional.gelu
+
+ self.dense_h_to_4h = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias
+ )
+ self.dense_4h_to_h = load_row(
+ config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense_h_to_4h(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.dense_4h_to_h(hidden_states)
+ return hidden_states
+
+
+class FlashRWLayer(nn.Module):
+ def __init__(
+ self,
+ layer_id,
+ prefix: str,
+ config,
+ weights,
+ ):
+ super().__init__()
+
+ parallel_attn = config.parallel_attn
+ self.parallel_attn = parallel_attn
+
+ prefix = f"{prefix}.h.{layer_id}"
+
+ # NOTE: Falcon 180B uses the ln_attn prefix
+ ln_prefix = "input_layernorm"
+ if config.num_hidden_layers == 80:
+ ln_prefix = "ln_attn"
+
+ self.input_layernorm = FastLayerNorm.load(
+ prefix=f"{prefix}.{ln_prefix}",
+ weights=weights,
+ eps=config.layer_norm_epsilon,
+ )
+ self.self_attention = FlashRWAttention(
+ config,
+ prefix=f"{prefix}.self_attention",
+ weights=weights,
+ )
+ self.post_attention_layernorm = (
+ FastLayerNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.layer_norm_epsilon,
+ )
+ if not parallel_attn
+ else None
+ )
+
+ self.mlp = FlashMLP(
+ config,
+ prefix=f"{prefix}.mlp",
+ weights=weights,
+ )
+
+ self.process_group = weights.process_group
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ if self.parallel_attn:
+ ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
+
+ attn_output = self.self_attention(
+ ln_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ mlp_output = self.mlp(ln_hidden_states)
+ intermediate = mlp_output + attn_output
+
+ if self.process_group.size() > 1:
+ torch.distributed.all_reduce(intermediate, group=self.process_group)
+
+ return intermediate, residual
+ else:
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
+
+ hidden_states = self.self_attention(
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ if self.post_attention_layernorm is not None:
+ hidden_states, residual = self.post_attention_layernorm(
+ hidden_states, residual
+ )
+
+ mlp_output = self.mlp(hidden_states)
+
+ return mlp_output, residual
+
+
+class FlashRWLayerNorm(nn.Module):
+ def __init__(self, config, prefix: str, weights):
+ super().__init__()
+ # Falcon2 includes the number of layer norms in the config
+ # in the case no number of layer norms is provided, we default to 1
+ self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1)
+
+ # Falcon 180B uses the ln_attn prefix and has 2 layer norms
+ if config.num_hidden_layers == 80:
+ self.num_ln = 2
+
+ if self.num_ln == 1:
+ self.input_ln = FastLayerNorm.load(
+ prefix=f"{prefix}.input_layernorm",
+ weights=weights,
+ eps=config.layer_norm_epsilon,
+ )
+ elif self.num_ln == 2:
+ self.ln_attn = FastLayerNorm.load(
+ prefix=f"{prefix}.ln_attn",
+ weights=weights,
+ eps=config.layer_norm_epsilon,
+ )
+ self.ln_mlp = FastLayerNorm.load(
+ prefix=f"{prefix}.ln_mlp",
+ weights=weights,
+ eps=config.layer_norm_epsilon,
+ )
+ else:
+ raise ValueError("Number of layer norms can either be 1 or 2.")
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ ):
+ if self.num_ln == 1:
+ ln_hidden_states, residual = self.input_ln(hidden_states, residual)
+ return ln_hidden_states, ln_hidden_states, residual
+ elif self.num_ln == 2:
+ ln_attn, residual = self.ln_attn(hidden_states, residual)
+ ln_mlp, _ = self.ln_mlp(residual)
+ return ln_attn, ln_mlp, residual
+
+
+class FlashRWLargeLayer(nn.Module):
+ def __init__(self, layer_id, prefix: str, config, weights):
+ super().__init__()
+ prefix = f"{prefix}.h.{layer_id}"
+
+ self.ln_layer = FlashRWLayerNorm(config, prefix, weights)
+
+ self.self_attention = FlashRWLargeAttention(
+ config,
+ prefix=f"{prefix}.self_attention",
+ weights=weights,
+ )
+ assert config.parallel_attn, "This version doesn't support non parallel_attn"
+
+ self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights)
+
+ self.process_group = weights.process_group
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ # Layer norm.
+ ln_attn, ln_mlp, residual = self.ln_layer(hidden_states, residual)
+
+ # Self attention.
+ attn_output = self.self_attention(
+ ln_attn,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ # MLP.
+ mlp_output = self.mlp(ln_mlp)
+
+ intermediate = attn_output + mlp_output
+
+ if self.process_group.size() > 1:
+ torch.distributed.all_reduce(intermediate, group=self.process_group)
+
+ return intermediate, residual
+
+
+class FlashRWPreTrainedModel(PreTrainedModel):
+ config_class = RWConfig
+
+
+class FlashRWModel(FlashRWPreTrainedModel):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__(config)
+ self.config = config
+
+ self.word_embeddings = TensorParallelEmbedding(
+ prefix=f"{prefix}.word_embeddings", weights=weights
+ )
+
+ if config.new_decoder_architecture:
+ self.h = nn.ModuleList(
+ [
+ FlashRWLargeLayer(layer_id, prefix, config, weights)
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.cache_size = self.h[0].self_attention.num_groups
+ else:
+ self.h = nn.ModuleList(
+ [
+ FlashRWLayer(layer_id, prefix, config, weights)
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.cache_size = self.h[0].self_attention.num_heads_kv
+
+ self.ln_f = FastLayerNorm.load(
+ prefix=f"{prefix}.ln_f",
+ weights=weights,
+ eps=config.layer_norm_epsilon,
+ )
+
+ self.head_size = self.h[0].self_attention.head_size
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ hidden_states = self.word_embeddings(input_ids)
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ for i, layer in enumerate(self.h):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ hidden_states, _ = self.ln_f(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashRWForCausalLM(FlashRWPreTrainedModel):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__(config)
+
+ if not prefix:
+ prefix = "transformer"
+ else:
+ prefix = f"{prefix}.transformer"
+
+ self.transformer = FlashRWModel(prefix, config, weights)
+
+ self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ hidden_states = self.transformer(
+ input_ids,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits = self.lm_head(hidden_states)
+ return logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
new file mode 100644
index 000000000..b68f47840
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
@@ -0,0 +1,500 @@
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from typing import Optional, List, Tuple
+
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers import (
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ SpeculativeHead,
+ TensorParallelEmbedding,
+ get_linear,
+)
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers.gptq import GPTQWeightsLoader
+from text_generation_server.layers.layernorm import (
+ FastLayerNorm,
+)
+
+
+def load_multi_mqa(
+ config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
+):
+ if config.quantize == "gptq":
+ return _load_multi_mqa_gptq(
+ config, prefix, weights, bias, head_size, num_heads, hidden_size
+ )
+ else:
+ return _load_multi_mqa(
+ config, prefix, weights, bias, head_size, num_heads, hidden_size
+ )
+
+
+def _load_multi_mqa_gptq(
+ config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
+):
+ from text_generation_server.layers.gptq import GPTQWeight
+
+ if any("c_attn" in k for k in weights.routing.keys()) and not config.transpose:
+ world_size = weights.process_group.size()
+ rank = weights.process_group.rank()
+
+ slice_ = weights._get_slice(f"{prefix}.c_attn.qweight")
+ shape = slice_.get_shape()
+ block_size = (shape[1] - 2 * head_size) // world_size
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+ assert (shape[1] - 2 * head_size) % world_size == 0
+ q_tensor = slice_[:, start:stop]
+ kv_tensor = slice_[:, -2 * head_size :]
+ qweight = torch.cat([q_tensor, kv_tensor], dim=1)
+ qweight = qweight.to(device=weights.device)
+
+ slice_ = weights._get_slice(f"{prefix}.c_attn.scales")
+ shape = slice_.get_shape()
+ block_size = (shape[1] - 2 * head_size) // world_size
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+ assert (shape[1] - 2 * head_size) % world_size == 0
+ q_tensor = slice_[:, start:stop]
+ kv_tensor = slice_[:, -2 * head_size :]
+ scales = torch.cat([q_tensor, kv_tensor], dim=1)
+ scales = scales.to(device=weights.device)
+
+ slice_ = weights._get_slice(f"{prefix}.c_attn.qzeros")
+ shape = slice_.get_shape()
+ block_size = (shape[1] - (2 * head_size) * 4 // 32) // world_size
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+ assert 2 * head_size % (32 // 4) == 0
+ q_tensor = slice_[:, start:stop]
+ kv_tensor = slice_[:, -2 * head_size * 4 // 32 :]
+ qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
+ qzeros = qzeros.to(device=weights.device)
+
+ loader = weights.weights_loader
+ assert isinstance(loader, GPTQWeightsLoader)
+ loader._get_gptq_params(weights)
+ if loader.quant_method == "gptq":
+ g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
+ g_idx = g_idx.to(device=weights.device)
+ elif loader.quant_method == "awq":
+ g_idx = None
+ from text_generation_server.layers.awq.conversion_utils import (
+ fast_awq_to_gptq,
+ )
+
+ qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
+
+ from text_generation_server.layers.gptq import HAS_EXLLAMA
+
+ weight = GPTQWeight(
+ qweight=qweight,
+ qzeros=qzeros,
+ scales=scales,
+ g_idx=g_idx,
+ bits=loader.bits,
+ groupsize=loader.groupsize,
+ use_awq_kernel=loader.quantize == "awq",
+ use_exllama=HAS_EXLLAMA,
+ )
+
+ if bias:
+ slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
+ shape = slice_.get_shape()
+ block_size = (shape[0] - 2 * head_size) // world_size
+ assert (shape[0] - 2 * head_size) % world_size == 0
+ q_tensor = slice_[start:stop]
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+ q_tensor = slice_[start:stop]
+ kv_tensor = slice_[-2 * head_size :]
+ bias = torch.cat([q_tensor, kv_tensor], dim=0)
+ bias = bias.to(device=weights.device)
+
+ return TensorParallelColumnLinear(get_linear(weight, bias))
+ else:
+ raise NotImplementedError("Gptq loading with santacoder is not implemented")
+
+
+def _load_multi_mqa(
+ config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
+):
+ if any("c_attn" in k for k in weights.routing.keys()):
+ slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
+ shape = slice_.get_shape()
+ world_size = weights.process_group.size()
+ rank = weights.process_group.rank()
+ if config.transpose:
+ block_size = (shape[1] - 2 * head_size) // world_size
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+ assert (shape[1] - 2 * head_size) % world_size == 0
+ q_tensor = slice_[:, start:stop]
+ kv_tensor = slice_[:, -2 * head_size :]
+ weight = torch.cat([q_tensor, kv_tensor], dim=1).T
+ else:
+ block_size = (shape[0] - 2 * head_size) // world_size
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+ assert (shape[0] - 2 * head_size) % world_size == 0
+ q_tensor = slice_[start:stop]
+ kv_tensor = slice_[-2 * head_size :]
+ weight = torch.cat([q_tensor, kv_tensor], dim=0)
+ if bias:
+ slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
+ shape = slice_.get_shape()
+ block_size = (shape[0] - 2 * head_size) // world_size
+ assert (shape[0] - 2 * head_size) % world_size == 0
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+ q_tensor = slice_[start:stop]
+ kv_tensor = slice_[-2 * head_size :]
+ bias = torch.cat([q_tensor, kv_tensor], dim=0)
+ else:
+ if config.transpose:
+ w = [
+ weights.get_sharded(f"{prefix}.q_attn.weight", dim=1).T,
+ weights.get_tensor(f"{prefix}.kv_attn.weight").T,
+ ]
+ weight = torch.cat(w, dim=0)
+ else:
+ w = [
+ weights.get_sharded(f"{prefix}.q_attn.weight", dim=0),
+ weights.get_tensor(f"{prefix}.kv_attn.weight"),
+ ]
+ weight = torch.cat(w, dim=1)
+
+ if bias:
+ b = [
+ weights.get_sharded(f"{prefix}.q_attn.bias", dim=0),
+ weights.get_tensor(f"{prefix}.kv_attn.bias"),
+ ]
+ bias = torch.cat(b, dim=0)
+ else:
+ bias = None
+
+ weight = weight.to(dtype=weights.dtype).to(device=weights.device)
+ assert list(weight.shape) == [
+ (num_heads + 2) * head_size,
+ hidden_size,
+ ], f"{weight.shape} != {[(num_heads + 2) * head_size, hidden_size]}"
+ if bias is not None:
+ bias = bias.to(dtype=weights.dtype).to(device=weights.device)
+ assert list(bias.shape) == [
+ (num_heads + 2) * head_size
+ ], f"{weight.shape} != {[(num_heads + 2) * head_size]}"
+ return TensorParallelColumnLinear(get_linear(weight, bias))
+
+
+def load_col(config, prefix: str, weights, bias: bool):
+ if config.transpose:
+ weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
+ else:
+ weight = weights.get_multi_weights_col([prefix], dim=0)
+
+ if bias:
+ bias = weights.get_sharded(f"{prefix}.bias", dim=0)
+ else:
+ bias = None
+ return TensorParallelColumnLinear(get_linear(weight, bias))
+
+
+def load_row(config, prefix: str, weights, bias: bool):
+ if config.transpose:
+ weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
+ else:
+ weight = weights.get_weights_row(prefix)
+
+ if bias and weights.process_group.rank() == 0:
+ # Rank is only on the first rank process
+ bias = weights.get_tensor(f"{prefix}.bias")
+ else:
+ bias = None
+ return TensorParallelRowLinear(
+ get_linear(weight, bias), process_group=weights.process_group
+ )
+
+
+class FlashMQAttention(torch.nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ num_heads = config.num_attention_heads
+ hidden_size = config.hidden_size
+
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.head_size = hidden_size // num_heads
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+
+ self.softmax_scale = self.head_size ** (-0.5)
+
+ self.c_attn = load_multi_mqa(
+ config,
+ prefix=prefix,
+ weights=weights,
+ bias=True,
+ head_size=self.head_size,
+ hidden_size=hidden_size,
+ num_heads=self.num_heads,
+ )
+ self.c_proj = load_row(
+ config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
+ )
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+ self.kv_head_mapping = torch.zeros(
+ self.num_heads, dtype=torch.int32, device=weights.device
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ qkv = self.c_attn(hidden_states)
+
+ # Split query from key_value
+ query, key_value = qkv.split(
+ [self.head_size * self.num_heads, 2 * self.head_size], dim=1
+ )
+
+ # Prepare query and key_value for indexing
+ query = query.view(-1, self.num_heads, self.head_size)
+ key_value = key_value.view(-1, 2, 1, self.head_size)
+
+ kv_cache.store(
+ key=key_value[:, 0],
+ value=key_value[:, 1],
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=key_value[:, 0],
+ value=key_value[:, 1],
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
+
+
+class MLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ act = config.activation_function
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+
+ self.c_fc = load_col(
+ config, prefix=f"{prefix}.c_fc", weights=weights, bias=True
+ )
+ self.c_proj = load_row(
+ config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = self.c_fc(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.c_proj(hidden_states)
+ return hidden_states
+
+
+class Block(nn.Module):
+ def __init__(self, prefix: str, layer_id, config, weights):
+ super().__init__()
+ prefix = f"{prefix}.h.{layer_id}"
+ self.ln_1 = FastLayerNorm.load(
+ prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
+ )
+ self.ln_2 = FastLayerNorm.load(
+ prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon
+ )
+ self.self_attn = FlashMQAttention(
+ prefix=f"{prefix}.attn",
+ config=config,
+ weights=weights,
+ )
+ self.mlp = MLP(
+ prefix=f"{prefix}.mlp",
+ config=config,
+ weights=weights,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ ):
+ hidden_states, residual = self.ln_1(hidden_states, residual)
+ hidden_states = self.self_attn(
+ hidden_states,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ hidden_states, residual = self.ln_2(hidden_states, residual)
+
+ mlp_output = self.mlp(hidden_states)
+
+ return mlp_output, residual
+
+
+class FlashSantacoderModel(nn.Module):
+ def __init__(self, prefix: str, config, weights):
+ super().__init__()
+ self.config = config
+
+ self.process_group = weights.process_group
+ self.wte = TensorParallelEmbedding(
+ prefix=f"{prefix}.wte",
+ weights=weights,
+ reduce=False,
+ )
+ self.wpe = TensorParallelEmbedding(
+ prefix=f"{prefix}.wpe",
+ weights=weights,
+ reduce=False,
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ Block(
+ prefix,
+ layer_id,
+ config,
+ weights,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.ln_f = FastLayerNorm.load(
+ prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon
+ )
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ hidden_states = self.wte(input_ids) + self.wpe(position_ids)
+
+ if self.process_group.size() > 1:
+ torch.distributed.all_reduce(hidden_states, group=self.process_group)
+
+ residual = None
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+
+ hidden_states, _ = self.ln_f(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashSantacoderForCausalLM(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+
+ if not prefix:
+ prefix = "transformer"
+ else:
+ prefix = f"{prefix}.transformer"
+
+ config.transpose = config.architectures[0].startswith("GPT2")
+ self.model = FlashSantacoderModel(prefix, config, weights)
+ self.lm_head = SpeculativeHead.load(
+ config, prefix=f"{prefix}.wte", weights=weights
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ hidden_states = self.model(
+ input_ids,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits = self.lm_head(hidden_states)
+ return logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py
new file mode 100644
index 000000000..76f6f473a
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py
@@ -0,0 +1,595 @@
+# coding=utf-8
+# Copyright 2024 Starcoder2 AI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+from typing import Optional, List, Tuple
+
+from text_generation_server.layers.attention import (
+ paged_attention,
+ attention,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.layers import (
+ TensorParallelMultiAdapterLinear,
+ TensorParallelAdapterRowLinear,
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+ get_linear,
+)
+from text_generation_server.layers.attention.kv_cache import get_kv_scales
+from text_generation_server.layers.layernorm import (
+ FastLayerNorm,
+ FastRMSNorm,
+)
+from text_generation_server.layers.rotary import (
+ PositionRotaryEmbedding,
+)
+from text_generation_server.utils.weights import UnquantizedWeight
+
+
+class Starcoder2Config(PretrainedConfig):
+ model_type = "starcoder2"
+
+ def __init__(
+ self,
+ vocab_size=49152,
+ hidden_size=3072,
+ intermediate_size=12288,
+ num_hidden_layers=30,
+ num_attention_heads=24,
+ num_key_value_heads=2,
+ mlp_type="default",
+ hidden_act="gelu_pytorch_tanh",
+ max_position_embeddings=4096,
+ initializer_range=0.018042,
+ norm_type="layer_norm",
+ norm_epsilon=1e-5,
+ use_cache=True,
+ bos_token_id=50256,
+ eos_token_id=50256,
+ rope_theta=10000.0,
+ sliding_window=None,
+ attention_dropout=0.0,
+ residual_dropout=0.0,
+ embedding_dropout=0.0,
+ use_bias: bool = True,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.sliding_window = sliding_window
+ self.use_bias = use_bias
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.mlp_type = mlp_type
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.norm_type = norm_type
+ self.norm_epsilon = norm_epsilon
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_dropout = attention_dropout
+ self.residual_dropout = residual_dropout
+ self.embedding_dropout = embedding_dropout
+
+ super().__init__(
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ **kwargs,
+ )
+
+
+def load_attention(config, prefix, weights, layer_id):
+ prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
+ head_size = config.hidden_size // config.num_attention_heads
+ sizes = [
+ head_size * config.num_attention_heads,
+ head_size * config.num_key_value_heads,
+ head_size * config.num_key_value_heads,
+ ]
+ if config.num_attention_heads != config.num_key_value_heads:
+ base_layer = _load_gqa(config, prefix, weights)
+ else:
+ base_layer = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=prefixes,
+ dim=0,
+ weights=weights,
+ bias=config.use_bias,
+ )
+ return TensorParallelMultiAdapterLinear.load(
+ base_layer=base_layer,
+ layer_id=layer_id,
+ layer_names=prefixes,
+ sizes=sizes,
+ process_group=weights.process_group,
+ )
+
+
+def _load_gqa(config, prefix: str, weights):
+ assert config.hidden_size % config.num_attention_heads == 0
+ assert config.num_attention_heads % weights.process_group.size() == 0
+
+ weight = weights.get_multi_weights_col(
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ )
+
+ if isinstance(weight, UnquantizedWeight):
+ weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
+
+ head_size = config.hidden_size // config.num_attention_heads
+ num_heads = config.num_attention_heads // weights.process_group.size()
+ num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
+ assert list(weight.weight.shape) == [
+ (num_heads + 2 * num_key_value_heads) * head_size,
+ config.hidden_size,
+ ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
+
+ if config.use_bias:
+ w = [
+ weights.get_sharded(f"{p}.bias", dim=0)
+ for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
+ ]
+ bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device)
+ else:
+ bias = None
+
+ return TensorParallelColumnLinear(get_linear(weight, bias=bias))
+
+
+class Starcoder2Attention(torch.nn.Module):
+ def __init__(
+ self,
+ index: int,
+ prefix: str,
+ config,
+ weights,
+ ):
+ super().__init__()
+ self.max_past = (
+ config.sliding_window if config.sliding_window is not None else -1
+ )
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ self.head_size = self.hidden_size // self.num_heads
+
+ self.rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=self.head_size,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
+ self.softmax_scale = self.head_size**-0.5
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+
+ self.query_key_value = load_attention(config, prefix, weights, index)
+ self.kv_scales = get_kv_scales(weights, f"{prefix}")
+
+ o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=getattr(config, "use_bias", False),
+ )
+
+ self.o_proj = TensorParallelAdapterRowLinear.load(
+ o_proj,
+ index,
+ "o_proj",
+ process_group=weights.process_group,
+ )
+
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ ):
+ qkv = self.query_key_value(hidden_states, adapter_data)
+ query, kv = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ 2 * self.head_size * self.num_key_value_heads,
+ ],
+ dim=1,
+ )
+ query = query.view(-1, self.num_heads, self.head_size)
+ kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
+
+ self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
+
+ kv_cache.store(
+ key=kv[:, 0],
+ value=kv[:, 1],
+ slots=slots,
+ kv_scales=self.kv_scales,
+ )
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # sdpa
+ attn_output = attention(
+ query=query,
+ key=kv[:, 0],
+ value=kv[:, 1],
+ kv_cache=kv_cache,
+ kv_scales=self.kv_scales,
+ seqlen=seqlen,
+ softmax_scale=self.softmax_scale,
+ window_size_left=self.max_past,
+ )
+ # Decode
+ else:
+ attn_output = paged_attention(
+ query,
+ kv_cache,
+ self.kv_head_mapping,
+ self.softmax_scale,
+ seqlen,
+ kv_scales=self.kv_scales,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ return self.o_proj(
+ attn_output.view(-1, self.num_heads * self.head_size), adapter_data
+ )
+
+
+class Starcoder2MLP(nn.Module):
+ def __init__(self, prefix, config, weights, index):
+ super().__init__()
+ act = config.hidden_act
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+ # Fuse gate and up proj
+ c_fc = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.c_fc",
+ weights=weights,
+ bias=config.use_bias,
+ )
+ c_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.c_proj",
+ weights=weights,
+ bias=config.use_bias,
+ )
+
+ self.c_fc = TensorParallelMultiAdapterLinear.load(
+ c_fc,
+ layer_id=index,
+ layer_names=[f"{prefix}.c_fc"],
+ sizes=[config.intermediate_size, config.intermediate_size],
+ process_group=weights.process_group,
+ )
+
+ self.c_proj = TensorParallelAdapterRowLinear.load(
+ c_proj,
+ index,
+ "c_proj",
+ process_group=weights.process_group,
+ )
+
+ def forward(self, hidden_states, adapter_data):
+ hidden_states = self.c_fc(hidden_states, adapter_data)
+ hidden_states = self.act(hidden_states)
+ return self.c_proj(hidden_states, adapter_data)
+
+
+class Starcoder2GatedMLP(nn.Module):
+ def __init__(self, index, prefix, config, weights):
+ super().__init__()
+ act = config.hidden_act
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+ # Fuse gate and up proj
+ prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"]
+ sizes = [
+ config.intermediate_size,
+ config.intermediate_size,
+ ]
+ gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=prefixes,
+ weights=weights,
+ dim=0,
+ bias=config.use_bias,
+ )
+ self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
+ gate_up_proj,
+ index,
+ layer_names=prefixes,
+ sizes=sizes,
+ process_group=weights.process_group,
+ )
+ down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=config.use_bias,
+ )
+ self.down_proj = TensorParallelAdapterRowLinear.load(
+ down_proj,
+ index,
+ "down_proj",
+ process_group=weights.process_group,
+ )
+ self.intermediate_size = (
+ config.intermediate_size // weights.process_group.size()
+ )
+
+ def forward(self, hidden_states, adapter_data):
+ gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
+ gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
+ return self.down_proj(
+ self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
+ )
+
+
+STARCODER2_NORMALIZATION_CLASSES = {
+ "layer_norm": FastLayerNorm,
+ "rms_norm": FastRMSNorm,
+}
+
+STARCODER2_MLP_CLASSES = {
+ "default": Starcoder2MLP,
+ "gated": Starcoder2GatedMLP,
+}
+
+
+class Starcoder2Layer(nn.Module):
+ def __init__(self, layer_id, config, weights):
+ super().__init__()
+ prefix = f"model.layers.{layer_id}"
+ self.self_attn = Starcoder2Attention(
+ prefix=f"{prefix}.self_attn", config=config, weights=weights, index=layer_id
+ )
+
+ self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
+ prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id
+ )
+
+ self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.norm_epsilon
+ )
+ self.post_attention_layernorm = STARCODER2_NORMALIZATION_CLASSES[
+ config.norm_type
+ ].load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.norm_epsilon,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ ):
+ normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ normed_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ )
+
+ # faster post attention rms norm
+ normed_attn_res_output, attn_res = self.post_attention_layernorm(
+ attn_output, res
+ )
+
+ mlp_output = self.mlp(normed_attn_res_output, adapter_data)
+
+ return mlp_output, attn_res
+
+
+class Starcoder2Model(torch.nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+
+ process_group = weights.process_group
+ self.tp_rank = process_group.rank()
+ self.tp_world_size = process_group.size()
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix=f"{prefix}.embed_tokens", weights=weights
+ )
+ self.layers = nn.ModuleList(
+ [
+ Starcoder2Layer(
+ layer_id,
+ config,
+ weights,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
+ prefix=f"{prefix}.norm", weights=weights, eps=config.norm_epsilon
+ )
+
+ self.gradient_checkpointing = False
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ adapter_data,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ ) -> torch.Tensor:
+ hidden_states = self.embed_tokens(input_ids)
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
+
+ residual = None
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ )
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashStarcoder2ForCausalLM(torch.nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+
+ if not prefix:
+ prefix = "model"
+ else:
+ prefix = f"{prefix}.model"
+
+ self.model = Starcoder2Model(prefix, config, weights)
+ try:
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix="lm_head",
+ weights=weights,
+ )
+ except RuntimeError:
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix=f"{prefix}.embed_tokens",
+ weights=weights,
+ )
+
+ self.max_past = config.sliding_window
+ self.max_past_tensor = (
+ torch.tensor(config.sliding_window, device=weights.device)
+ if self.max_past is not None
+ else None
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+
+ hidden_states = self.model(
+ input_ids,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ slots,
+ seqlen,
+ adapter_data,
+ hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits = self.lm_head(hidden_states)
+ return logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py
new file mode 100644
index 000000000..02806ac94
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py
@@ -0,0 +1,852 @@
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch Idefics2 model."""
+
+from typing import List, Optional, Tuple
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+import math
+
+from transformers.activations import ACT2FN
+from text_generation_server.models.custom_modeling.vlm import (
+ load_text_model,
+)
+from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
+from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
+
+from text_generation_server.layers import (
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ TensorParallelRowLinear,
+)
+from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(
+ batch, num_key_value_heads, n_rep, slen, head_dim
+ )
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+class Idefics2VisionEmbeddings(nn.Module):
+ """
+ This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
+ resolution.
+
+ The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
+ which allows treating images in their native aspect ratio and without the need to resize them to the same
+ fixed size. In particular, we start from the original pre-trained SigLIP model
+ (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
+ """
+
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding="valid",
+ )
+ self.patch_embedding.weight = nn.Parameter(
+ weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
+ )
+ self.patch_embedding.bias = nn.Parameter(
+ weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
+ )
+
+ self.num_patches_per_side = self.image_size // self.patch_size
+ self.num_patches = self.num_patches_per_side**2
+ self.num_positions = self.num_patches
+ self.position_embedding = TensorParallelEmbedding(
+ prefix=f"{prefix}.position_embedding", weights=weights
+ )
+
+ def forward(
+ self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor
+ ) -> torch.Tensor:
+ batch_size, _, max_im_h, max_im_w = pixel_values.shape
+
+ patch_embeds = self.patch_embedding(pixel_values)
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
+
+ max_nb_patches_h, max_nb_patches_w = (
+ max_im_h // self.patch_size,
+ max_im_w // self.patch_size,
+ )
+ boundaries = torch.arange(
+ 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
+ )
+ position_ids = torch.full(
+ size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
+ )
+
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
+ nb_patches_h = p_attn_mask[:, 0].sum()
+ nb_patches_w = p_attn_mask[0].sum()
+
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
+
+ bucket_coords_h = torch.bucketize(
+ fractional_coords_h, boundaries, right=True
+ )
+ bucket_coords_w = torch.bucketize(
+ fractional_coords_w, boundaries, right=True
+ )
+
+ pos_ids = (
+ bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
+ ).flatten()
+ position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
+
+ position_ids = position_ids.to(self.position_embedding.weight.device)
+ embeddings = embeddings + self.position_embedding(position_ids)
+ return embeddings
+
+
+class Idefics2VisionAttention(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_size = self.embed_dim // self.num_heads
+ if self.head_size * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_size**-0.5
+ self.dropout = config.attention_dropout
+
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.embed_dim = self.embed_dim // weights.process_group.size()
+
+ self.qkv = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=True,
+ )
+ self.out_proj = TensorParallelRowLinear.load(
+ config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
+ )
+ self.is_causal = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, q_len, _ = hidden_states.size()
+
+ qkv = self.qkv(hidden_states)
+ query_states, key_states, value_states = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ self.head_size * self.num_heads,
+ self.head_size * self.num_heads,
+ ],
+ dim=2,
+ )
+
+ query_states = query_states.view(
+ batch_size, q_len, self.num_heads, self.head_size
+ ).transpose(1, 2)
+ key_states = key_states.view(
+ batch_size, q_len, self.num_heads, self.head_size
+ ).transpose(1, 2)
+ value_states = value_states.view(
+ batch_size, q_len, self.num_heads, self.head_size
+ ).transpose(1, 2)
+
+ k_v_seq_len = key_states.shape[-2]
+ attn_weights = (
+ torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
+ )
+
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(
+ attn_weights, dim=-1, dtype=torch.float32
+ ).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(
+ attn_weights, p=self.dropout, training=self.training
+ )
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):
+ raise ValueError(
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output
+
+
+class Idefics2VisionMLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = TensorParallelColumnLinear.load(
+ prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
+ )
+ self.fc2 = TensorParallelRowLinear.load(
+ prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Idefics2EncoderLayer(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = Idefics2VisionAttention(
+ prefix=f"{prefix}.self_attn", config=config, weights=weights
+ )
+ self.layer_norm1 = nn.LayerNorm.load(
+ prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights
+ )
+ self.layer_norm2 = nn.LayerNorm.load(
+ prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights
+ )
+ self.mlp = Idefics2VisionMLP(
+ prefix=f"{prefix}.mlp", config=config, weights=weights
+ )
+
+ # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+class Idefics2Encoder(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList(
+ [
+ Idefics2EncoderLayer(
+ prefix=f"{prefix}.layers.{i}", config=config, weights=weights
+ )
+ for i in range(config.num_hidden_layers)
+ ]
+ )
+
+ # Ignore copy
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ hidden_states = encoder_layer(
+ hidden_states,
+ attention_mask,
+ )
+ return hidden_states
+
+
+class Idefics2VisionTransformer(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.embeddings = Idefics2VisionEmbeddings(
+ prefix=f"{prefix}.embeddings", config=config, weights=weights
+ )
+ self.encoder = Idefics2Encoder(
+ prefix=f"{prefix}.encoder", config=config, weights=weights
+ )
+ self.post_layernorm = nn.LayerNorm.load(
+ prefix=f"{prefix}.post_layernorm",
+ weights=weights,
+ eps=config.layer_norm_eps,
+ )
+
+ def forward(
+ self,
+ pixel_values,
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
+ ):
+ batch_size = pixel_values.size(0)
+ if patch_attention_mask is None:
+ patch_size = self.config.patch_size
+ patch_attention_mask = torch.ones(
+ (
+ batch_size,
+ pixel_values.size(2) // patch_size,
+ pixel_values.size(3) // patch_size,
+ )
+ )
+ patch_attention_mask = patch_attention_mask.to(
+ dtype=torch.bool, device=pixel_values.device
+ )
+
+ hidden_states = self.embeddings(
+ pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
+ )
+
+ patch_attention_mask = patch_attention_mask.view(batch_size, -1)
+ # The call to `_upad_input` in `_flash_attention_forward` is expensive
+ # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
+ # avoiding passing the attention_mask, which is equivalent to attending to the full sequence
+ if not torch.any(~patch_attention_mask):
+ patch_attention_mask = None
+ else:
+ patch_attention_mask = _prepare_4d_attention_mask(
+ patch_attention_mask, hidden_states.dtype
+ )
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=patch_attention_mask,
+ )
+
+ last_hidden_state = encoder_outputs
+ last_hidden_state = self.post_layernorm(last_hidden_state)
+
+ return last_hidden_state
+
+
+class Idefics2MLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ act = config.text_config.hidden_act
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate=(
+ "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
+ ),
+ )
+ )
+ self.gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=False,
+ )
+ self.down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=False,
+ )
+
+ def forward(self, hidden_states):
+ start_shape = hidden_states.shape[:-1]
+ gate_up_states = self.gate_up_proj(hidden_states)
+ intermediate_size = gate_up_states.shape[-1] // 2
+ gate_up_states = gate_up_states.view(-1, 2, intermediate_size)
+ return self.down_proj(
+ self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]
+ ).view(*start_shape, -1)
+
+
+class Idefics2RMSNorm(nn.Module):
+ def __init__(self, prefix, weights, eps):
+ """
+ Idefics2RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(
+ weights.get_tensor(f"{prefix}.weight"), requires_grad=False
+ )
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+class Idefics2PerceiverAttention(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+
+ self.layer_idx = None
+ self.hidden_size = config.text_config.hidden_size
+ self.num_heads = config.perceiver_config.resampler_n_heads
+ self.head_size = config.perceiver_config.resampler_head_dim
+ self.num_key_value_heads = config.perceiver_config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.attention_dropout = config.perceiver_config.attention_dropout
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ self.num_key_value_heads // weights.process_group.size()
+ )
+
+ self.q_proj = TensorParallelColumnLinear.load(
+ config,
+ prefix=f"{prefix}.q_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.kv = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=False,
+ )
+ self.o_proj = TensorParallelRowLinear.load(
+ config=config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
+ )
+
+ self.is_causal = False
+
+ def forward(
+ self,
+ latents: torch.Tensor,
+ context: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = latents.size()
+ kv_seq_len = q_len + context.size()[1]
+
+ hidden_states = torch.concat([context, latents], dim=-2)
+ query_states = self.q_proj(latents)
+ kv = self.kv(hidden_states)
+ key_states, value_states = kv.split(
+ [
+ self.head_size * self.num_key_value_heads,
+ self.head_size * self.num_key_value_heads,
+ ],
+ dim=2,
+ )
+
+ query_states = query_states.view(
+ bsz, q_len, self.num_heads, self.head_size
+ ).transpose(1, 2)
+ key_states = key_states.view(
+ bsz, kv_seq_len, self.num_key_value_heads, self.head_size
+ ).transpose(1, 2)
+ value_states = value_states.view(
+ bsz, kv_seq_len, self.num_key_value_heads, self.head_size
+ ).transpose(1, 2)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(
+ query_states, key_states.transpose(2, 3)
+ ) / math.sqrt(self.head_size)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(
+ attn_weights, dim=-1, dtype=torch.float32
+ ).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_size):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_size)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output
+
+
+class Idefics2PerceiverLayer(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.hidden_size = config.text_config.hidden_size
+ self.n_latents = config.perceiver_config.resampler_n_latents
+ self.depth = config.perceiver_config.resampler_depth
+ self.rms_norm_eps = config.text_config.rms_norm_eps
+
+ self.input_latents_norm = Idefics2RMSNorm(
+ prefix=f"{prefix}.input_latents_norm",
+ weights=weights,
+ eps=self.rms_norm_eps,
+ )
+ self.input_context_norm = Idefics2RMSNorm(
+ prefix=f"{prefix}.input_context_norm",
+ weights=weights,
+ eps=self.rms_norm_eps,
+ )
+ self.self_attn = Idefics2PerceiverAttention(
+ prefix=f"{prefix}.self_attn", config=config, weights=weights
+ )
+ self.post_attention_layernorm = Idefics2RMSNorm(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=self.rms_norm_eps,
+ )
+ self.mlp = Idefics2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
+
+ def forward(
+ self,
+ latents: torch.Tensor,
+ context: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ """
+ Args:
+ latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ """
+ residual = latents
+
+ latents = self.input_latents_norm(latents)
+ context = self.input_context_norm(context)
+
+ latents = self.self_attn(
+ latents=latents,
+ context=context,
+ attention_mask=attention_mask,
+ )
+ latents = residual + latents
+ residual = latents
+
+ latents = self.post_attention_layernorm(latents)
+ latents = self.mlp(latents)
+ latents = residual + latents
+
+ return latents
+
+
+class Idefics2PerceiverResampler(nn.Module):
+ def __init__(self, prefix, config, weights) -> None:
+ super().__init__()
+ self.hidden_size = config.text_config.hidden_size
+ self.hidden_act = config.perceiver_config.hidden_act
+ self.n_latents = config.perceiver_config.resampler_n_latents
+ self.depth = config.perceiver_config.resampler_depth
+ self.rms_norm_eps = config.text_config.rms_norm_eps
+
+ # Create Latents for Perceiver
+ self.latents = weights.get_tensor(f"{prefix}.latents")
+
+ # Create Transformer Blocks
+ self.layers = nn.ModuleList(
+ [
+ Idefics2PerceiverLayer(
+ prefix=f"{prefix}.layers.{idx}", config=config, weights=weights
+ )
+ for idx in range(self.depth)
+ ]
+ )
+ self.norm = Idefics2RMSNorm(
+ prefix=f"{prefix}.norm",
+ weights=weights,
+ eps=config.text_config.rms_norm_eps,
+ )
+
+ def forward(
+ self,
+ context: torch.Tensor,
+ attention_mask,
+ ) -> torch.Tensor:
+ # seq embed -> bsz seq embed
+ latents = self.latents.unsqueeze(0).expand(
+ (context.shape[0], *self.latents.size())
+ )
+
+ latent_attention_mask = torch.ones(
+ (attention_mask.size(0), latents.size(1)),
+ dtype=attention_mask.dtype,
+ device=attention_mask.device,
+ )
+ attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1)
+ attention_mask = _prepare_4d_attention_mask(
+ attention_mask, latents.dtype, tgt_len=self.n_latents
+ )
+
+ compressed_context = latents
+ for perceiver_layer in self.layers:
+ compressed_context = perceiver_layer(
+ compressed_context,
+ context,
+ attention_mask=attention_mask,
+ )
+ compressed_context = self.norm(compressed_context)
+
+ return compressed_context
+
+
+class Idefics2Connector(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.modality_projection = Idefics2MLP(
+ prefix=f"{prefix}.modality_projection", config=config, weights=weights
+ )
+ self.perceiver_resampler = Idefics2PerceiverResampler(
+ prefix=f"{prefix}.perceiver_resampler", config=config, weights=weights
+ )
+
+ def forward(self, image_hidden_states, attention_mask):
+ image_hidden_states = self.modality_projection(image_hidden_states)
+ image_hidden_states = self.perceiver_resampler(
+ context=image_hidden_states, attention_mask=attention_mask
+ )
+ return image_hidden_states
+
+
+class Idefics2ForConditionalGeneration(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ config.vision_config.quantize = None
+ config.vision_config.speculator = config.speculator
+ config.text_config.quantize = config.quantize
+ config.text_config.speculator = config.speculator
+
+ vision_config = config.vision_config
+ self.text_model = load_text_model(
+ prefix="model" if not prefix else f"{prefix}.model",
+ config=config.text_config,
+ weights=weights,
+ name="text_model",
+ )
+ self.dtype = weights.dtype
+
+ # The vision and connector models are not quantized.
+ with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):
+ self.vision_model = Idefics2VisionTransformer(
+ prefix=(
+ f"{prefix}.model.vision_model" if prefix else "model.vision_model"
+ ),
+ config=vision_config,
+ weights=weights,
+ )
+
+ config.quantize = None
+ self.connector = Idefics2Connector(
+ prefix=f"{prefix}.model.connector" if prefix else "model.connector",
+ config=config,
+ weights=weights,
+ )
+
+ self.config = config
+ self.image_seq_len = config.perceiver_config.resampler_n_latents
+ self.image_token_id = config.image_token_id
+ self.pad_token_id = (
+ config.pad_token_id if config.pad_token_id is not None else -1
+ )
+
+ def _merge_input_ids_with_image_features(
+ self,
+ input_ids: torch.Tensor,
+ inputs_embeds: torch.Tensor,
+ image_features: torch.Tensor,
+ ):
+ """In place merges in vision_embeddings with inputs_embeds."""
+ # mask = input_ids == self.config.image_token_index
+ # - replace `==` with torch.where to fix the issue in hpu graph
+ mask = torch.where(input_ids == self.config.image_token_id)
+ # Let's pray we have enabled enough slots !
+ inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
+ return inputs_embeds
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ pixel_values: torch.FloatTensor = None,
+ pixel_attention_mask: Optional[torch.BoolTensor] = None,
+ # Unused here
+ image_sizes: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ ):
+ inputs_embeds = self.text_model.embed_tokens(input_ids)
+ if pixel_values is not None:
+ batch_size, num_images, num_channels, height, width = pixel_values.shape
+ all_states = []
+ all_pixel_values = pixel_values
+ all_pixel_mask = pixel_attention_mask
+ for i in range(batch_size):
+ pixel_values = all_pixel_values.to(
+ dtype=self.dtype
+ ) # fp16 compatibility
+ pixel_values = pixel_values[i : i + 1]
+ pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
+
+ # Remove padding images - padding images are full 0.
+ nb_values_per_image = pixel_values.shape[1:].numel()
+ real_images_inds = (pixel_values == 0.0).sum(
+ dim=(-1, -2, -3)
+ ) != nb_values_per_image
+ pixel_values = pixel_values[real_images_inds].contiguous()
+
+ # Handle the vision attention mask
+ if pixel_attention_mask is None:
+ pixel_attention_mask = torch.ones(
+ size=(
+ pixel_values.size(0),
+ pixel_values.size(2),
+ pixel_values.size(3),
+ ),
+ dtype=torch.bool,
+ device=pixel_values.device,
+ )
+ else:
+ # Remove padding images from the mask/pP p
+ pixel_attention_mask = all_pixel_mask[i : i + 1]
+ pixel_attention_mask = pixel_attention_mask.view(
+ 1 * num_images, *pixel_attention_mask.shape[2:]
+ )
+ pixel_attention_mask = pixel_attention_mask[
+ real_images_inds
+ ].contiguous()
+
+ patch_size = self.config.vision_config.patch_size
+ """
+ patches_subgrid = pixel_attention_mask.unfold(
+ dimension=1, size=patch_size, step=patch_size
+ )
+ patches_subgrid = patches_subgrid.unfold(
+ dimension=2, size=patch_size, step=patch_size
+ )
+ patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
+ """
+ # hpu does none support unfold
+ conv_kernel = torch.ones(
+ [1, 1, patch_size, patch_size],
+ dtype=pixel_values.dtype,
+ device=pixel_values.device,
+ )
+ patches_subgrid = torch.nn.functional.conv2d(
+ pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype),
+ conv_kernel,
+ stride=patch_size,
+ ).squeeze(1)
+ patch_attention_mask = torch.eq(
+ patches_subgrid, (patch_size * patch_size)
+ )
+
+ # Get sequence from the vision encoder
+ image_hidden_states = self.vision_model(
+ pixel_values=pixel_values,
+ patch_attention_mask=patch_attention_mask,
+ )
+
+ # Modality projection & resampling
+ image_hidden_states = self.connector(
+ image_hidden_states,
+ attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
+ )
+ all_states.append(image_hidden_states)
+ image_hidden_states = torch.stack(all_states, dim=0)
+ # When we generate, we don't want to replace the potential image_token_id that we generated by images
+ # that simply don't exist
+ inputs_embeds = self._merge_input_ids_with_image_features(
+ input_ids, inputs_embeds, image_hidden_states
+ )
+
+ hidden_states = self.text_model.model(
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ cu_seqlen_prefill=cu_seqlen_prefill,
+ kv_cache=kv_cache,
+ slots=slots,
+ seqlen=seqlen,
+ hpu_attention_meta=hpu_attention_meta,
+ adapter_data=adapter_data,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.text_model.lm_head(hidden_states)
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py
new file mode 100644
index 000000000..964526fcf
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py
@@ -0,0 +1,596 @@
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch Idefics3 model."""
+
+from typing import List, Optional, Tuple
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from transformers.activations import ACT2FN
+from text_generation_server.models.custom_modeling.vlm import (
+ load_text_model,
+)
+from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
+from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
+
+from text_generation_server.layers import (
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ TensorParallelRowLinear,
+)
+from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(
+ batch, num_key_value_heads, n_rep, slen, head_dim
+ )
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+class Idefics3VisionEmbeddings(nn.Module):
+ """
+ This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
+ resolution.
+
+ The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
+ which allows treating images in their native aspect ratio and without the need to resize them to the same
+ fixed size. In particular, we start from the original pre-trained SigLIP model
+ (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
+ """
+
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding="valid",
+ )
+ self.patch_embedding.weight = nn.Parameter(
+ weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
+ )
+ self.patch_embedding.bias = nn.Parameter(
+ weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
+ )
+
+ self.num_patches_per_side = self.image_size // self.patch_size
+ self.num_patches = self.num_patches_per_side**2
+ self.num_positions = self.num_patches
+ self.position_embedding = TensorParallelEmbedding(
+ prefix=f"{prefix}.position_embedding", weights=weights
+ )
+
+ def forward(
+ self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor
+ ) -> torch.Tensor:
+ batch_size, _, max_im_h, max_im_w = pixel_values.shape
+
+ patch_embeds = self.patch_embedding(pixel_values)
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
+
+ max_nb_patches_h, max_nb_patches_w = (
+ max_im_h // self.patch_size,
+ max_im_w // self.patch_size,
+ )
+ boundaries = torch.arange(
+ 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
+ )
+ position_ids = torch.full(
+ size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
+ )
+
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
+ nb_patches_h = p_attn_mask[:, 0].sum()
+ nb_patches_w = p_attn_mask[0].sum()
+
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
+
+ bucket_coords_h = torch.bucketize(
+ fractional_coords_h, boundaries, right=True
+ )
+ bucket_coords_w = torch.bucketize(
+ fractional_coords_w, boundaries, right=True
+ )
+
+ pos_ids = (
+ bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
+ ).flatten()
+ position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
+
+ position_ids = position_ids.to(self.position_embedding.weight.device)
+ embeddings = embeddings + self.position_embedding(position_ids)
+ return embeddings
+
+
+class Idefics3VisionAttention(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_size = self.embed_dim // self.num_heads
+ if self.head_size * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_size**-0.5
+ self.dropout = config.attention_dropout
+
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.embed_dim = self.embed_dim // weights.process_group.size()
+
+ self.qkv = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=True,
+ )
+ self.out_proj = TensorParallelRowLinear.load(
+ config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
+ )
+ self.is_causal = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, q_len, _ = hidden_states.size()
+
+ qkv = self.qkv(hidden_states)
+ query_states, key_states, value_states = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ self.head_size * self.num_heads,
+ self.head_size * self.num_heads,
+ ],
+ dim=2,
+ )
+
+ query_states = query_states.view(
+ batch_size, q_len, self.num_heads, self.head_size
+ ).transpose(1, 2)
+ key_states = key_states.view(
+ batch_size, q_len, self.num_heads, self.head_size
+ ).transpose(1, 2)
+ value_states = value_states.view(
+ batch_size, q_len, self.num_heads, self.head_size
+ ).transpose(1, 2)
+
+ k_v_seq_len = key_states.shape[-2]
+ attn_weights = (
+ torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
+ )
+
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(
+ attn_weights, dim=-1, dtype=torch.float32
+ ).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(
+ attn_weights, p=self.dropout, training=self.training
+ )
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):
+ raise ValueError(
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output
+
+
+class Idefics3VisionMLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = TensorParallelColumnLinear.load(
+ prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
+ )
+ self.fc2 = TensorParallelRowLinear.load(
+ prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Idefics3EncoderLayer(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = Idefics3VisionAttention(
+ prefix=f"{prefix}.self_attn", config=config, weights=weights
+ )
+ self.layer_norm1 = nn.LayerNorm.load(
+ prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights
+ )
+ self.layer_norm2 = nn.LayerNorm.load(
+ prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights
+ )
+ self.mlp = Idefics3VisionMLP(
+ prefix=f"{prefix}.mlp", config=config, weights=weights
+ )
+
+ # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+class Idefics3Encoder(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList(
+ [
+ Idefics3EncoderLayer(
+ prefix=f"{prefix}.layers.{i}", config=config, weights=weights
+ )
+ for i in range(config.num_hidden_layers)
+ ]
+ )
+
+ # Ignore copy
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ hidden_states = encoder_layer(
+ hidden_states,
+ attention_mask,
+ )
+ return hidden_states
+
+
+class Idefics3VisionTransformer(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.embeddings = Idefics3VisionEmbeddings(
+ prefix=f"{prefix}.embeddings", config=config, weights=weights
+ )
+ self.encoder = Idefics3Encoder(
+ prefix=f"{prefix}.encoder", config=config, weights=weights
+ )
+ self.post_layernorm = nn.LayerNorm.load(
+ prefix=f"{prefix}.post_layernorm",
+ weights=weights,
+ eps=config.layer_norm_eps,
+ )
+
+ def forward(
+ self,
+ pixel_values,
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
+ ):
+ batch_size = pixel_values.size(0)
+ if patch_attention_mask is None:
+ patch_size = self.config.patch_size
+ patch_attention_mask = torch.ones(
+ (
+ batch_size,
+ pixel_values.size(2) // patch_size,
+ pixel_values.size(3) // patch_size,
+ )
+ )
+ patch_attention_mask = patch_attention_mask.to(
+ dtype=torch.bool, device=pixel_values.device
+ )
+
+ hidden_states = self.embeddings(
+ pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
+ )
+
+ patch_attention_mask = patch_attention_mask.view(batch_size, -1)
+ # The call to `_upad_input` in `_flash_attention_forward` is expensive
+ # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
+ # avoiding passing the attention_mask, which is equivalent to attending to the full sequence
+ if not torch.any(~patch_attention_mask):
+ patch_attention_mask = None
+ else:
+ patch_attention_mask = _prepare_4d_attention_mask(
+ patch_attention_mask, hidden_states.dtype
+ )
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=patch_attention_mask,
+ )
+
+ last_hidden_state = encoder_outputs
+ last_hidden_state = self.post_layernorm(last_hidden_state)
+
+ return last_hidden_state
+
+
+class Idefics3SimpleMLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ input_size = config.vision_config.hidden_size * (config.scale_factor**2)
+ output_size = config.text_config.hidden_size
+ proj = nn.Parameter(
+ weights.get_tensor(f"{prefix}.modality_projection.proj.weight"),
+ requires_grad=False,
+ ).to(weights.dtype)
+ self.proj = nn.Linear(input_size, output_size, bias=False)
+ self.proj.weight = proj
+
+ def forward(self, x):
+ return self.proj(x)
+
+
+class Idefics3Connector(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.modality_projection = Idefics3SimpleMLP(prefix, config, weights)
+ self.scale_factor = config.scale_factor
+
+ def pixel_shuffle(self, x, scale_factor=2):
+ bsz, seq, embed_dim = x.size()
+ height = width = int(seq**0.5)
+ x = x.view(bsz, height, width, embed_dim)
+ x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
+ x = x.permute(0, 2, 1, 3)
+ x = x.reshape(
+ bsz,
+ int(width / scale_factor),
+ int(height / scale_factor),
+ embed_dim * (scale_factor**2),
+ )
+ x = x.permute(0, 2, 1, 3)
+ x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
+ return x
+
+ def forward(self, image_hidden_states):
+ image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
+ image_hidden_states = self.modality_projection(image_hidden_states)
+ return image_hidden_states
+
+
+class Idefics3ForConditionalGeneration(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ config.vision_config.quantize = None
+ config.vision_config.speculator = config.speculator
+ config.text_config.quantize = config.quantize
+ config.text_config.speculator = config.speculator
+ # set tie_word_embeddings to True to load `.embed_tokens.weight` instead of `.lm_head.weight`
+ # since Idefics3 uses the `embed_tokens` for the final prediction
+ # config.text_config.tie_word_embeddings = True
+
+ vision_config = config.vision_config
+ self.text_model = load_text_model(
+ prefix="model" if not prefix else f"{prefix}.model",
+ config=config.text_config,
+ weights=weights,
+ name="text_model",
+ )
+ self.dtype = weights.dtype
+
+ # The vision and connector models are not quantized.
+ with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):
+ self.vision_model = Idefics3VisionTransformer(
+ prefix=(
+ f"{prefix}.model.vision_model" if prefix else "model.vision_model"
+ ),
+ config=vision_config,
+ weights=weights,
+ )
+
+ config.quantize = None
+ self.connector = Idefics3Connector(
+ prefix=f"{prefix}.model.connector" if prefix else "model.connector",
+ config=config,
+ weights=weights,
+ )
+
+ self.config = config
+ self.image_token_id = config.image_token_id
+ self.pad_token_id = (
+ config.pad_token_id if config.pad_token_id is not None else -1
+ )
+
+ def _merge_input_ids_with_image_features(
+ self,
+ input_ids: torch.Tensor,
+ inputs_embeds: torch.Tensor,
+ image_features: torch.Tensor,
+ ):
+ """In place merges in vision_embeddings with inputs_embeds."""
+ # mask = input_ids == self.config.image_token_index
+ # - replace `==` with torch.where to fix the issue in hpu graph
+ mask = torch.where(input_ids == self.config.image_token_id)
+ # Let's pray we have enabled enough slots !
+ inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
+ return inputs_embeds
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor] = None,
+ pixel_values: torch.FloatTensor = None,
+ pixel_attention_mask: Optional[torch.BoolTensor] = None,
+ # Unused here
+ image_sizes: Optional[torch.Tensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ cross_attention_states: Optional[torch.Tensor] = None,
+ image_indices=None,
+ ):
+ inputs_embeds = self.text_model.embed_tokens(input_ids)
+ if pixel_values is not None:
+ batch_size, num_images, num_channels, height, width = pixel_values.shape
+ all_states = []
+ all_pixel_values = pixel_values
+ all_pixel_mask = pixel_attention_mask
+ for i in range(batch_size):
+ pixel_values = all_pixel_values.to(
+ dtype=self.dtype
+ ) # fp16 compatibility
+ pixel_values = pixel_values[i : i + 1]
+ pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
+
+ # Remove padding images - padding images are full 0.
+ nb_values_per_image = pixel_values.shape[1:].numel()
+ real_images_inds = (pixel_values == 0.0).sum(
+ dim=(-1, -2, -3)
+ ) != nb_values_per_image
+ pixel_values = pixel_values[real_images_inds].contiguous()
+ # Handle the vision attention mask
+ if pixel_attention_mask is None:
+ pixel_attention_mask = torch.ones(
+ size=(
+ pixel_values.size(0),
+ pixel_values.size(2),
+ pixel_values.size(3),
+ ),
+ dtype=torch.bool,
+ device=pixel_values.device,
+ )
+ else:
+ # Remove padding images from the mask/pP p
+ pixel_attention_mask = all_pixel_mask[i : i + 1]
+ pixel_attention_mask = pixel_attention_mask.view(
+ 1 * num_images, *pixel_attention_mask.shape[2:]
+ )
+ pixel_attention_mask = pixel_attention_mask[
+ real_images_inds
+ ].contiguous()
+
+ patch_size = self.config.vision_config.patch_size
+ """
+ patches_subgrid = pixel_attention_mask.unfold(
+ dimension=1, size=patch_size, step=patch_size
+ )
+ patches_subgrid = patches_subgrid.unfold(
+ dimension=2, size=patch_size, step=patch_size
+ )
+ patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
+ """
+ # hpu does none support unfold
+ conv_kernel = torch.ones(
+ [1, 1, patch_size, patch_size],
+ dtype=pixel_values.dtype,
+ device=pixel_values.device,
+ )
+ patches_subgrid = torch.nn.functional.conv2d(
+ pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype),
+ conv_kernel,
+ stride=patch_size,
+ ).squeeze(1)
+ patch_attention_mask = torch.eq(
+ patches_subgrid, (patch_size * patch_size)
+ )
+
+ # Get sequence from the vision encoder
+ image_hidden_states = self.vision_model(
+ pixel_values=pixel_values,
+ patch_attention_mask=patch_attention_mask,
+ )
+
+ # Modality projection & resampling
+ image_hidden_states = self.connector(
+ image_hidden_states,
+ )
+
+ all_states.append(image_hidden_states)
+ image_hidden_states = torch.stack(all_states, dim=0)
+
+ inputs_embeds = self._merge_input_ids_with_image_features(
+ input_ids, inputs_embeds, image_hidden_states
+ )
+
+ hidden_states = self.text_model.model(
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ cu_seqlen_prefill=cu_seqlen_prefill,
+ kv_cache=kv_cache,
+ slots=slots,
+ seqlen=seqlen,
+ hpu_attention_meta=hpu_attention_meta,
+ adapter_data=adapter_data,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.text_model.lm_head(hidden_states)
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_config.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_config.py
new file mode 100644
index 000000000..a55658194
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_config.py
@@ -0,0 +1,326 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Idefics model configuration"""
+import copy
+
+from transformers import PretrainedConfig
+
+IDEFICS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "HuggingFaceM4/idefics-9b": "https://huggingface.co/HuggingFaceM4/idefics-9b/blob/main/config.json",
+ "HuggingFaceM4/idefics-80b": "https://huggingface.co/HuggingFaceM4/idefics-80b/blob/main/config.json",
+}
+
+
+class IdeficsVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`IdeficsModel`]. It is used to instantiate an
+ Idefics model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Idefics-9B.
+ e.g. [HuggingFaceM4/idefics-9b](https://huggingface.co/HuggingFaceM4/idefics-9b)
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `hidden_size`)
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ intermediate_size (`int`, *optional*, defaults to 5120):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ patch_size (`int`, *optional*, defaults to 14):
+ The size (resolution) of each patch.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ image_num_channels (`int`, *optional*, defaults to `3`):
+ Number of image channels.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_factor (`float`, *optional*, defaults to 1.0):
+ A factor for initializing all weight matrices (should be kept to 1.0, used internally for initialization
+ testing).
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ """
+
+ model_type = "idefics"
+ attribute_map = {
+ "hidden_size": "embed_dim",
+ }
+
+ def __init__(
+ self,
+ embed_dim=768,
+ image_size=224,
+ intermediate_size=5120,
+ patch_size=14,
+ num_hidden_layers=32,
+ num_attention_heads=16,
+ num_channels=3,
+ hidden_act="gelu",
+ layer_norm_eps=1e-5,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ initializer_factor=1.0,
+ **kwargs,
+ ):
+ self.embed_dim = embed_dim
+ self.image_size = image_size
+ self.intermediate_size = intermediate_size
+ self.patch_size = patch_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.layer_norm_eps = layer_norm_eps
+ self.attention_dropout = attention_dropout
+ self.initializer_range = initializer_range
+ self.initializer_factor = initializer_factor
+ self.hidden_act = hidden_act
+
+ super().__init__(**kwargs)
+
+
+class IdeficsPerceiverConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`IdeficsModel`]. It is used to instantiate an
+ Idefics model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Idefics-9B.
+ e.g. [HuggingFaceM4/idefics-9b](https://huggingface.co/HuggingFaceM4/idefics-9b)
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+ Args:
+ use_resampler (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the resampler
+ resampler_n_latents (`int`, *optional*, defaults to ):
+ Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
+ resampler_depth (`int`, *optional*, defaults to 6):
+ Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3).
+ resampler_n_heads (`int`, *optional*, defaults to 16):
+ Number of heads in each Transformer block (for multi-headed self-attention).
+ resampler_head_dim (`int`, *optional*, defaults to 96):
+ Dimensionality of each head projection in the Transformer block.
+ qk_layer_norms_perceiver (`bool`, *optional*, defaults to `False`):
+ Whether or not to use qk layer norms in perceiver
+ """
+
+ model_type = "idefics"
+
+ def __init__(
+ self,
+ use_resampler=False,
+ resampler_n_latents=64,
+ resampler_depth=6,
+ resampler_n_heads=16,
+ resampler_head_dim=96,
+ qk_layer_norms_perceiver=False,
+ **kwargs,
+ ):
+ self.use_resampler = use_resampler
+ self.resampler_n_latents = resampler_n_latents
+ self.resampler_depth = resampler_depth
+ self.resampler_n_heads = resampler_n_heads
+ self.resampler_head_dim = resampler_head_dim
+ self.qk_layer_norms_perceiver = qk_layer_norms_perceiver
+
+ super().__init__(**kwargs)
+
+
+class IdeficsConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`IdeficsModel`]. It is used to instantiate an
+ Idefics model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Idefics-9B.
+ e.g. [HuggingFaceM4/idefics-9b](https://huggingface.co/HuggingFaceM4/idefics-9b)
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+ Args:
+ additional_vocab_size (`int`, *optional`, defaults to 0):
+ Additional vocabulary size of the model, typically for the special "
" token. Additional vocab tokens
+ are always trainable whereas regular vocab tokens can be frozen or not.
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the Idefics model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`~IdeficsModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 11008):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ alpha_initializer (`str`, *optional*, defaults to `"zeros"`):
+ Initialization type for the alphas.
+ alphas_initializer_range (`float`, *optional*, defaults to 0.0):
+ The standard deviation of the truncated_normal_initializer for initializing the alphas in the Gated Cross
+ Attention.
+ alpha_type (`str`, *optional*, defaults to `"float"`):
+ Whether the gating alphas should be vectors or single floats.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-6):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 0)
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1)
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2)
+ End of stream token id.
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ cross_layer_interval (`int`, *optional*, default to 1)
+ Interval for cross attention (from text to image) layers.
+ qk_layer_norms (`bool`, *optional*, defaults to `False`): Whether to add layer norm after q and k
+ freeze_text_layers (`bool`, *optional*, defaults to `True`): Whether to freeze text layers
+ freeze_text_module_exceptions (`bool`, *optional*, defaults to `[]`):
+ Exceptions to freezing text layers when `freeze_text_layers` is `True`
+ freeze_lm_head (`bool`, *optional*, defaults to `False`): Whether to freeze lm head
+ freeze_vision_layers (`bool`, *optional*, defaults to `True`): Whether to freeze vision layers
+ freeze_vision_module_exceptions (`bool`, *optional*, defaults to `[]`):
+ Exceptions to freezing vision layers when `freeze_vision_layers` is `True`
+ use_resampler (`bool`, *optional*, defaults to `False`): Whether to use the Resampler
+ vision_config (`IdeficsVisionConfig`, *optional*): Custom vision config or dict
+ perceiver_config (`IdeficsPerceiverConfig`, *optional*): Custom perceiver config or dict
+ Example:
+ ```python
+ >>> from transformers import IdeficsModel, IdeficsConfig
+ >>> # Initializing a Idefics idefics-9b style configuration
+ >>> configuration = IdeficsConfig()
+ >>> # Initializing a model from the idefics-9b style configuration
+ >>> model = IdeficsModel(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "idefics"
+ is_composition = True
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ additional_vocab_size=0,
+ hidden_size=4096,
+ intermediate_size=11008,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ dropout=0.0,
+ hidden_act="silu",
+ initializer_range=0.02,
+ alpha_initializer="zeros",
+ alphas_initializer_range=0.0,
+ alpha_type="float",
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=False,
+ cross_layer_interval=1,
+ qk_layer_norms=False,
+ freeze_text_layers=True,
+ freeze_text_module_exceptions=[],
+ freeze_lm_head=False,
+ freeze_vision_layers=True,
+ freeze_vision_module_exceptions=[],
+ use_resampler=False,
+ vision_config=None,
+ perceiver_config=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.additional_vocab_size = additional_vocab_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.dropout = dropout
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.alpha_initializer = alpha_initializer
+ self.alphas_initializer_range = alphas_initializer_range
+ self.alpha_type = alpha_type
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+
+ self.cross_layer_interval = cross_layer_interval
+ self.qk_layer_norms = qk_layer_norms
+ self.freeze_vision_layers = freeze_vision_layers
+
+ self.freeze_text_layers = freeze_text_layers
+ self.freeze_text_module_exceptions = freeze_text_module_exceptions
+ self.freeze_vision_module_exceptions = freeze_vision_module_exceptions
+ self.freeze_lm_head = freeze_lm_head
+
+ self.use_resampler = use_resampler
+
+ if perceiver_config is None:
+ self.perceiver_config = IdeficsPerceiverConfig()
+ elif isinstance(perceiver_config, dict):
+ self.perceiver_config = IdeficsPerceiverConfig(**perceiver_config)
+ elif isinstance(perceiver_config, IdeficsPerceiverConfig):
+ self.perceiver_config = perceiver_config
+
+ if vision_config is None:
+ self.vision_config = IdeficsVisionConfig()
+ elif isinstance(vision_config, dict):
+ self.vision_config = IdeficsVisionConfig(**vision_config)
+ elif isinstance(vision_config, IdeficsVisionConfig):
+ self.vision_config = vision_config
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ # IMPORTANT: Do not do any __init__ args-based checks in the constructor, since
+ # PretrainedConfig.from_dict first instantiates the class with the config dict and only then
+ # updates the config object with `kwargs` from from_pretrained, so during the instantiation
+ # of this object many attributes have default values and haven't yet been overridden.
+ # Do any required checks inside `from_pretrained` once the superclass' `from_pretrained` was run.
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
+ Returns:
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = copy.deepcopy(self.__dict__)
+
+ output["vision_config"] = self.vision_config.to_dict()
+ output["perceiver_config"] = self.perceiver_config.to_dict()
+ output["model_type"] = self.__class__.model_type
+
+ return output
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_image_processing.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_image_processing.py
new file mode 100644
index 000000000..afb8e1f9c
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_image_processing.py
@@ -0,0 +1,297 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for Idefics."""
+
+from typing import Callable, Dict, List, Optional, Union, Iterable
+import numpy as np
+
+from PIL import Image
+
+import transformers
+from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
+from transformers.image_transforms import (
+ resize,
+ to_channel_dimension_format,
+ rescale,
+ normalize,
+)
+from transformers.image_utils import (
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ make_list_of_images,
+ to_numpy_array,
+ valid_images,
+)
+from io import BytesIO
+import base64
+import requests
+from transformers import TensorType, is_torch_available
+
+
+IDEFICS_STANDARD_MEAN = [0.48145466, 0.4578275, 0.40821073]
+IDEFICS_STANDARD_STD = [0.26862954, 0.26130258, 0.27577711]
+
+
+def convert_to_rgb(image):
+ # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
+ # for transparent images. The call to `alpha_composite` handles this case
+ if image.mode == "RGB":
+ return image
+
+ image_rgba = image.convert("RGBA")
+ background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
+ alpha_composite = Image.alpha_composite(background, image_rgba)
+ alpha_composite = alpha_composite.convert("RGB")
+ return alpha_composite
+
+
+class IdeficsImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a Idefics image processor.
+ Args:
+ image_size (`int`, *optional*, defaults to `224`):
+ Resize to image size
+ image_num_channels (`int`, *optional*, defaults to `3`):
+ Number of image channels.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
+ overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ image_size: int = 224,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ image_num_channels: Optional[int] = 3,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.image_size = image_size
+ self.image_num_channels = image_num_channels
+ self.image_mean = image_mean
+ self.image_std = image_std
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ image_num_channels: Optional[int] = 3,
+ image_size: Optional[Dict[str, int]] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ transform: Callable = None,
+ **kwargs,
+ ) -> TensorType.PYTORCH:
+ """
+ Preprocess a batch of images.
+ Args:
+ images (`ImageInput`):
+ A list of images to preprocess.
+ image_size (`int`, *optional*, defaults to `self.image_size`):
+ Resize to image size
+ image_num_channels (`int`, *optional*, defaults to `self.image_num_channels`):
+ Number of image channels.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can
+ be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess`
+ method. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ transform (`Callable`, *optional*, defaults to `None`):
+ A custom transform function that accepts a single image can be passed for training. For example,
+ `torchvision.Compose` can be used to compose multiple transforms. If `None` - an inference mode is
+ assumed - and then a preset of inference-specific transforms will be applied to the images
+ Returns:
+ a PyTorch tensor of the processed images
+ """
+ image_size = image_size if image_size is not None else self.image_size
+ image_num_channels = (
+ image_num_channels
+ if image_num_channels is not None
+ else self.image_num_channels
+ )
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ size = (image_size, image_size)
+
+ if len(images) == 0:
+ return []
+
+ images = make_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ # For training a user needs to pass their own set of transforms as a Callable.
+ # For reference this is what was used in the original IDEFICS training:
+ # transform = transforms.Compose([
+ # convert_to_rgb,
+ # transforms.RandomResizedCrop((size, size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
+ # transforms.ToTensor(),
+ # transforms.Normalize(mean=image_mean, std=image_std),
+ # ])
+ if transform is not None:
+ if not is_torch_available():
+ raise ImportError("To pass in `transform` torch must be installed")
+ import torch
+
+ images = [transform(x) for x in images]
+ return torch.stack(images)
+
+ # for inference we do the exact transforms that were used to train IDEFICS
+ images = [convert_to_rgb(x) for x in images]
+ # further transforms expect numpy arrays
+ images = [to_numpy_array(x) for x in images]
+ images = [resize(x, size, resample=PILImageResampling.BICUBIC) for x in images]
+ images = [self.rescale(image=image, scale=1 / 255) for image in images]
+ images = [self.normalize(x, mean=image_mean, std=image_std) for x in images]
+ images = [
+ to_channel_dimension_format(x, ChannelDimension.FIRST) for x in images
+ ]
+ # TODO: this converts to torch tensors - switch to convert_to_tensors once it becomes available
+ images = BatchFeature(
+ data={"pixel_values": images}, tensor_type=TensorType.PYTORCH
+ )["pixel_values"]
+
+ return images
+
+ def fetch_images(self, image_url_or_urls: Union[str, List[str]]):
+ """
+ Convert a single or a list of urls into the corresponding `PIL.Image` objects.
+ If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
+ returned.
+ """
+ headers = {
+ "User-Agent": (
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0"
+ " Safari/537.36"
+ )
+ }
+ if isinstance(image_url_or_urls, list):
+ return [self.fetch_images(x) for x in image_url_or_urls]
+ elif isinstance(image_url_or_urls, str):
+ image = image_url_or_urls
+
+ if image.startswith("http://") or image.startswith("https://"):
+ response = requests.get(
+ image_url_or_urls, stream=True, headers=headers, timeout=(1, 5)
+ )
+ response.raise_for_status()
+ content = response.content
+ elif image.startswith("data:"):
+ # https://stackoverflow.com/questions/17090571/is-there-a-way-to-set-background-image-as-a-base64-encoded-image
+ # data:image/png;base64,xxx
+ image = image.split(",")[-1]
+ content = base64.b64decode(image)
+ else:
+ raise ValueError(f"Unrecognized image {image}")
+
+ try:
+ image = Image.open(BytesIO(content))
+ # image.verify()
+ except Exception:
+ raise ValueError(f"Could not load image from url {image_url_or_urls}")
+ return image
+ else:
+ raise ValueError(
+ f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}"
+ )
+
+ def rescale(
+ self,
+ image: np.ndarray,
+ scale: float,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Rescale an image by a scale factor. image = image * scale.
+
+ Args:
+ image (`np.ndarray`):
+ Image to rescale.
+ scale (`float`):
+ The scaling factor to rescale pixel values by.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+
+ Returns:
+ `np.ndarray`: The rescaled image.
+ """
+ # return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs)
+ # requires 4.32
+ return rescale(image, scale=scale, data_format=data_format, **kwargs)
+
+ def normalize(
+ self,
+ image: np.ndarray,
+ mean: Union[float, Iterable[float]],
+ std: Union[float, Iterable[float]],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Normalize an image. image = (image - image_mean) / image_std.
+
+ Args:
+ image (`np.ndarray`):
+ Image to normalize.
+ mean (`float` or `Iterable[float]`):
+ Image mean to use for normalization.
+ std (`float` or `Iterable[float]`):
+ Image standard deviation to use for normalization.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+
+ Returns:
+ `np.ndarray`: The normalized image.
+ """
+ # TODO 4.32
+ return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
+
+
+transformers.IdeficsImageProcessor = IdeficsImageProcessor
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py
new file mode 100644
index 000000000..a130dbc12
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_modeling.py
@@ -0,0 +1,1474 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch Idefics model."""
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from transformers import PreTrainedModel
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ dataclass,
+)
+from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
+from text_generation_server.models.custom_modeling.idefics_vision import (
+ IdeficsVisionTransformer,
+)
+from text_generation_server.models.custom_modeling.idefics_perceiver import (
+ IdeficsPerceiverResampler,
+)
+from text_generation_server.layers import (
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ TensorParallelRowLinear,
+ SpeculativeHead,
+ FastLinear,
+)
+from text_generation_server.layers.rotary import PositionRotaryEmbedding
+from loguru import logger
+
+dropout_layer_norm = None
+
+
+@dataclass
+class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+class CausalLMOutputWithPastImage(CausalLMOutputWithPast):
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+# logger = logging.get_logger(__name__)
+
+# _CONFIG_FOR_DOC = "IdeficsConfig"
+
+# IDEFICS_PRETRAINED_MODEL_ARCHIVE_LIST = [
+# "HuggingFaceM4/idefics-9b",
+# "HuggingFaceM4/idefics-80b",
+# # See all Idefics models at https://huggingface.co/models?filter=idefics
+# ]
+
+
+def expand_inputs_for_generation(
+ input_ids,
+ expand_size=1,
+ is_encoder_decoder=False,
+ attention_mask=None,
+ encoder_outputs=None,
+ **model_kwargs,
+):
+ expanded_return_idx = (
+ torch.arange(input_ids.shape[0])
+ .view(-1, 1)
+ .repeat(1, expand_size)
+ .view(-1)
+ .to(input_ids.device)
+ )
+ input_ids = input_ids.index_select(0, expanded_return_idx)
+
+ if "token_type_ids" in model_kwargs:
+ token_type_ids = model_kwargs["token_type_ids"]
+ model_kwargs["token_type_ids"] = token_type_ids.index_select(
+ 0, expanded_return_idx
+ )
+
+ if attention_mask is not None:
+ model_kwargs["attention_mask"] = attention_mask.index_select(
+ 0, expanded_return_idx
+ )
+ model_kwargs["image_attention_mask"] = model_kwargs[
+ "image_attention_mask"
+ ].index_select(0, expanded_return_idx)
+ model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(
+ 0, expanded_return_idx
+ )
+
+ if is_encoder_decoder:
+ if encoder_outputs is None:
+ raise ValueError(
+ "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
+ )
+ encoder_outputs["last_hidden_state"] = (
+ encoder_outputs.last_hidden_state.index_select(
+ 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
+ )
+ )
+ model_kwargs["encoder_outputs"] = encoder_outputs
+ return input_ids, model_kwargs
+
+
+def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False):
+ # must have this key set to at least None
+ model_kwargs["past_key_values"] = model_kwargs.get("past_key_values", None)
+
+ # update past
+ if "past_key_values" in outputs:
+ model_kwargs["past"] = outputs.past_key_values
+ elif "mems" in outputs:
+ model_kwargs["past"] = outputs.mems
+ elif "past_buckets_states" in outputs:
+ model_kwargs["past"] = outputs.past_buckets_states
+ else:
+ model_kwargs["past"] = None
+
+ # update token_type_ids with last value
+ if "token_type_ids" in model_kwargs:
+ token_type_ids = model_kwargs["token_type_ids"]
+ model_kwargs["token_type_ids"] = torch.cat(
+ [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1
+ )
+
+ # update attention masks
+ if not is_encoder_decoder:
+ if "attention_mask" in model_kwargs:
+ attention_mask = model_kwargs["attention_mask"]
+ model_kwargs["attention_mask"] = torch.cat(
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))],
+ dim=-1,
+ )
+ if "image_attention_mask" in model_kwargs:
+ image_attention_mask = model_kwargs["image_attention_mask"]
+ last_mask = image_attention_mask[:, -1, :].unsqueeze(1)
+ model_kwargs["image_attention_mask"] = last_mask
+
+ return model_kwargs
+
+
+def prepare_inputs_for_generation(input_ids, past=None, **kwargs):
+ token_type_ids = kwargs.get("token_type_ids", None)
+ # only last token for inputs_ids if past is defined in kwargs
+ if past:
+ input_ids = input_ids[:, -1].unsqueeze(-1)
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
+
+ attention_mask = kwargs.get("attention_mask", None)
+ position_ids = kwargs.get("position_ids", None)
+
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past:
+ position_ids = position_ids[:, -1].unsqueeze(-1)
+
+ pixel_values = kwargs.get("pixel_values", None)
+ image_attention_mask = kwargs.get("image_attention_mask", None)
+ # if pixel_values is None or image_attention_mask is None:
+ # raise ValueError("pixel values and image attention mask cannot be None")
+
+ return {
+ "input_ids": input_ids,
+ "past_key_values": past,
+ "use_cache": kwargs.get("use_cache"),
+ "position_ids": position_ids,
+ "attention_mask": attention_mask,
+ "token_type_ids": token_type_ids,
+ "pixel_values": pixel_values,
+ "image_attention_mask": image_attention_mask,
+ }
+
+
+def freeze_model(model, module_exceptions=[]):
+ mapping = {
+ "LayerNorm": nn.LayerNorm,
+ "Linear": nn.Linear,
+ "Embedding": nn.Embedding,
+ }
+ module_exceptions_mapped = [mapping[m] for m in module_exceptions]
+ for module in model.modules():
+ if module_exceptions and any(
+ [isinstance(module, t) for t in module_exceptions_mapped]
+ ):
+ module.requires_grad_(
+ True
+ ) # Explicitely setting it to true to avoid any mistakes
+ else:
+ module.requires_grad_(False)
+ return model
+
+
+class IdeficsDecoupledPartialTPEmbedding(nn.Module):
+ def __init__(
+ self,
+ config,
+ weights,
+ ):
+ super().__init__()
+ self.num_embeddings = config.vocab_size
+ self.weight = TensorParallelEmbedding(
+ prefix="model.embed_tokens", weights=weights
+ )
+ self.additional_weight = nn.Parameter(
+ weights.get_tensor("model.embed_tokens.additional_embedding.weight")
+ )
+
+ def forward(self, input_ids):
+ # Clone so that we don't modify the original input_ids later on
+ input_ids = input_ids.clone()
+ additional_vocab_indices = torch.where(input_ids >= self.num_embeddings)
+ input_ids_additional_vocab = input_ids[additional_vocab_indices]
+ additional_embeddings = torch.nn.functional.embedding(
+ input_ids_additional_vocab - self.num_embeddings, self.additional_weight
+ )
+
+ # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
+ input_ids[additional_vocab_indices] = 0
+ full_vector = self.weight(input_ids)
+
+ # overwrite the records with high indices
+ full_vector[additional_vocab_indices] = additional_embeddings
+
+ return full_vector
+
+
+class IdeficsDecoupledTensorParallelLinear(nn.Module):
+ # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
+ """
+ Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
+ regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0,
+ then it will create `out_additional_features * in_features` additional parameters that are always trained. If
+ `out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
+ """
+
+ def __init__(
+ self,
+ config,
+ weights,
+ ) -> None:
+ super().__init__()
+ self.fc = SpeculativeHead.load(config=config, prefix="lm_head", weights=weights)
+ self.additional_fc = FastLinear.load(
+ config=config,
+ prefix="lm_head.additional_fc",
+ weights=weights,
+ bias=False,
+ )
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ output, speculative_logits = self.fc(input)
+ additional_features = self.additional_fc(input)
+ output = torch.cat((output, additional_features), -1)
+
+ return output, speculative_logits
+
+ def extra_repr(self) -> str:
+ """Overwriting `nn.Linear.extra_repr` to include new parameters."""
+ return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format(
+ self.in_features,
+ self.out_features,
+ self.out_additional_features,
+ self.bias is not None,
+ self.partially_freeze,
+ )
+
+
+# Copied from transformers.models.bart.modeling_bart._make_causal_mask
+def _make_causal_mask(
+ input_ids_shape: torch.Size,
+ dtype: torch.dtype,
+ device: torch.device,
+ past_key_values_length: int = 0,
+):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz, tgt_len = input_ids_shape
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
+ mask_cond = torch.arange(mask.size(-1), device=device)
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+ mask = mask.to(dtype)
+
+ if past_key_values_length > 0:
+ mask = torch.cat(
+ [
+ torch.zeros(
+ tgt_len, past_key_values_length, dtype=dtype, device=device
+ ),
+ mask,
+ ],
+ dim=-1,
+ )
+ return mask[None, None, :, :].expand(
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
+ )
+
+
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
+ )
+
+
+class IdeficsRMSNorm(nn.Module):
+ def __init__(self, prefix, weights, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+
+ weight = weights.get_tensor(f"{prefix}.weight")
+ self.weight = nn.Parameter(weight)
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states, residual=None):
+ from vllm_hpu_extension.kernels import rms_norm
+
+ orig_shape = hidden_states.shape
+ if residual is not None:
+ residual += hidden_states.view(residual.shape)
+ else:
+ residual = hidden_states
+ # Note: HPUFusedRMSNorm requires 3D tensors as inputs
+ if len(orig_shape) == 2:
+ residual = residual.unsqueeze(0)
+ x = rms_norm().apply(residual, self.weight, self.variance_epsilon)
+ return x.view(orig_shape), residual.view(orig_shape)
+
+
+# this was adapted from LlamaMLP
+class IdeficsMLP(nn.Module):
+ def __init__(
+ self,
+ config,
+ prefix,
+ weights,
+ ):
+ super().__init__()
+ self.gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=False,
+ )
+ self.down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_states):
+ gate_up_states = self.gate_up_proj(hidden_states)
+ shape = gate_up_states.shape
+ gate_up_states = gate_up_states.view(*shape[:-1], 2, shape[-1] // 2)
+ return self.down_proj(
+ self.act_fn(gate_up_states[:, :, 0]) * gate_up_states[:, :, 1]
+ )
+
+
+# this was adapted from LlamaAttention
+class IdeficsAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ config,
+ prefix,
+ weights,
+ qk_layer_norms: bool = False,
+ is_cross_attention: bool = False,
+ ):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.dropout = config.dropout
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ self.is_cross_attention = is_cross_attention
+
+ # if not hasattr(nn.functional, "scaled_dot_product_attention"):
+ # raise ValueError("this model requires pytorch 2.0 or higher")
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads //= weights.process_group.size()
+
+ if self.is_cross_attention:
+ # kv_input_dim = (
+ # self.hidden_size if not hasattr(config.vision_config, "embed_dim") else config.vision_config.embed_dim
+ # )
+ self.q_proj = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.q_proj", weights=weights, bias=False
+ )
+ self.k_proj = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.k_proj", weights=weights, bias=False
+ )
+ self.v_proj = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.v_proj", weights=weights, bias=False
+ )
+ else:
+ self.qkv = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=False,
+ )
+ self.o_proj = TensorParallelRowLinear.load(
+ config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
+ )
+ self.rotary_emb = PositionRotaryEmbedding.static(
+ config=config, dim=self.head_dim, base=10000.0, device=weights.device
+ )
+ self.qk_layer_norms = qk_layer_norms
+ if self.qk_layer_norms:
+ self.q_layer_norm = IdeficsRMSNorm(
+ prefix=f"{prefix}.q_layer_norm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ self.k_layer_norm = IdeficsRMSNorm(
+ prefix=f"{prefix}.q_layer_norm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return (
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ .contiguous()
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ is_cross_attention = self.is_cross_attention or key_value_states is not None
+
+ bsz, q_len, _ = hidden_states.size()
+
+ if is_cross_attention:
+ query_states = self.q_proj(hidden_states).view(
+ bsz, q_len, self.num_heads, self.head_dim
+ ) # .transpose(1, 2)
+ query_states = query_states.transpose(1, 2)
+ (
+ _,
+ kv_len,
+ _,
+ ) = (
+ key_value_states.size()
+ ) # Note that, in this case, `kv_len` == `kv_seq_len`
+ key_states = (
+ self.k_proj(key_value_states)
+ .view(bsz, kv_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ value_states = (
+ self.v_proj(key_value_states)
+ .view(bsz, kv_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ else:
+ qkv = self.qkv(hidden_states)
+ query_states, key_states, value_states = qkv.split(
+ self.num_heads * self.head_dim, dim=2
+ )
+
+ query_states = query_states.view(
+ bsz, q_len, self.num_heads, self.head_dim
+ ) # .transpose(1, 2)
+ key_states = key_states.view(
+ bsz, q_len, self.num_heads, self.head_dim
+ ) # . transpose(1, 2)
+ value_states = value_states.view(
+ bsz, q_len, self.num_heads, self.head_dim
+ ) # .transpose(1, 2)
+ kv_seq_len = q_len
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ max_s = max(kv_seq_len, q_len)
+ cos, sin = self.rotary_emb.get_cos_sin(
+ position_ids.view(-1), max_s, hidden_states.dtype
+ )
+
+ query_shape = query_states.shape
+ key_shape = key_states.shape
+ self.rotary_emb(
+ query_states.view(-1, *query_shape[2:]),
+ key_states.reshape(-1, *key_shape[2:]),
+ cos,
+ sin,
+ )
+
+ query_states = query_states.view(query_shape)
+ key_states = key_states.view(key_shape)
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ # [bsz, nh, t, hd]
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ if self.qk_layer_norms:
+ query_states = self.q_layer_norm(query_states)
+ key_states = self.k_layer_norm(key_states)
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+
+ attn_output = nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=attention_mask,
+ dropout_p=self.dropout,
+ )
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+
+ attn_output = self.o_proj(attn_output)
+
+ attn_weights = None
+ if output_attentions:
+ logger.warning_once(
+ "attn_weights are not extracted in scaled_dot_product_attention. The model returns None instead"
+ )
+
+ return attn_output, attn_weights, past_key_value
+
+
+# this was adapted from LlamaDecoderLayer
+class IdeficsDecoderLayer(nn.Module):
+ def __init__(self, layer_id: int, config: IdeficsConfig, weights):
+ super().__init__()
+ self.process_group = weights.process_group
+ self.hidden_size = config.hidden_size
+ prefix = f"model.layers.{layer_id}"
+ self.self_attn = IdeficsAttention(
+ config=config,
+ prefix=f"{prefix}.self_attn",
+ weights=weights,
+ qk_layer_norms=False,
+ is_cross_attention=False,
+ )
+ self.mlp = IdeficsMLP(
+ config=config,
+ prefix=f"{prefix}.mlp",
+ weights=weights,
+ )
+ self.input_layernorm = IdeficsRMSNorm(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = IdeficsRMSNorm(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ self.dropout = config.dropout
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ ) -> Tuple[
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
+ ]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ # hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ # hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class IdeficsGatedCrossAttentionLayer(nn.Module):
+ def __init__(self, layer_id, config: IdeficsConfig, weights):
+ super().__init__()
+ self.process_group = weights.process_group
+ self.hidden_size = config.hidden_size
+ prefix = f"model.gated_cross_attn_layers.{layer_id}"
+ self.cross_attn = IdeficsAttention(
+ config=config,
+ prefix=f"{prefix}.cross_attn",
+ weights=weights,
+ qk_layer_norms=True,
+ is_cross_attention=True,
+ )
+ self.mlp = IdeficsMLP(
+ config=config,
+ prefix=f"{prefix}.mlp",
+ weights=weights,
+ )
+ self.input_layernorm = IdeficsRMSNorm(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = IdeficsRMSNorm(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ self.config = config.dropout
+
+ self.act_cross_attn = nn.Tanh()
+ self.act_dense = nn.Tanh()
+
+ self.alpha_cross_attn = nn.Parameter(
+ weights.get_tensor(f"{prefix}.alpha_cross_attn")
+ )
+ self.alpha_dense = nn.Parameter(weights.get_tensor(f"{prefix}.alpha_dense"))
+
+ if not (hasattr(self, "alpha_cross_attn") and hasattr(self, "alpha_dense")):
+ raise ValueError("Alpha parameters not initialized correctly!")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_hidden_states: Optional[torch.Tensor] = None,
+ image_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ no_images: Optional[bool] = False,
+ ) -> Tuple[
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
+ ]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ no_images (`bool`, *optional*, defaults to `False`): If `True` the vision part is ignored
+ """
+ if image_hidden_states is None:
+ raise ValueError(
+ "`image_hidden_states` is required for Idefics cross attention module which are visual features to be"
+ " conditioned on."
+ )
+
+ if past_key_value is not None:
+ raise NotImplementedError(
+ "Past key value states are not implemented for Idefics cross attention module."
+ )
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.cross_attn(
+ hidden_states=hidden_states,
+ key_value_states=image_hidden_states,
+ attention_mask=image_attention_mask,
+ output_attentions=output_attentions,
+ )
+ # hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
+ # when there are no images the model is used in pure language mode
+ gate = 0 if no_images else 1
+ hidden_states = (
+ residual + gate * self.act_cross_attn(self.alpha_cross_attn) * hidden_states
+ )
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ # hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
+ hidden_states = residual + self.act_dense(self.alpha_dense) * hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+LLAMA_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`IdeficsConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+# @add_start_docstrings(
+# "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+# LLAMA_START_DOCSTRING,
+# )
+class IdeficsPreTrainedModel(PreTrainedModel):
+ config_class = IdeficsConfig
+ # base_model_prefix = "model"
+ # supports_gradient_checkpointing = True
+ # _no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"]
+
+ # def _init_weights(self, module):
+ # # important: this ported version of Idefics isn't meant for training from scratch - only
+ # # inference and fine-tuning - so the proper init weights code has been removed - the m4 code
+ # # base should be used for training from scratch and it contains the correct code.
+ # std = self.config.initializer_range
+ # if isinstance(module, nn.Linear):
+ # module.weight.data.normal_(mean=0.0, std=std)
+ # if module.bias is not None:
+ # module.bias.data.zero_()
+ # elif isinstance(module, nn.Embedding):
+ # module.weight.data.normal_(mean=0.0, std=std)
+ # if module.padding_idx is not None:
+ # module.weight.data[module.padding_idx].zero_()
+
+ # def _set_gradient_checkpointing(self, module, value=False):
+ # if isinstance(module, IdeficsModel):
+ # module.gradient_checkpointing = value
+
+
+# LLAMA_INPUTS_DOCSTRING = r"""
+# Args:
+# input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+# Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+# it.
+
+# Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+# [`PreTrainedTokenizer.__call__`] for details.
+
+# [What are input IDs?](../glossary#input-ids)
+# attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+# Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+# - 1 for tokens that are **not masked**,
+# - 0 for tokens that are **masked**.
+
+# [What are attention masks?](../glossary#attention-mask)
+
+# Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+# [`PreTrainedTokenizer.__call__`] for details.
+
+# If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+# `past_key_values`).
+
+# If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+# and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+# information on the default strategy.
+
+# - 1 indicates the head is **not masked**,
+# - 0 indicates the head is **masked**.
+# position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+# Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+# config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
+# past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+# Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+# `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+# `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+# Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+# blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+# If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+# don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+# `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+# inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+# Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+# is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+# model's internal embedding lookup matrix.
+# use_cache (`bool`, *optional*):
+# If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+# `past_key_values`).
+# output_attentions (`bool`, *optional*):
+# Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+# tensors for more detail.
+# output_hidden_states (`bool`, *optional*):
+# Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+# more detail.
+# return_dict (`bool`, *optional*):
+# Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+# """
+
+
+# @add_start_docstrings(
+# "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+# LLAMA_START_DOCSTRING,
+# )
+class IdeficsModel(IdeficsPreTrainedModel):
+ # """
+ # Transformer decoder consisting of `config.num_hidden_layers` layers. Each layer is a [`IdeficsDecoderLayer`]
+
+ # Args:
+ # config: IdeficsConfig
+ # """
+
+ def __init__(self, config: IdeficsConfig, weights):
+ super().__init__(config)
+ self.config = config
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = IdeficsDecoupledPartialTPEmbedding(
+ config=config,
+ weights=weights,
+ )
+
+ self.image_size = config.vision_config.image_size
+ self.vision_config = config.vision_config
+ self.vision_model = IdeficsVisionTransformer(
+ prefix="model.vision_model",
+ config=config.vision_config,
+ weights=weights,
+ )
+
+ # Perceiver Resampler
+ if config.use_resampler:
+ perceiver_config = config.perceiver_config
+ self.perceiver_resampler = IdeficsPerceiverResampler(
+ prefix="model.perceiver_resampler",
+ config=config,
+ embed_dim=config.vision_config.embed_dim,
+ depth=perceiver_config.resampler_depth,
+ n_heads=perceiver_config.resampler_n_heads,
+ head_dim=perceiver_config.resampler_head_dim,
+ n_latents=perceiver_config.resampler_n_latents,
+ weights=weights,
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ IdeficsDecoderLayer(layer_id, config, weights)
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+
+ self.cross_layer_interval = config.cross_layer_interval
+ num_cross_layers = config.num_hidden_layers // self.cross_layer_interval
+ self.gated_cross_attn_layers = nn.ModuleList(
+ [
+ IdeficsGatedCrossAttentionLayer(layer_id, config, weights)
+ for layer_id in range(num_cross_layers)
+ ]
+ )
+ # self.gradient_checkpointing = False
+
+ self.norm = IdeficsRMSNorm(
+ prefix="model.norm", weights=weights, eps=config.rms_norm_eps
+ )
+
+ # self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ # self.post_init()
+
+ # self.freeze_relevant_params(config)
+
+ # def freeze_relevant_params(self, config=None):
+ # if config is None:
+ # config = self.config
+
+ # if config.freeze_text_layers:
+ # self.freeze_text_layers(config.freeze_text_module_exceptions)
+
+ # if config.freeze_vision_layers:
+ # freeze_model(self.vision_model, module_exceptions=config.freeze_vision_module_exceptions)
+
+ # def freeze_text_layers(self, module_exceptions=[]):
+ # for module in [self.layers, self.norm]:
+ # freeze_model(module, module_exceptions=module_exceptions)
+
+ # def freeze_vision_layers(self, module_exceptions=[]):
+ # freeze_model(self.vision_model, module_exceptions=module_exceptions)
+
+ # def get_input_embeddings(self):
+ # return self.embed_tokens
+
+ # def set_input_embeddings(self, value):
+ # self.embed_tokens = value
+
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
+ def _prepare_decoder_attention_mask(
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
+ ):
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ combined_attention_mask = None
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(
+ input_shape,
+ inputs_embeds.dtype,
+ device=inputs_embeds.device,
+ past_key_values_length=past_key_values_length,
+ )
+
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ ).to(inputs_embeds.device)
+ combined_attention_mask = (
+ expanded_attn_mask
+ if combined_attention_mask is None
+ else expanded_attn_mask + combined_attention_mask
+ )
+
+ return combined_attention_mask
+
+ # @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ image_hidden_states: Optional[torch.FloatTensor] = None,
+ image_embeddings: Optional[torch.FloatTensor] = None,
+ image_attention_mask: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPastImage]:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError(
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
+ )
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ elif position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length,
+ seq_length + past_key_values_length,
+ dtype=torch.long,
+ device=device,
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ no_images = False
+
+ if image_hidden_states is None:
+ if pixel_values is None and image_embeddings is None:
+ raise ValueError(
+ "Either pixel_values and image_embeddings have to be not-None."
+ )
+
+ elif pixel_values is not None and image_embeddings is not None:
+ raise ValueError(
+ "You cannot specify both pixel_values and image_embeddings at the same time"
+ )
+
+ elif pixel_values is not None:
+ no_images = len(torch.nonzero(pixel_values)) == 0
+ pixel_values = pixel_values.to(
+ dtype=self.dtype, device=device
+ ) # fp16 compatibility
+ batch_size, num_images = pixel_values.shape[:2]
+ pixel_values = pixel_values.contiguous().view(
+ batch_size * num_images, *pixel_values.shape[2:]
+ )
+
+ # Get sequence from the vision encoder
+ image_hidden_states = self.vision_model(
+ pixel_values=pixel_values
+ ).last_hidden_state
+
+ elif image_embeddings is not None:
+ (
+ batch_size,
+ num_images,
+ image_seq_len,
+ image_hidden_size,
+ ) = image_embeddings.size()
+ image_hidden_states = image_embeddings.to(
+ dtype=self.dtype, device=input_ids.device
+ )
+ image_hidden_states = image_hidden_states.view(
+ batch_size * num_images, image_seq_len, image_hidden_size
+ )
+
+ if self.config.use_resampler:
+ image_hidden_states = self.perceiver_resampler(image_hidden_states)
+ image_seq_len, image_hidden_size = image_hidden_states.size(
+ 1
+ ), image_hidden_states.size(2)
+ image_hidden_states = image_hidden_states.view(
+ batch_size, num_images * image_seq_len, image_hidden_size
+ )
+ else:
+ no_images = False
+ num_images = pixel_values.shape[1]
+ image_seq_len = image_hidden_states.shape[1] // num_images
+
+ # # Hack to use the model in full language modeling mode
+ # image_attention_mask = torch.zeros(batch_size, seq_length, 1, dtype=torch.long, device=image_hidden_states.device)
+ # Make image_attention_mask compatible with hidden states
+ text_seq_len = image_attention_mask.size(1)
+ image_attention_mask = image_attention_mask.unsqueeze(-1)
+ image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len)
+ image_attention_mask = image_attention_mask.view(
+ batch_size, text_seq_len, num_images * image_seq_len
+ )
+ image_batch_size, image_sequence_length, _ = image_hidden_states.size()
+ image_hidden_shape = (image_batch_size, image_sequence_length)
+ if image_attention_mask is None:
+ image_attention_mask = torch.ones(image_hidden_shape, device=device)
+ image_attention_mask = self.invert_attention_mask(image_attention_mask)
+
+ # if list(image_attention_mask.shape) != [4, 1, 1024, 64]:
+ # raise ValueError(f"Image hidden_states {image_hidden_states.shape} - mask {image_attention_mask.shape} {num_images} {image_seq_len} {text_seq_len}")
+
+ # if image_hidden_states is not None:
+ # else:
+ # image_attention_mask = None
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past),
+ dtype=torch.bool,
+ device=inputs_embeds.device,
+ )
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ )
+
+ hidden_states = inputs_embeds
+
+ # if self.gradient_checkpointing and self.training:
+ # if use_cache:
+ # logger.warning_once(
+ # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ # )
+ # use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = (
+ past_key_values[idx] if past_key_values is not None else None
+ )
+
+ def vblock(
+ main_block,
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_value,
+ image_hidden_states,
+ image_attention_mask,
+ output_attentions,
+ use_cache,
+ no_images,
+ layer_idx,
+ cross_layer_interval,
+ gated_cross_attn_layers,
+ ):
+ # TODO(ls): Add cross attention values to respective lists
+ if layer_idx % cross_layer_interval == 0:
+ xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval]
+ outputs = xblock(
+ hidden_states,
+ attention_mask=attention_mask,
+ image_hidden_states=image_hidden_states,
+ image_attention_mask=image_attention_mask,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ past_key_value=None, # not implemented
+ no_images=no_images,
+ )
+ hidden_states = outputs[0]
+
+ layer_outputs = main_block(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ return layer_outputs
+
+ # if self.gradient_checkpointing and self.training:
+ # past_key_value = None
+ # if use_cache:
+ # logger.warning_once(
+ # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ # )
+ # use_cache = False
+
+ # layer_outputs = torch.utils.checkpoint.checkpoint(
+ # vblock,
+ # decoder_layer,
+ # hidden_states,
+ # attention_mask,
+ # position_ids,
+ # past_key_value,
+ # image_hidden_states,
+ # image_attention_mask,
+ # output_attentions,
+ # use_cache,
+ # no_images,
+ # idx,
+ # self.cross_layer_interval,
+ # self.gated_cross_attn_layers,
+ # )
+ # else:
+ layer_outputs = vblock(
+ decoder_layer,
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ image_hidden_states=image_hidden_states,
+ image_attention_mask=image_attention_mask,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ no_images=no_images,
+ layer_idx=idx,
+ cross_layer_interval=self.cross_layer_interval,
+ gated_cross_attn_layers=self.gated_cross_attn_layers,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
+ if v is not None
+ )
+ return BaseModelOutputWithPastImage(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ image_hidden_states=image_hidden_states,
+ )
+
+
+class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
+ def __init__(
+ self,
+ config,
+ weights,
+ ):
+ super().__init__(config)
+ self.model = IdeficsModel(
+ config=config,
+ weights=weights,
+ )
+
+ self.lm_head = IdeficsDecoupledTensorParallelLinear(
+ config=config,
+ weights=weights,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ image_embeddings: Optional[torch.FloatTensor] = None,
+ image_hidden_states: Optional[torch.FloatTensor] = None,
+ image_attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPastImage]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
+
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
+ ```"""
+
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ pixel_values=pixel_values,
+ image_embeddings=image_embeddings,
+ image_hidden_states=image_hidden_states,
+ image_attention_mask=image_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ logits, speculative_logits = self.lm_head(hidden_states)
+
+ loss = None
+
+ return (
+ CausalLMOutputWithPastImage(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ ),
+ speculative_logits,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
+ inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs)
+ unwanted_kwargs = ["token_type_ids"]
+ for kwarg in unwanted_kwargs:
+ inputs.pop(kwarg, None)
+ return inputs
+
+ @staticmethod
+ def _expand_inputs_for_generation(
+ *args,
+ **model_kwargs,
+ ):
+ return expand_inputs_for_generation(*args, **model_kwargs)
+
+ @staticmethod
+ def _update_model_kwargs_for_generation(
+ outputs, model_kwargs, is_encoder_decoder=False
+ ):
+ return update_model_kwargs_for_generation(
+ outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder
+ )
+
+ @staticmethod
+ def _reorder_cache(past, beam_idx):
+ reordered_past = ()
+ for layer_past in past:
+ reordered_past += (
+ tuple(
+ past_state.index_select(0, beam_idx) for past_state in layer_past
+ ),
+ )
+ return reordered_past
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_perceiver.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_perceiver.py
new file mode 100644
index 000000000..6da8045bc
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_perceiver.py
@@ -0,0 +1,276 @@
+# This code was adapted from https://github.com/lucidrains/flamingo-pytorch licensed under the MIT License.
+#
+# MIT License
+#
+# Copyright (c) 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and github/lonePatient
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+
+"""
+
+Generic interface to various configurations of the Perceiver Resampler, that simply takes in a series of (potentially
+time-indexed) contextual embeddings, and "resamples" (compresses) them down to a pre-specified number of latents! Note
+that the Perceiver in general resamples based solely off the *long-range* context; there's a nice opportunity here to
+prime the Perceiver Resampler with say a single layer's worth of language embeddings (the target domain), and use that
+to softly "retrieve & compress" what we need --> this would be a novel contribution we should explore.
+
+References:
+ - DeepMind's Flamingo: https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model
+ - Code borrowed w/ love from: https://github.com/lucidrains/flamingo-pytorch
+
+"""
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+
+from text_generation_server.layers import (
+ TensorParallelColumnLinear,
+ TensorParallelRowLinear,
+)
+
+EPS = 1e-5
+
+
+class IdeficsPerceiverResampler(nn.Module):
+ def __init__(
+ self,
+ prefix,
+ config,
+ embed_dim: int,
+ depth: int,
+ n_heads: int,
+ head_dim: int,
+ n_latents: int,
+ weights,
+ ) -> None:
+ """
+ Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or
+ MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then
+ returns a Tensor of shape [bsz, n_latents, embed_dim]. :param embed_dim: Dimensionality of embeddings being fed
+ to the Perceiver Resampler (also dimensionality of latent embeddings *returned* by the Perceiver Resampler.
+ Could be e.g., VIT embed_dim, ResNet pool dim, and so on.
+
+ Args:
+ config (`IdeficsConfig`): config object
+ embed_dim (`int`): The size of each embedding vector
+ depth (`int`): Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3).
+ n_heads (`int`): Number of heads in each Transformer block (for multi-headed self-attention).
+ head_dim (`int`): Dimensionality of each head projection in the Transformer block.
+ n_latents (`int`):
+ Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
+
+ """
+ super().__init__()
+ self.embed_dim, self.n_heads, self.head_dim, self.n_latents = (
+ embed_dim,
+ n_heads,
+ head_dim,
+ n_latents,
+ )
+ self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver
+
+ # Create Latents for Perceiver
+ self.latents = nn.Parameter(weights.get_tensor(f"{prefix}.latents"))
+
+ self.intermediate_dim = (
+ self.embed_dim * 4
+ if not hasattr(config.vision_config, "embed_dim")
+ else config.vision_config.embed_dim * 4
+ )
+ # Create Transformer Blocks
+ self.blocks = nn.ModuleList(
+ [
+ nn.ModuleList(
+ [
+ IdeficsPerceiverAttention(
+ prefix=f"{prefix}.blocks.{layer_id}.0",
+ config=config,
+ embed_dim=self.embed_dim,
+ n_heads=self.n_heads,
+ head_dim=self.head_dim,
+ qk_layer_norms=self.qk_layer_norms,
+ weights=weights,
+ ),
+ IdeficsMLP(
+ prefix=f"{prefix}.blocks.{layer_id}.1",
+ intermediate_size=self.intermediate_dim,
+ config=config,
+ weights=weights,
+ ),
+ ]
+ )
+ for layer_id in range(depth)
+ ]
+ )
+ self.layer_norm = nn.LayerNorm.load(
+ prefix=f"{prefix}.layer_norm", weights=weights, eps=EPS
+ )
+
+ def forward(self, context: torch.Tensor) -> torch.Tensor:
+ """Resample arbitrary length context & *compress* down to self.n_latents latent embeddings"""
+ # einsum.repeat(self.latents, "seq embed -> bsz seq embed", bsz=context.shape[0])
+ latents = self.latents.repeat(context.shape[0], 1, 1)
+
+ # Feed through Perceiver Attention blocks...
+ for attn, ff in self.blocks:
+ latents = attn(context, latents) + latents
+ latents = ff(latents) + latents
+
+ return self.layer_norm(latents)
+
+
+class IdeficsPerceiverAttention(nn.Module):
+ def __init__(
+ self,
+ prefix,
+ config,
+ embed_dim: int,
+ n_heads: int,
+ head_dim: int,
+ qk_layer_norms: bool,
+ weights,
+ ) -> None:
+ """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
+ super().__init__()
+ self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim
+ self.qk_layer_norms = qk_layer_norms
+ # Normalization & Scaling
+ self.context_layer_norm = nn.LayerNorm.load(
+ prefix=f"{prefix}.context_layer_norm", weights=weights, eps=EPS
+ )
+ self.latents_layer_norm = nn.LayerNorm.load(
+ prefix=f"{prefix}.latents_layer_norm", weights=weights, eps=EPS
+ )
+ if self.qk_layer_norms:
+ self.q_layer_norm = nn.LayerNorm.load(
+ prefix=f"{prefix}.q_layer_norm", weights=weights, eps=EPS
+ )
+ self.k_layer_norm = nn.LayerNorm.load(
+ prefix=f"{prefix}.k_layer_norm", weights=weights, eps=EPS
+ )
+
+ self.qk_scale = self.head_dim**-0.5
+
+ if n_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {n_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.n_heads //= weights.process_group.size()
+
+ # Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers).
+ self.q_proj = TensorParallelColumnLinear.load(
+ config=config, prefix=f"{prefix}.q_proj", weights=weights, bias=False
+ )
+ self.k_proj = TensorParallelColumnLinear.load(
+ config=config, prefix=f"{prefix}.k_proj", weights=weights, bias=False
+ )
+ self.v_proj = TensorParallelColumnLinear.load(
+ config=config, prefix=f"{prefix}.v_proj", weights=weights, bias=False
+ )
+
+ self.output_proj = TensorParallelRowLinear.load(
+ config=config, prefix=f"{prefix}.output_proj", weights=weights, bias=False
+ )
+
+ def forward(self, context: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
+ """
+ Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!
+
+ Args:
+ context (`torch.Tensor`):
+ Tensor of shape `[bsz, seq, embed_dim]` representing long-form context to resample.
+ latents (`torch.Tensor`):
+ Tensor of shape `[bsz, n_latents, embed_dim]` representing fixed length latents to compress to.
+
+ Returns:
+ `torch.Tensor`: Tensor of shape `[bsz, n_latents, embed_dim]` representing attention over latents w/ cross
+ from context.
+ """
+ context = self.context_layer_norm(context)
+ latents = self.latents_layer_norm(latents)
+ batch_size, seq_length, embed_dim = context.shape[:3]
+
+ # Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn!
+ # Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents`
+ q = self.q_proj(latents)
+ k = self.k_proj(torch.cat([context, latents], dim=-2))
+ v = self.v_proj(torch.cat([context, latents], dim=-2))
+
+ # Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call)
+ # =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)]
+ # einsum.rearrange(x, "bsz seq (heads embed) -> bsz heads seq embed", heads=self.n_heads)
+ q, k, v = [
+ x.reshape(batch_size, x.shape[1], self.n_heads, self.head_dim).transpose(
+ 1, 2
+ )
+ for x in (q, k, v)
+ ]
+
+ if self.qk_layer_norms:
+ q = self.q_layer_norm(q)
+ k = self.k_layer_norm(k)
+
+ scores = torch.einsum("... i d, ... j d -> ... i j", q * self.qk_scale, k)
+ stabilized_scores = scores - (scores.amax(dim=-1, keepdim=True).detach())
+ attn = stabilized_scores.softmax(dim=-1)
+
+ # Attend & project back to output...
+ resampled = torch.einsum("... i j, ... j d -> ... i d", attn, v)
+ # einsum.rearrange(resampled, "bsz heads seq embed -> bsz seq (heads embed)", heads=self.n_heads)
+ return self.output_proj(resampled.transpose(1, 2).flatten(-2))
+
+
+class IdeficsMLP(nn.Module):
+ def __init__(
+ self,
+ prefix,
+ intermediate_size,
+ config,
+ weights,
+ ):
+ """Simple MLP block with intermediate_size and embedding size"""
+ super().__init__()
+ self.embed_dim = config.vision_config.embed_dim
+ self.ln = nn.LayerNorm.load(prefix=f"{prefix}.ln", weights=weights, eps=EPS)
+ self.fc = TensorParallelColumnLinear.load(
+ config=config,
+ prefix=f"{prefix}.fc",
+ weights=weights,
+ bias=False,
+ )
+ self.act = nn.ReLU()
+ self.c_proj = TensorParallelRowLinear.load(
+ config=config,
+ prefix=f"{prefix}.c_proj",
+ weights=weights,
+ bias=False,
+ )
+
+ def forward(
+ self, hidden_states: Optional[Tuple[torch.FloatTensor]]
+ ) -> torch.FloatTensor:
+ hidden_states = self.ln(hidden_states)
+ hidden_states = self.fc(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.c_proj(hidden_states)
+
+ return hidden_states
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_processing.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_processing.py
new file mode 100644
index 000000000..ca61e27d4
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_processing.py
@@ -0,0 +1,443 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for IDEFICS.
+"""
+
+from typing import Callable, List, Optional, Union
+from urllib.parse import urlparse
+
+from transformers.feature_extraction_utils import BatchFeature
+from transformers.processing_utils import ProcessorMixin
+from transformers.tokenization_utils_base import (
+ BatchEncoding,
+ PaddingStrategy,
+ TextInput,
+ TruncationStrategy,
+)
+from transformers.utils import TensorType, is_torch_available
+
+
+if is_torch_available():
+ import torch
+
+
+IMAGE_TOKEN = ""
+
+
+# copied from m4.training.packing
+def incremental_to_binary_attention_mask(incremental_mask, num_classes=-1):
+ # This function converts: [-1, 0, 1] => [[0, 0], [1, 0], [0, 1]]
+
+ # If any of images index are more than num_classes, set them to -1.
+ # Words after the max number of images allowed have been seen don't attend on anything
+ if num_classes != -1:
+ incremental_mask[incremental_mask >= num_classes] = -1
+
+ negatives = incremental_mask == -1
+ incremental_mask[negatives] = 0
+ attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes)
+ attn_mask[negatives, :] = 0
+ return attn_mask
+
+
+# copied from m4.training.packing
+def image_attention_mask_for_packed_input_ids(input_ids, tokenizer):
+ image_attention_mask = torch.full_like(input_ids, fill_value=-1)
+ next_image_attention_mask = torch.full_like(input_ids, fill_value=-1)
+ image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
+ eod_token_id = tokenizer.eos_token_id
+ for batch_idx in range(input_ids.size(0)):
+ count = -1
+ seen_eod = False
+ for idx, token_id in enumerate(input_ids[batch_idx]):
+ if token_id == image_token_id:
+ count += 1
+ image_attention_mask[batch_idx][idx] = count
+ seen_eod = False
+ else:
+ image_attention_mask[batch_idx][idx] = count
+
+ if seen_eod:
+ image_attention_mask[batch_idx][idx] = -1
+
+ if token_id == eod_token_id:
+ seen_eod = True
+
+ for batch_idx in range(input_ids.size(0)):
+ count = -1
+ seen_eod = False
+ for idx in range(input_ids[batch_idx].size(0) - 1, -1, -1):
+ token_id = input_ids[batch_idx][idx]
+ if token_id == image_token_id:
+ count += 1
+ next_image_attention_mask[batch_idx][idx] = count
+ seen_eod = False
+ else:
+ next_image_attention_mask[batch_idx][idx] = count
+
+ if token_id == eod_token_id:
+ seen_eod = True
+
+ if seen_eod:
+ next_image_attention_mask[batch_idx][idx] = -1
+
+ non_negative_indices = next_image_attention_mask[batch_idx] != -1
+ next_image_attention_mask[batch_idx][non_negative_indices] -= count
+ next_image_attention_mask[batch_idx][non_negative_indices] *= -1
+
+ return image_attention_mask, next_image_attention_mask
+
+
+def is_url(string):
+ """Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately
+ invalidated the url"""
+ if " " in string:
+ return False
+ result = urlparse(string)
+ return all([result.scheme, result.netloc])
+
+
+def is_image(string):
+ """Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately
+ invalidated the url"""
+ return is_url(string) or string.startswith("data:")
+
+
+class IdeficsProcessor(ProcessorMixin):
+ r"""
+ Constructs a IDEFICS processor which wraps a LLama tokenizer and IDEFICS image processor into a single processor.
+
+ [`IdeficsProcessor`] offers all the functionalities of [`IdeficsImageProcessor`] and [`LlamaTokenizerFast`]. See
+ the docstring of [`~IdeficsProcessor.__call__`] and [`~IdeficsProcessor.decode`] for more information.
+
+ Args:
+ image_processor (`IdeficsImageProcessor`):
+ An instance of [`IdeficsImageProcessor`]. The image processor is a required input.
+ tokenizer (`LlamaTokenizerFast`):
+ An instance of [`LlamaTokenizerFast`]. The tokenizer is a required input.
+ image_size (`int`, *optional*, defaults to 224): Image size (assuming a square image)
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ image_processor_class = "IdeficsImageProcessor"
+ tokenizer_class = "LlamaTokenizerFast"
+
+ def __init__(
+ self,
+ image_processor,
+ tokenizer=None,
+ image_size=224,
+ add_end_of_utterance_token=None,
+ **kwargs,
+ ):
+ if image_processor is None:
+ raise ValueError("You need to specify an `image_processor`.")
+ if tokenizer is None:
+ raise ValueError("You need to specify a `tokenizer`.")
+
+ super().__init__(image_processor, tokenizer)
+ self.current_processor = self.image_processor
+ self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
+
+ self.default_image_dims = (
+ self.image_processor.image_num_channels,
+ self.image_processor.image_size,
+ self.image_processor.image_size,
+ )
+
+ self.tokenizer_was_trained_with_end_of_utterance_token = (
+ True
+ if ""
+ in self.tokenizer.special_tokens_map.get("additional_special_tokens", [])
+ else False
+ )
+
+ def __call__(
+ self,
+ prompts: Union[List[TextInput], List[List[TextInput]]],
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = None,
+ max_length: Optional[int] = None,
+ transform: Callable = None,
+ add_eos_token=False,
+ add_end_of_utterance_token=None,
+ debug=False,
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
+ ) -> BatchEncoding:
+ """This method takes batched or non-batched prompts made of text and images and converts them into prompts that
+ the model was trained on and prepares the image pixel values for the model to process.
+
+ Args:
+ prompts (`Union[List[TextInput], [List[List[TextInput]]]]`):
+ either a single prompt or a batched list of prompts - see the detailed description immediately after
+ the end of the arguments doc section.
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
+ index) among:
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ max_length (`int`, *optional*):
+ Maximum length of the returned list and optionally padding length (see above).
+ truncation (`bool`, *optional*):
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
+ transform (`Callable`, *optional*):
+ A custom transform function that accepts a single image can be passed for training. For example,
+ `torchvision.Compose` can be used to compose multiple functions. If `None` a preset inference-specific
+ set of transforms will be applied to the images
+ add_eos_token (`bool`, *optional*, defaults to `False`):
+ Adds `eos_token` at the end of the final prompt if True`
+ add_end_of_utterance_token (`bool`, *optional*)
+ Whether to automatically add `` after each prompt's text input (unless followed by an
+ image). If `None` the tokenizer will be checked instead and if this token is found in
+ `additional_special_tokens` then the value will be `True`.
+ debug (`bool`, *optional*, defaults to `False`):
+ `True` value will help debug prompt generation by dumping useful information
+ return_tensors (`str` or `TensorType`, *optional*, defaults to `TensorType.PYTORCH`):
+ The type of tensors to return. Can be one of:
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+
+ Returns:
+ a dict with entries: `input_ids`, `attention_mask`, `pixel_values`, `image_attention_mask` which can be
+ directly passed to `model.generate`
+
+ Detailed explanation:
+
+ Each entry in `prompts` is either a text to be passed as is or an image that will be processed.
+
+ An image can be either an image object (`PIL.Image`) or a url from which the image can be retrieved.
+
+ When the processor encounters an image it'll inject ``
+ entry into the prompt.
+
+ Example:
+
+ ```python
+ checkpoint = "HuggingFaceM4/idefics-9b"
+ processor = AutoProcessor.from_pretrained(checkpoint)
+ url = "https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-in-grass-1593184777.jpg"
+ img = processor.image_processor.fetch_images([url])[0]
+
+ prompts = [
+ "User:",
+ img,
+ "Describe this image.\nAssistant: An image of two kittens in grass.\n",
+ "User:",
+ "https://hips.hearstapps.com/hmg-prod/images/dog-puns-1581708208.jpg",
+ "Describe this image.\nAssistant:",
+ ]
+
+ inputs = processor(prompts, return_tensors="pt")
+ generated_ids = model.generate(**inputs, max_length=100)
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ ```
+
+ In this example the `prompts` will be converted into:
+
+ ```
+ User:Describe this image.
+ Assistant: An image of two kittens in grass.
+ User:Describe this image.
+ Assistant:'
+ ```
+
+ and the two images will be massaged using [`IdeficsImageProcessor.__call__`] method and placed inside the
+ `pixel_values` dict entry of the return value.
+
+ This example also examplifies that images can be passed as objects or as text urls. It can be seen that the
+ first image is passed as object and the second one as a url.
+
+ To do training do:
+
+ ```python
+ image_transform = transforms.Compose(
+ [
+ transforms.RandomResizedCrop(
+ (w, h), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC
+ ),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=self.image_mean, std=self.image_std),
+ ]
+ )
+ inputs = processor(prompts, transform=image_transform, return_tensors="pt")
+ ```
+
+ In order to help debug prompt generation enable `debug=True` which will show you what's happening.
+
+ """
+
+ # if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it
+ if add_end_of_utterance_token is None:
+ add_end_of_utterance_token = (
+ self.tokenizer_was_trained_with_end_of_utterance_token
+ )
+
+ # turn non-batched prompts into batched
+ if not any(isinstance(i, list) for i in prompts):
+ prompts = [prompts]
+
+ fake_token = ""
+ image_token = ""
+ end_of_utterance_token = ""
+
+ def image_tokens(last_was_image):
+ if last_was_image:
+ return image_token + fake_token
+ else:
+ return fake_token + image_token + fake_token
+
+ all_texts = []
+ all_images = []
+ for sample in prompts:
+ # the model was trained on samples starting with
+ full_text = f"{self.tokenizer.bos_token}"
+
+ # an image can either be an image object in the item or the url, everything else is a verbatim prompt text
+ image_objects = []
+ last_was_image = False
+ last_was_text = False
+ for i, item in enumerate(sample):
+ if i > 0:
+ last_was_text = True if not last_was_image else False
+
+ if isinstance(item, str):
+ item = item.strip(" ")
+ if is_image(item):
+ image = self.image_processor.fetch_images(item)
+ full_text += image_tokens(last_was_image)
+ image_objects.append(image)
+ last_was_image = True
+ else:
+ # we add end_of_utterance_token between each subsequent text prompts (but not at the last one!)
+ if add_end_of_utterance_token and last_was_text:
+ full_text += end_of_utterance_token
+ full_text += item
+ last_was_image = False
+ else:
+ # must be an image obj
+ full_text += image_tokens(last_was_image)
+ image_objects.append(item)
+ last_was_image = True
+
+ if add_eos_token:
+ full_text += self.tokenizer.eos_token
+
+ if debug is True:
+ print(f"{full_text=}")
+
+ image_objects = self.image_processor(image_objects, transform=transform)
+
+ text_encoding = self.tokenizer(
+ text=full_text,
+ add_special_tokens=False,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ )
+
+ all_texts.append(text_encoding["input_ids"])
+ all_images.append(image_objects)
+
+ max_seq_len = max(len(x) for x in all_texts)
+
+ # max_num_images has to be at least 1 even when there are no images
+ max_num_images = max(len(x) for x in all_images)
+ max_num_images = max(1, max_num_images)
+
+ at_least_one_image = sum(len(x) for x in all_images) > 0
+ output_input_ids = []
+ output_images = []
+ output_attention_masks = []
+ for text, images in zip(all_texts, all_images):
+ padded_input_ids = [self.tokenizer.pad_token_id] * max_seq_len
+ unpadded_seq_len = len(text)
+ start = max_seq_len - unpadded_seq_len
+ padded_input_ids[start:] = text[:max_seq_len]
+
+ attention_mask = torch.zeros((max_seq_len,), dtype=torch.long)
+ attention_mask[start:] = 1
+
+ image_count = padded_input_ids.count(self.image_token_id)
+ local_max_num_images = min(image_count, max_num_images)
+
+ current_images = images[:local_max_num_images]
+
+ if len(current_images) > 0:
+ padded_image_tensor = torch.zeros(
+ max_num_images, *current_images.size()[1:]
+ )
+ padded_image_tensor[: current_images.size(0)] = current_images
+ else:
+ padded_image_tensor = torch.zeros(
+ max_num_images, *self.default_image_dims
+ )
+
+ output_images.append(padded_image_tensor)
+ output_input_ids.append(torch.tensor(padded_input_ids))
+
+ output_attention_masks.append(attention_mask)
+
+ output_input_ids = torch.stack(output_input_ids)
+ output_images = torch.stack(output_images)
+ output_attention_masks = torch.stack(output_attention_masks)
+
+ if at_least_one_image:
+ image_attention_mask, _ = image_attention_mask_for_packed_input_ids(
+ output_input_ids, self.tokenizer
+ )
+ image_attention_mask = incremental_to_binary_attention_mask(
+ image_attention_mask, num_classes=max_num_images
+ )
+ else:
+ # in full language mode we set the image mask to all-0s
+ image_attention_mask = torch.zeros(
+ output_input_ids.shape[0],
+ output_input_ids.shape[1],
+ 1,
+ dtype=torch.bool,
+ )
+
+ return BatchFeature(
+ data={
+ "input_ids": output_input_ids,
+ "attention_mask": output_attention_masks,
+ "pixel_values": output_images,
+ "image_attention_mask": image_attention_mask,
+ }
+ )
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ @property
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_vision.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_vision.py
new file mode 100644
index 000000000..dd8f76bc4
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics_vision.py
@@ -0,0 +1,529 @@
+# coding=utf-8
+# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch IdeficsVision model: a copy of CLIPVisionModel using a simpler config object"""
+
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
+from transformers.utils import (
+ ModelOutput,
+ logging,
+)
+from text_generation_server.layers import (
+ TensorParallelColumnLinear,
+ TensorParallelRowLinear,
+ TensorParallelEmbedding,
+)
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class IdeficsVisionModelOutput(ModelOutput):
+ """
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
+
+ Args:
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
+ The image embeddings obtained by applying the projection layer to the pooler_output.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ image_embeds: Optional[torch.FloatTensor] = None
+ last_hidden_state: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Idefics
+class IdeficsVisionEmbeddings(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.class_embedding = nn.Parameter(
+ weights.get_tensor(f"{prefix}.class_embedding")
+ )
+
+ self.patch_embedding = nn.Conv2d.load_no_bias(
+ prefix=f"{prefix}.patch_embedding",
+ weights=weights,
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ )
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches + 1
+ self.position_embedding = TensorParallelEmbedding(
+ prefix="model.vision_model.embeddings.position_embedding", weights=weights
+ )
+ self.position_ids = (
+ torch.arange(self.num_positions).expand((1, -1)).to(device=weights.device)
+ )
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ batch_size = pixel_values.shape[0]
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(
+ pixel_values.to(dtype=target_dtype)
+ ) # shape = [*, width, grid, grid]
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+ embeddings = embeddings + self.position_embedding(self.position_ids)
+ return embeddings
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->IdeficsVision
+class IdeficsVisionAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.embed_dim = self.embed_dim // weights.process_group.size()
+
+ self.k_proj = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.k_proj", weights=weights, bias=True
+ )
+ self.v_proj = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.v_proj", weights=weights, bias=True
+ )
+ self.q_proj = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.q_proj", weights=weights, bias=True
+ )
+ self.out_proj = TensorParallelRowLinear.load(
+ config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
+ )
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return (
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ .contiguous()
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scale
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ # apply the causal_attention_mask first
+ if causal_attention_mask is not None:
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {causal_attention_mask.size()}"
+ )
+ attn_weights = (
+ attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ + causal_attention_mask
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = (
+ attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ + attention_mask
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if output_attentions:
+ # this operation is a bit akward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(
+ bsz, self.num_heads, tgt_len, src_len
+ )
+ attn_weights = attn_weights_reshaped.view(
+ bsz * self.num_heads, tgt_len, src_len
+ )
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(
+ attn_weights, p=self.dropout, training=self.training
+ )
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->IdeficsVision
+class IdeficsVisionMLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.fc1", weights=weights, bias=True
+ )
+ self.fc2 = TensorParallelRowLinear.load(
+ config, prefix=f"{prefix}.fc2", weights=weights, bias=True
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->IdeficsVision
+class IdeficsVisionEncoderLayer(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = IdeficsVisionAttention(
+ prefix=f"{prefix}.self_attn", config=config, weights=weights
+ )
+ self.layer_norm1 = nn.LayerNorm.load(
+ prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
+ )
+ self.mlp = IdeficsVisionMLP(
+ prefix=f"{prefix}.mlp", config=config, weights=weights
+ )
+ self.layer_norm2 = nn.LayerNorm.load(
+ prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ causal_attention_mask: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ `(config.encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->IdeficsVision
+class IdeficsVisionEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`IdeficsVisionEncoderLayer`].
+
+ Args:
+ config: IdeficsVisionConfig
+ """
+
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList(
+ [
+ IdeficsVisionEncoderLayer(
+ prefix=f"{prefix}.encoder.layers.{layer_id}",
+ config=config,
+ weights=weights,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ # self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ hidden_states = inputs_embeds
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ # if self.gradient_checkpointing and self.training:
+
+ # def create_custom_forward(module):
+ # def custom_forward(*inputs):
+ # return module(*inputs, output_attentions)
+
+ # return custom_forward
+
+ # layer_outputs = torch.utils.checkpoint.checkpoint(
+ # create_custom_forward(encoder_layer),
+ # hidden_states,
+ # attention_mask,
+ # causal_attention_mask,
+ # )
+ # else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ causal_attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, encoder_states, all_attentions]
+ if v is not None
+ )
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=encoder_states,
+ attentions=all_attentions,
+ )
+
+
+# Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer
+class IdeficsVisionTransformer(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+
+ self.embeddings = IdeficsVisionEmbeddings(
+ prefix=f"{prefix}.embeddings", config=config, weights=weights
+ )
+ self.pre_layrnorm = nn.LayerNorm.load(
+ prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps
+ )
+ self.encoder = IdeficsVisionEncoder(
+ prefix=prefix, config=config, weights=weights
+ )
+ self.post_layernorm = nn.LayerNorm.load(
+ prefix=f"{prefix}.post_layernorm",
+ weights=weights,
+ eps=config.layer_norm_eps,
+ )
+
+ # copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+
+ """
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ hidden_states = self.embeddings(pixel_values)
+ hidden_states = self.pre_layrnorm(hidden_states)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ pooled_output = last_hidden_state[:, 0, :]
+ pooled_output = self.post_layernorm(pooled_output)
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py
new file mode 100644
index 000000000..00ecdf952
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/llava_next.py
@@ -0,0 +1,467 @@
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch Llava-NeXT model."""
+
+from typing import List, Optional, Union
+
+import torch
+import torch.utils.checkpoint
+import numpy as np
+
+from loguru import logger
+from transformers.models.llava_next.modeling_llava_next import (
+ unpad_image,
+)
+from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
+from transformers.image_processing_utils import select_best_resolution
+
+
+def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
+ """
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
+
+ Args:
+ image_size (`tuple`):
+ The size of the input image in the format (width, height).
+ grid_pinpoints (`List`):
+ A list containing possible resolutions. Each item in the list should be a tuple or list
+ of the form `(height, width)`.
+ patch_size (`int`):
+ The size of each image patch.
+
+ Returns:
+ tuple: The shape of the image patch grid in the format (width, height).
+ """
+ if not isinstance(grid_pinpoints, list):
+ raise ValueError("grid_pinpoints should be a list of tuples or lists")
+
+ height, width = select_best_resolution(image_size, grid_pinpoints)
+ return height // patch_size, width // patch_size
+
+
+# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79
+def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
+ """
+ Calculate the number of patches after the preprocessing for images of any resolution.
+
+ Args:
+ image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`):
+ The size of the input image in the format (height, width). ?
+ grid_pinpoints (`List`):
+ A list containing possible resolutions. Each item in the list should be a tuple or list
+ of the form `(height, width)`.
+ patch_size (`int`):
+ The size of each image patch.
+
+ Returns:
+ int: the number of patches
+ """
+ if not isinstance(grid_pinpoints, list):
+ raise TypeError("grid_pinpoints should be a list of tuples or lists")
+
+ # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
+ if not isinstance(image_size, (list, tuple)):
+ if not isinstance(image_size, (torch.Tensor, np.ndarray)):
+ raise TypeError(
+ f"image_size invalid type {type(image_size)} with value {image_size}"
+ )
+ image_size = image_size.tolist()
+
+ best_resolution = select_best_resolution(image_size, grid_pinpoints)
+ height, width = best_resolution
+ num_patches = 0
+ # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
+ for i in range(0, height, patch_size):
+ for j in range(0, width, patch_size):
+ num_patches += 1
+ # add the base patch
+ num_patches += 1
+ return num_patches
+
+
+class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ image_sizes: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[int] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ use_flash_attention: Optional[bool] = True,
+ flash_attention_recompute: Optional[bool] = True,
+ ):
+
+ if token_idx is not None:
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ token_idx=token_idx,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=flash_attention_recompute,
+ )
+
+ logits = outputs[0]
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return output
+
+ return outputs
+
+ # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L411
+ def pack_image_features(
+ self,
+ image_features,
+ image_sizes,
+ vision_feature_select_strategy,
+ image_newline=None,
+ ):
+ """
+ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
+
+ Args:
+ image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
+ List of image feature tensor, each contains all the visual feature of all patches.
+ image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
+ Actual image size of each images (H, W).
+ vision_feature_select_strategy (`str`)
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ image_newline (`torch.Tensor` of shape `(embed_dim)`)
+ New line embedding vector.
+ Returns:
+ image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
+ feature_lens (`List[int]`)
+ token length of each image in image_features
+ """
+ new_image_features = []
+ feature_lens = []
+ for image_idx, image_feature in enumerate(image_features):
+ if image_feature.shape[0] > 1:
+ base_image_feature = image_feature[0]
+ image_feature = image_feature[1:]
+ height = width = (
+ self.config.vision_config.image_size
+ // self.config.vision_config.patch_size
+ )
+
+ num_patch_height, num_patch_width = get_anyres_image_grid_shape(
+ image_sizes[image_idx],
+ self.config.image_grid_pinpoints,
+ self.config.vision_config.image_size,
+ )
+
+ if (
+ np.prod(image_feature.shape)
+ % (num_patch_height * num_patch_width * height * width)
+ != 0
+ and vision_feature_select_strategy == "default"
+ ):
+ logger.warning_once(
+ "Image feature shape does not line up with the provided patch size. "
+ "You may be using the `default` vision_feature_select_strategy with a"
+ " visual encoder that does not have CLS."
+ )
+
+ image_feature = image_feature.view(
+ num_patch_height, num_patch_width, height, width, -1
+ )
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
+ if image_newline is not None:
+ image_feature = torch.cat(
+ (
+ image_feature,
+ image_newline[:, None, None]
+ .expand(*image_feature.shape[:-1], 1)
+ .to(image_feature.device, image_feature.dtype),
+ ),
+ dim=-1,
+ )
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
+ else:
+ image_feature = image_feature[0]
+ if image_newline is not None:
+ image_feature = torch.cat(
+ (image_feature, image_newline[None].to(image_feature)), dim=0
+ )
+ new_image_features.append(image_feature)
+ feature_lens.append(image_feature.size(0))
+ image_features = torch.cat(new_image_features, dim=0)
+ feature_lens = torch.tensor(
+ feature_lens, dtype=torch.long, device=image_features.device
+ )
+ return image_features, feature_lens
+
+ # Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ image_sizes: torch.Tensor,
+ vision_feature_layer: Union[int, List[int]],
+ vision_feature_select_strategy: str,
+ ):
+ """
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
+ The tensors corresponding to the input images.
+ image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
+ Actual image size of each images (H, W).
+ vision_feature_layer (`Union[int, List[int]]`):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ vision_feature_select_strategy (`str`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`
+ Returns:
+ image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
+ and are of shape `(num_patches, image_length, embed_dim)`).
+ """
+ # ! infer image_num_patches from image_sizes
+ image_num_patches = [
+ image_size_to_num_patches(
+ image_size=imsize,
+ grid_pinpoints=self.config.image_grid_pinpoints,
+ patch_size=self.config.vision_config.image_size,
+ )
+ for imsize in image_sizes
+ ]
+ if pixel_values.dim() == 5:
+ # stacked if input is (batch_size, num_patches, num_channels, height, width)
+ _pixel_values_list = [
+ pix_val[:num_patch]
+ for pix_val, num_patch in zip(pixel_values, image_num_patches)
+ ]
+ pixel_values = torch.cat(_pixel_values_list, dim=0)
+ elif pixel_values.dim() != 4:
+ # otherwise has to be stacked from list of (num_patches, num_channels, height, width)
+ raise ValueError(
+ f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions"
+ )
+
+ image_features = self.vision_tower(pixel_values, output_hidden_states=True)
+ # If we have one vision feature layer, return the corresponding hidden states,
+ # otherwise, select the hidden states of each feature layer and concatenate them
+ if isinstance(vision_feature_layer, int):
+ selected_image_feature = image_features.hidden_states[vision_feature_layer]
+ else:
+ hs_pool = [
+ image_features.hidden_states[layer_idx]
+ for layer_idx in vision_feature_layer
+ ]
+ selected_image_feature = torch.cat(hs_pool, dim=-1)
+
+ if vision_feature_select_strategy == "default":
+ selected_image_feature = selected_image_feature[:, 1:]
+ elif vision_feature_select_strategy == "full":
+ selected_image_feature = selected_image_feature
+
+ image_features = self.multi_modal_projector(selected_image_feature)
+ image_features = torch.split(image_features, image_num_patches, dim=0)
+ return image_features
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ pixel_values=None,
+ image_sizes=None,
+ attention_mask=None,
+ **kwargs,
+ ):
+ """
+ Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635
+ The only differences are:
+ - add new args token_idx
+ - add the process of merging images into inputs_embeds
+ """
+ token_idx = kwargs.get("token_idx", None)
+ if token_idx is None:
+ return super().prepare_inputs_for_generation(
+ input_ids=input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ pixel_values=pixel_values,
+ image_sizes=image_sizes,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ else:
+ use_flash_attention = kwargs.get("use_flash_attention", True)
+ flash_attention_recompute = kwargs.get("flash_attention_recompute", True)
+
+ position_ids = kwargs.get("position_ids", None)
+ labels = kwargs.get("labels", None)
+ if (
+ past_key_values is None
+ and pixel_values is not None
+ and input_ids.shape[1] != 1
+ ):
+ vision_feature_select_strategy = kwargs.get(
+ "vision_feature_select_strategy", None
+ )
+ vision_feature_layer = kwargs.get("vision_feature_layer", None)
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+ vision_feature_layer = (
+ vision_feature_layer
+ if vision_feature_layer is not None
+ else self.config.vision_feature_layer
+ )
+
+ # 1. Extract the input embeddings
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+ # 2. Merge text and images
+ image_features = self.get_image_features(
+ pixel_values,
+ image_sizes,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ )
+
+ # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
+ image_features, feature_lens = self.pack_image_features(
+ image_features,
+ image_sizes,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ image_newline=self.image_newline,
+ )
+
+ special_image_mask = (
+ input_ids == self.config.image_token_index
+ ).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
+ inputs_embeds.device
+ )
+ if inputs_embeds[special_image_mask].numel() != image_features.numel():
+ n_image_tokens = (input_ids == self.config.image_token_index).sum()
+ n_image_features = image_features.shape[0]
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+
+ image_features = image_features.to(
+ inputs_embeds.device, inputs_embeds.dtype
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(
+ special_image_mask, image_features
+ )
+
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
+ # generation with cache
+ elif past_key_values is not None:
+ seq_len = input_ids.shape[1]
+ pad_len = seq_len - token_idx.item()
+ input_ids = torch.index_select(input_ids, 1, token_idx - 1)
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
+ # that are set to 0
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
+ batch_index, non_attended_tokens = torch.where(
+ first_layer_past_key_value.float().sum(-2) == 0
+ )
+ # Get the target length
+ past_length = first_layer_past_key_value.shape[-1]
+ extended_attention_mask = torch.ones(
+ (attention_mask.shape[0], past_length),
+ dtype=attention_mask.dtype,
+ device=attention_mask.device,
+ )
+ # Filter out only the tokens that can be un-attended, this can happen
+ # if one uses Llava + Fused modules where the cache on the
+ # first iteration is already big enough, or if one passes custom cache
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
+ new_batch_index = batch_index[valid_indices]
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
+
+ # Zero-out the places where we don't need to attend
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
+
+ attention_mask = extended_attention_mask
+ attention_mask[:, -pad_len:] = 0
+
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ if token_idx is not None:
+ position_ids = (
+ torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
+ )
+ else:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ "token_idx": token_idx,
+ "labels": labels,
+ "use_flash_attention": use_flash_attention,
+ "flash_attention_recompute": flash_attention_recompute,
+ }
+ )
+
+ return model_inputs
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/mamba_modeling.py
new file mode 100644
index 000000000..5a9c05887
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/mamba_modeling.py
@@ -0,0 +1,238 @@
+import torch
+import torch.distributed
+
+from mamba_ssm.ops.triton.selective_state_update import selective_state_update
+from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
+from torch import nn
+from typing import Optional, Tuple, Any
+from transformers.configuration_utils import PretrainedConfig
+import torch.nn.functional as F
+
+from text_generation_server.layers import (
+ SpeculativeHead,
+ TensorParallelEmbedding,
+ FastLinear,
+)
+from text_generation_server.layers.layernorm import FastRMSNorm
+
+from einops import rearrange
+from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
+import math
+from dataclasses import dataclass
+
+
+@dataclass
+class InferenceParams:
+ """Inference parameters that are passed to the main model in order
+ to efficienly calculate and store the context during inference."""
+
+ max_seqlen: int
+ max_batch_size: int
+ conv_states: torch.Tensor
+ ssm_states: torch.Tensor
+ seqlen_offset: int
+
+
+class MambaConfig(PretrainedConfig):
+ def __init__(
+ self,
+ vocab_size=50280,
+ d_model=768,
+ d_state=16,
+ n_layer=32,
+ layer_norm_epsilon=1e-5,
+ tie_word_embeddings=False,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ expand=2,
+ dt_rank="auto",
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.n_layer = n_layer
+ self.layer_norm_epsilon = layer_norm_epsilon
+ self.d_model = d_model
+ self.d_inner = d_model * 2
+ self.d_conv = 4
+ self.d_state = d_state
+ self.expand = expand
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+class MambaBlock(nn.Module):
+ def __init__(self, prefix, config, weights, layer_id):
+ super().__init__()
+ self.layer_id = layer_id
+ self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False)
+ self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False)
+ self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True)
+ self.dt_proj_no_bias = FastLinear.load(
+ config, f"{prefix}.dt_proj", weights, bias=False
+ )
+ self.out_proj = FastLinear.load(
+ config, f"{prefix}.out_proj", weights, bias=False
+ )
+ self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True)
+ self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float())
+ self.D = weights.get_tensor(f"{prefix}.D")
+ self.activation = "silu"
+ self.dt_rank = config.dt_rank
+ self.d_state = config.d_state
+ self.d_conv = config.d_conv
+ self.act = nn.SiLU()
+
+ # inference_params
+ def forward(self, hidden_states: torch.Tensor, inference_params=None):
+ if inference_params.seqlen_offset > 0:
+ conv_state = inference_params.conv_states[self.layer_id]
+ ssm_state = inference_params.ssm_states[self.layer_id]
+ out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state)
+ return out, conv_state, ssm_state
+
+ _, seqlen, _ = hidden_states.shape
+ projected_states = self.in_proj(hidden_states).transpose(1, 2)
+ # assert projected_states.shape == [batch_size, 2 * dstate, seqlen], f"{projected_states.shape} [{batch_size}, {dstate}, {seqlen}]"
+ x, z = projected_states.chunk(2, dim=1)
+ conv_state = F.pad(x, (self.d_conv - seqlen, 0))
+ x = causal_conv1d_fn(
+ x=x,
+ weight=self.conv1d.weight.squeeze(1),
+ bias=self.conv1d.bias,
+ activation=self.activation,
+ )
+
+ # We're careful here about the layout, to avoid extra transposes.
+ # We want dt to have d as the slowest moving dimension
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
+ dt, B, C = torch.split(
+ x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
+ )
+ dt = self.dt_proj.weight @ dt.t()
+ dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
+ y, last_state = selective_scan_fn(
+ x,
+ dt,
+ self.negA,
+ B,
+ C,
+ self.D.float(),
+ z=z,
+ delta_bias=self.dt_proj.bias.float(),
+ delta_softplus=True,
+ return_last_state=True,
+ )
+ y = rearrange(y, "b d l -> b l d")
+ attn_outputs = self.out_proj(y)
+ return attn_outputs, conv_state, last_state
+
+ def step(self, hidden_states, conv_state, ssm_state):
+ xz = self.in_proj(hidden_states.squeeze(1))
+ x, z = xz.chunk(2, dim=-1) # (B D)
+ x = causal_conv1d_update(
+ x,
+ conv_state,
+ self.conv1d.weight.squeeze(1),
+ self.conv1d.bias,
+ self.activation,
+ )
+ x_db = self.x_proj(x) # (B dt_rank+2*d_state)
+ dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
+ dt = F.linear(dt, self.dt_proj.weight)
+ A = self.negA
+ y = selective_state_update(
+ ssm_state,
+ x,
+ dt,
+ A,
+ B,
+ C,
+ self.D,
+ z=z,
+ dt_bias=self.dt_proj.bias,
+ dt_softplus=True,
+ )
+ out = self.out_proj(y)
+ return out.unsqueeze(1), conv_state.clone(), ssm_state.clone()
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, prefix, config, weights, layer_id):
+ super().__init__()
+ self.mamba_block = MambaBlock(
+ prefix=f"{prefix}.mixer", config=config, weights=weights, layer_id=layer_id
+ )
+ self.layer_norm = FastRMSNorm.load(
+ prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_epsilon
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ residual: Optional[torch.Tensor] = None,
+ inference_params: Optional[Any] = None,
+ ):
+ residual = (hidden_states + residual) if residual is not None else hidden_states
+ shape = residual.shape
+ hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1]))
+ hidden_states, conv_state, last_ssm_state = self.mamba_block(
+ hidden_states.view(*shape), inference_params
+ )
+ return hidden_states, residual, conv_state, last_ssm_state
+
+
+class MambaModel(nn.Module):
+ def __init__(self, config, weights):
+ super().__init__()
+ prefix = "backbone"
+ try:
+ self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embeddings", weights)
+ except RuntimeError:
+ self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
+ self.blocks = nn.ModuleList(
+ [
+ ResidualBlock(f"{prefix}.layers.{i}", config, weights, layer_id=i)
+ for i in range(config.n_layer)
+ ]
+ )
+ self.norm_f = FastRMSNorm.load(
+ f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
+ )
+ try:
+ self.lm_head = SpeculativeHead.load(config, f"{prefix}.embeddings", weights)
+ except RuntimeError:
+ self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights)
+ self.config = config
+
+ def forward(
+ self, input_ids: torch.Tensor, inference_params=None, residual=None
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ hidden_states = self.embed_tokens(input_ids)
+ for i, block in enumerate(self.blocks):
+ hidden_states, residual, conv_state, ssm_state = block(
+ hidden_states, residual, inference_params
+ )
+ inference_params.conv_states[i].copy_(conv_state)
+ inference_params.ssm_states[i].copy_(ssm_state)
+
+ hidden_states = (
+ hidden_states + residual if residual is not None else hidden_states
+ )
+ hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
+ hidden_states = hidden_states.view(residual.shape)
+ logits, speculative_logits = self.lm_head(hidden_states)
+
+ # update the offset for the next inference using these params
+ inference_params.seqlen_offset += input_ids.size(1)
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py
new file mode 100644
index 000000000..6ba0ffff8
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/mllama.py
@@ -0,0 +1,292 @@
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Mllama model."""
+
+from typing import Optional, Tuple, List, Union
+
+import torch
+import torch.utils.checkpoint
+
+from optimum.habana.transformers.models import GaudiMllamaForConditionalGeneration
+from optimum.habana.transformers.models.mllama.modeling_mllama import (
+ _prepare_cross_attention_mask,
+)
+from transformers.modeling_outputs import CausalLMOutputWithPast
+
+
+class MllamaForConditionalGeneration(GaudiMllamaForConditionalGeneration):
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ aspect_ratio_mask: Optional[torch.Tensor] = None,
+ aspect_ratio_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_states: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ num_logits_to_keep: int = 0,
+ token_idx: Optional[torch.Tensor] = None,
+ use_flash_attention: Optional[bool] = True,
+ flash_attention_recompute: Optional[bool] = True,
+ **kwargs,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ """
+ Copied from MllamaForConditionalGeneration::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2077
+ The only differences are:
+ - add token_idx input
+ - add use_flash_attention and flash_attention_recompute
+ """
+ full_text_row_masked_out_mask = kwargs.get(
+ "full_text_row_masked_out_mask", None
+ )
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ outputs = self.language_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ cross_attention_states=cross_attention_states,
+ cross_attention_mask=cross_attention_mask,
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ output_hidden_states=output_hidden_states,
+ output_attentions=output_attentions,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ num_logits_to_keep=num_logits_to_keep,
+ token_idx=token_idx,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=flash_attention_recompute,
+ )
+
+ logits = outputs[0]
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return output
+
+ return outputs
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids=None,
+ inputs_embeds=None,
+ attention_mask=None,
+ position_ids=None,
+ pixel_values=None,
+ aspect_ratio_ids=None,
+ aspect_ratio_mask=None,
+ cross_attention_mask=None,
+ past_key_values=None,
+ use_cache=False,
+ cache_position=None,
+ num_logits_to_keep=None,
+ **kwargs,
+ ):
+ """
+ Copied from MllamaForConditionalGeneration::prepare_inputs_for_generation: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2208
+ The only differences are:
+ - add token_idx handling
+ - add bucket_internal handling
+ - add use_flash_attention and flash_attention_recompute
+ """
+
+ token_idx = kwargs.get("token_idx", None)
+ if token_idx is None:
+ return super().prepare_inputs_for_generation(
+ input_ids=input_ids,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ aspect_ratio_ids=aspect_ratio_ids,
+ aspect_ratio_mask=aspect_ratio_mask,
+ cross_attention_mask=cross_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ else:
+ use_flash_attention = kwargs.get("use_flash_attention", True)
+ flash_attention_recompute = kwargs.get("flash_attention_recompute", True)
+ position_ids = kwargs.get("position_ids", None)
+ output_attentions = kwargs.get("output_attentions", None)
+ output_hidden_states = kwargs.get("output_hidden_states", None)
+ return_dict = kwargs.get("return_dict", None)
+ labels = kwargs.get("labels", None)
+ cross_attention_states = kwargs.get("cross_attention_states", None)
+
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+ bucket_internal = kwargs.get("bucket_internal", None)
+
+ if past_key_values is not None:
+ if token_idx is not None:
+ input_ids = torch.index_select(input_ids, 1, token_idx - 1)
+ elif inputs_embeds is not None: # Exception 1
+ input_ids = input_ids[:, -cache_position.shape[0] :]
+ elif (
+ input_ids.shape[1] != cache_position.shape[0]
+ ): # Default case (the "else", a no op, is Exception 2)
+ input_ids = input_ids[:, cache_position]
+ elif bucket_internal and token_idx is not None:
+ # for the 1st token we can slice the inputs till token idx for the fwd pass.
+ input_ids = input_ids[:, :token_idx]
+ attention_mask = attention_mask[:, :token_idx]
+ if cross_attention_mask is not None:
+ cross_attention_mask = cross_attention_mask[:, :token_idx, ...]
+
+ # TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ if token_idx is not None:
+ position_ids = torch.index_select(
+ position_ids, 1, token_idx - 1
+ )
+ else:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
+ position_ids = position_ids.clone(
+ memory_format=torch.contiguous_format
+ )
+
+ if pixel_values is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if pixel_values is not None and cross_attention_states is not None:
+ raise ValueError(
+ "`pixel_values` and `cross_attention_states` cannot be provided simultaneously"
+ )
+
+ if pixel_values is not None:
+ if aspect_ratio_ids is None:
+ raise ValueError(
+ "`aspect_ratio_ids` must be provided if `pixel_values` is provided"
+ )
+ # get vision tokens from vision model
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ aspect_ratio_ids=aspect_ratio_ids,
+ aspect_ratio_mask=aspect_ratio_mask,
+ output_hidden_states=output_hidden_states,
+ output_attentions=output_attentions,
+ return_dict=return_dict,
+ use_flash_attention=use_flash_attention,
+ )
+ cross_attention_states = vision_outputs[0]
+ cross_attention_states = self.multi_modal_projector(
+ cross_attention_states
+ ).reshape(-1, cross_attention_states.shape[-2], self.hidden_size)
+
+ if cross_attention_mask is not None:
+ cross_attention_mask, full_text_row_masked_out_mask = (
+ _prepare_cross_attention_mask(
+ cross_attention_mask,
+ num_vision_tokens=self.vision_model.num_patches,
+ dtype=self.dtype,
+ token_idx=token_idx,
+ )
+ )
+ else:
+ full_text_row_masked_out_mask = None
+
+ if cross_attention_mask is not None:
+ if cache_position is not None:
+ cross_attention_mask = cross_attention_mask[:, :, cache_position]
+ full_text_row_masked_out_mask = full_text_row_masked_out_mask[
+ :, :, cache_position
+ ]
+ elif past_key_values is not None:
+ if token_idx is not None:
+ cross_attention_mask = torch.index_select(
+ cross_attention_mask, -2, token_idx - 1
+ )
+ full_text_row_masked_out_mask = torch.index_select(
+ full_text_row_masked_out_mask, -2, token_idx - 1
+ )
+ else:
+ cross_attention_mask = cross_attention_mask[:, :, -1:]
+ full_text_row_masked_out_mask = full_text_row_masked_out_mask[
+ :, :, -1:
+ ]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
+ else:
+ # The clone here is for the same reason as for `position_ids`.
+ model_inputs = {
+ "input_ids": input_ids.clone(memory_format=torch.contiguous_format),
+ "inputs_embeds": None,
+ }
+
+ if num_logits_to_keep is not None:
+ model_inputs["num_logits_to_keep"] = num_logits_to_keep
+
+ # keep cache_position implementation as None for HPU
+ cache_position = None
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ "token_idx": token_idx,
+ "labels": labels,
+ "return_dict": kwargs.get("return_dict"),
+ "full_text_row_masked_out_mask": full_text_row_masked_out_mask,
+ "use_flash_attention": use_flash_attention,
+ "cross_attention_mask": cross_attention_mask,
+ "cross_attention_states": cross_attention_states,
+ "output_attentions": output_attentions,
+ "flash_attention_recompute": flash_attention_recompute,
+ }
+ )
+
+ return model_inputs
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py
new file mode 100644
index 000000000..441b0016e
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py
@@ -0,0 +1,946 @@
+# coding=utf-8
+# Copyright 2025 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Qwen2.5 VL model."""
+
+from typing import Optional, Tuple, List
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from habana_frameworks.torch.hpex.kernels import FusedSDPA
+from vllm_hpu_extension.utils import ModuleFusedSDPA
+
+
+import numpy as np
+
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+
+import torch.nn.functional as F
+
+from text_generation_server.layers.layernorm import FastRMSNorm
+from text_generation_server.layers import (
+ TensorParallelColumnLinear,
+ TensorParallelRowLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+)
+from text_generation_server.layers.attention import (
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
+ Qwen2Model,
+)
+
+# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py
+from typing import Union
+from transformers.feature_extraction_utils import BatchFeature
+from transformers.image_utils import ImageInput, VideoInput
+from transformers.processing_utils import (
+ ProcessingKwargs,
+ ProcessorMixin,
+ Unpack,
+ VideosKwargs,
+)
+from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
+
+
+class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False):
+ fps: Union[List[float], float]
+
+
+class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
+ videos_kwargs: Qwen2_5_VLVideosProcessorKwargs
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ },
+ "videos_kwargs": {"fps": 2.0},
+ }
+
+
+class Qwen2_5_VLProcessor(ProcessorMixin):
+ r"""
+ Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor.
+ [`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
+ [`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information.
+ Args:
+ image_processor ([`Qwen2VLImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`Qwen2TokenizerFast`], *optional*):
+ The tokenizer is a required input.
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ valid_kwargs = ["chat_template"]
+
+ image_processor_class = "AutoImageProcessor"
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
+
+ def __init__(
+ self, image_processor=None, tokenizer=None, chat_template=None, **kwargs
+ ):
+ self.image_token = (
+ "<|image_pad|>"
+ if not hasattr(tokenizer, "image_token")
+ else tokenizer.image_token
+ )
+ self.video_token = (
+ "<|video_pad|>"
+ if not hasattr(tokenizer, "video_token")
+ else tokenizer.video_token
+ )
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+ def __call__(
+ self,
+ images: ImageInput = None,
+ text: Union[
+ TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
+ ] = None,
+ videos: VideoInput = None,
+ **kwargs: Unpack[Qwen2_5_VLProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
+ the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
+ Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
+ tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
+ - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
+ - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
+ - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
+ """
+ output_kwargs = self._merge_kwargs(
+ Qwen2_5_VLProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+ if images is not None:
+ image_inputs = self.image_processor(
+ images=images, videos=None, **output_kwargs["images_kwargs"]
+ )
+ image_grid_thw = image_inputs["image_grid_thw"]
+ else:
+ image_inputs = {}
+ image_grid_thw = None
+
+ if videos is not None:
+ videos_inputs = self.image_processor(
+ images=None, videos=videos, **output_kwargs["images_kwargs"]
+ )
+ video_grid_thw = videos_inputs["video_grid_thw"]
+
+ fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
+ if isinstance(fps, (int, float)):
+ second_per_grid_ts = [
+ self.image_processor.temporal_patch_size / fps
+ ] * len(video_grid_thw)
+ elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
+ second_per_grid_ts = [
+ self.image_processor.temporal_patch_size / tmp for tmp in fps
+ ]
+ else:
+ raise ValueError(
+ f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
+ )
+ videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})
+
+ else:
+ videos_inputs = {}
+ video_grid_thw = None
+
+ if not isinstance(text, list):
+ text = [text]
+
+ if image_grid_thw is not None:
+ merge_length = self.image_processor.merge_size**2
+ index = 0
+ for i in range(len(text)):
+ while self.image_token in text[i]:
+ text[i] = text[i].replace(
+ self.image_token,
+ "<|placeholder|>"
+ * (image_grid_thw[index].prod() // merge_length),
+ 1,
+ )
+ index += 1
+ text[i] = text[i].replace("<|placeholder|>", self.image_token)
+
+ if video_grid_thw is not None:
+ merge_length = self.image_processor.merge_size**2
+ index = 0
+ for i in range(len(text)):
+ while self.video_token in text[i]:
+ text[i] = text[i].replace(
+ self.video_token,
+ "<|placeholder|>"
+ * (video_grid_thw[index].prod() // merge_length),
+ 1,
+ )
+ index += 1
+ text[i] = text[i].replace("<|placeholder|>", self.video_token)
+
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
+
+ return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ def post_process_image_text_to_text(self, generated_outputs):
+ """
+ Post-process the output of the model to decode the text.
+
+ Args:
+ generated_outputs (`torch.Tensor` or `np.ndarray`):
+ The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
+ or `(sequence_length,)`.
+
+ Returns:
+ `List[str]`: The decoded text.
+ """
+ return self.tokenizer.batch_decode(
+ generated_outputs,
+ skip_special_tokens=True,
+ clean_up_tokenization_spaces=False,
+ )
+
+ @property
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+ names_from_processor = list(
+ dict.fromkeys(tokenizer_input_names + image_processor_input_names)
+ )
+ return names_from_processor + ["second_per_grid_ts"]
+
+
+# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py
+class Qwen2_5_VLVisionConfig(PretrainedConfig):
+ model_type = "qwen2_5_vl"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ depth=32,
+ hidden_size=3584,
+ hidden_act="silu",
+ intermediate_size=3420,
+ num_heads=16,
+ in_channels=3,
+ patch_size=14,
+ spatial_merge_size=2,
+ spatial_patch_size=14,
+ temporal_patch_size=2,
+ tokens_per_second=4,
+ window_size=112,
+ out_hidden_size=3584,
+ fullatt_block_indexes=[7, 15, 23, 31],
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.depth = depth
+ self.hidden_size = hidden_size
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.patch_size = patch_size
+ self.spatial_patch_size = spatial_patch_size
+ self.spatial_merge_size = spatial_merge_size
+ self.temporal_patch_size = temporal_patch_size
+ self.tokens_per_second = tokens_per_second
+ self.window_size = window_size
+ self.fullatt_block_indexes = fullatt_block_indexes
+ self.out_hidden_size = out_hidden_size
+
+
+class Qwen2_5_VLConfig(PretrainedConfig):
+
+ def __init__(
+ self,
+ vocab_size=152064,
+ hidden_size=8192,
+ intermediate_size=29568,
+ num_hidden_layers=80,
+ num_attention_heads=64,
+ num_key_value_heads=8,
+ hidden_act="silu",
+ max_position_embeddings=32768,
+ initializer_range=0.02,
+ rms_norm_eps=1e-05,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=1000000.0,
+ use_sliding_window=False,
+ sliding_window=4096,
+ max_window_layers=80,
+ attention_dropout=0.0,
+ vision_config=None,
+ rope_scaling=None,
+ **kwargs,
+ ):
+ if vision_config is not None:
+ self.vision_config = Qwen2_5_VLVisionConfig(**vision_config)
+
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.use_sliding_window = use_sliding_window
+ self.sliding_window = sliding_window
+ self.max_window_layers = max_window_layers
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_dropout = attention_dropout
+ self.rope_scaling = rope_scaling
+
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, move it to 'rope_type'.
+ # and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations
+ # one can set it to "linear"/"dynamic" etc. to have scaled RoPE
+ # TODO: @raushan update config in the hub
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ if self.rope_scaling["type"] == "mrope":
+ self.rope_scaling["type"] = "default"
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+
+# Copied from transformers.models.llama.modeling_llama.rotate_half
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb_vision(
+ tensor: torch.Tensor, freqs: torch.Tensor
+) -> torch.Tensor:
+ orig_dtype = tensor.dtype
+ tensor = tensor.float()
+ cos = freqs.cos()
+ sin = freqs.sin()
+ cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
+ sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
+ output = output.to(orig_dtype)
+ return output
+
+
+class Qwen2_5VLAttention(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+ self.embed_dim = config.hidden_size // weights.process_group.size()
+ self.head_dim = config.hidden_size // config.num_heads
+ self.num_heads = config.num_heads // weights.process_group.size()
+
+ self.qkv = TensorParallelColumnLinear.load_qkv(
+ config,
+ prefix=f"{prefix}.qkv",
+ weights=weights,
+ bias=False,
+ num_heads=self.num_heads,
+ num_key_value_heads=self.num_heads,
+ )
+ self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
+
+ self.proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.proj",
+ weights=weights,
+ bias=True,
+ )
+ self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads)
+
+ def forward(
+ self,
+ hidden_state: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: torch.Tensor,
+ max_seqlen: int,
+ ) -> torch.Tensor:
+ # apply the qkv linear layer to the hidden state
+ qkv = self.qkv(hidden_state)
+ query, key, value = qkv.split(
+ [self.embed_dim, self.embed_dim, self.embed_dim], dim=1
+ )
+
+ # reshape the query, key, and value tensors
+ _shape = (
+ hidden_state.shape[0],
+ self.num_heads,
+ self.embed_dim // self.num_heads,
+ )
+ query = query.view(*_shape)
+ key = key.view(*_shape)
+ value = value.view(*_shape)
+
+ # apply rotary positional embeddings
+ query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
+ 0
+ )
+ key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
+
+ # calc maximum sequence length for any batch
+ query = query.contiguous()
+ key = key.contiguous()
+ value = value.contiguous()
+ causal = False
+
+ # execute sdpa
+ query = query.unsqueeze(0).transpose(1, 2)
+ key = key.unsqueeze(0).transpose(1, 2)
+ value = value.unsqueeze(0).transpose(1, 2)
+ fsdpa_op = ModuleFusedSDPA(FusedSDPA)
+ attn_output = fsdpa_op(
+ query,
+ key,
+ value,
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=causal,
+ scale=None,
+ softmax_mode="None",
+ recompute_mode=None,
+ valid_sequence_lengths=None,
+ )
+ attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
+
+ # reshape output to original dimensions
+ attn_output = attn_output.reshape(hidden_state.shape[0], -1)
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class Qwen2_5VLVisionMLP(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+ self.activation_fn = ACT2FN[config.hidden_act]
+
+ self.intermediate_size = (
+ config.intermediate_size // weights.process_group.size()
+ )
+
+ self.up = TensorParallelColumnLinear.load(
+ prefix=f"{prefix}.up_proj", weights=weights, config=config, bias=True
+ )
+ self.gate = TensorParallelColumnLinear.load(
+ prefix=f"{prefix}.gate_proj", weights=weights, config=config, bias=True
+ )
+ self.down = TensorParallelRowLinear.load(
+ prefix=f"{prefix}.down_proj", weights=weights, config=config, bias=True
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ gate_states = self.gate(hidden_states)
+ up_states = self.up(hidden_states)
+ activated_states = self.activation_fn(gate_states) * up_states
+ down_states = self.down(activated_states)
+ return down_states
+
+
+class Qwen2_5VLVisionBlock(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.attn = Qwen2_5VLAttention(
+ prefix=f"{prefix}.attn",
+ config=config,
+ weights=weights,
+ )
+ self.norm1 = FastRMSNorm.load(
+ prefix=f"{prefix}.norm1",
+ weights=weights,
+ eps=1e-6,
+ )
+ self.norm2 = FastRMSNorm.load(
+ prefix=f"{prefix}.norm2",
+ weights=weights,
+ eps=1e-6,
+ )
+ self.mlp = Qwen2_5VLVisionMLP(
+ prefix=f"{prefix}.mlp",
+ config=config,
+ weights=weights,
+ )
+
+ def forward(
+ self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
+ ) -> torch.Tensor:
+ norm1_out, _ = self.norm1(hidden_states)
+ attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
+ hidden_states = hidden_states + attn_out
+ norm2_out, _ = self.norm2(hidden_states)
+ mlp_out = self.mlp(norm2_out)
+ hidden_states = hidden_states + mlp_out
+ return hidden_states
+
+
+class Qwen2_5VLPatchMerger(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+ self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
+ self.patch_merger_ln_q = FastRMSNorm.load(
+ prefix=f"{prefix}.ln_q",
+ weights=weights,
+ eps=1e-6,
+ )
+ self.fc1 = TensorParallelColumnLinear.load(
+ prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True
+ )
+ self.fc2 = TensorParallelRowLinear.load(
+ prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
+ )
+
+ def forward(self, hidden_states) -> torch.Tensor:
+ hidden_states, _ = self.patch_merger_ln_q(hidden_states)
+ hidden_states = hidden_states.view(-1, self.hidden_size)
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = F.gelu(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Qwen2_5VisionModel(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+
+ self.spatial_merge_size = config.spatial_merge_size
+ kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]
+ self.patch_embedding = nn.Conv3d(
+ in_channels=config.in_channels,
+ out_channels=config.hidden_size,
+ kernel_size=kernel_size,
+ stride=kernel_size,
+ bias=False,
+ )
+ self.patch_embedding.weight = nn.Parameter(
+ weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False
+ )
+ head_dim = config.hidden_size // config.num_heads
+
+ theta = 10000.0
+ dim = head_dim // 2
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ self.blocks = nn.ModuleList(
+ [
+ Qwen2_5VLVisionBlock(
+ prefix=f"{prefix}.blocks.{i}",
+ config=config,
+ weights=weights,
+ )
+ for i in range(config.depth)
+ ]
+ )
+ self.merger = Qwen2_5VLPatchMerger(
+ prefix=f"{prefix}.merger",
+ config=config,
+ weights=weights,
+ )
+ # import ipdb; ipdb.set_trace()
+ self.temporal_patch_size = config.temporal_patch_size
+ self.spatial_patch_size = config.spatial_patch_size
+ self.in_channels = config.in_channels
+ self.embed_dim = config.hidden_size
+ self.window_size = config.window_size
+ self.patch_size = config.patch_size
+ self.spatial_merge_unit = config.spatial_merge_size * config.spatial_merge_size
+ self.fullatt_block_indexes = config.fullatt_block_indexes
+
+ def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ batch_size, _, hidden_size = hidden_state.shape
+ class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
+ hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
+ return hidden_state
+
+ def get_window_index(self, grid_thw):
+ window_index: list = []
+ cu_window_seqlens: list = [0]
+ window_index_id = 0
+ vit_merger_window_size = (
+ self.window_size // self.spatial_merge_size // self.patch_size
+ )
+
+ for grid_t, grid_h, grid_w in grid_thw:
+ llm_grid_h, llm_grid_w = (
+ grid_h // self.spatial_merge_size,
+ grid_w // self.spatial_merge_size,
+ )
+ index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
+ grid_t, llm_grid_h, llm_grid_w
+ )
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
+ num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
+ num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
+ index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
+ index_padded = index_padded.reshape(
+ grid_t,
+ num_windows_h,
+ vit_merger_window_size,
+ num_windows_w,
+ vit_merger_window_size,
+ )
+ index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
+ grid_t,
+ num_windows_h * num_windows_w,
+ vit_merger_window_size,
+ vit_merger_window_size,
+ )
+ seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
+ index_padded = index_padded.reshape(-1)
+ index_new = index_padded[index_padded != -100]
+ window_index.append(index_new + window_index_id)
+ cu_seqlens_tmp = (
+ seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
+ )
+ cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
+ window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
+ window_index = torch.cat(window_index, dim=0)
+
+ return window_index, cu_window_seqlens
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ grid_thw: Optional[torch.LongTensor] = None,
+ ) -> torch.Tensor:
+
+ # reshape the input tensor for processing
+ shape = (
+ -1,
+ self.in_channels,
+ self.temporal_patch_size,
+ self.spatial_patch_size,
+ self.spatial_patch_size,
+ )
+ pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)
+ hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)
+ # TODO: revisit to see if we can avoid some of these reshapes
+
+ # find the position ids for the input tensor based on the grid_thw
+ pos_ids = []
+ for t, h, w in grid_thw:
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
+ hpos_ids = hpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
+ hpos_ids = hpos_ids.flatten()
+
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
+ wpos_ids = wpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
+ wpos_ids = wpos_ids.flatten()
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
+
+ pos_ids = torch.cat(pos_ids, dim=0)
+
+ max_grid_size = grid_thw[:, 1:].max()
+
+ # apply the positional embeddings to the position ids
+ seq = torch.arange(
+ max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype
+ )
+ rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
+ window_index, cu_window_seqlens = self.get_window_index(grid_thw)
+ seq_len = hidden_states.shape[0]
+ patch_shape = (seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
+ og_shape = (seq_len, -1)
+
+ hidden_states = hidden_states.view(patch_shape)[window_index, :, :].view(
+ og_shape
+ )
+ rotary_pos_emb = rotary_pos_emb.view(patch_shape)[window_index, :, :].view(
+ og_shape
+ )
+
+ rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device)
+
+ cu_window_seqlens = torch.tensor(
+ cu_window_seqlens,
+ device="cpu",
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens).to(
+ hidden_states.device
+ )
+
+ # create a cu_seqlens tensor to be used in the attention mask
+ cu_seqlens = torch.repeat_interleave(
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
+ ).cumsum(dim=0, dtype=torch.int32)
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+ max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
+
+ # iterately apply the blocks to the hidden states
+ for layer_num, block in enumerate(self.blocks):
+ # NOTE: qwen2_5_vl.py has a concept of full attention blocks
+ # that are applied at specific layers.
+ if layer_num in self.fullatt_block_indexes:
+ cu_seqlens_now = cu_seqlens
+ else:
+ cu_seqlens_now = cu_window_seqlens
+
+ hidden_states = block(
+ hidden_states, cu_seqlens_now, rotary_pos_emb, max_seqlen
+ )
+
+ # apply the final patch merger to the hidden states
+ hidden_states = self.merger(hidden_states)
+ reverse_indices = torch.argsort(window_index)
+ hidden_states = hidden_states[reverse_indices, :]
+ return hidden_states
+
+
+class Qwen2_5VLForConditionalGeneration(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ config.vision_config.quantize = None
+ config.vision_config.speculator = config.speculator
+ # set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly
+ # returns rope_scaling.type == "default" for Qwen2_5-VL model at the moment
+ if (
+ hasattr(config, "rope_scaling")
+ and config.rope_scaling is not None
+ and config.rope_scaling.get("type", None) == "default"
+ ):
+ config.rope_scaling.update({"rope_type": "mrope"})
+ self.hidden_size = config.hidden_size
+ self.vision_start_token_id = config.vision_start_token_id
+ self.vision_end_token_id = config.vision_end_token_id
+ self.image_token_id = config.image_token_id
+ self.video_token_id = config.video_token_id
+ self.spatial_merge_size = config.vision_config.spatial_merge_size
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix="model.embed_tokens", weights=weights
+ )
+ self.visual = Qwen2_5VisionModel(
+ prefix="visual", config=config.vision_config, weights=weights
+ )
+ self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
+ if config.tie_word_embeddings:
+ suffix = "model.embed_tokens"
+ else:
+ suffix = "lm_head"
+
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix=suffix if not prefix else f"{prefix}.{suffix}",
+ weights=weights,
+ )
+ self.device = weights.device
+
+ # based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391
+ # modified to first find segments then initialize position ids for each segment
+ # Steps:
+ # locate all vision and text segments
+ # calculate `vision_segment_lengths` for each vision segment to be use as offset
+ # calculate `text_segment_lengths` for each text segment to be used as offset
+ # create position ids for each vision segment based on the image grid
+ # create position ids for each text segment
+ # combine all the position ids
+ # the final segment is the difference between the last vision segment and the end of the input
+ # combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
+ def get_position_ids(
+ self,
+ input_ids: torch.Tensor,
+ image_grid_thw: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if image_grid_thw is None:
+ return (
+ torch.arange(input_ids.shape[0], device=input_ids.device)
+ .unsqueeze(1)
+ .repeat(1, 3)
+ )
+
+ spatial_merge_size = self.spatial_merge_size
+ vision_start_token_id = self.vision_start_token_id
+ vision_end_token_id = self.vision_end_token_id
+ device = input_ids.device
+ dtype = input_ids.dtype
+ input_ids_len = input_ids.shape[0]
+
+ vision_starts = torch.where(input_ids == vision_start_token_id)[0]
+ vision_ends = torch.where(input_ids == vision_end_token_id)[0]
+ vision_segments = torch.stack((vision_starts, vision_ends), dim=1)
+ prev_vision_end = torch.cat(
+ [torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]
+ )
+ text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1
+ vision_widths_max = torch.cat(
+ [
+ torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
+ image_grid_thw[:-1, 2] // spatial_merge_size,
+ ]
+ )
+ vision_segment_lengths = vision_widths_max + text_lengths_between_vision
+ vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)
+ text_segment_lengths = vision_segment_lengths - text_lengths_between_vision
+
+ # create position ids for each vision segment based on the image grid
+ llm_pos_ids_list = []
+ for i, _ in enumerate(vision_segments):
+ t, h, w = (
+ image_grid_thw[i][0],
+ image_grid_thw[i][1] // spatial_merge_size,
+ image_grid_thw[i][2] // spatial_merge_size,
+ )
+ t_indices = torch.arange(t, device=device).repeat_interleave(h * w)
+ h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
+ w_indices = torch.arange(w, device=device).repeat(t * h)
+ image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
+
+ # offset by the position of the last vision segment
+ im = image_position_ids + vision_segment_lengths[i]
+ llm_pos_ids_list.append(im)
+
+ # create position ids for each text segment
+ text_ranges = [
+ torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
+ + text_segment_lengths[i]
+ for i, seq_len in enumerate(text_lengths_between_vision)
+ ]
+
+ full_llm_pos_ids_list = [
+ item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
+ ]
+ # import ipdb
+
+ # ipdb.set_trace()
+ max_s = full_llm_pos_ids_list[-1].max() + 1
+ final_text_len = input_ids_len - vision_ends[-1]
+ if final_text_len > 0:
+ m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
+ full_llm_pos_ids_list.append(m + max_s)
+
+ position_ids = (
+ torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
+ )
+ return position_ids
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor],
+ pixel_values: torch.FloatTensor = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ # Unused in this model
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ pixel_attention_mask=None,
+ image_sizes: Optional[torch.LongTensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ cross_attention_states: Optional[torch.Tensor] = None,
+ image_indices=None,
+ ):
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ # apply the visual model to the pixel values if they are provided
+ if pixel_values is not None and len(pixel_values) > 0:
+ if pixel_values is not None:
+ image_embeds = self.visual(
+ pixel_values, grid_thw=image_grid_thw
+ ).squeeze(0)
+ mask = torch.where(input_ids == self.image_token_id)
+ inputs_embeds[mask] = image_embeds
+
+ hidden_states = self.text_model(
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ cu_seqlen_prefill=cu_seqlen_prefill,
+ kv_cache=kv_cache,
+ slots=slots,
+ seqlen=seqlen,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.lm_head(hidden_states)
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py
new file mode 100644
index 000000000..47ae2ac94
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py
@@ -0,0 +1,519 @@
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Qwen2 VL model."""
+
+from typing import Optional, Tuple, List
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+
+from habana_frameworks.torch.hpex.kernels import FusedSDPA
+from vllm_hpu_extension.utils import ModuleFusedSDPA
+
+
+import numpy as np
+
+from transformers.activations import ACT2FN
+import torch.nn.functional as F
+
+from text_generation_server.layers.layernorm import FastLayerNorm, FastRMSNorm
+from text_generation_server.layers import (
+ TensorParallelColumnLinear,
+ TensorParallelRowLinear,
+ TensorParallelEmbedding,
+ SpeculativeHead,
+)
+from text_generation_server.layers.attention import (
+ Seqlen,
+ HPUPagedAttentionMetadata,
+)
+from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
+ Qwen2Model,
+)
+
+
+# Copied from transformers.models.llama.modeling_llama.rotate_half
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb_vision(
+ tensor: torch.Tensor, freqs: torch.Tensor
+) -> torch.Tensor:
+ orig_dtype = tensor.dtype
+ tensor = tensor.float()
+ cos = freqs.cos()
+ sin = freqs.sin()
+ cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
+ sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
+ output = output.to(orig_dtype)
+ return output
+
+
+class Qwen2VLAttention(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+ self.embed_dim = config.embed_dim // weights.process_group.size()
+ self.head_dim = config.hidden_size // config.num_heads
+ self.num_heads = config.num_heads // weights.process_group.size()
+
+ self.qkv = TensorParallelColumnLinear.load_qkv(
+ config,
+ prefix=f"{prefix}.qkv",
+ weights=weights,
+ bias=False,
+ num_heads=self.num_heads,
+ num_key_value_heads=self.num_heads,
+ )
+ self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
+ self.proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.proj",
+ weights=weights,
+ bias=True,
+ )
+ self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads)
+
+ def forward(
+ self,
+ hidden_state: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: torch.Tensor,
+ max_seqlen: int,
+ ) -> torch.Tensor:
+ # apply the qkv linear layer to the hidden state
+ qkv = self.qkv(hidden_state)
+ query, key, value = qkv.split(
+ [self.embed_dim, self.embed_dim, self.embed_dim], dim=1
+ )
+
+ # reshape the query, key, and value tensors
+ _shape = (
+ hidden_state.shape[0],
+ self.num_heads,
+ self.embed_dim // self.num_heads,
+ )
+ query = query.view(*_shape)
+ key = key.view(*_shape)
+ value = value.view(*_shape)
+
+ # apply rotary positional embeddings
+ query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
+ 0
+ )
+ key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
+
+ # calc maximum sequence length for any batch
+ query = query.contiguous()
+ key = key.contiguous()
+ value = value.contiguous()
+ causal = False
+
+ # execute sdpa
+ query = query.unsqueeze(0).transpose(1, 2)
+ key = key.unsqueeze(0).transpose(1, 2)
+ value = value.unsqueeze(0).transpose(1, 2)
+ fsdpa_op = ModuleFusedSDPA(FusedSDPA)
+ attn_output = fsdpa_op(
+ query,
+ key,
+ value,
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=causal,
+ scale=None,
+ softmax_mode="None",
+ recompute_mode=None,
+ valid_sequence_lengths=None,
+ )
+ attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
+ # reshape output to original dimensions
+ attn_output = attn_output.reshape(hidden_state.shape[0], -1)
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class Qwen2VLVisionMLP(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = TensorParallelColumnLinear.load(
+ prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True
+ )
+ self.fc2 = TensorParallelRowLinear.load(
+ prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Qwen2VLVisionBlock(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.attn = Qwen2VLAttention(
+ prefix=f"{prefix}.attn",
+ config=config,
+ weights=weights,
+ )
+ self.norm1 = FastLayerNorm.load(
+ prefix=f"{prefix}.norm1",
+ weights=weights,
+ eps=1e-6,
+ )
+ self.norm2 = FastLayerNorm.load(
+ prefix=f"{prefix}.norm2",
+ weights=weights,
+ eps=1e-6,
+ )
+ self.mlp = Qwen2VLVisionMLP(
+ prefix=f"{prefix}.mlp",
+ config=config,
+ weights=weights,
+ )
+
+ def forward(
+ self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
+ ) -> torch.Tensor:
+ norm1_out, residual = self.norm1(hidden_states)
+ attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
+ hidden_states = attn_out + residual
+ norm2_out, residual = self.norm2(hidden_states)
+ hidden_states = hidden_states + self.mlp(norm2_out)
+ return hidden_states
+
+
+class Qwen2VLPatchMerger(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+ self.hidden_size = config.embed_dim * (config.spatial_merge_size**2)
+ self.patch_merger_ln_q = FastLayerNorm.load(
+ prefix=f"{prefix}.ln_q",
+ weights=weights,
+ eps=1e-6,
+ )
+ self.fc1 = TensorParallelColumnLinear.load(
+ prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True
+ )
+ self.fc2 = TensorParallelRowLinear.load(
+ prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
+ )
+
+ def forward(self, hidden_states) -> torch.Tensor:
+ hidden_states, _ = self.patch_merger_ln_q(hidden_states)
+ hidden_states = hidden_states.view(-1, self.hidden_size)
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = F.gelu(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Qwen2VisionModel(nn.Module):
+ def __init__(self, *, prefix, config, weights):
+ super().__init__()
+ self.spatial_merge_size = config.spatial_merge_size
+ kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]
+ self.patch_embedding = nn.Conv3d(
+ in_channels=config.in_chans,
+ out_channels=config.embed_dim,
+ kernel_size=kernel_size,
+ stride=kernel_size,
+ bias=False,
+ )
+ self.patch_embedding.weight = nn.Parameter(
+ weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False
+ )
+ head_dim = config.embed_dim // config.num_heads
+ # TODO: replace with static positional embeddings once implemented
+ theta = 10000.0
+ dim = head_dim // 2
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ self.blocks = nn.ModuleList(
+ [
+ Qwen2VLVisionBlock(
+ prefix=f"{prefix}.blocks.{i}",
+ config=config,
+ weights=weights,
+ )
+ for i in range(config.depth)
+ ]
+ )
+ self.merger = Qwen2VLPatchMerger(
+ prefix=f"{prefix}.merger",
+ config=config,
+ weights=weights,
+ )
+
+ self.temporal_patch_size = config.temporal_patch_size
+ self.spatial_patch_size = config.spatial_patch_size
+ self.in_channels = config.in_channels
+ self.embed_dim = config.embed_dim
+
+ def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ batch_size, _, hidden_size = hidden_state.shape
+ class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
+ hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
+ return hidden_state
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ grid_thw: Optional[torch.LongTensor] = None,
+ ) -> torch.Tensor:
+ # reshape the input tensor for processing
+ shape = (
+ -1,
+ self.in_channels,
+ self.temporal_patch_size,
+ self.spatial_patch_size,
+ self.spatial_patch_size,
+ )
+ pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)
+ hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)
+ # TODO: revisit to see if we can avoid some of these reshapes
+
+ # find the position ids for the input tensor based on the grid_thw
+ pos_ids = []
+ for t, h, w in grid_thw:
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
+ hpos_ids = hpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
+ hpos_ids = hpos_ids.flatten()
+
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
+ wpos_ids = wpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
+ wpos_ids = wpos_ids.flatten()
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
+
+ pos_ids = torch.cat(pos_ids, dim=0)
+ max_grid_size = grid_thw[:, 1:].max()
+
+ # apply the positional embeddings to the position ids
+ seq = torch.arange(
+ max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype
+ )
+ rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
+ rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype)
+
+ # create a cu_seqlens tensor to be used in the attention mask
+ cu_seqlens = torch.repeat_interleave(
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
+ ).cumsum(dim=0, dtype=torch.int32)
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+ max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
+ # iterately apply the blocks to the hidden states
+ for block in self.blocks:
+ hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen)
+
+ # apply the final patch merger to the hidden states
+ hidden_states = self.merger(hidden_states)
+ return hidden_states
+
+
+class Qwen2VLForConditionalGeneration(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ config.vision_config.quantize = None
+ config.vision_config.speculator = config.speculator
+ # set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly
+ # returns rope_scaling.type == "default" for Qwen2-VL model at the moment
+ if (
+ hasattr(config, "rope_scaling")
+ and config.rope_scaling is not None
+ and config.rope_scaling.get("type", None) == "default"
+ ):
+ config.rope_scaling.update({"rope_type": "mrope"})
+ self.hidden_size = config.hidden_size
+ self.vision_start_token_id = config.vision_start_token_id
+ self.vision_end_token_id = config.vision_end_token_id
+ self.image_token_id = config.image_token_id
+ self.video_token_id = config.video_token_id
+ self.spatial_merge_size = config.vision_config.spatial_merge_size
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix="model.embed_tokens", weights=weights
+ )
+ self.visual = Qwen2VisionModel(
+ prefix="visual", config=config.vision_config, weights=weights
+ )
+ self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
+ if config.tie_word_embeddings:
+ suffix = "model.embed_tokens"
+ else:
+ suffix = "lm_head"
+
+ self.lm_head = SpeculativeHead.load(
+ config,
+ prefix=suffix if not prefix else f"{prefix}.{suffix}",
+ weights=weights,
+ )
+ self.norm = FastRMSNorm.load(
+ prefix="model.norm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+ self.device = weights.device
+
+ # based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391
+ # modified to first find segments then initialize position ids for each segment
+ # Steps:
+ # locate all vision and text segments
+ # calculate `vision_segment_lengths` for each vision segment to be use as offset
+ # calculate `text_segment_lengths` for each text segment to be used as offset
+ # create position ids for each vision segment based on the image grid
+ # create position ids for each text segment
+ # combine all the position ids
+ # the final segment is the difference between the last vision segment and the end of the input
+ # combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
+ def get_position_ids(
+ self,
+ input_ids: torch.Tensor,
+ image_grid_thw: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if image_grid_thw is None:
+ return (
+ torch.arange(input_ids.shape[0], device=input_ids.device)
+ .unsqueeze(1)
+ .repeat(1, 3)
+ )
+
+ spatial_merge_size = self.spatial_merge_size
+ vision_start_token_id = self.vision_start_token_id
+ vision_end_token_id = self.vision_end_token_id
+ device = input_ids.device
+ dtype = input_ids.dtype
+ input_ids_len = input_ids.shape[0]
+
+ vision_starts = torch.where(input_ids == vision_start_token_id)[0]
+ vision_ends = torch.where(input_ids == vision_end_token_id)[0]
+ vision_segments = torch.stack((vision_starts, vision_ends), dim=1)
+ prev_vision_end = torch.cat(
+ [torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]
+ )
+ text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1
+ vision_widths_max = torch.cat(
+ [
+ torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
+ image_grid_thw[:-1, 2] // spatial_merge_size,
+ ]
+ )
+ vision_segment_lengths = vision_widths_max + text_lengths_between_vision
+ vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)
+ text_segment_lengths = vision_segment_lengths - text_lengths_between_vision
+
+ # create position ids for each vision segment based on the image grid
+ llm_pos_ids_list = []
+ for i, _ in enumerate(vision_segments):
+ t, h, w = (
+ image_grid_thw[i][0],
+ image_grid_thw[i][1] // spatial_merge_size,
+ image_grid_thw[i][2] // spatial_merge_size,
+ )
+ t_indices = torch.arange(t, device=device).repeat_interleave(h * w)
+ h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
+ w_indices = torch.arange(w, device=device).repeat(t * h)
+ image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
+
+ # offset by the position of the last vision segment
+ im = image_position_ids + vision_segment_lengths[i]
+ llm_pos_ids_list.append(im)
+
+ # create position ids for each text segment
+ text_ranges = [
+ torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
+ + text_segment_lengths[i]
+ for i, seq_len in enumerate(text_lengths_between_vision)
+ ]
+
+ full_llm_pos_ids_list = [
+ item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
+ ]
+ max_s = full_llm_pos_ids_list[-1].max() + 1
+ final_text_len = input_ids_len - vision_ends[-1]
+ if final_text_len > 0:
+ m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
+ full_llm_pos_ids_list.append(m + max_s)
+
+ position_ids = (
+ torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
+ )
+ return position_ids
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ slots: torch.Tensor,
+ seqlen: Seqlen,
+ hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
+ lm_head_indices: Optional[torch.Tensor],
+ pixel_values: torch.FloatTensor = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ pixel_attention_mask=None,
+ image_sizes: Optional[torch.LongTensor] = None,
+ adapter_data: Optional[torch.Tensor] = None,
+ cross_attention_states: Optional[torch.Tensor] = None,
+ image_indices=None,
+ ):
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ # apply the visual model to the pixel values if they are provided
+ if pixel_values is not None and len(pixel_values) > 0:
+ if pixel_values is not None:
+ image_embeds = self.visual(
+ pixel_values, grid_thw=image_grid_thw
+ ).squeeze(0)
+ mask = torch.where(input_ids == self.image_token_id)
+ inputs_embeds[mask] = image_embeds
+
+ hidden_states = self.text_model(
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ cu_seqlen_prefill=cu_seqlen_prefill,
+ kv_cache=kv_cache,
+ slots=slots,
+ seqlen=seqlen,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits, speculative_logits = self.lm_head(hidden_states)
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/siglip.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/siglip.py
new file mode 100644
index 000000000..95ac9edee
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/siglip.py
@@ -0,0 +1,410 @@
+from typing import Optional, Tuple
+import warnings
+import math
+import torch
+from torch import nn
+
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPooling,
+)
+from transformers import SiglipConfig, SiglipVisionConfig
+from torch.nn.init import _calculate_fan_in_and_fan_out
+
+from text_generation_server.layers.tensor_parallel import (
+ TensorParallelEmbedding,
+ TensorParallelColumnLinear,
+ TensorParallelRowLinear,
+)
+
+
+class SiglipVisionEmbeddings(nn.Module):
+ def __init__(self, prefix, config: SiglipVisionConfig, weights):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding="valid",
+ )
+ self.patch_embedding.weight = nn.Parameter(
+ weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
+ )
+ self.patch_embedding.bias = nn.Parameter(
+ weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
+ )
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches
+ self.position_embedding = TensorParallelEmbedding(
+ prefix=f"{prefix}.position_embedding", weights=weights
+ )
+ self.register_buffer(
+ "position_ids",
+ torch.arange(self.num_positions, device=weights.device).expand((1, -1)),
+ persistent=False,
+ )
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ patch_embeds = self.patch_embedding(
+ pixel_values
+ ) # shape = [*, width, grid, grid]
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
+
+ embeddings = embeddings + self.position_embedding(self.position_ids)
+ return embeddings
+
+
+class SiglipAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ self.head_size = self.head_dim
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.embed_dim = self.embed_dim // weights.process_group.size()
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+
+ self.k_proj = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.k_proj", weights=weights, bias=True
+ )
+ self.v_proj = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.v_proj", weights=weights, bias=True
+ )
+ self.q_proj = TensorParallelColumnLinear.load(
+ config, prefix=f"{prefix}.q_proj", weights=weights, bias=True
+ )
+ self.out_proj = TensorParallelRowLinear.load(
+ config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
+ )
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return (
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ .contiguous()
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ bsz, tgt_len, _ = hidden_states.size()
+ query_states = self.q_proj(hidden_states)
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ # scale post matmul
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) * self.scale
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = (
+ attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ + attention_mask
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(
+ attn_weights, dim=-1, dtype=torch.float32
+ ).to(attn_weights.dtype)
+ attn_weights = nn.functional.dropout(
+ attn_weights, p=self.dropout, training=self.training
+ )
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class SiglipMLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = TensorParallelColumnLinear.load( # config.hidden_size, config.intermediate_size
+ prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
+ )
+ self.fc2 = TensorParallelRowLinear.load( # config.intermediate_size, config.hidden_size
+ prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class SiglipEncoderLayer(nn.Module):
+ def __init__(self, prefix, config: SiglipConfig, weights):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = SiglipAttention(
+ prefix=f"{prefix}.self_attn", config=config, weights=weights
+ )
+ self.layer_norm1 = nn.LayerNorm.load(
+ prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
+ )
+ self.mlp = SiglipMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
+ self.layer_norm2 = nn.LayerNorm.load(
+ prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ ) -> Tuple[torch.FloatTensor]:
+ residual = hidden_states
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ )
+ hidden_states = residual + hidden_states
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states, None
+
+
+class SiglipMultiheadAttentionPoolingHead(nn.Module):
+ """Multihead Attention Pooling."""
+
+ def __init__(self, prefix, config: SiglipVisionConfig, weights):
+ super().__init__()
+
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+ self.attention = torch.nn.MultiheadAttention(
+ config.hidden_size, config.num_attention_heads, batch_first=True
+ )
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.mlp = SiglipMLP(prefix, config, weights)
+
+ def forward(self, hidden_state):
+ batch_size = hidden_state.shape[0]
+ probe = self.probe.repeat(batch_size, 1, 1)
+
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
+
+ residual = hidden_state
+ hidden_state = self.layernorm(hidden_state)
+ hidden_state = residual + self.mlp(hidden_state)
+
+ return hidden_state[:, 0]
+
+
+def _trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2,
+ )
+
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ lower = norm_cdf((a - mean) / std)
+ upper = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * lower - 1, 2 * upper - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.0))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+
+
+def trunc_normal_tf_(
+ tensor: torch.Tensor,
+ mean: float = 0.0,
+ std: float = 1.0,
+ a: float = -2.0,
+ b: float = 2.0,
+) -> torch.Tensor:
+ """Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \\leq \text{mean} \\leq b`.
+
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
+ and the result is subsquently scaled and shifted by the mean and std args.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ """
+ with torch.no_grad():
+ _trunc_normal_(tensor, 0, 1.0, a, b)
+ tensor.mul_(std).add_(mean)
+
+
+def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
+ if mode == "fan_in":
+ denom = fan_in
+ elif mode == "fan_out":
+ denom = fan_out
+ elif mode == "fan_avg":
+ denom = (fan_in + fan_out) / 2
+
+ variance = scale / denom
+
+ if distribution == "truncated_normal":
+ # constant is stddev of standard normal truncated to (-2, 2)
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
+ elif distribution == "normal":
+ with torch.no_grad():
+ tensor.normal_(std=math.sqrt(variance))
+ elif distribution == "uniform":
+ bound = math.sqrt(3 * variance)
+ with torch.no_grad():
+ tensor.uniform_(-bound, bound)
+ else:
+ raise ValueError(f"invalid distribution {distribution}")
+
+
+def lecun_normal_(tensor):
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
+
+
+def default_flax_embed_init(tensor):
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
+
+
+class SiglipEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`SiglipEncoderLayer`].
+
+ Args:
+ config: SiglipConfig
+ """
+
+ def __init__(self, prefix, config: SiglipConfig, weights):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList(
+ [
+ SiglipEncoderLayer(
+ prefix=f"{prefix}.layers.{i}", config=config, weights=weights
+ )
+ for i in range(config.num_hidden_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ hidden_states = inputs_embeds
+ for idx, encoder_layer in enumerate(self.layers):
+ hidden_states, _ = encoder_layer(
+ hidden_states,
+ attention_mask,
+ )
+
+ return hidden_states
+
+
+class SiglipVisionTransformer(nn.Module):
+ def __init__(self, prefix, config: SiglipVisionConfig, weights):
+ super().__init__()
+ self.config = config
+
+ self.embeddings = SiglipVisionEmbeddings(
+ prefix=f"{prefix}.embeddings", config=config, weights=weights
+ )
+ self.encoder = SiglipEncoder(
+ prefix=f"{prefix}.encoder", config=config, weights=weights
+ )
+
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ ):
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ hidden_states = self.embeddings(pixel_values)
+
+ # NOTE: up until this point, the code logits are exactly
+ # the same as the transformers code. The values evaulate
+ # slightly differently in our encoder layer.
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ )
+ last_hidden_state = encoder_outputs
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ # pooler_output=pooled_output,
+ # hidden_states=encoder_outputs,
+ )
diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py
new file mode 100644
index 000000000..ae704af31
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py
@@ -0,0 +1,54 @@
+def load_text_model(prefix, config, weights, name=None):
+ if config.model_type == "llama":
+ from text_generation_server.models.custom_modeling.flash_llama_modeling import (
+ FlashLlamaForCausalLM,
+ )
+
+ return FlashLlamaForCausalLM(prefix, config, weights, name=name)
+ elif config.model_type == "mistral":
+ from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
+ FlashMistralForCausalLM,
+ )
+
+ return FlashMistralForCausalLM(prefix, config, weights, name=name)
+ elif config.model_type == "gemma":
+ from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
+ FlashGemmaForCausalLM,
+ )
+
+ return FlashGemmaForCausalLM(prefix, config, weights)
+ elif config.model_type == "gemma2":
+ from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
+ FlashGemma2ForCausalLM,
+ )
+
+ return FlashGemma2ForCausalLM(prefix, config, weights)
+ elif config.model_type == "paligemma":
+ from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
+ FlashGemmaForCausalLM,
+ )
+
+ return FlashGemmaForCausalLM(prefix, config, weights)
+ else:
+ raise RuntimeError(f"Unsupported model type {config.model_type}")
+
+
+def load_vision_model(prefix, config, weights):
+ if config.model_type == "clip_vision_model":
+ from text_generation_server.models.custom_modeling.clip import (
+ CLIPVisionTransformer,
+ )
+
+ return CLIPVisionTransformer(
+ prefix=f"{prefix}.vision_model", config=config, weights=weights
+ )
+ if config.model_type == "siglip_vision_model":
+ from text_generation_server.models.custom_modeling.siglip import (
+ SiglipVisionTransformer,
+ )
+
+ return SiglipVisionTransformer(
+ prefix="vision_tower.vision_model", config=config, weights=weights
+ )
+ else:
+ raise RuntimeError(f"Unsupported model type {config.model_type}")
diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py
new file mode 100644
index 000000000..a4d58596b
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py
@@ -0,0 +1,2181 @@
+import math
+import os
+import time
+import torch
+import torch.distributed
+
+import numpy as np
+
+from loguru import logger
+from dataclasses import dataclass
+from opentelemetry import trace
+from transformers import (
+ PreTrainedTokenizerBase,
+ AutoConfig,
+ AutoTokenizer,
+ GenerationConfig,
+)
+from typing import (
+ Any,
+ Iterable,
+ Optional,
+ Tuple,
+ List,
+ Type,
+ Dict,
+ Union,
+)
+import torch.nn.functional as F
+from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
+from text_generation_server.utils.chunks import concat_text_chunks
+from text_generation_server.models import Model
+from text_generation_server.utils.log import log_master
+from text_generation_server.utils.tokens import batch_top_tokens
+from text_generation_server.utils.speculate import get_speculate
+from text_generation_server.utils import (
+ initialize_torch_distributed,
+ weight_files,
+ Weights,
+)
+from text_generation_server.models.types import (
+ Batch,
+ Tokens,
+ Generation,
+ GeneratedText,
+)
+from text_generation_server.pb import generate_pb2
+from text_generation_server.models.globals import (
+ BLOCK_SIZE,
+ REQUEST_LOGPROBS,
+ TGI_WIGGLE_ROOM,
+ get_adapter_to_index,
+)
+from text_generation_server.layers.attention import (
+ KVCache,
+ Seqlen,
+ HPUPagedAttentionMetadata,
+ trim_attn_metadata,
+ trim_seqlen_metadata,
+)
+from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
+from text_generation_server.utils.dist import MEMORY_FRACTION
+from text_generation_server.utils.quantization import get_loader
+from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
+from text_generation_server.utils.import_utils import (
+ empty_cache,
+ synchronize,
+ get_free_memory,
+)
+
+import vllm_hpu_extension.environment as environment
+import habana_frameworks.torch as htorch
+import itertools
+from vllm_hpu_extension.ops import batch2block, block2batch
+
+tracer = trace.get_tracer(__name__)
+
+# Will be set in init
+SLIDING_WINDOW: Optional[int] = None
+
+
+def set_sliding_window(sliding_window: int):
+ global SLIDING_WINDOW
+ SLIDING_WINDOW = sliding_window
+
+
+def get_sliding_windows() -> int:
+ global SLIDING_WINDOW
+ return SLIDING_WINDOW
+
+
+def prepare_for_decode(
+ dtype, use_contiguous_pa, device, slot, block_tables, batch_size
+):
+ # Prepare values if we need to continue decoding
+ # need for HPUPagedAttentionMetadata preparation
+ def flatten(in_list):
+ return list(itertools.chain(*in_list))
+
+ def gather_list(input, indices, v):
+ return [input[i] if i is not None else v for i in indices]
+
+ def pad_list(input, k, v):
+ input_len = len(input)
+ target_len = (input_len + k - 1) // k * k
+ padding = target_len - input_len
+ return input + [v] * padding
+
+ last_block_usage = slot % BLOCK_SIZE + 1
+ block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)]
+ block_usage = [
+ [BLOCK_SIZE] * (len(bt) - 1) + [lbu]
+ for bt, lbu in zip(block_tables, last_block_usage)
+ if bt
+ ]
+
+ block_list = flatten(block_tables)
+ block_groups = flatten(block_groups)
+ block_usage = flatten(block_usage)
+ assert len(block_list) == len(block_groups)
+ assert len(block_list) == len(block_usage)
+ if use_contiguous_pa:
+ block_bucket_size = max(max(block_list) + 1, len(block_list))
+ # block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks(
+ # block_bucket_size)
+ indices: List[Any]
+ indices = [None] * block_bucket_size
+ for i, bid in enumerate(block_list):
+ indices[bid] = i
+ block_list = gather_list(block_list, indices, 0)
+ block_groups = gather_list(block_groups, indices, -1)
+ block_usage = gather_list(block_usage, indices, 1)
+ else:
+ block_bucket_size = len(block_list)
+ block_list = pad_list(block_list, block_bucket_size, 0)
+ block_groups = pad_list(block_groups, block_bucket_size, -1)
+ block_usage = pad_list(block_usage, block_bucket_size, 1)
+
+ block_list = torch.tensor(block_list, dtype=torch.int, device=device)
+ block_groups = torch.tensor(block_groups, dtype=torch.int, device=device)
+ block_usage = torch.tensor(block_usage, dtype=dtype, device=device)
+ block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch_size)
+ mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
+ mask = mask >= block_usage.unsqueeze(-1)
+ attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
+ ones = torch.ones(
+ (block_mapping.size(0),), device=device, dtype=block_mapping.dtype
+ )
+ sums = batch2block(block2batch(ones, block_mapping), block_mapping)
+ block_scales = torch.reciprocal(torch.maximum(ones, sums))
+ return trim_attn_metadata(
+ HPUPagedAttentionMetadata(
+ block_list=block_list,
+ block_groups=block_groups,
+ block_usage=block_usage,
+ block_mapping=block_mapping.to(dtype),
+ attn_bias=attn_bias,
+ block_scales=block_scales,
+ )
+ )
+
+
+@dataclass
+class FlashCausalLMBatch(Batch):
+ batch_id: int
+ requests: List[generate_pb2.Request]
+ # request id -> idx in list mapping
+ requests_idx_mapping: Dict[int, int]
+
+ # Decoder values
+ # Can be a list for easy filtering
+ # If `input_ids` is a list, it needs to be materialized to a tensor first
+ input_ids: Union[torch.Tensor, List[List[int]]]
+ # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
+ position_ids: Optional[torch.Tensor]
+ speculative_ids: Optional[torch.Tensor]
+
+ # Set when creating the batch
+ # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
+ # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
+ slot_indices: Optional[torch.Tensor]
+
+ # list of length b of list of length s_i // block_size
+ block_tables: List[List[int]]
+ # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
+ block_tables_tensor: torch.Tensor
+ # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
+ slots: torch.Tensor
+ # list of length b + 1 containing the cumulative sequence slot lengths of the sequences in the batch
+ # used for filtering
+ cu_slots: torch.Tensor
+
+ max_input_length: int
+ max_current_length: int
+
+ # Whether this batch contains at least one request that is prefilling
+ prefilling: bool
+ # Whether each request is prefilling
+ prefilling_mask: List[bool]
+
+ # Prefill metadata tensors to efficiently compute logprobs
+ # tensor of length b + 1 containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
+ cu_seqlen_prefill: Optional[torch.Tensor]
+ # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
+ # as we only keep SLIDING_WINDOW values instead of the whole tensor
+ prefill_cache_indices: Optional[torch.Tensor]
+ # Will be set by `generate_token` and reset after each prefill forward
+ prefill_head_indices: Optional[torch.Tensor]
+ # Will be set by `generate_token` and reset after each prefill forward
+ prefill_next_token_indices: Optional[torch.tensor]
+ # Will be set by `generate_token` and reset after each prefill forward
+ prefill_cu_outlens: Optional[List[int]]
+ # Will be set by `generate_token` and reset after each prefill forward
+ prefill_logprob_tokens: List[Optional[Tokens]]
+
+ # All tokens
+ all_input_ids: List[List[int]]
+ all_input_ids_tensor: torch.Tensor
+
+ # Lengths of all generations present in the batch
+ input_lengths: List[int]
+ # size [b], containing the number of blocks that can be retrieved from the cache
+ cache_lengths: List[int]
+ prompt_lengths: List[int]
+ # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
+ input_lengths_tensor: Optional[torch.Tensor]
+ cache_lengths_tensor: Optional[torch.Tensor]
+ prompt_lengths_tensor: torch.Tensor
+
+ prefix_offsets: List[Optional[int]]
+ read_offsets: List[Optional[int]]
+
+ # Generation helpers
+ next_token_chooser: HeterogeneousNextTokenChooser
+ stopping_criterias: List[StoppingCriteria]
+ top_n_tokens: List[int]
+ top_n_tokens_tensor: torch.Tensor
+
+ # Adapter metadata for each request
+ # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
+ adapter_meta: Optional[AdapterBatchMetadata]
+
+ # Number of blocks in this batch
+ num_blocks: int
+ # Maximum number of blocks
+ max_blocks: int
+
+ hpu_attn_meta: Optional[HPUPagedAttentionMetadata]
+
+ def to_pb(self) -> generate_pb2.CachedBatch:
+ return generate_pb2.CachedBatch(
+ id=self.batch_id,
+ request_ids=[r.id for r in self.requests],
+ size=len(self),
+ max_tokens=self.num_blocks * BLOCK_SIZE,
+ current_tokens=(
+ sum([len(i) for i in self.input_ids])
+ if isinstance(self.input_ids, list)
+ else len(self.input_ids)
+ ),
+ )
+
+ @classmethod
+ def batch_tokenized_inputs(
+ cls, requests: Iterable[generate_pb2.Request], tokenizer
+ ):
+ max_length = 0
+ all_input_ids = []
+ batch_size = 0
+ for r in requests:
+ batch_size += 1
+ inputs = concat_text_chunks(r.input_chunks.chunks)
+ input_ids = tokenizer(
+ inputs,
+ truncation=True,
+ max_length=r.truncate,
+ add_special_tokens=r.add_special_tokens,
+ )["input_ids"]
+ max_length = max(max_length, len(input_ids))
+ all_input_ids.append(input_ids)
+ return all_input_ids
+
+ @classmethod
+ def from_tokenized(
+ cls,
+ pb: generate_pb2.Batch,
+ tokenizer: PreTrainedTokenizerBase,
+ batch_tokenized_inputs,
+ dtype: torch.dtype,
+ device: torch.device,
+ ) -> "FlashCausalLMBatch":
+ cache_lengths = []
+ input_lengths = []
+ prompt_lengths = []
+ prefix_offsets = []
+ read_offsets = []
+ all_input_ids = []
+ all_postfix_ids = []
+ requests_idx_mapping = {}
+ slots = []
+ cu_slots = [0]
+
+ next_token_chooser_parameters = []
+ stopping_criterias = []
+ top_n_tokens = []
+
+ num_blocks = 0
+ max_input_length = 0
+ max_current_length = 0
+ max_length = 0
+ max_blocks = 0
+
+ cu_blocks = [0]
+ block_tables = []
+ block_tables_ragged = []
+
+ # Parse batch
+ for i, (r, tokenized_input) in enumerate(
+ zip(pb.requests, batch_tokenized_inputs)
+ ):
+ ### XXX: This consumes so much memory on long requests
+ ### Deactivating it by default seems like the best course.
+ if not REQUEST_LOGPROBS:
+ r.prefill_logprobs = False
+ # request id -> idx in list mapping
+ requests_idx_mapping[r.id] = i
+
+ prompt_length = len(tokenized_input)
+ prompt_lengths.append(prompt_length)
+
+ cache_length = r.cache_len
+
+ assert (
+ cache_length <= prompt_length
+ ), f"Prefix {cache_length} vs input {prompt_length}"
+ if cache_length == prompt_length:
+ assert False, "unreachable"
+
+ # `chunk_len` is an optional field in the protobuf
+ # It is only set if the model support chunking
+ # Use all the remaining ids
+ postfix_ids = tokenized_input[cache_length:]
+ input_length = len(postfix_ids)
+
+ input_lengths.append(input_length)
+
+ prefix_offsets.append(prompt_length - 5)
+ read_offsets.append(prompt_length)
+
+ all_postfix_ids.append(postfix_ids)
+ all_input_ids.append(tokenized_input)
+
+ next_token_chooser_parameters.append(r.parameters)
+
+ stopping_criteria = StoppingCriteria.from_pb(
+ r.stopping_parameters, tokenizer
+ )
+ max_new_tokens = stopping_criteria.max_new_tokens
+ stopping_criterias.append(stopping_criteria)
+ top_n_tokens.append(r.top_n_tokens)
+
+ # Paged attention
+ # Remove one as the first token des not have a past
+ speculative_length = get_speculate()
+ speculative_length = 0 if speculative_length is None else speculative_length
+
+ # Tokens that need to be mapped to blocks.
+ block_tokens = prompt_length + max_new_tokens - 1 + speculative_length
+
+ # blocks and slots can be empty (for example in warmup)
+ if not r.blocks:
+ needed_blocks = math.ceil(block_tokens / BLOCK_SIZE)
+ request_blocks = [
+ b for b in range(num_blocks, num_blocks + needed_blocks)
+ ]
+ request_slots = [
+ s
+ for b in request_blocks
+ for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
+ ]
+ else:
+ request_blocks = r.blocks
+ request_slots = r.slots
+
+ block_tables.append(request_blocks)
+ block_tables_ragged.extend(request_blocks)
+ cu_blocks.append(len(block_tables_ragged))
+
+ slots.extend(request_slots)
+ cu_slots.append(len(slots))
+
+ cache_lengths.append(cache_length)
+ num_blocks += len(request_blocks)
+
+ # Update
+ max_blocks = max(max_blocks, len(request_blocks))
+ max_input_length = max(max_input_length, input_length)
+ max_current_length = max(max_current_length, cache_length + input_length)
+ max_length = max(
+ max_length,
+ prompt_length + max_new_tokens + speculative_length,
+ )
+
+ next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
+ next_token_chooser_parameters, dtype, device, tokenizer
+ )
+
+ # Padded all_input_ids_tensor
+ all_input_ids_tensor = np.zeros(
+ (len(all_input_ids), max_length), dtype=np.int64
+ )
+ for i, input_ids in enumerate(all_input_ids):
+ all_input_ids_tensor[i, : len(input_ids)] = input_ids
+
+ # Create tensors on device
+ all_input_ids_tensor = torch.tensor(
+ all_input_ids_tensor, dtype=torch.int64, device=device
+ )
+
+ top_n_tokens_tensor = torch.tensor(
+ top_n_tokens, device=device, dtype=torch.int64
+ )
+
+ block_tables_ragged = torch.tensor(
+ block_tables_ragged, device=device, dtype=torch.int32
+ )
+ cu_blocks = torch.tensor(cu_blocks, device=device, dtype=torch.int64)
+ block_tables_tensor = torch.empty(
+ (len(block_tables), max_blocks),
+ device=device,
+ dtype=torch.int32,
+ )
+
+ for i, request_blocks in enumerate(block_tables):
+ block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
+
+ prompt_lengths_tensor = torch.tensor(
+ prompt_lengths, dtype=torch.int32, device=device
+ )
+
+ slots = torch.tensor(slots, dtype=torch.int64, device=device)
+ cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
+
+ return cls(
+ batch_id=pb.id,
+ requests=pb.requests,
+ requests_idx_mapping=requests_idx_mapping,
+ input_ids=all_postfix_ids,
+ block_tables=block_tables,
+ block_tables_tensor=block_tables_tensor,
+ cache_lengths=cache_lengths,
+ max_input_length=max_input_length,
+ max_current_length=max_current_length,
+ prefilling=True,
+ prefilling_mask=[True] * len(pb.requests),
+ prefill_logprob_tokens=[None] * len(pb.requests),
+ input_lengths=input_lengths,
+ prompt_lengths=prompt_lengths,
+ prefix_offsets=prefix_offsets,
+ read_offsets=read_offsets,
+ all_input_ids=all_input_ids,
+ all_input_ids_tensor=all_input_ids_tensor,
+ next_token_chooser=next_token_chooser,
+ stopping_criterias=stopping_criterias,
+ top_n_tokens=top_n_tokens,
+ top_n_tokens_tensor=top_n_tokens_tensor,
+ num_blocks=num_blocks,
+ max_blocks=max_blocks,
+ speculative_ids=None,
+ prompt_lengths_tensor=prompt_lengths_tensor,
+ # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
+ position_ids=None,
+ cu_seqlen_prefill=None,
+ prefill_cache_indices=None,
+ slot_indices=None,
+ slots=slots,
+ cu_slots=cu_slots,
+ prefill_head_indices=None,
+ prefill_next_token_indices=None,
+ prefill_cu_outlens=None,
+ cache_lengths_tensor=None,
+ input_lengths_tensor=None,
+ adapter_meta=None,
+ hpu_attn_meta=None,
+ )
+
+ @classmethod
+ def from_pb(
+ cls,
+ pb: generate_pb2.Batch,
+ tokenizer: PreTrainedTokenizerBase,
+ dtype: torch.dtype,
+ device: torch.device,
+ ) -> "FlashCausalLMBatch":
+ assert len(pb.requests) > 0
+ batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
+ return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
+
+ @tracer.start_as_current_span("filter")
+ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
+ if len(request_ids) == 0:
+ raise ValueError("Batch must have at least one request")
+ # We assume that if len(requests) == len(self) then the requests are the same
+ if len(request_ids) == len(self):
+ return self
+
+ device = self.block_tables_tensor.device
+
+ # New values after filtering
+ requests_idx_mapping = {}
+
+ # Used to index into tensors
+ indices = []
+
+ # slots to keep after filtering
+ slot_filtering_indices = torch.zeros(
+ self.slots.shape[0], dtype=torch.bool, device=device
+ )
+
+ # Create on CPU to only move to GPU once instead of at every copy
+ slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
+ max_input_length = 0
+ max_current_length = 0
+
+ requests = []
+ block_tables = []
+ all_input_ids = []
+ input_ids = []
+
+ prompt_lengths = []
+ input_lengths = []
+ cache_lengths = []
+ prefix_offsets = []
+ read_offsets = []
+ cu_slots = [0]
+
+ prefilling_mask = []
+ prefill_logprob_tokens = []
+
+ stopping_criterias = []
+ top_n_tokens = []
+ adapter_set = set()
+
+ num_blocks = 0
+ max_blocks = 0
+ max_slots = 0
+ cumulative_slot_tokens = 0
+
+ for i, request_id in enumerate(request_ids):
+ idx = self.requests_idx_mapping[request_id]
+ indices.append(idx)
+ requests_idx_mapping[request_id] = i
+
+ requests.append(self.requests[idx])
+
+ # Prefilling
+ request_prefilling = self.prefilling_mask[idx]
+ prefilling_mask.append(request_prefilling)
+
+ # Get length
+ request_input_length = self.input_lengths[idx]
+ request_cache_length = self.cache_lengths[idx]
+ max_input_length = max(max_input_length, request_input_length)
+ max_current_length = max(
+ max_current_length, request_cache_length + request_input_length
+ )
+
+ all_input_ids.append(self.all_input_ids[idx])
+
+ prompt_lengths.append(self.prompt_lengths[idx])
+ input_lengths.append(request_input_length)
+ cache_lengths.append(request_cache_length)
+ prefix_offsets.append(self.prefix_offsets[idx])
+ read_offsets.append(self.read_offsets[idx])
+
+ stopping_criteria = self.stopping_criterias[idx]
+ stopping_criterias.append(stopping_criteria)
+
+ top_n_tokens.append(self.top_n_tokens[idx])
+ prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx])
+
+ ADAPTER_TO_INDEX = get_adapter_to_index()
+ adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
+ adapter_set.add(adapter_index)
+
+ request_block_table = self.block_tables[idx]
+ num_blocks += len(request_block_table)
+ block_tables.append(request_block_table)
+
+ start_slot = self.cu_slots[idx]
+ end_slot = self.cu_slots[idx + 1]
+ slot_length = end_slot - start_slot
+
+ # Set slice
+ slot_filtering_indices[start_slot:end_slot] = True
+
+ cu_slots.append(cumulative_slot_tokens + slot_length)
+
+ # Input ids if the request was part of a prefilling batch
+ # If the batch was decoding we can index into the tensor directly later
+ if self.prefilling:
+ input_ids.append(self.input_ids[idx])
+ else:
+ # Copy to tensor (CPU)
+ slot_indices[i] = cumulative_slot_tokens + request_cache_length
+
+ cumulative_slot_tokens += slot_length
+ max_blocks = max(max_blocks, len(request_block_table))
+ max_slots = max(max_slots, slot_length)
+
+ all_input_ids_tensor = self.all_input_ids_tensor[indices]
+ block_tables_tensor = self.block_tables_tensor[indices]
+ next_token_chooser = self.next_token_chooser.filter(indices)
+ top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
+ speculative_ids = (
+ self.speculative_ids[indices] if self.speculative_ids is not None else None
+ )
+ prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
+
+ cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
+
+ slots = self.slots[slot_filtering_indices]
+
+ if self.prefilling:
+ # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
+ position_ids = None
+ slot_indices = None
+ cache_lengths_tensor = None
+ input_lengths_tensor = None
+ adapter_meta = None
+ else:
+ # Index into tensors
+ input_ids = self.input_ids[indices]
+ position_ids = self.position_ids[indices]
+ adapter_indices = self.adapter_meta.adapter_indices[indices]
+ input_lengths_tensor = self.input_lengths_tensor[indices]
+ cache_lengths_tensor = self.cache_lengths_tensor[indices]
+
+ # Move to GPU now that we have the whole tensor
+ slot_indices = slot_indices.to(device)
+
+ adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
+ adapter_segments = torch.tensor(
+ adapter_segments, dtype=torch.int32, device=device
+ )
+ adapter_meta = AdapterBatchMetadata(
+ adapter_indices=adapter_indices,
+ adapter_set=adapter_set,
+ adapter_segments=adapter_segments,
+ segment_indices=adapter_segment_indices,
+ )
+
+ return type(self)(
+ batch_id=self.batch_id,
+ requests=requests,
+ requests_idx_mapping=requests_idx_mapping,
+ input_ids=input_ids,
+ position_ids=position_ids,
+ cu_seqlen_prefill=None,
+ prefill_cache_indices=None,
+ slot_indices=slot_indices,
+ block_tables=block_tables,
+ block_tables_tensor=block_tables_tensor,
+ slots=slots,
+ cu_slots=cu_slots,
+ max_input_length=max_input_length,
+ max_current_length=max_current_length,
+ prefilling=self.prefilling,
+ prefilling_mask=prefilling_mask,
+ prefill_head_indices=None,
+ prefill_next_token_indices=None,
+ prefill_cu_outlens=None,
+ prefill_logprob_tokens=prefill_logprob_tokens,
+ prompt_lengths=prompt_lengths,
+ prompt_lengths_tensor=prompt_lengths_tensor,
+ input_lengths=input_lengths,
+ input_lengths_tensor=input_lengths_tensor,
+ cache_lengths=cache_lengths,
+ cache_lengths_tensor=cache_lengths_tensor,
+ prefix_offsets=prefix_offsets,
+ read_offsets=read_offsets,
+ all_input_ids=all_input_ids,
+ all_input_ids_tensor=all_input_ids_tensor,
+ next_token_chooser=next_token_chooser,
+ stopping_criterias=stopping_criterias,
+ top_n_tokens=top_n_tokens,
+ top_n_tokens_tensor=top_n_tokens_tensor,
+ num_blocks=num_blocks,
+ max_blocks=max_blocks,
+ speculative_ids=speculative_ids,
+ adapter_meta=adapter_meta,
+ hpu_attn_meta=None,
+ )
+
+ @classmethod
+ @tracer.start_as_current_span("concatenate")
+ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
+ # Batch attributes
+ requests = []
+ requests_idx_mapping = {}
+
+ prefilling = False
+ num_blocks = 0
+ total_batch_size = 0
+ total_slots = 0
+ max_blocks = 0
+ max_length = 0
+ max_input_length = 0
+ max_current_length = 0
+ for b in batches:
+ total_batch_size += len(b)
+ max_blocks = max(max_blocks, b.max_blocks)
+ total_slots += len(b.slots)
+ num_blocks += b.num_blocks
+ speculative_length = (
+ b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
+ )
+ max_input_length = max(max_input_length, b.max_input_length)
+ max_current_length = max(max_current_length, b.max_current_length)
+ max_length = max(
+ max_length,
+ max(
+ prompt_length
+ + stopping_criteria.max_new_tokens
+ + speculative_length
+ for prompt_length, stopping_criteria in zip(
+ b.prompt_lengths, b.stopping_criterias
+ )
+ ),
+ )
+ prefilling = prefilling or b.prefilling
+
+ slots = batches[0].slots.new_empty(total_slots)
+ cu_slots = torch.zeros(total_batch_size + 1, dtype=torch.int64)
+ if prefilling:
+ input_ids = []
+ # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
+ position_ids = None
+ slot_indices = None
+ cache_lengths_tensor = None
+ input_lengths_tensor = None
+ adapter_meta = None
+ adapter_segment_builder = None
+ else:
+ input_ids = batches[0].input_ids.new_empty(total_batch_size)
+ if (
+ batches[0].position_ids is not None
+ and batches[0].position_ids.dim() == 2
+ ):
+ # Qwen2_vl case:
+ position_ids = batches[0].position_ids.new_empty(
+ (total_batch_size, batches[0].position_ids.shape[-1])
+ )
+ else:
+ position_ids = batches[0].position_ids.new_empty(total_batch_size)
+ slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
+ input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
+ total_batch_size
+ )
+ cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(
+ total_batch_size
+ )
+ total_indices_size = sum(
+ b.adapter_meta.adapter_indices.shape[0] for b in batches
+ )
+ adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
+ total_indices_size
+ )
+ adapter_segment_builder = SegmentConcatBuilder()
+ adapter_set = set()
+
+ prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(
+ total_batch_size
+ )
+ block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
+ (total_batch_size, max_blocks)
+ )
+ all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
+ (total_batch_size, max_length)
+ )
+ top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
+ total_batch_size,
+ )
+
+ block_tables = []
+ cache_lengths = []
+ all_input_ids = []
+
+ prompt_lengths = []
+ input_lengths = []
+ prefix_offsets = []
+ read_offsets = []
+
+ prefill_logprob_tokens = []
+
+ next_token_chooser_parameters = []
+ fsm_grammar_states = []
+ stopping_criterias = []
+ top_n_tokens = []
+ prefilling_mask = []
+
+ # Cumulative length
+ cumulative_batch_size = 0
+ cumulative_slots = 0
+ cumulative_adapter_indices_size = 0
+
+ for i, batch in enumerate(batches):
+ requests.extend(batch.requests)
+
+ if i == 0:
+ requests_idx_mapping = batch.requests_idx_mapping
+ else:
+ # We need to offset the mapping for each batch by the cumulative batch size
+ for k, v in batch.requests_idx_mapping.items():
+ requests_idx_mapping[k] = v + cumulative_batch_size
+
+ start_index = cumulative_batch_size
+ end_index = cumulative_batch_size + len(batch)
+
+ # Copy tensors (GPU)
+ top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
+ all_input_ids_tensor[
+ start_index:end_index, : batch.all_input_ids_tensor.shape[1]
+ ] = batch.all_input_ids_tensor[:, :max_length]
+
+ block_tables_tensor[
+ start_index:end_index, : batch.block_tables_tensor.shape[1]
+ ] = batch.block_tables_tensor[:, :max_blocks]
+ prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
+
+ slots_start_index = cumulative_slots
+ slots_end_index = cumulative_slots + len(batch.slots)
+ slots[slots_start_index:slots_end_index] = batch.slots
+ cu_slots[start_index + 1 : end_index + 1] = (
+ batch.cu_slots[1:] + cumulative_slots
+ )
+
+ if not prefilling:
+ input_ids[start_index:end_index] = batch.input_ids
+ position_ids[start_index:end_index] = batch.position_ids
+ slot_indices[start_index:end_index] = (
+ batch.slot_indices + cumulative_slots
+ )
+ input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
+ cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor
+
+ # Copy over adapter indices
+ adapter_start_index = cumulative_adapter_indices_size
+ adapter_end_index = (
+ cumulative_adapter_indices_size
+ + batch.adapter_meta.adapter_indices.shape[0]
+ )
+ adapter_indices[adapter_start_index:adapter_end_index] = (
+ batch.adapter_meta.adapter_indices
+ )
+ cumulative_adapter_indices_size = adapter_end_index
+ adapter_set.update(batch.adapter_meta.adapter_set)
+ adapter_segment_builder.concat(
+ batch.adapter_meta.adapter_segments,
+ batch.adapter_meta.segment_indices,
+ )
+ else:
+ if isinstance(batch.input_ids, torch.Tensor):
+ batch.input_ids = batch.input_ids.view(-1, 1).tolist()
+ input_ids.extend(batch.input_ids)
+
+ prefilling_mask.extend(batch.prefilling_mask)
+ block_tables.extend(batch.block_tables)
+ cache_lengths.extend(batch.cache_lengths)
+ all_input_ids.extend(batch.all_input_ids)
+
+ prompt_lengths.extend(batch.prompt_lengths)
+ input_lengths.extend(batch.input_lengths)
+ prefix_offsets.extend(batch.prefix_offsets)
+ read_offsets.extend(batch.read_offsets)
+
+ prefill_logprob_tokens.extend(batch.prefill_logprob_tokens)
+
+ next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
+ fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
+ stopping_criterias.extend(batch.stopping_criterias)
+
+ top_n_tokens.extend(batch.top_n_tokens)
+
+ # Update
+ cumulative_slots += len(batch.slots)
+ cumulative_batch_size += len(batch)
+
+ next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
+ next_token_chooser_parameters,
+ dtype=batches[0].next_token_chooser.dtype,
+ device=batches[0].next_token_chooser.device,
+ tokenizer=batches[0].next_token_chooser.tokenizer,
+ fsm_grammar_states=fsm_grammar_states,
+ )
+
+ # We skip computing the speculative_ids when the batch size is too large, so
+ # we must check that all batches have them, otherwise they must be discarded
+ if get_speculate() > 0 and all(b.speculative_ids is not None for b in batches):
+ speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
+ else:
+ speculative_ids = None
+
+ if adapter_segment_builder is not None:
+ adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
+ adapter_meta = AdapterBatchMetadata(
+ adapter_indices=adapter_indices,
+ adapter_set=adapter_set,
+ adapter_segments=adapter_segments,
+ segment_indices=adapter_segment_indices,
+ )
+
+ return cls(
+ batch_id=batches[0].batch_id,
+ requests=requests,
+ requests_idx_mapping=requests_idx_mapping,
+ input_ids=input_ids,
+ position_ids=position_ids,
+ cu_seqlen_prefill=None,
+ prefill_cache_indices=None,
+ slot_indices=slot_indices,
+ block_tables=block_tables,
+ block_tables_tensor=block_tables_tensor,
+ cache_lengths=cache_lengths,
+ cache_lengths_tensor=cache_lengths_tensor,
+ slots=slots,
+ cu_slots=cu_slots,
+ max_input_length=max_input_length,
+ max_current_length=max_current_length,
+ prefilling=prefilling,
+ prefilling_mask=prefilling_mask,
+ prefill_head_indices=None,
+ prefill_next_token_indices=None,
+ prefill_cu_outlens=None,
+ prefill_logprob_tokens=prefill_logprob_tokens,
+ prompt_lengths=prompt_lengths,
+ prompt_lengths_tensor=prompt_lengths_tensor,
+ input_lengths=input_lengths,
+ input_lengths_tensor=input_lengths_tensor,
+ prefix_offsets=prefix_offsets,
+ read_offsets=read_offsets,
+ all_input_ids=all_input_ids,
+ all_input_ids_tensor=all_input_ids_tensor,
+ next_token_chooser=next_token_chooser,
+ stopping_criterias=stopping_criterias,
+ top_n_tokens=top_n_tokens,
+ top_n_tokens_tensor=top_n_tokens_tensor,
+ num_blocks=num_blocks,
+ max_blocks=max_blocks,
+ speculative_ids=speculative_ids,
+ adapter_meta=adapter_meta,
+ hpu_attn_meta=None,
+ )
+
+ def prepare_for_decode(self, dtype, use_contiguous_pa):
+ block_num = self.cache_lengths_tensor // BLOCK_SIZE + 1
+ block_tables = []
+ for i, bt in enumerate(self.block_tables):
+ block_tables.append(bt[0 : block_num[i]])
+
+ self.hpu_attn_meta = prepare_for_decode(
+ dtype,
+ use_contiguous_pa,
+ self.block_tables_tensor.device,
+ self.slots[self.slot_indices],
+ block_tables,
+ self.input_ids.size(0),
+ )
+
+ def prepare_for_prefill(self):
+ # Prepare values if we need to continue prefilling
+ # Speculation must be ignored while we prefill even with chunking
+ # it simplifies everything
+ assert self.speculative_ids is None
+
+ device = self.block_tables_tensor.device
+
+ # hpu does not support varlen for prefill, use sdpa instead. so need to pad input_tensor, position
+ # padding to left to work with sliding window
+ # use prefill_cache_indices to indicate the valid kv slot, update prefill_next_token_indices to indicate
+ # the right logit position
+ input_ids_padded_length = []
+ # need extra pad to match warmup seq
+ extra_pad = 0
+ if isinstance(self.input_ids, list) and len(self) > 1:
+ input_ids_padded_length = []
+ input_ids = []
+ for input_id in self.input_ids:
+ padded = self.max_input_length - len(input_id) + extra_pad
+ if padded > 0:
+ input_id = [0] * padded + input_id
+ input_ids.append(input_id)
+ input_ids_padded_length.append(padded)
+ input_ids = np.concatenate(input_ids, dtype=np.int64)
+ self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
+ elif isinstance(self.input_ids, list):
+ input_ids = self.input_ids[0]
+ input_ids_padded_length.append(extra_pad)
+ input_ids = [0] * extra_pad + input_ids
+ self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
+ else:
+ self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0)
+ input_ids_padded_length.append(extra_pad)
+
+ self.input_lengths_tensor = torch.tensor(
+ self.input_lengths, dtype=torch.int32, device=device
+ )
+ cu_seqlen_prefill = self.input_lengths_tensor.new_zeros(len(self) + 1)
+ torch.cumsum(self.input_lengths_tensor, out=cu_seqlen_prefill[1:], dim=0)
+ self.cu_seqlen_prefill = cu_seqlen_prefill.to(torch.int32)
+ self.cache_lengths_tensor = torch.tensor(
+ self.cache_lengths, dtype=torch.int32, device=device
+ )
+
+ sliding_window = get_sliding_windows()
+ position_ids = []
+ slot_indices = []
+ prefill_cache_indices = []
+ all_prefill_logprobs = True
+ no_prefill_logprobs = True
+ prefill_cu_outlens = [0]
+
+ # Cumulative length
+ cumulative_length = 0
+ cumulative_slot_tokens = 0
+ prefill_out_cumulative_length = 0
+
+ adapter_indices_list = []
+ adapter_set = set()
+
+ for i, (
+ r,
+ cache_length,
+ input_length,
+ prompt_length,
+ request_prefilling,
+ blocks,
+ ) in enumerate(
+ zip(
+ self.requests,
+ self.cache_lengths,
+ self.input_lengths,
+ self.prompt_lengths,
+ self.prefilling_mask,
+ self.block_tables,
+ )
+ ):
+ next_chunk_length = input_length
+
+ # Position ids
+ request_position_ids = torch.arange(
+ cache_length, cache_length + input_length, dtype=torch.int32
+ )
+ request_position_ids = F.pad(
+ request_position_ids, (input_ids_padded_length[i], 0), value=1
+ )
+ position_ids.append(request_position_ids)
+
+ if not r.slots:
+ request_slots = [
+ s
+ for b in blocks
+ for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
+ ]
+ else:
+ request_slots = r.slots
+
+ request_slot_indices = torch.arange(
+ cache_length + cumulative_slot_tokens,
+ cache_length + cumulative_slot_tokens + input_length,
+ dtype=torch.int64,
+ )
+
+ slot_indices.append(request_slot_indices)
+
+ # Update
+ cumulative_slot_tokens += len(request_slots)
+
+ # Create tensor to slice into the kv tensor in prefill
+ # hpu need request_prefill_cache_indices to skip padding in kv cache
+ sliding_window = get_sliding_windows()
+ if sliding_window is None:
+ sliding_window = input_length
+ cumulative_length += input_ids_padded_length[i]
+ if sliding_window is not None:
+ request_prefill_cache_indices = torch.arange(
+ cumulative_length + max(0, input_length - sliding_window),
+ cumulative_length + input_length,
+ dtype=torch.int64,
+ )
+
+ # Prefill logprobs is ignored if the request is done prefilling
+ prefill_logprobs = r.prefill_logprobs and request_prefilling
+
+ all_prefill_logprobs = all_prefill_logprobs and prefill_logprobs
+ no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
+
+ if prefill_logprobs:
+ prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
+ prefill_out_cumulative_length += input_length
+ else:
+ prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
+ prefill_out_cumulative_length += 1
+
+ prefill_cache_indices.append(request_prefill_cache_indices)
+
+ ADAPTER_TO_INDEX = get_adapter_to_index()
+ if ADAPTER_TO_INDEX:
+ adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
+ adapter_indices_list.append(
+ torch.full((next_chunk_length,), adapter_index)
+ )
+ adapter_set.add(adapter_index)
+
+ # Update
+ cumulative_length += next_chunk_length
+
+ if not all_prefill_logprobs and not no_prefill_logprobs:
+ prefill_head_indices = []
+ prefill_next_token_indices = []
+
+ # Cumulative length
+ cumulative_length = 0
+ prefill_out_cumulative_length = 0
+
+ for i, (
+ r,
+ input_length,
+ request_prefilling,
+ ) in enumerate(
+ zip(
+ self.requests,
+ self.input_lengths,
+ self.prefilling_mask,
+ )
+ ):
+ # Prefill logprobs is ignored if the request is done prefilling
+ prefill_logprobs = r.prefill_logprobs and request_prefilling
+
+ if prefill_logprobs:
+ prefill_head_indices.append(
+ torch.arange(
+ cumulative_length,
+ cumulative_length + input_length,
+ dtype=torch.int64,
+ )
+ )
+ prefill_next_token_indices.append(
+ prefill_out_cumulative_length + input_length - 1
+ )
+ prefill_out_cumulative_length += input_length
+ else:
+ prefill_head_indices.append(
+ torch.tensor(
+ [cumulative_length + input_length - 1],
+ dtype=torch.int64,
+ )
+ )
+ prefill_next_token_indices.append(prefill_out_cumulative_length)
+ prefill_out_cumulative_length += 1
+
+ # Update
+ cumulative_length += input_length
+
+ if len(self) > 1:
+ if position_ids:
+ position_ids = torch.cat(position_ids)
+ if slot_indices:
+ slot_indices = torch.cat(slot_indices)
+ prefill_cache_indices = torch.cat(prefill_cache_indices)
+ else:
+ if position_ids:
+ position_ids = position_ids[0]
+ if slot_indices:
+ slot_indices = slot_indices[0]
+ prefill_cache_indices = prefill_cache_indices[0]
+
+ self.position_ids = position_ids.to(device)
+ self.slot_indices = slot_indices.to(device)
+
+ self.prefill_cu_outlens = prefill_cu_outlens
+ self.prefill_cache_indices = torch.zeros_like(self.input_ids, dtype=torch.bool)
+ self.prefill_cache_indices[prefill_cache_indices.to(device)] = True
+
+ if all_prefill_logprobs:
+ prefill_head_indices = None
+ prefill_next_token_indices = self.cu_seqlen_prefill[1:] - 1
+ elif no_prefill_logprobs:
+ prefill_head_indices = self.cu_seqlen_prefill[1:] - 1
+ prefill_next_token_indices = None
+ else:
+ prefill_head_indices = torch.cat(prefill_head_indices).to(device)
+ prefill_next_token_indices = torch.tensor(
+ prefill_next_token_indices, dtype=torch.int64, device=device
+ )
+
+ self.prefill_head_indices = prefill_head_indices
+ self.prefill_next_token_indices = prefill_next_token_indices
+ input_ids_padded_length_tensor = torch.cumsum(
+ torch.tensor(input_ids_padded_length, dtype=torch.int64, device=device),
+ dim=-1,
+ )
+ if self.prefill_head_indices is not None:
+ self.prefill_head_indices = (
+ self.prefill_head_indices + input_ids_padded_length_tensor
+ )
+
+ if self.prefill_next_token_indices is not None:
+ self.prefill_next_token_indices = (
+ self.prefill_next_token_indices + input_ids_padded_length_tensor
+ )
+
+ if adapter_set:
+ adapter_indices = torch.cat(adapter_indices_list).to(
+ dtype=torch.int64, device=device
+ )
+ adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
+ else:
+ adapter_indices = torch.zeros_like(self.input_ids)
+ adapter_segments = [0, len(adapter_indices)]
+ adapter_segment_indices = [len(adapter_indices) - 1]
+
+ adapter_segments = torch.tensor(
+ adapter_segments, dtype=torch.int32, device=device
+ )
+ self.adapter_meta = AdapterBatchMetadata(
+ adapter_indices=adapter_indices,
+ adapter_set=adapter_set,
+ adapter_segments=adapter_segments,
+ segment_indices=adapter_segment_indices,
+ )
+
+ def __len__(self):
+ return len(self.requests)
+
+
+ADAPTER_LAYERS = [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ "o_proj",
+ "gate_proj",
+ "up_proj",
+ "down_proj",
+]
+ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
+
+
+class FlashCausalLM(Model):
+ def __init__(
+ self,
+ model_id: str,
+ model_class,
+ revision: Optional[str] = None,
+ quantize: Optional[str] = None,
+ speculator: Optional[str] = None,
+ dtype: Optional[torch.dtype] = None,
+ trust_remote_code: bool = False,
+ lora_adapter_ids: Optional[list] = [],
+ tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
+ config_class: PreTrainedTokenizerBase = AutoConfig,
+ default_dtype=torch.float16,
+ aliases=None,
+ # Used for Santacoder override of config
+ num_kv_heads: Optional[int] = None,
+ # Deepseek V2 uses different QK and V dims.
+ head_size: Optional[int] = None,
+ skip_special_tokens: bool = True,
+ kv_cache_dtype: Optional[torch.dtype] = None,
+ support_chunking: bool = True,
+ ):
+ self.quantize = quantize
+ self.process_group, rank, world_size = initialize_torch_distributed()
+
+ device = torch.device("hpu")
+ dtype = torch.bfloat16 if dtype is None else dtype
+
+ tokenizer = tokenizer_class.from_pretrained(
+ model_id,
+ revision=revision,
+ padding_side="left",
+ truncation_side="left",
+ trust_remote_code=trust_remote_code,
+ )
+ try:
+ generation_config = GenerationConfig.from_pretrained(
+ model_id, revision=revision, trust_remote_code=trust_remote_code
+ )
+ if isinstance(generation_config.eos_token_id, (list, set)):
+ # TODO Huge hack
+ tokenizer._eos_token_ids = set(generation_config.eos_token_id)
+ except Exception:
+ pass
+
+ config = config_class.from_pretrained(
+ model_id, revision=revision, trust_remote_code=trust_remote_code
+ )
+ config.quantize = quantize
+ config.speculator = speculator
+
+ torch.distributed.barrier(group=self.process_group)
+
+ weights_loader = get_loader(quantize, model_id, revision)
+ filenames = weight_files(model_id, revision=revision, extension=".safetensors")
+ weights = Weights(
+ filenames,
+ device,
+ dtype,
+ process_group=self.process_group,
+ aliases=aliases,
+ weights_loader=weights_loader,
+ )
+
+ prefix = None
+ model = model_class(prefix, config, weights)
+ torch.distributed.barrier(group=self.process_group)
+
+ # VLM models define the config we care about in their text_config
+ text_config = getattr(config, "text_config", None)
+ if text_config is not None:
+ config = text_config
+
+ if getattr(config, "sliding_window", None) is not None:
+ set_sliding_window(config.sliding_window)
+ else:
+ config.sliding_window = None
+
+ self.num_layers = config.num_hidden_layers
+ self.num_heads = config.num_attention_heads // self.process_group.size()
+ self.config = config
+ # Validation is done in the model itself
+ if num_kv_heads is None:
+ num_kv_heads = getattr(config, "num_key_value_heads", None)
+ # GPT-2 workaround
+ if num_kv_heads is None:
+ num_kv_heads = getattr(config, "n_head", None)
+ if num_kv_heads is None:
+ raise ValueError("Cannot get the number of key/value heads")
+ self.num_kv_heads = (
+ num_kv_heads // self.process_group.size()
+ if num_kv_heads > 1
+ else num_kv_heads
+ )
+ assert self.num_kv_heads > 0
+
+ if head_size is None:
+ # Some models use GQA and different sizes for o_proj
+ # and q_proj, that allows for that.
+ if hasattr(config, "head_dim"):
+ self.head_size = config.head_dim
+ else:
+ self.head_size = config.hidden_size // config.num_attention_heads
+ else:
+ self.head_size = head_size
+
+ self.cuda_graphs = {}
+ self.kv_cache = []
+ self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
+
+ if htorch.utils.internal.is_lazy():
+ htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=False)
+ environment.set_model_config(self.config)
+ self.use_contiguous_pa = (
+ os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true"
+ )
+ super().__init__(
+ model_id=model_id,
+ model=model,
+ tokenizer=tokenizer,
+ requires_padding=False,
+ dtype=dtype,
+ device=device,
+ rank=rank,
+ world_size=world_size,
+ sliding_window=config.sliding_window,
+ support_chunking=support_chunking,
+ )
+
+ @property
+ def batch_type(self) -> Type[FlashCausalLMBatch]:
+ return FlashCausalLMBatch
+
+ def max_past(self) -> int:
+ return getattr(self.model, "max_past", None)
+
+ def init_kv_cache(
+ self,
+ num_blocks: int,
+ num_layers: int,
+ num_heads: int,
+ head_size: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ ):
+ self.kv_cache = []
+ empty_cache()
+ self.kv_cache = [
+ KVCache(
+ num_blocks=num_blocks,
+ num_heads=num_heads,
+ head_size=head_size,
+ dtype=dtype,
+ device=device,
+ )
+ for _ in range(num_layers)
+ ]
+
+ def warmup(
+ self,
+ batch: FlashCausalLMBatch,
+ max_input_tokens: Optional[int],
+ max_total_tokens: Optional[int],
+ ):
+ # The warmup batch is the biggest batch we could ever receive
+ self.kv_cache = []
+ empty_cache()
+
+ # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
+ # Calculate the number of blocks that can be allocated with the free memory
+ dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
+ cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
+ total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
+
+ try:
+ self.init_kv_cache(
+ batch.num_blocks,
+ self.num_layers,
+ self.num_kv_heads,
+ self.head_size,
+ self.kv_cache_dtype,
+ self.device,
+ )
+
+ batch_num_blocks = batch.num_blocks
+
+ num_tokens = batch.to_pb().current_tokens
+ synchronize(self.device)
+ free_memory = get_free_memory(
+ self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM
+ )
+ real_free_memory = get_free_memory(self.device, MEMORY_FRACTION)
+ log_master(
+ logger.debug,
+ f"Free memory {free_memory / 1e9:.2f}GB , (real: {real_free_memory / 1e9:.2f}GB",
+ )
+
+ _, _batch, _ = self.generate_token([batch])
+ except Exception:
+ raise RuntimeError(
+ f"Not enough memory to handle {num_tokens} prefill tokens. "
+ f"You need to decrease `--max-batch-prefill-tokens`"
+ )
+
+ synchronize(self.device)
+ free_memory = get_free_memory(self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM)
+ kv_memory = free_memory
+ num_blocks = (
+ # Leave 5% for some wiggle room
+ int(kv_memory // total_cache_size)
+ # Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
+ + batch_num_blocks
+ )
+
+ log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
+ if max_total_tokens is None:
+ max_total_tokens = sum(batch.cache_lengths)
+
+ if max_input_tokens is None:
+ max_input_tokens = max_total_tokens - 1
+
+ del _batch, batch
+ self.kv_cache = []
+ empty_cache()
+
+ self.init_kv_cache(
+ num_blocks,
+ self.num_layers,
+ self.num_kv_heads,
+ self.head_size,
+ self.kv_cache_dtype,
+ self.device,
+ )
+ return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
+
+ def warmup_prefill(self, prompt_len: int, bs: int):
+ input_ids = torch.zeros(
+ prompt_len, dtype=torch.int64, device=self.device
+ ).repeat(bs)
+ position_ids = torch.arange(
+ prompt_len, dtype=torch.int32, device=self.device
+ ).repeat(bs)
+ max_bt = (prompt_len // BLOCK_SIZE + 1) * bs
+ block_tables = torch.arange(
+ max_bt, dtype=torch.int32, device=self.device
+ ).reshape(bs, -1)
+ slot_acc = []
+ for i in range(bs):
+ slots = []
+ for b in block_tables[i]:
+ slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
+ slot_acc.extend(slots[:prompt_len])
+ slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device)
+
+ input_lengths = (
+ torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len
+ )
+ cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
+ cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32)
+ torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
+
+ seqlen = Seqlen(
+ input_lengths=input_lengths,
+ cache_lengths=cache_lengths_tensor,
+ cu_seqlen_q=cu_seqlen_prefill,
+ )
+ lm_head_indices = input_lengths - 1
+
+ # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
+ self.model.forward(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ cu_seqlen_prefill=cu_seqlen_prefill,
+ kv_cache=self.kv_cache,
+ slots=slots,
+ seqlen=trim_seqlen_metadata(seqlen),
+ lm_head_indices=lm_head_indices,
+ adapter_data=None,
+ hpu_attention_meta=None,
+ )
+
+ def warmup_decode(self, bs: int, block_num: int):
+ input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
+ position_ids = torch.arange(bs, dtype=torch.int32, device=self.device)
+ block_tables = torch.arange(
+ start=1, end=block_num + 1, dtype=torch.int32, device=self.device
+ ).reshape(bs, -1)
+ slots = []
+ past_len = (
+ len(block_tables[0]) * BLOCK_SIZE - 1
+ ) # for decode, we only need to pass the past token
+ # fetch the last blocked to warmup block num
+ for i in range(bs):
+ slots.append(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 1)
+ slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
+ input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device)
+ cache_lengths_tensor = (
+ torch.ones(bs, dtype=torch.int32, device=self.device) * past_len
+ )
+ cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32)
+ torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
+
+ seqlen = Seqlen(
+ input_lengths=input_lengths,
+ cache_lengths=cache_lengths_tensor,
+ cu_seqlen_q=cu_seqlen_prefill,
+ )
+ block_num = cache_lengths_tensor // BLOCK_SIZE + 1
+ block_tables_valid = []
+ for i, bt in enumerate(block_tables.tolist()):
+ block_tables_valid.append(bt[0 : block_num[i]])
+
+ hpu_attention_meta = prepare_for_decode(
+ self.dtype,
+ self.use_contiguous_pa,
+ self.device,
+ slots,
+ block_tables_valid,
+ bs,
+ )
+
+ # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
+ self.model.forward(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ cu_seqlen_prefill=None,
+ kv_cache=self.kv_cache,
+ slots=slots,
+ seqlen=trim_seqlen_metadata(seqlen),
+ lm_head_indices=None,
+ adapter_data=None,
+ hpu_attention_meta=hpu_attention_meta,
+ )
+
+ def forward(
+ self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ # Model Forward
+ if batch.speculative_ids is not None:
+ input_ids = batch.input_ids
+ position_ids = batch.position_ids
+ cu_seqlen_prefill = batch.cu_seqlen_prefill
+ kv_cache = self.kv_cache
+ block_tables = batch.block_tables_tensor
+ slots = batch.slots[batch.slot_indices]
+ input_lengths = batch.input_lengths_tensor
+ max_s = batch.max_current_length
+ lm_head_indices = batch.prefill_head_indices
+
+ speculative_ids = batch.speculative_ids
+
+ B, speculative_length = speculative_ids.shape
+ new_length = speculative_length + 1
+ new_input_ids = torch.cat(
+ [input_ids.unsqueeze(-1), speculative_ids], dim=1
+ ).reshape(-1)
+ arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
+ arange_int = arange.to(dtype=torch.int32)
+ new_position_ids = (
+ position_ids.unsqueeze(-1).expand(B, new_length) + arange
+ ).view(-1)
+
+ # Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices,
+ # then update the slots with the additional indices to ensure we're grabbing the ones that have been
+ # allocated
+ slot_indices = (
+ batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int
+ ).view(-1)
+ slots = batch.slots[slot_indices]
+
+ input_lengths = (
+ input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
+ ).view(-1)
+ cache_lengths_tensor = (
+ batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
+ ).reshape(-1)
+
+ # Add Copy the block tables for all members
+ block_tables = (
+ block_tables.unsqueeze(1)
+ .expand(B, new_length, -1)
+ .reshape(B * new_length, -1)
+ .contiguous()
+ )
+ max_s = max_s + speculative_length
+
+ input_ids = new_input_ids
+ position_ids = new_position_ids
+ else:
+ input_ids = batch.input_ids
+ position_ids = batch.position_ids
+ cu_seqlen_prefill = batch.cu_seqlen_prefill
+ kv_cache = self.kv_cache
+ block_tables = batch.block_tables_tensor
+ slots = batch.slots[batch.slot_indices]
+ input_lengths = batch.input_lengths_tensor
+ cache_lengths_tensor = batch.cache_lengths_tensor
+ max_s = batch.max_current_length
+ lm_head_indices = batch.prefill_head_indices
+
+ if cu_seqlen_prefill is None and self.max_past() is not None:
+ # In decode, not prefill, we're actually overwriting the KV-cache
+ # in a circular buffer mode.
+ # This makes sure the max_s for the decode pass is correct.
+ max_s = min(self.max_past(), max_s)
+
+ seqlen = Seqlen(
+ input_lengths=input_lengths,
+ cache_lengths=cache_lengths_tensor,
+ cu_seqlen_q=cu_seqlen_prefill,
+ )
+ kwargs = {}
+ if htorch.utils.internal.is_lazy():
+ kwargs["bypass_hpu_graphs"] = False
+ if batch.prefill_cache_indices is not None:
+ slots_pad = torch.zeros_like(input_ids)
+ slots_pad[batch.prefill_cache_indices] = slots
+ slots = slots_pad
+ logits, speculative_logits = self.model.forward(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ cu_seqlen_prefill=cu_seqlen_prefill,
+ kv_cache=kv_cache,
+ slots=slots,
+ seqlen=trim_seqlen_metadata(seqlen),
+ lm_head_indices=lm_head_indices,
+ # TODO not support adapter now, need the add in the future
+ adapter_data=None,
+ hpu_attention_meta=batch.hpu_attn_meta,
+ **kwargs,
+ )
+ return logits, speculative_logits
+
+ @tracer.start_as_current_span("generate_token")
+ def generate_token(
+ self, batches: List[FlashCausalLMBatch]
+ ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
+ if len(batches) > 1:
+ batch = self.batch_type.concatenate(batches)
+ else:
+ batch = batches[0]
+ start = time.time_ns()
+ prefill = batch.prefilling
+ if prefill:
+ batch.prepare_for_prefill()
+ else:
+ batch.prepare_for_decode(self.dtype, self.use_contiguous_pa)
+ prefill_logprobs = batch.prefill_next_token_indices is not None
+ # Update adapter indices for speculative tokens (if present)
+ adapter_meta = batch.adapter_meta
+ if batch.speculative_ids is not None:
+ B, speculative_length = batch.speculative_ids.shape
+ new_length = speculative_length + 1
+ adapter_indices = (
+ adapter_meta.adapter_indices.unsqueeze(-1)
+ .expand(B, new_length)
+ .reshape(-1)
+ )
+ adapter_segments = adapter_meta.adapter_segments * new_length
+ adapter_meta = AdapterBatchMetadata(
+ adapter_indices=adapter_indices,
+ adapter_set=adapter_meta.adapter_set,
+ adapter_segments=adapter_segments,
+ segment_indices=adapter_meta.segment_indices,
+ )
+
+ # Assign pointers to adapter weights
+ # TODO(travis): don't update this if indices haven't changed
+ adapter_data = AdapterBatchData.from_meta(
+ adapter_meta,
+ self.layer_to_adapter_weights,
+ prefill,
+ batch.prefill_head_indices,
+ )
+
+ out, speculative_logits = self.forward(batch, adapter_data)
+
+ if prefill:
+ next_token_logits = (
+ out[batch.prefill_next_token_indices] if prefill_logprobs else out
+ )
+ if speculative_logits is not None:
+ speculative_logits = (
+ speculative_logits[batch.prefill_next_token_indices]
+ if prefill_logprobs
+ else speculative_logits
+ )
+ if len(batch) > 1 and prefill_logprobs:
+ # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
+ # When batch == 1, we will just use the batch.input_ids values directly
+ prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
+ else:
+ prefill_logprobs = None
+ next_token_logits = out
+
+ finished_prefilling = True
+ next_chunk_lengths = []
+ current_prefilling_mask = batch.prefilling_mask
+ if prefill:
+ finished_prefilling = True
+ next_prefilling_mask = [False] * len(batch)
+
+ batch.prefilling = not finished_prefilling
+ batch.prefilling_mask = next_prefilling_mask
+
+ speculate = get_speculate()
+ (
+ next_input_ids,
+ next_token_logprobs,
+ logprobs,
+ accepted_ids,
+ speculative_ids,
+ ) = batch.next_token_chooser(
+ batch.all_input_ids_tensor[:, : batch.max_current_length],
+ next_token_logits,
+ speculate,
+ batch.speculative_ids,
+ speculative_logits,
+ )
+
+ batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
+ batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
+ )
+
+ # Since we are done prefilling, all the tensors that were concatenating values for all the requests
+ # instantly become of shape [BATCH_SIZE]
+ if prefill and finished_prefilling:
+ indices = batch.cu_seqlen_prefill[1:] - 1
+ # pad in left
+ if batch.prefill_cache_indices is not None:
+ batch.position_ids = batch.position_ids[batch.prefill_cache_indices][
+ indices
+ ]
+ else:
+ batch.position_ids = batch.position_ids[indices]
+
+ batch.slot_indices = batch.slot_indices[indices]
+ batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
+ indices
+ ]
+
+ # Zipped iterator
+ iterator = zip(
+ batch.requests,
+ batch.prompt_lengths,
+ batch.cache_lengths,
+ batch.input_lengths,
+ batch.all_input_ids,
+ accepted_ids,
+ current_prefilling_mask,
+ batch.prefilling_mask,
+ )
+
+ # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
+ # one, we need to first do a HPU <-> CPU sync
+ # It is faster if we delay this sync for the maximum amount of time
+
+ # For each member of the batch
+ # Cumulative length
+ cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
+ torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
+ cumulative_length = 0
+ for i, (
+ request,
+ prompt_length,
+ cache_length,
+ input_length,
+ all_input_ids,
+ n_accepted_ids,
+ request_was_prefilling,
+ request_is_prefilling,
+ ) in enumerate(iterator):
+ # Used to gather prefill logprobs
+ # Copy batch.all_input_ids_tensor to prefill_token_indices
+ if request.prefill_logprobs and request_was_prefilling:
+ # Indexing metadata
+ out_start_index = batch.prefill_cu_outlens[i]
+ out_end_index = batch.prefill_cu_outlens[i + 1]
+
+ # Logprobs generated by the model are for the next token
+ # So we need to translate the id tensor by 1
+ ids = batch.all_input_ids_tensor[
+ i, cache_length + 1 : cache_length + input_length + 1
+ ]
+ if len(batch) > 1:
+ prefill_tokens_indices[out_start_index:out_end_index] = ids
+ else:
+ # Set prefill_tokens_indices to the correct slice
+ prefill_tokens_indices = ids
+
+ # If the device does not support triton, we copy one by one
+ if not request_is_prefilling:
+ # Only save tokens if we are done prefilling for this request
+ batch.all_input_ids_tensor[
+ i,
+ batch.cache_lengths_tensor[i]
+ + batch.input_lengths[i] : batch.cache_lengths_tensor[i]
+ + batch.input_lengths[i]
+ + accepted_ids[i],
+ ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
+ cumulative_length += input_length
+
+ # Update values
+ # These values can be updated without a HPU -> CPU sync
+ if not prefill or (prefill and finished_prefilling):
+ batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
+ batch.speculative_ids = speculative_ids
+ if batch.position_ids.dim() == 2:
+ # Qwen2_vl case:
+ batch.position_ids += accepted_ids.unsqueeze(-1)
+ else:
+ batch.position_ids += accepted_ids
+ batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1
+ batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
+ batch.slot_indices += accepted_ids
+
+ if prefill and prefill_logprobs:
+ # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size))
+ torch.log_softmax(out, -1, out=out)
+ prefill_logprobs_tensor = out
+ prefill_logprobs = torch.gather(
+ prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
+ )
+ # HPU <-> CPU sync
+ prefill_logprobs = prefill_logprobs.view(-1).tolist()
+
+ # Does a HPU <-> CPU sync internally
+ if prefill and finished_prefilling:
+ # adjust segment lengths to account for all request lengths being 1 during decoding
+ adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
+ batch.adapter_meta.adapter_segments = torch.tensor(
+ adapter_segments,
+ dtype=torch.int32,
+ device=batch.adapter_meta.adapter_segments.device,
+ )
+
+ # HPU <-> CPU sync
+ next_token_logprobs = next_token_logprobs.tolist()
+ next_token_ids = next_input_ids.tolist()
+ accepted_ids = accepted_ids.tolist()
+
+ # Update values if we need to continue prefilling
+ # This represents the `else` case of the `Update values` if above
+ # but since this require the `next_token_ids` to be on CPU, it is better to do it here
+ if prefill and not finished_prefilling:
+ # Speculation must be ignored while we prefill even with chunking
+ # it simplifies everything
+ assert batch.speculative_ids is None
+
+ all_postfix_ids = []
+ for i, (
+ request_prefilling,
+ next_token_id,
+ all_input_ids,
+ cache_length,
+ input_length,
+ next_chunk_length,
+ ) in enumerate(
+ zip(
+ batch.prefilling_mask,
+ next_token_ids,
+ batch.all_input_ids,
+ batch.cache_lengths,
+ batch.input_lengths,
+ next_chunk_lengths,
+ )
+ ):
+ if request_prefilling:
+ next_cache_length = cache_length + input_length
+ # Get new prompt IDs to prefill
+ postfix_ids = all_input_ids[
+ next_cache_length : next_cache_length + next_chunk_length
+ ]
+ else:
+ # This request is done prefilling, the new id is the one selected the sampling method
+ postfix_ids = [next_token_id]
+
+ all_postfix_ids.append(postfix_ids)
+
+ batch.input_ids = all_postfix_ids
+
+ start_decode = time.time_ns()
+
+ # Results
+ generations: List[Generation] = []
+ stopped = True
+
+ # Zipped iterator
+ iterator = zip(
+ batch.requests,
+ batch.prompt_lengths,
+ batch.cache_lengths,
+ batch.input_lengths,
+ batch.prefix_offsets,
+ batch.read_offsets,
+ batch.stopping_criterias,
+ batch.all_input_ids,
+ batch.next_token_chooser.do_sample,
+ batch.next_token_chooser.seeds,
+ batch.top_n_tokens,
+ current_prefilling_mask,
+ batch.prefilling_mask,
+ accepted_ids,
+ batch_top_token_ids,
+ batch_top_token_logprobs,
+ )
+
+ # Reset max_input_length
+ batch.max_input_length = 0
+ # For each member of the batch
+ index = 0
+ for i, (
+ request,
+ prompt_length,
+ cache_length,
+ input_length,
+ prefix_offset,
+ read_offset,
+ stopping_criteria,
+ all_input_ids,
+ do_sample,
+ seed,
+ top_n_tokens,
+ request_was_prefilling,
+ request_is_prefilling,
+ n_accepted_ids,
+ top_token_ids,
+ top_token_logprobs,
+ ) in enumerate(iterator):
+ # Compute logprobs first as, even though we might skip the token,
+ # it can still be required to compute the logprobs
+ # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need
+ # this state to be stable
+ if request.id % self.world_size == self.rank:
+ # Prefill
+ if request_was_prefilling and request.prefill_logprobs:
+ out_start_index = batch.prefill_cu_outlens[i]
+ out_end_index = batch.prefill_cu_outlens[i + 1]
+ if not request_is_prefilling:
+ # The request is dones prefilling, meaning that we started generating new tokens
+ # The last logprob is a logprob for a generated token that was not part of the prompt
+ # We need to remove it
+ out_end_index -= 1
+
+ request_prefill_logprobs = prefill_logprobs[
+ out_start_index:out_end_index
+ ]
+ # Logprobs generated by the model are for the next token
+ # So we need to translate the id tensor by 1
+ prefill_token_ids = all_input_ids[
+ cache_length + 1 : cache_length + input_length + 1
+ ]
+
+ past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]
+
+ if past_prefill_logprob_tokens is None:
+ # add nan for cached prompt tokens/first token
+ request_prefill_logprobs = [float("nan")] * (
+ cache_length + 1
+ ) + request_prefill_logprobs
+ prefill_token_ids = (
+ all_input_ids[: cache_length + 1] + prefill_token_ids
+ )
+
+ prefill_texts = self.tokenizer.batch_decode(
+ prefill_token_ids,
+ clean_up_tokenization_spaces=False,
+ skip_special_tokens=False,
+ )
+
+ prefill_logprob_tokens = Tokens(
+ prefill_token_ids,
+ request_prefill_logprobs,
+ prefill_texts,
+ is_special=[],
+ )
+ if past_prefill_logprob_tokens is not None:
+ prefill_logprob_tokens = (
+ past_prefill_logprob_tokens + prefill_logprob_tokens
+ )
+
+ batch.prefill_logprob_tokens[i] = prefill_logprob_tokens
+ else:
+ batch.prefill_logprob_tokens[i] = None
+
+ # If it is, the tokens we decoded should be ignored
+ if request_is_prefilling:
+ # Make sure that we do not stop as even though this request did not create a token, it is still
+ # processing
+ stopped = False
+ new_input_length = next_chunk_lengths[i]
+ new_cache_length = cache_length + input_length
+ else:
+ new_input_length = 1
+ new_cache_length = cache_length + input_length + n_accepted_ids - 1
+ # Append next token to all tokens
+ next_token_texts = []
+ left = 0
+
+ if n_accepted_ids > 1:
+ log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")
+
+ current_stopped = False
+ for j in range(index, index + n_accepted_ids):
+ # Generated token
+ next_token_id = next_token_ids[j]
+ all_input_ids.append(next_token_id)
+ next_token_text, prefix_offset, read_offset = self.decode_token(
+ all_input_ids,
+ prefix_offset,
+ read_offset,
+ )
+ next_token_texts.append(next_token_text)
+
+ stop, reason = stopping_criteria(
+ next_token_id,
+ next_token_text,
+ )
+
+ if stop:
+ left = index + n_accepted_ids - j - 1
+ current_stopped = True
+ break
+ else:
+ current_stopped = False
+ stopped = stopped and current_stopped
+
+ _next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
+ _next_token_logprobs = next_token_logprobs[
+ index : index + n_accepted_ids - left
+ ]
+
+ # Shard generations
+ # All generations will be appended in the rust sharded client
+ if request.id % self.world_size == self.rank:
+ if stop:
+ # Decode generated tokens
+ output_text, _, _ = self.decode_token(
+ all_input_ids,
+ prefix_offset=len(all_input_ids)
+ - stopping_criteria.current_tokens
+ - 1,
+ read_offset=len(all_input_ids)
+ - stopping_criteria.current_tokens,
+ skip_special_tokens=True,
+ )
+ generated_text = GeneratedText(
+ output_text,
+ stopping_criteria.current_tokens,
+ reason,
+ seed if do_sample else None,
+ )
+ else:
+ generated_text = None
+
+ if top_n_tokens > 0:
+ all_top_tokens = []
+ for top_token_ids, top_token_logprobs in zip(
+ top_token_ids, top_token_logprobs
+ ):
+ toptoken_texts = self.tokenizer.batch_decode(
+ top_token_ids,
+ clean_up_tokenization_spaces=False,
+ skip_special_tokens=False,
+ )
+ special_toptokens = [
+ token_id in self.all_special_ids
+ for token_id in top_token_ids
+ ]
+ top_tokens = Tokens(
+ top_token_ids,
+ top_token_logprobs,
+ toptoken_texts,
+ special_toptokens,
+ )
+ all_top_tokens.append(top_tokens)
+ top_tokens = all_top_tokens
+ else:
+ top_tokens = None
+
+ generation = Generation(
+ request.id,
+ batch.prefill_logprob_tokens[i],
+ Tokens(
+ _next_token_ids,
+ _next_token_logprobs,
+ next_token_texts,
+ [nid in self.all_special_ids for nid in _next_token_ids],
+ ),
+ generated_text,
+ top_tokens,
+ )
+
+ generations.append(generation)
+
+ # accept each new token for this specific request since we may
+ # have more than one new token per request with speculative decoding
+ for next_token_id in _next_token_ids:
+ batch.next_token_chooser = (
+ batch.next_token_chooser.advance_grammar_single(
+ i, next_token_id
+ )
+ )
+
+ # Update values
+ index += n_accepted_ids
+ batch.cache_lengths[i] = new_cache_length
+ batch.max_input_length = max(batch.max_input_length, new_input_length)
+ batch.input_lengths[i] = new_input_length
+ current_length = new_cache_length + new_input_length
+ batch.max_current_length = max(batch.max_current_length, current_length)
+
+ batch.prefix_offsets[i] = prefix_offset
+ batch.read_offsets[i] = read_offset
+ batch.all_input_ids[i] = all_input_ids
+
+ if stopped:
+ # No need to return a batch if we know that all requests stopped
+ forward_ns = start_decode - start
+ decode_ns = time.time_ns() - start_decode
+ return generations, None, (forward_ns, decode_ns)
+
+ if prefill and finished_prefilling:
+ # We do not need prefill tensors anymore
+ batch.cu_seqlen_prefill = None
+ batch.prefill_cache_indices = None
+ batch.prefill_cu_outlens = None
+ batch.prefill_head_indices = None
+ batch.prefill_next_token_indices = None
+
+ forward_ns = start_decode - start
+ decode_ns = time.time_ns() - start_decode
+ return generations, batch, (forward_ns, decode_ns)
diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py
new file mode 100644
index 000000000..208ab3582
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py
@@ -0,0 +1,489 @@
+import torch
+from PIL import Image
+from io import BytesIO
+
+from opentelemetry import trace
+from typing import Iterable, Optional, Tuple, List, Type, Dict
+
+from transformers import PreTrainedTokenizerBase
+from transformers.image_processing_utils import select_best_resolution
+from text_generation_server.pb import generate_pb2
+from text_generation_server.models.flash_causal_lm import (
+ FlashCausalLMBatch,
+ FlashCausalLM,
+)
+from text_generation_server.models.globals import PREFIX_CACHING
+from loguru import logger
+from text_generation_server.utils.log import log_master
+from transformers import AutoProcessor
+from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
+import habana_frameworks.torch as htorch
+
+tracer = trace.get_tracer(__name__)
+
+IDEFICS2_FAKE_TOKEN = ""
+IDEFICS2_IMAGE_TOKEN = ""
+
+IDEFICS3_IMAGE_TOKEN = ""
+IDEFICS3_FAKE_IMAGE_TOKEN = ""
+IDEFICS3_GLOBAL_IMG_TOKEN = ""
+
+
+# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60
+def _prompt_split_image(
+ *,
+ image_seq_len: int,
+ image_rows: int,
+ image_cols: int,
+ fake_token_around_image: str,
+ image_token: str,
+ global_img_token: str,
+):
+ """Prompt with expanded image tokens for when the image is split into patches."""
+ text_split_images = ""
+ for n_h in range(image_rows):
+ for n_w in range(image_cols):
+ text_split_images += (
+ f"{fake_token_around_image}"
+ + f""
+ + f"{image_token}" * image_seq_len
+ )
+ text_split_images += "\n"
+
+ text_split_images += (
+ f"\n{fake_token_around_image}"
+ + f"{global_img_token}"
+ + f"{image_token}" * image_seq_len
+ + f"{fake_token_around_image}"
+ )
+ return text_split_images
+
+
+def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
+ """
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
+
+ Args:
+ image_size (`tuple`):
+ The size of the input image in the format (height, width).
+ grid_pinpoints (`List`):
+ A list containing possible resolutions. Each item in the list should be a tuple or list
+ of the form `(height, width)`.
+ patch_size (`int`):
+ The size of each image patch.
+
+ Returns:
+ tuple: The shape of the image patch grid in the format (width, height).
+ """
+ if not isinstance(grid_pinpoints, list):
+ raise ValueError("grid_pinpoints should be a list of tuples or lists")
+
+ height, width = select_best_resolution(image_size, grid_pinpoints)
+ return height // patch_size, width // patch_size
+
+
+def image_text_replacement(processor, image_input, config, image_id: int) -> str:
+ if config.model_type == "idefics2":
+ image_seq_len = 64
+ image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
+ if processor.image_processor.do_image_splitting:
+ image_str *= 5
+ return image_str
+ if config.model_type == "idefics3":
+ # TODO: implement this in a more general way
+ n_rows = image_input["rows"][0][image_id]
+ n_cols = image_input["cols"][0][image_id]
+ image_seq_len = int(
+ ((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
+ / (config.scale_factor**2)
+ )
+ image_str = _prompt_split_image(
+ image_seq_len=image_seq_len,
+ image_rows=n_rows,
+ image_cols=n_cols,
+ fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN,
+ image_token=IDEFICS3_IMAGE_TOKEN,
+ global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
+ )
+ return image_str
+ elif config.model_type == "llava_next":
+ height, width = image_input["image_sizes"][image_id]
+ num_features = get_number_of_features(height, width, config)
+
+ log_master(
+ logger.info,
+ f"Found {num_features} features in image of resolution {height}x{width}",
+ )
+ return "" * num_features
+
+ elif config.model_type == "paligemma":
+ return "" * config.text_config.num_image_tokens
+ elif config.model_type == "qwen2_vl":
+ grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
+ num_pads = grid_t * grid_h * grid_w // 4
+ padding = "<|image_pad|>" * num_pads
+ return f"<|vision_start|>{padding}<|vision_end|>"
+ elif config.model_type == "qwen2_5_vl":
+ grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
+ num_pads = grid_t * grid_h * grid_w // 4
+ padding = "<|image_pad|>" * num_pads
+ return f"<|vision_start|>{padding}<|vision_end|>"
+ elif config.model_type == "gemma3":
+ # TODO: get correct number of features via reviewing the Gemma3 architecture
+ # and calculating the number of image tokens
+ num_pads = 256
+ padding = "" * num_pads
+ return f"\n\n{padding}\n\n"
+ else:
+ raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
+
+
+def image_text_replacement_fixup(config, text: str) -> str:
+ if config.model_type == "idefics2":
+ return text.replace(
+ f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
+ )
+ return text
+
+
+def get_unpadded_features(
+ original_height: int,
+ original_width: int,
+ npatches: int,
+ num_patch_height: int,
+ num_patch_width: int,
+) -> Tuple[int, int]:
+ current_height = npatches * num_patch_height
+ current_width = npatches * num_patch_width
+
+ aspect_ratio: float = original_width / original_height
+ current_aspect_ratio: float = current_width / current_height
+
+ if aspect_ratio > current_aspect_ratio:
+ new_height = (original_height * current_width) // original_width
+ padding = (current_height - new_height) // 2
+ current_height = current_height - (2 * padding)
+ else:
+ new_width = (original_width * current_height) // original_height
+ padding = (current_width - new_width) // 2
+ current_width = current_width - (2 * padding)
+
+ unpadded_features = current_height * current_width
+ newline_features = current_height
+ return (unpadded_features, newline_features)
+
+
+def get_number_of_features(height: int, width: int, config) -> int:
+ # From config
+ # Hardcoded for CLIP for now
+ # image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
+ image_grid_pinpoints = config.image_grid_pinpoints
+ image_size = config.vision_config.image_size
+ patch_size = config.vision_config.patch_size
+
+ assert image_size % patch_size == 0
+
+ npatches = image_size // patch_size
+
+ # Dimensions are intentionally swapped to be bug-compatible with
+ # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(
+ [height, width],
+ image_grid_pinpoints,
+ image_size,
+ )
+ unpadded_features, newline_features = get_unpadded_features(
+ height, width, npatches, num_patch_height, num_patch_width
+ )
+ # The base patch covers the entire image
+ base_features = npatches**2
+ return unpadded_features + newline_features + base_features
+
+
+class FlashVlmCausalLMBatch(FlashCausalLMBatch):
+ pixel_values: Optional[List[torch.Tensor]]
+ pixel_attention_mask: Optional[List[torch.Tensor]]
+ image_sizes: Optional[List[Tuple[int, int]]]
+ image_grid_thw: Optional[torch.Tensor]
+
+ @classmethod
+ @tracer.start_as_current_span("concatenate")
+ def concatenate(cls, batches):
+ batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches)
+ batch.pixel_values = None
+ batch.pixel_attention_mask = None
+ batch.image_sizes = None
+ batch.image_grid_thw = None
+ return batch
+
+ @tracer.start_as_current_span("filter")
+ def filter(self, request_ids: List[int]):
+ batch = super().filter(request_ids)
+ batch.pixel_values = None
+ batch.pixel_attention_mask = None
+ batch.image_sizes = None
+ batch.image_grid_thw = None
+ return batch
+
+ @classmethod
+ def batch_tokenized_inputs(
+ cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
+ ):
+ # Process images first. We need all of them so that the processor
+ # can make the image splits the same size. And we need the final
+ # sizes to insert correct number of image tokens.
+ images = []
+ for r in requests:
+ for chunk in r.input_chunks.chunks:
+ chunk_type = chunk.WhichOneof("chunk")
+ if chunk_type == "text":
+ pass
+ elif chunk_type == "image":
+ image = Image.open(BytesIO(chunk.image.data))
+ # qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the
+ # default warmup image is 20x20
+ if config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
+ if image.width <= 20:
+ w = image.width * 2
+ h = image.height * 2
+ image = image.resize((w, h))
+
+ if config.model_type == "llava_next":
+ images.append(image)
+ elif config.model_type == "gemma3":
+ images.append(image)
+ else:
+ images.append([image])
+ else:
+ raise RuntimeError(f"Invalid chunk type {chunk_type}")
+
+ if images:
+ kwargs = {}
+ if (
+ hasattr(processor, "image_processor_class")
+ and processor.image_processor_class == "Idefics3ImageProcessor"
+ ):
+ kwargs["return_row_col_info"] = True
+
+ image_inputs = processor.image_processor(
+ images, return_tensors="pt", **kwargs
+ )
+ else:
+ image_inputs = None
+
+ batch_tokenized_inputs = []
+ max_length = 0
+ image_id = 0
+ for r in requests:
+ full_text = ""
+ for chunk in r.input_chunks.chunks:
+ chunk_type = chunk.WhichOneof("chunk")
+ if chunk_type == "text":
+ full_text += chunk.text
+ elif chunk_type == "image":
+ full_text += image_text_replacement(
+ processor, image_inputs, config, image_id
+ )
+ image_id += 1
+
+ full_text = image_text_replacement_fixup(config, full_text)
+ input_ids = tokenizer(
+ full_text,
+ truncation=True,
+ max_length=r.truncate,
+ add_special_tokens=r.add_special_tokens,
+ )["input_ids"]
+ max_length = max(max_length, len(input_ids))
+ batch_tokenized_inputs.append(input_ids)
+
+ return batch_tokenized_inputs, image_inputs
+
+ @classmethod
+ def from_pb_processor(
+ cls,
+ pb: generate_pb2.Batch,
+ tokenizer: PreTrainedTokenizerBase,
+ processor,
+ config,
+ dtype: torch.dtype,
+ device: torch.device,
+ ) -> "FlashVlmCausalLMBatch":
+ batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
+ pb.requests, tokenizer, processor, config
+ )
+ batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
+ if image_inputs is not None:
+ batch.pixel_values = image_inputs["pixel_values"].to(device=device)
+ if "pixel_attention_mask" in image_inputs:
+ batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
+ device=device
+ )
+ else:
+ batch.pixel_attention_mask = None
+ if "image_sizes" in image_inputs:
+ batch.image_sizes = image_inputs["image_sizes"].to(device=device)
+ else:
+ batch.image_sizes = None
+ if "image_grid_thw" in image_inputs:
+ batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
+ else:
+ batch.image_grid_thw = None
+ else:
+ batch.pixel_values = None
+ batch.pixel_attention_mask = None
+ batch.image_sizes = None
+ batch.image_grid_thw = None
+ return batch
+
+
+class FlashVlmCausalLM(FlashCausalLM):
+ def __init__(
+ self,
+ model_id: str,
+ *,
+ processor_class=AutoProcessor,
+ processor_kwargs=None,
+ batch_class=FlashVlmCausalLMBatch,
+ revision,
+ trust_remote_code: bool,
+ **kwargs,
+ ):
+ if PREFIX_CACHING:
+ raise NotImplementedError("Vlm do not work with prefix caching yet")
+ if processor_kwargs is None:
+ processor_kwargs = {}
+ self.processor = processor_class.from_pretrained(
+ model_id,
+ revision=revision,
+ trust_remote_code=trust_remote_code,
+ **processor_kwargs,
+ )
+ self.batch_class = batch_class
+ super().__init__(
+ model_id=model_id,
+ revision=revision,
+ trust_remote_code=trust_remote_code,
+ # FIXME: VLM do not work with context chunking yet
+ support_chunking=False,
+ **kwargs,
+ )
+
+ @property
+ def batch_type(self) -> Type[FlashVlmCausalLMBatch]:
+ return self.batch_class
+
+ def max_past(self) -> Optional[int]:
+ return getattr(self.model.text_model, "max_past", None)
+
+ def forward(
+ self,
+ batch: FlashVlmCausalLMBatch,
+ adapter_data: Optional[Dict[str, torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ # Model Forward
+ if batch.speculative_ids is not None:
+ input_ids = batch.input_ids
+ position_ids = batch.position_ids
+ cu_seqlen_prefill = batch.cu_seqlen_prefill
+ kv_cache = self.kv_cache
+ block_tables = batch.block_tables_tensor
+ slots = batch.slots[batch.slot_indices]
+ input_lengths = batch.input_lengths_tensor
+ max_s = batch.max_current_length
+ lm_head_indices = batch.prefill_head_indices
+
+ speculative_ids = batch.speculative_ids
+
+ B, speculative_length = speculative_ids.shape
+ new_length = speculative_length + 1
+ new_input_ids = torch.cat(
+ [input_ids.unsqueeze(-1), speculative_ids], dim=1
+ ).reshape(-1)
+ arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
+ arange_int = arange.to(dtype=torch.int32)
+ new_position_ids = (
+ position_ids.unsqueeze(-1).expand(B, new_length) + arange
+ ).view(-1)
+ slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
+ input_lengths = (
+ input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
+ ).view(-1)
+ cache_lengths_tensor = (
+ batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
+ ).reshape(-1)
+
+ # Add Copy the block tables for all members
+ block_tables = (
+ block_tables.unsqueeze(1)
+ .expand(B, new_length, -1)
+ .reshape(B * new_length, -1)
+ .contiguous()
+ )
+ max_s = max_s + speculative_length
+
+ input_ids = new_input_ids
+ position_ids = new_position_ids
+ else:
+ input_ids = batch.input_ids
+ position_ids = batch.position_ids
+ cu_seqlen_prefill = batch.cu_seqlen_prefill
+ kv_cache = self.kv_cache
+ block_tables = batch.block_tables_tensor
+ slots = batch.slots[batch.slot_indices]
+ input_lengths = batch.input_lengths_tensor
+ cache_lengths_tensor = batch.cache_lengths_tensor
+ max_s = batch.max_current_length
+ lm_head_indices = batch.prefill_head_indices
+
+ if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
+ if position_ids.dim() == 1 and batch.prefilling:
+ position_ids = self.model.get_position_ids(
+ input_ids, batch.image_grid_thw
+ )
+ batch.position_ids = position_ids
+
+ if cu_seqlen_prefill is None and self.max_past() is not None:
+ # In decode, not prefill, we're actually overwriting the KV-cache
+ # in a circular buffer mode.
+ # This makes sure the max_s for the decode pass is correct.
+ max_s = min(self.max_past(), max_s)
+
+ kwargs = {}
+ if htorch.utils.internal.is_lazy():
+ kwargs["bypass_hpu_graphs"] = False
+
+ seqlen = Seqlen(
+ input_lengths=input_lengths,
+ cache_lengths=cache_lengths_tensor,
+ cu_seqlen_q=cu_seqlen_prefill,
+ )
+ if batch.prefill_cache_indices is not None:
+ slots_pad = torch.zeros_like(input_ids)
+ slots_pad[batch.prefill_cache_indices] = slots
+ slots = slots_pad
+ logits, speculative_logits = self.model.forward(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ cu_seqlen_prefill=cu_seqlen_prefill,
+ kv_cache=kv_cache,
+ slots=slots,
+ seqlen=trim_seqlen_metadata(seqlen),
+ hpu_attention_meta=batch.hpu_attn_meta,
+ lm_head_indices=lm_head_indices,
+ pixel_values=batch.pixel_values,
+ pixel_attention_mask=batch.pixel_attention_mask,
+ image_sizes=batch.image_sizes,
+ image_grid_thw=batch.image_grid_thw,
+ **kwargs,
+ )
+ if batch.prefill_cache_indices is not None:
+ batch.prefill_cache_indices = None
+ if batch.pixel_values is not None:
+ batch.pixel_values = None
+ if batch.pixel_attention_mask is not None:
+ batch.pixel_attention_mask = None
+ if batch.image_sizes is not None:
+ batch.image_sizes = None
+ if batch.image_grid_thw is not None:
+ batch.image_grid_thw = None
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/galactica.py b/backends/gaudi/server/text_generation_server/models/galactica.py
new file mode 100644
index 000000000..7c4e462c7
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/galactica.py
@@ -0,0 +1,156 @@
+import re
+import torch
+import torch.distributed
+
+
+from transformers import (
+ PreTrainedTokenizerBase,
+)
+from text_generation_server.models.causal_lm import CausalLMBatch
+from text_generation_server.pb import generate_pb2
+from text_generation_server.utils import (
+ NextTokenChooser,
+ StoppingCriteria,
+)
+from text_generation_server.utils.chunks import concat_text_chunks
+
+# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
+
+# we split individual characters inside special tokens like [START_DNA]
+CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")
+
+# token added to implement a custom sequence tokenization. This token is added at
+# corpus cleaning step and removed in pretokenization. The digits are added to increase the chance
+# that they do not occur in the corpus. The digits are escaped so that the token does not appear
+# literally in the source code in case we ever include it in the training data.
+SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"
+
+
+def _insert_split_marker(m: re.Match):
+ """
+ Applies split marker based on a regex match of special tokens such as
+ [START_DNA].
+ Parameters
+ ----------
+ n : str
+ Input text to split
+ Returns
+ ----------
+ str - the text with the split token added
+ """
+ start_token, _, sequence, end_token = m.groups()
+ sequence = re.sub(r"(.)", rf"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
+ return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"
+
+
+def escape_custom_split_sequence(text):
+ """
+ Applies custom splitting to the text for GALILEO's tokenization
+ Parameters
+ ----------
+ text : str
+ Input text to split
+ Returns
+ ----------
+ str - the text with the split token added
+ """
+ return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)
+
+
+# END CREDIT
+
+
+class GalacticaCausalLMBatch(CausalLMBatch):
+ @classmethod
+ def from_pb(
+ cls,
+ pb: generate_pb2.Batch,
+ tokenizer: PreTrainedTokenizerBase,
+ dtype: torch.dtype,
+ device: torch.device,
+ ) -> "GalacticaCausalLMBatch":
+ inputs = []
+ next_token_choosers = []
+ stopping_criterias = []
+ prefix_offsets = []
+ top_n_tokens = []
+ read_offsets = []
+ requests_idx_mapping = {}
+
+ # Parse batch
+ max_truncation = 0
+ padding_right_offset = 0
+ max_decode_tokens = 0
+ for i, r in enumerate(pb.requests):
+ requests_idx_mapping[r.id] = i
+ # Add escape_custom_split_sequence to the CausalLMBatch logic
+ inputs.append(
+ escape_custom_split_sequence(concat_text_chunks(r.input_chunks.chunks))
+ )
+ next_token_choosers.append(
+ NextTokenChooser.from_pb(r.parameters, device, tokenizer)
+ )
+ stopping_criteria = StoppingCriteria.from_pb(
+ r.stopping_parameters, tokenizer
+ )
+ stopping_criterias.append(stopping_criteria)
+ top_n_tokens.append(r.top_n_tokens)
+ max_truncation = max(max_truncation, r.truncate)
+ max_decode_tokens += stopping_criteria.max_new_tokens
+ padding_right_offset = max(
+ padding_right_offset, stopping_criteria.max_new_tokens
+ )
+
+ tokenized_inputs = tokenizer(
+ inputs,
+ return_tensors="pt",
+ padding=True,
+ return_token_type_ids=False,
+ truncation=True,
+ max_length=max_truncation,
+ ).to(device)
+ for _ in pb.requests:
+ input_len = tokenized_inputs["input_ids"].shape[1]
+ prefix_offsets.append(0)
+ read_offsets.append(input_len)
+
+ input_lengths = tokenized_inputs["attention_mask"].sum(1)
+ max_input_length = input_lengths.max()
+
+ input_ids = tokenized_inputs["input_ids"]
+ # Allocate maximum attention_mask
+ attention_mask = input_ids.new_zeros(
+ (pb.size, max_input_length + padding_right_offset)
+ )
+ # Copy tokenizer attention_mask into fully allocated attention_mask
+ attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
+
+ position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
+ position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
+ all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
+ top_n_tokens_tensor = torch.tensor(
+ top_n_tokens, device=device, dtype=torch.int64
+ )
+
+ max_tokens = len(inputs) * max_input_length + max_decode_tokens
+
+ return cls(
+ batch_id=pb.id,
+ requests=pb.requests,
+ requests_idx_mapping=requests_idx_mapping,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=None,
+ all_input_ids=list(all_input_ids),
+ input_lengths=input_lengths.tolist(),
+ prefix_offsets=prefix_offsets,
+ read_offsets=read_offsets,
+ next_token_choosers=next_token_choosers,
+ stopping_criterias=stopping_criterias,
+ top_n_tokens=top_n_tokens,
+ top_n_tokens_tensor=top_n_tokens_tensor,
+ max_input_length=max_input_length.item(),
+ padding_right_offset=padding_right_offset,
+ max_tokens=max_tokens,
+ )
diff --git a/backends/gaudi/server/text_generation_server/models/globals.py b/backends/gaudi/server/text_generation_server/models/globals.py
new file mode 100644
index 000000000..cd221e148
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/globals.py
@@ -0,0 +1,52 @@
+import os
+from typing import Dict, Optional
+from loguru import logger
+from text_generation_server.utils.log import log_master
+
+REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
+ATTENTION = os.getenv("ATTENTION", "default")
+# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
+PREFIX_CACHING = os.getenv("PREFIX_CACHING", "0").lower() in {
+ "1",
+ "true",
+}
+log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
+_expected = {"paged", "default"}
+assert (
+ ATTENTION in _expected
+), f"Attention is not valid {ATTENTION}, expected {_expected}"
+log_master(logger.info, f"Using Attention = {ATTENTION}")
+
+TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
+assert TGI_WIGGLE_ROOM > 0
+assert TGI_WIGGLE_ROOM < 1
+
+# This is overridden by the cli
+BLOCK_SIZE: int
+
+BLOCK_SIZE = 128
+
+
+# This is overridden at model loading.
+global MODEL_ID
+MODEL_ID = None
+
+
+def set_model_id(model_id: str):
+ global MODEL_ID
+ MODEL_ID = model_id
+
+
+# NOTE: eventually we should move this into the router and pass back the
+# index in all cases.
+ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
+
+
+def set_adapter_to_index(adapter_to_index: Dict[str, int]):
+ global ADAPTER_TO_INDEX
+ ADAPTER_TO_INDEX = adapter_to_index
+
+
+def get_adapter_to_index():
+ global ADAPTER_TO_INDEX
+ return ADAPTER_TO_INDEX
diff --git a/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py b/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py
new file mode 100644
index 000000000..98d7352a8
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/idefics_causal_lm.py
@@ -0,0 +1,882 @@
+from io import BytesIO
+from PIL import Image
+import torch
+import time
+
+from dataclasses import dataclass
+from opentelemetry import trace
+from transformers import (
+ AutoConfig,
+ AutoProcessor,
+ AutoTokenizer,
+ PreTrainedTokenizerBase,
+ ProcessorMixin,
+)
+from typing import Optional, Tuple, List, Type, Dict
+
+from text_generation_server.models import Model
+from text_generation_server.models.types import (
+ Batch,
+ Tokens,
+ Generation,
+ GeneratedText,
+)
+from text_generation_server.pb import generate_pb2
+from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
+import torch.distributed
+from text_generation_server.models.custom_modeling.idefics_modeling import (
+ IdeficsForVisionText2Text,
+)
+from text_generation_server.utils import (
+ initialize_torch_distributed,
+ weight_files,
+ Weights,
+)
+from text_generation_server.utils.quantization import get_loader
+
+tracer = trace.get_tracer(__name__)
+
+
+@dataclass
+class IdeficsCausalLMBatch(Batch):
+ batch_id: int
+ requests: List[generate_pb2.Request]
+ requests_idx_mapping: Dict[int, int]
+
+ # Decoder values
+ input_ids: torch.Tensor
+ attention_mask: torch.Tensor
+ position_ids: torch.Tensor
+ pixel_values: Optional[torch.Tensor]
+ image_hidden_states: Optional[torch.Tensor]
+ image_attention_mask: Optional[torch.Tensor]
+ past_key_values: Optional[List[Tuple]]
+
+ # All tokens
+ all_input_ids: List[torch.Tensor]
+
+ # Lengths of all generations present in the batch
+ input_lengths: List[int]
+ prefix_offsets: List[int]
+ read_offsets: List[int]
+
+ # Generation helpers
+ next_token_choosers: List[NextTokenChooser]
+ stopping_criterias: List[StoppingCriteria]
+
+ # Metadata used for padding
+ max_input_length: int
+ padding_right_offset: int
+
+ # Maximum number of tokens this batch will grow to
+ max_tokens: int
+
+ # Past metadata
+ keys_head_dim_last: bool = True
+
+ def to_pb(self) -> generate_pb2.CachedBatch:
+ return generate_pb2.CachedBatch(
+ id=self.batch_id,
+ request_ids=[r.id for r in self.requests],
+ size=len(self),
+ max_tokens=self.max_tokens,
+ )
+
+ @classmethod
+ def from_pb(
+ cls,
+ pb: generate_pb2.Batch,
+ tokenizer: PreTrainedTokenizerBase,
+ dtype: torch.dtype,
+ device: torch.device,
+ ) -> "IdeficsCausalLMBatch":
+ raise NotImplementedError
+
+ @classmethod
+ def from_pb_processor(
+ cls,
+ pb: generate_pb2.Batch,
+ tokenizer: PreTrainedTokenizerBase,
+ processor: ProcessorMixin, # Hack
+ config,
+ dtype: torch.dtype,
+ device: torch.device,
+ ) -> "IdeficsCausalLMBatch":
+ inputs = []
+ next_token_choosers = []
+ stopping_criterias = []
+ prefix_offsets = []
+ read_offsets = []
+ requests_idx_mapping = {}
+
+ # Parse batch
+ max_truncation = 0
+ padding_right_offset = 0
+ max_decode_tokens = 0
+ for i, r in enumerate(pb.requests):
+ requests_idx_mapping[r.id] = i
+ inputs.append(r.input_chunks.chunks)
+ next_token_choosers.append(
+ NextTokenChooser.from_pb(r.parameters, device, tokenizer)
+ )
+ stopping_criteria = StoppingCriteria.from_pb(
+ r.stopping_parameters, tokenizer
+ )
+ stopping_criterias.append(stopping_criteria)
+ max_truncation = max(max_truncation, r.truncate)
+ max_decode_tokens += stopping_criteria.max_new_tokens
+ padding_right_offset = max(
+ padding_right_offset, stopping_criteria.max_new_tokens
+ )
+
+ # TODO Check impact on idefics
+ prompts = []
+ for inp in inputs:
+ # Each input is encoded into a list, where each element of this input list is either a string or a URL
+ prompt = []
+ for chunk in inp:
+ chunk_type = chunk.WhichOneof("chunk")
+ if chunk_type == "text":
+ prompt.append(chunk.text)
+ elif chunk_type == "image":
+ image = Image.open(BytesIO(chunk.image.data))
+ prompt.append(image)
+ else:
+ raise RuntimeError(f"Invalid chunk type {chunk_type}")
+ prompts.append(prompt)
+
+ # The processor replaces the call to tokenizer, and
+ # a/ takes care of fetching images from the URL
+ # b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model
+ tokenized_inputs = processor(
+ prompts,
+ return_tensors="pt",
+ padding=True,
+ truncation=True,
+ max_length=max_truncation,
+ # TODO Check impact on idefics
+ # add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
+ ).to(device)
+ for _ in pb.requests:
+ input_len = tokenized_inputs["input_ids"].shape[1]
+ prefix_offsets.append(
+ input_len - 5
+ ) # To decode without potential fallbacks errors
+ read_offsets.append(
+ input_len
+ ) # To decode without potential fallbacks errors
+
+ input_lengths = tokenized_inputs["attention_mask"].sum(1)
+ max_input_length = input_lengths.max()
+
+ input_ids = tokenized_inputs["input_ids"]
+ pixel_values = tokenized_inputs.get("pixel_values", None)
+ image_hidden_states = None
+ # Allocate maximum attention_mask
+ attention_mask = input_ids.new_zeros(
+ (pb.size, max_input_length + padding_right_offset)
+ )
+ # Copy tokenizer attention_mask into fully allocated attention_mask
+ attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
+ # Do the same for image_attention_mask
+ if pixel_values is None:
+ image_attention_mask = None
+ else:
+ image_attention_mask = input_ids.new_zeros(
+ (
+ pb.size,
+ max_input_length + padding_right_offset,
+ pixel_values.size(1),
+ )
+ )
+ image_attention_mask[:, :max_input_length, :] = tokenized_inputs[
+ "image_attention_mask"
+ ]
+
+ position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
+ position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
+ all_input_ids = tokenized_inputs["input_ids"].T.split(
+ 1, dim=1
+ ) # It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list
+
+ max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
+
+ return cls(
+ batch_id=pb.id,
+ requests=pb.requests,
+ requests_idx_mapping=requests_idx_mapping,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ image_hidden_states=image_hidden_states,
+ image_attention_mask=image_attention_mask,
+ past_key_values=None,
+ all_input_ids=list(all_input_ids),
+ input_lengths=input_lengths.tolist(),
+ prefix_offsets=prefix_offsets,
+ read_offsets=read_offsets,
+ next_token_choosers=next_token_choosers,
+ stopping_criterias=stopping_criterias,
+ max_input_length=max_input_length.item(),
+ padding_right_offset=padding_right_offset,
+ max_tokens=max_tokens,
+ )
+
+ @tracer.start_as_current_span("filter")
+ def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]:
+ # It deletes requests from the batch. For instance when client lost connection
+ if len(request_ids) == 0:
+ raise ValueError("Batch must have at least one request")
+ if len(request_ids) == len(self):
+ return self
+
+ keep_indices = []
+
+ # New values after filtering
+ requests_idx_mapping = {}
+ requests = []
+ input_lengths = []
+ prefix_offsets = []
+ read_offsets = []
+ all_input_ids = []
+ max_input_length = 0
+
+ next_token_choosers = []
+ stopping_criterias = []
+
+ total_remaining_decode_tokens = 0
+ new_padding_right_offset = 0
+
+ for i, request_id in enumerate(request_ids):
+ idx = self.requests_idx_mapping[request_id]
+ requests_idx_mapping[request_id] = i
+ keep_indices.append(idx)
+
+ requests.append(self.requests[idx])
+ prefix_offsets.append(self.prefix_offsets[idx])
+ read_offsets.append(self.read_offsets[idx])
+ all_input_ids.append(self.all_input_ids[idx])
+
+ request_input_length = self.input_lengths[idx]
+ input_lengths.append(request_input_length)
+ max_input_length = max(max_input_length, request_input_length)
+
+ next_token_choosers.append(self.next_token_choosers[idx])
+ stopping_criteria = self.stopping_criterias[idx]
+ stopping_criterias.append(stopping_criteria)
+ remaining_decode_tokens = (
+ stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
+ )
+ total_remaining_decode_tokens += remaining_decode_tokens
+ new_padding_right_offset = max(
+ new_padding_right_offset, remaining_decode_tokens
+ )
+
+ # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
+ input_ids = self.input_ids[keep_indices]
+ position_ids = self.position_ids[keep_indices]
+ self.attention_mask = self.attention_mask[
+ keep_indices,
+ -(self.padding_right_offset + max_input_length) : (
+ self.attention_mask.shape[1] - self.padding_right_offset
+ )
+ + new_padding_right_offset,
+ ]
+ # Do the same for pixel_values and image_attention_mask
+ pixel_values = self.pixel_values[keep_indices]
+ self.image_attention_mask = self.image_attention_mask[
+ keep_indices,
+ -(self.padding_right_offset + max_input_length) : (
+ self.image_attention_mask.shape[1] - self.padding_right_offset
+ )
+ + new_padding_right_offset,
+ :,
+ ]
+ if self.image_hidden_states is None:
+ image_hidden_states = None
+ else:
+ image_hidden_states = self.image_hidden_states[keep_indices]
+
+ # Ensure that past_key_values tensors can be updated in-place
+ if type(self.past_key_values[0]) is tuple:
+ self.past_key_values = [list(layer) for layer in self.past_key_values]
+
+ # Update tensors in-place to allow incremental garbage collection
+ past_kv_length = max_input_length - 1
+ for layer in self.past_key_values:
+ past_keys, past_values = layer
+ if len(past_keys.shape) == 3:
+ # Force past to be of dim [self_size, num_heads, ...] for easy indexing
+ past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
+ past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
+ if self.keys_head_dim_last:
+ layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
+ else:
+ layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
+ del past_keys
+ layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
+ del past_values
+
+ max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
+
+ self.requests = requests
+ self.requests_idx_mapping = requests_idx_mapping
+ self.input_ids = input_ids
+ self.pixel_values = pixel_values
+ self.image_hidden_states = image_hidden_states
+ self.position_ids = position_ids
+ self.all_input_ids = all_input_ids
+ self.input_lengths = input_lengths
+ self.prefix_offsets = prefix_offsets
+ self.read_offsets = read_offsets
+ self.next_token_choosers = next_token_choosers
+ self.stopping_criterias = stopping_criterias
+ self.max_input_length = max_input_length
+ self.padding_right_offset = new_padding_right_offset
+ self.max_tokens = max_tokens
+
+ return self
+
+ @classmethod
+ @tracer.start_as_current_span("concatenate")
+ def concatenate(
+ cls, batches: List["IdeficsCausalLMBatch"]
+ ) -> "IdeficsCausalLMBatch":
+ # It adds new requests to the batch
+ # Used for padding
+ total_batch_size = 0
+ max_input_length = 0
+ max_num_images = 0
+ padding_right_offset = 0
+ for batch in batches:
+ total_batch_size += len(batch)
+ max_input_length = max(max_input_length, batch.max_input_length)
+ max_num_images = max(max_num_images, batch.pixel_values.size(1))
+ padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
+
+ # Batch attributes
+ requests = []
+ requests_idx_mapping = {}
+ input_lengths = []
+ prefix_offsets = []
+ read_offsets = []
+ all_input_ids = []
+ next_token_choosers = []
+ stopping_criterias = []
+ max_tokens = 0
+
+ # Batch tensors
+ input_ids = None
+ attention_mask = None
+ position_ids = None
+ pixel_values = None
+ image_hidden_states = None
+ image_attention_mask = None
+ past_key_values = []
+
+ # Used for slicing correctly inside the tensors
+ # Equivalent to a cumsum on batch sizes
+ start_index = 0
+ for i, batch in enumerate(batches):
+ requests.extend(batch.requests)
+ input_lengths.extend(batch.input_lengths)
+ prefix_offsets.extend(batch.prefix_offsets)
+ read_offsets.extend(batch.read_offsets)
+ all_input_ids.extend(batch.all_input_ids)
+ next_token_choosers.extend(batch.next_token_choosers)
+ stopping_criterias.extend(batch.stopping_criterias)
+
+ if i == 0:
+ requests_idx_mapping = batch.requests_idx_mapping
+ else:
+ # We need to offset the mapping for each batch by the cumulative batch size
+ for k, v in batch.requests_idx_mapping.items():
+ requests_idx_mapping[k] = v + start_index
+
+ # Slicing end index for this batch
+ end_index = start_index + len(batch)
+
+ # We only concatenate batches that did at least one step
+ if batch.past_key_values is None:
+ raise ValueError("only concatenate prefilled batches")
+
+ # Create empty tensor
+ # input_ids is always of shape [batch_size, 1]
+ # We do not need to pad it
+ if input_ids is None:
+ input_ids = batch.input_ids.new_empty((total_batch_size, 1))
+ # Copy to correct indices
+ input_ids[start_index:end_index] = batch.input_ids
+
+ # Create padded tensor
+ if attention_mask is None:
+ attention_mask = batch.attention_mask.new_zeros(
+ (total_batch_size, max_input_length + padding_right_offset),
+ )
+
+ curr_batch_max_num_images = batch.pixel_values.size(1)
+ if pixel_values is None:
+ pixel_values = batch.pixel_values.new_zeros(
+ (total_batch_size, max_num_images, 3, 224, 224)
+ )
+ pixel_values[start_index:end_index, :curr_batch_max_num_images] = (
+ batch.pixel_values
+ )
+
+ if image_attention_mask is None:
+ image_attention_mask = batch.image_attention_mask.new_zeros(
+ (
+ total_batch_size,
+ max_input_length + padding_right_offset,
+ max_num_images,
+ )
+ )
+
+ # We need to slice the attention mask to remove padding from previous steps
+ # and to remove unused allocated space
+ left_offset = max_input_length - batch.max_input_length
+ batch_left_offset = (
+ batch.attention_mask.shape[1]
+ - batch.max_input_length
+ - batch.padding_right_offset
+ )
+ attention_mask[
+ start_index:end_index,
+ left_offset:-padding_right_offset,
+ ] = batch.attention_mask[
+ :,
+ batch_left_offset : -batch.padding_right_offset,
+ ]
+ image_attention_mask[
+ start_index:end_index,
+ left_offset:-padding_right_offset,
+ :curr_batch_max_num_images,
+ ] = batch.image_attention_mask[
+ :, batch_left_offset : -batch.padding_right_offset, :
+ ]
+
+ # Create empty tensor
+ # position_ids is always of shape [batch_size, 1]
+ if position_ids is None:
+ position_ids = batch.position_ids.new_empty((total_batch_size, 1))
+ position_ids[start_index:end_index] = batch.position_ids
+
+ # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
+ # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
+ # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
+ # And ensure that we can update tensors in-place
+ if isinstance(batch.past_key_values[0], tuple):
+ batch.past_key_values = [
+ [t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
+ for layer in batch.past_key_values
+ ]
+ elif len(batch.past_key_values[0][0].shape) == 3:
+ for layer in batch.past_key_values:
+ for k, t in enumerate(layer):
+ layer[k] = t.view(len(batch), -1, *t.shape[-2:])
+
+ # Add eventual padding tokens that were added while concatenating
+ max_tokens += batch.max_tokens + (
+ max_input_length - batch.max_input_length
+ ) * len(batch)
+
+ start_index = end_index
+
+ first_past_kvs = batches[0].past_key_values
+ _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape
+
+ padded_past_values_shape = (
+ total_batch_size,
+ num_heads,
+ max_input_length - 1,
+ head_dim,
+ )
+
+ if batches[0].keys_head_dim_last:
+ padded_past_keys_shape = padded_past_values_shape
+ else:
+ # seq_length is last for BLOOM
+ padded_past_keys_shape = (
+ total_batch_size,
+ num_heads,
+ head_dim,
+ max_input_length - 1,
+ )
+
+ # Iterate over attention layers
+ # Concatenate past key values layer by layer to allow incremental garbage collection
+ for j in range(len(first_past_kvs)):
+ padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)
+ start_index = 0
+ for batch in batches:
+ past_keys = batch.past_key_values[j][0]
+ # Clear reference to the original tensor
+ batch.past_key_values[j][0] = None
+
+ # Slicing end index for this batch
+ end_index = start_index + len(batch)
+ # We slice the keys to remove the padding from previous batches
+ past_seq_len = batch.max_input_length - 1
+ if batch.keys_head_dim_last:
+ padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
+ past_keys[:, :, -past_seq_len:, :]
+ )
+ else:
+ # BLOOM case
+ padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
+ past_keys[:, :, :, -past_seq_len:]
+ )
+ del past_keys
+
+ start_index = end_index
+
+ padded_past_values = first_past_kvs[j][1].new_zeros(
+ padded_past_values_shape
+ )
+ start_index = 0
+ for batch in batches:
+ past_values = batch.past_key_values[j][1]
+ # Clear reference to the original tensor
+ batch.past_key_values[j][1] = None
+
+ # Slicing end index for this batch
+ end_index = start_index + len(batch)
+ # We slice the past values to remove the padding from previous batches
+ past_seq_len = batch.max_input_length - 1
+ padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
+ past_values[:, :, -past_seq_len:, :]
+ )
+ del past_values
+
+ # Update values
+ start_index = end_index
+
+ past_key_values.append([padded_past_keys, padded_past_values])
+
+ return cls(
+ batch_id=batches[0].batch_id,
+ requests=requests,
+ requests_idx_mapping=requests_idx_mapping,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ image_hidden_states=image_hidden_states,
+ image_attention_mask=image_attention_mask,
+ past_key_values=past_key_values,
+ all_input_ids=all_input_ids,
+ input_lengths=input_lengths,
+ prefix_offsets=prefix_offsets,
+ read_offsets=read_offsets,
+ next_token_choosers=next_token_choosers,
+ stopping_criterias=stopping_criterias,
+ max_input_length=max_input_length,
+ padding_right_offset=padding_right_offset,
+ keys_head_dim_last=batches[0].keys_head_dim_last,
+ max_tokens=max_tokens,
+ )
+
+ def __len__(self):
+ return len(self.requests)
+
+
+class IdeficsCausalLM(Model):
+ def __init__(
+ self,
+ model_id: str,
+ revision: Optional[str] = None,
+ quantize: Optional[str] = None,
+ speculator: Optional[str] = None,
+ dtype: Optional[torch.dtype] = None,
+ trust_remote_code: bool = False,
+ ):
+ self.quantize = quantize
+ self.process_group, rank, world_size = initialize_torch_distributed()
+ device = torch.device("hpu")
+ dtype = torch.bfloat16 if dtype is None else dtype
+ self.device, self.dtype = device, dtype
+
+ config = AutoConfig.from_pretrained(
+ model_id,
+ revision=revision,
+ trust_remote_code=trust_remote_code,
+ )
+ config.quantize = quantize
+ config.speculator = speculator
+ config.vision_config.quantize = quantize
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id,
+ revision=revision,
+ padding_side="left",
+ truncation_side="left",
+ trust_remote_code=trust_remote_code,
+ )
+ self.processor = AutoProcessor.from_pretrained(
+ model_id,
+ revision=revision,
+ padding_side="left",
+ truncation_side="left",
+ trust_remote_code=trust_remote_code,
+ )
+
+ weights_loader = get_loader(
+ quantize=quantize, model_id=model_id, revision=revision
+ )
+ torch.distributed.barrier(group=self.process_group)
+ filenames = weight_files(model_id, revision=revision, extension=".safetensors")
+ weights = Weights(
+ filenames,
+ device=device,
+ dtype=dtype,
+ process_group=self.process_group,
+ weights_loader=weights_loader,
+ )
+
+ model = IdeficsForVisionText2Text(config, weights)
+
+ self.config = config
+
+ torch.distributed.barrier(group=self.process_group)
+ super().__init__(
+ model_id=model_id,
+ model=model,
+ tokenizer=tokenizer,
+ requires_padding=True,
+ dtype=dtype,
+ device=device,
+ rank=rank,
+ world_size=world_size,
+ )
+
+ @property
+ def batch_type(self) -> Type[IdeficsCausalLMBatch]:
+ return IdeficsCausalLMBatch
+
+ def forward(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ pixel_values,
+ image_hidden_states,
+ image_attention_mask,
+ past_key_values: Optional = None,
+ ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
+ # Model Forward
+ kwargs = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "pixel_values": pixel_values,
+ "image_hidden_states": image_hidden_states,
+ "image_attention_mask": image_attention_mask,
+ "past_key_values": past_key_values,
+ "use_cache": True,
+ "return_dict": True,
+ }
+ if self.has_position_ids:
+ kwargs["position_ids"] = position_ids
+
+ outputs, speculative_logits = self.model.forward(**kwargs)
+ return (
+ outputs.logits,
+ speculative_logits,
+ outputs.past_key_values,
+ outputs.image_hidden_states,
+ )
+
+ @tracer.start_as_current_span("generate_token")
+ def generate_token(
+ self, batch: IdeficsCausalLMBatch
+ ) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch], Tuple[int, int]]:
+ start = time.time_ns()
+ # slice the attention mask to the correct shape
+ attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
+ if batch.image_attention_mask is None:
+ image_attention_mask = None
+ else:
+ if batch.input_ids.size(1) == 1:
+ # THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images),
+ # but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension
+ # this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
+ # token need to attend to the encoder hidden states (i.e. the vision encoder)
+ # Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
+ image_attention_mask = batch.image_attention_mask[
+ :, -(batch.padding_right_offset + 1)
+ ].unsqueeze(1)
+ else:
+ image_attention_mask = batch.image_attention_mask[
+ :, : -batch.padding_right_offset
+ ]
+
+ logits, speculative_logits, past, image_hidden_states = self.forward(
+ input_ids=batch.input_ids,
+ attention_mask=attention_mask,
+ position_ids=batch.position_ids,
+ pixel_values=batch.pixel_values,
+ image_hidden_states=batch.image_hidden_states,
+ image_attention_mask=image_attention_mask,
+ past_key_values=batch.past_key_values,
+ )
+ # Hardcoded remove image tokens
+ logits[:, 32000:32001] = torch.finfo(logits.dtype).min
+
+ start_decode = time.time_ns()
+
+ # Results
+ generations: List[Generation] = []
+ stopped = True
+
+ # Zipped iterator
+ iterator = zip(
+ batch.requests,
+ batch.input_lengths,
+ batch.prefix_offsets,
+ batch.read_offsets,
+ logits,
+ batch.next_token_choosers,
+ batch.stopping_criterias,
+ batch.all_input_ids,
+ )
+
+ # For each member of the batch
+ for i, (
+ request,
+ input_length,
+ prefix_offset,
+ read_offset,
+ logits,
+ next_token_chooser,
+ stopping_criteria,
+ all_input_ids,
+ ) in enumerate(iterator):
+ # Select next token
+ next_token_id, logprobs = next_token_chooser(
+ all_input_ids.view(1, -1), logits[-1:, :]
+ )
+
+ # Append next token to all tokens
+ all_input_ids = torch.cat([all_input_ids, next_token_id])
+ new_input_length = input_length + 1
+
+ # Generated token
+ next_token_logprob = logprobs[-1, next_token_id]
+ next_token_id_squeezed = next_token_id.squeeze()
+ next_token_text, prefix_offset, read_offset = self.decode_token(
+ all_input_ids[:, 0], prefix_offset, read_offset
+ )
+
+ # Evaluate stopping criteria
+ stop, reason = stopping_criteria(
+ next_token_id_squeezed,
+ next_token_text,
+ )
+
+ if not stop:
+ stopped = False
+
+ # Shard generations
+ # All generations will be appended in the rust sharded client
+ if i % self.world_size == self.rank:
+ if stop:
+ # Decode generated tokens
+ output_text, _, _ = self.decode_token(
+ all_input_ids[:, 0],
+ prefix_offset=len(all_input_ids)
+ - stopping_criteria.current_tokens
+ - 1,
+ read_offset=len(all_input_ids)
+ - stopping_criteria.current_tokens,
+ skip_special_tokens=True,
+ )
+ # Get seed
+ if isinstance(next_token_chooser.choice, Sampling):
+ seed = next_token_chooser.choice.seed
+ else:
+ seed = None
+
+ generated_text = GeneratedText(
+ output_text, stopping_criteria.current_tokens, reason, seed
+ )
+ else:
+ generated_text = None
+
+ # Prefill
+ if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
+ # Remove generated token to only have prefill and add nan for first prompt token
+ prefill_logprobs = [float("nan")] + torch.log_softmax(
+ logits, -1
+ ).gather(1, all_input_ids[1:]).squeeze(1)[
+ -new_input_length:-1
+ ].tolist()
+ prefill_token_ids = all_input_ids[-new_input_length:-1]
+ prefill_texts = self.tokenizer.batch_decode(
+ prefill_token_ids,
+ clean_up_tokenization_spaces=False,
+ skip_special_tokens=False,
+ )
+ prefill_tokens = Tokens(
+ prefill_token_ids,
+ prefill_logprobs,
+ prefill_texts,
+ is_special=[],
+ )
+ else:
+ prefill_tokens = None
+
+ top_tokens = None
+
+ generation = Generation(
+ request.id,
+ prefill_tokens,
+ Tokens(
+ [next_token_id_squeezed],
+ [next_token_logprob],
+ [next_token_text],
+ [next_token_id_squeezed.item() in self.all_special_ids],
+ ),
+ generated_text,
+ top_tokens,
+ )
+
+ generations.append(generation)
+
+ # Update values
+ batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
+ next_token_id_squeezed.item()
+ )
+ batch.input_ids[i, 0] = next_token_id
+ batch.all_input_ids[i] = all_input_ids
+ batch.input_lengths[i] = new_input_length
+ batch.prefix_offsets[i] = prefix_offset
+ batch.read_offsets[i] = read_offset
+ batch.max_input_length = max(batch.max_input_length, new_input_length)
+
+ # We finished all generations in the batch; there is no next batch
+ if stopped:
+ forward_ns = start_decode - start
+ decode_ns = time.time_ns() - start_decode
+ return generations, None, (forward_ns, decode_ns)
+
+ # Slice unused values from prefill
+ batch.input_ids = batch.input_ids[:, :1]
+
+ # Update attention_mask as we added a new token to input_ids
+ batch.attention_mask[:, -batch.padding_right_offset] = 1
+ batch.image_attention_mask[:, -batch.padding_right_offset, :] = (
+ batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]
+ )
+ # Decrease right offset
+ batch.padding_right_offset -= 1
+
+ # Update position_ids
+ batch.position_ids = batch.position_ids[:, -1:] + 1
+
+ # Update past key values
+ batch.past_key_values = past
+ batch.image_hidden_states = image_hidden_states
+
+ forward_ns = start_decode - start
+ decode_ns = time.time_ns() - start_decode
+ return generations, batch, (forward_ns, decode_ns)
diff --git a/backends/gaudi/server/text_generation_server/models/mamba.py b/backends/gaudi/server/text_generation_server/models/mamba.py
new file mode 100644
index 000000000..f6dcde68a
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/mamba.py
@@ -0,0 +1,814 @@
+import torch
+import torch.distributed
+from transformers import AutoTokenizer, PreTrainedTokenizerBase
+from typing import Optional
+from text_generation_server.models.custom_modeling.mamba_modeling import (
+ MambaConfig,
+)
+from loguru import logger
+from text_generation_server.pb import generate_pb2
+from text_generation_server.utils import (
+ initialize_torch_distributed,
+ weight_files,
+ Weights,
+)
+from text_generation_server.models.globals import CUDA_GRAPHS, MEM_POOL
+import time
+from text_generation_server.models.custom_modeling.mamba_modeling import (
+ MambaModel,
+ InferenceParams,
+)
+from text_generation_server.models import Model
+from typing import Any, List, Tuple, Type, Dict
+from text_generation_server.models.types import (
+ Batch,
+ Tokens,
+ Generation,
+ GeneratedText,
+)
+from text_generation_server.utils.chunks import concat_text_chunks
+from text_generation_server.utils.quantization import get_loader
+from text_generation_server.utils.tokens import batch_top_tokens, Sampling
+from dataclasses import dataclass
+from text_generation_server.utils import NextTokenChooser, StoppingCriteria
+
+
+def new_inference_params(
+ n_blocks: int,
+ batch_size: int,
+ d_inner: int,
+ d_conv: int,
+ d_state: int,
+ seqlen_offset: int,
+ dtype: torch.dtype,
+ device: torch.device,
+):
+ max_seqlen = 0
+ conv_states = torch.zeros(
+ (
+ n_blocks,
+ batch_size,
+ d_inner,
+ d_conv,
+ ),
+ device=device,
+ dtype=dtype,
+ )
+ ssm_states = torch.zeros(
+ (
+ n_blocks,
+ batch_size,
+ d_inner,
+ d_state,
+ ),
+ device=device,
+ dtype=dtype,
+ )
+ inference_params = InferenceParams(
+ max_seqlen=max_seqlen,
+ max_batch_size=batch_size,
+ seqlen_offset=seqlen_offset,
+ conv_states=conv_states,
+ ssm_states=ssm_states,
+ )
+ return inference_params
+
+
+@dataclass
+class MambaBatch(Batch):
+ batch_id: int
+ requests: List[generate_pb2.Request]
+ requests_idx_mapping: Dict[int, int]
+
+ # Decoder values
+ input_ids: torch.Tensor
+
+ # All tokens
+ all_input_ids: List[torch.Tensor]
+
+ # Lengths of all generations present in the batch
+ input_lengths: List[int]
+ prefix_offsets: List[int]
+ read_offsets: List[int]
+
+ # Generation helpers
+ next_token_choosers: List[NextTokenChooser]
+ stopping_criterias: List[StoppingCriteria]
+ top_n_tokens: List[int]
+ top_n_tokens_tensor: torch.Tensor
+
+ # Metadata used for padding
+ max_input_length: int
+ padding_right_offset: int
+
+ # Maximum number of tokens this batch will grow to
+ max_tokens: int
+
+ # Past metadata
+ keys_head_dim_last: bool = True
+
+ # Inference params
+ inference_params: Optional[Dict[str, Any]] = None
+
+ def to_pb(self) -> generate_pb2.CachedBatch:
+ return generate_pb2.CachedBatch(
+ id=self.batch_id,
+ request_ids=[r.id for r in self.requests],
+ size=len(self),
+ max_tokens=self.max_tokens,
+ )
+
+ @classmethod
+ def from_pb(
+ cls,
+ pb: generate_pb2.Batch,
+ tokenizer: PreTrainedTokenizerBase,
+ dtype: torch.dtype,
+ device: torch.device,
+ ) -> "MambaBatch":
+ inputs = []
+ next_token_choosers = []
+ stopping_criterias = []
+ top_n_tokens = []
+ prefix_offsets = []
+ read_offsets = []
+ requests_idx_mapping = {}
+
+ # Parse batch
+ max_truncation = 0
+ padding_right_offset = 0
+ max_decode_tokens = 0
+ for i, r in enumerate(pb.requests):
+ requests_idx_mapping[r.id] = i
+ inputs.append(concat_text_chunks(r.input_chunks.chunks))
+ next_token_choosers.append(
+ NextTokenChooser.from_pb(r.parameters, device, tokenizer)
+ )
+ stopping_criteria = StoppingCriteria.from_pb(
+ r.stopping_parameters, tokenizer
+ )
+ stopping_criterias.append(stopping_criteria)
+ top_n_tokens.append(r.top_n_tokens)
+ max_truncation = max(max_truncation, r.truncate)
+ max_decode_tokens += stopping_criteria.max_new_tokens
+ padding_right_offset = max(
+ padding_right_offset, stopping_criteria.max_new_tokens
+ )
+
+ tokenized_inputs = tokenizer(
+ inputs,
+ return_tensors="pt",
+ padding=True,
+ return_token_type_ids=False,
+ truncation=True,
+ max_length=max_truncation,
+ ).to(device)
+ for _ in pb.requests:
+ input_len = tokenized_inputs["input_ids"].shape[1]
+ prefix_offsets.append(input_len - 5)
+ read_offsets.append(input_len)
+
+ input_lengths = tokenized_inputs["attention_mask"].sum(1)
+ max_input_length = input_lengths.max()
+ input_ids = tokenized_inputs["input_ids"]
+ all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
+ top_n_tokens_tensor = torch.tensor(
+ top_n_tokens, device=device, dtype=torch.int64
+ )
+ max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
+ return cls(
+ batch_id=pb.id,
+ requests=pb.requests,
+ requests_idx_mapping=requests_idx_mapping,
+ input_ids=input_ids,
+ # past_input_ids=None,
+ all_input_ids=list(all_input_ids),
+ input_lengths=input_lengths.tolist(),
+ prefix_offsets=prefix_offsets,
+ read_offsets=read_offsets,
+ next_token_choosers=next_token_choosers,
+ stopping_criterias=stopping_criterias,
+ top_n_tokens=top_n_tokens,
+ top_n_tokens_tensor=top_n_tokens_tensor,
+ max_input_length=max_input_length.item(),
+ padding_right_offset=padding_right_offset,
+ max_tokens=max_tokens,
+ )
+
+ def filter(self, request_ids: List[int]) -> Optional["MambaBatch"]:
+ if len(request_ids) == 0:
+ raise ValueError("Batch must have at least one request")
+ if len(request_ids) == len(self):
+ return self
+
+ keep_indices = []
+
+ # New values after filtering
+ requests_idx_mapping = {}
+ requests = []
+ input_lengths = []
+ prefix_offsets = []
+ read_offsets = []
+ all_input_ids = []
+ max_input_length = 0
+
+ next_token_choosers = []
+ stopping_criterias = []
+ top_n_tokens = []
+
+ total_remaining_decode_tokens = 0
+ new_padding_right_offset = 0
+
+ indices = []
+ for i, request_id in enumerate(request_ids):
+ idx = self.requests_idx_mapping[request_id]
+ requests_idx_mapping[request_id] = i
+ keep_indices.append(idx)
+
+ requests.append(self.requests[idx])
+ prefix_offsets.append(self.prefix_offsets[idx])
+ read_offsets.append(self.read_offsets[idx])
+ all_input_ids.append(self.all_input_ids[idx])
+
+ request_input_length = self.input_lengths[idx]
+ input_lengths.append(request_input_length)
+ max_input_length = max(max_input_length, request_input_length)
+ indices.append(idx)
+
+ next_token_choosers.append(self.next_token_choosers[idx])
+ stopping_criteria = self.stopping_criterias[idx]
+ stopping_criterias.append(stopping_criteria)
+ top_n_tokens.append(self.top_n_tokens[idx])
+ remaining_decode_tokens = (
+ stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
+ )
+ total_remaining_decode_tokens += remaining_decode_tokens
+ new_padding_right_offset = max(
+ new_padding_right_offset, remaining_decode_tokens
+ )
+
+ # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
+ input_ids = self.input_ids[keep_indices]
+
+ top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
+ max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
+
+ self.requests = requests
+ self.requests_idx_mapping = requests_idx_mapping
+ self.input_ids = input_ids
+ self.all_input_ids = all_input_ids
+ self.input_lengths = input_lengths
+ self.prefix_offsets = prefix_offsets
+ self.read_offsets = read_offsets
+ self.next_token_choosers = next_token_choosers
+ self.stopping_criterias = stopping_criterias
+ self.top_n_tokens = top_n_tokens
+ self.top_n_tokens_tensor = top_n_tokens_tensor
+ self.max_input_length = max_input_length
+ self.padding_right_offset = new_padding_right_offset
+ self.max_tokens = max_tokens
+
+ # TODO
+ # Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
+ self.inference_params.conv_states = self.inference_params.conv_states[
+ :, indices
+ ]
+ self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices]
+ return self
+
+ @classmethod
+ def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch":
+ # Used for padding
+ total_batch_size = 0
+ max_input_length = 0
+ padding_right_offset = 0
+ for batch in batches:
+ total_batch_size += len(batch)
+ max_input_length = max(max_input_length, batch.max_input_length)
+ padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
+
+ # Batch attributes
+ requests = []
+ requests_idx_mapping = {}
+ input_lengths = []
+ prefix_offsets = []
+ read_offsets = []
+ all_input_ids = []
+ next_token_choosers = []
+ stopping_criterias = []
+ top_n_tokens = []
+ max_tokens = 0
+ seqlen_offset = 0
+
+ (n_blocks, _, d_inner, d_conv) = batches[0].inference_params.conv_states.shape
+ (_, _, _, d_state) = batches[0].inference_params.ssm_states.shape
+ dtype = batches[0].inference_params.conv_states.dtype
+ device = batches[0].inference_params.conv_states.device
+ inference_params = new_inference_params(
+ n_blocks=n_blocks,
+ batch_size=total_batch_size,
+ d_state=d_state,
+ d_conv=d_conv,
+ d_inner=d_inner,
+ seqlen_offset=seqlen_offset,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Batch tensors
+ input_ids = None
+ top_n_tokens_tensor = None
+
+ # Used for slicing correctly inside the tensors
+ # Equivalent to a cumsum on batch sizes
+ start_index = 0
+ for i, batch in enumerate(batches):
+ requests.extend(batch.requests)
+ input_lengths.extend(batch.input_lengths)
+ prefix_offsets.extend(batch.prefix_offsets)
+ read_offsets.extend(batch.read_offsets)
+ all_input_ids.extend(batch.all_input_ids)
+ next_token_choosers.extend(batch.next_token_choosers)
+ stopping_criterias.extend(batch.stopping_criterias)
+ top_n_tokens.extend(batch.top_n_tokens)
+
+ if i == 0:
+ requests_idx_mapping = batch.requests_idx_mapping
+ else:
+ # We need to offset the mapping for each batch by the cumulative batch size
+ for k, v in batch.requests_idx_mapping.items():
+ requests_idx_mapping[k] = v + start_index
+
+ # Slicing end index for this batch
+ end_index = start_index + len(batch)
+
+ # Create empty tensor
+ # input_ids is always of shape [batch_size, 1]
+ # We do not need to pad it
+ if input_ids is None:
+ input_ids = batch.input_ids.new_empty((total_batch_size, 1))
+ # Copy to correct indices
+ input_ids[start_index:end_index] = batch.input_ids
+
+ if top_n_tokens_tensor is None:
+ top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
+ total_batch_size,
+ )
+ top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
+
+ # Add eventual padding tokens that were added while concatenating
+ max_tokens += batch.max_tokens + (
+ max_input_length - batch.max_input_length
+ ) * len(batch)
+
+ inference_params.max_seqlen = max(
+ inference_params.max_seqlen, batch.inference_params.max_seqlen
+ )
+ assert batch.inference_params.seqlen_offset != 0, "Invalid seqlen offset"
+ inference_params.seqlen_offset = max(
+ inference_params.seqlen_offset, batch.inference_params.seqlen_offset
+ )
+
+ inference_params.conv_states[:, start_index:end_index] = (
+ batch.inference_params.conv_states
+ )
+ inference_params.ssm_states[:, start_index:end_index] = (
+ batch.inference_params.ssm_states
+ )
+
+ start_index = end_index
+
+ return cls(
+ batch_id=batches[0].batch_id,
+ requests=requests,
+ requests_idx_mapping=requests_idx_mapping,
+ input_ids=input_ids,
+ all_input_ids=all_input_ids,
+ input_lengths=input_lengths,
+ prefix_offsets=prefix_offsets,
+ read_offsets=read_offsets,
+ next_token_choosers=next_token_choosers,
+ stopping_criterias=stopping_criterias,
+ top_n_tokens=top_n_tokens,
+ top_n_tokens_tensor=top_n_tokens_tensor,
+ max_input_length=max_input_length,
+ padding_right_offset=padding_right_offset,
+ keys_head_dim_last=batches[0].keys_head_dim_last,
+ max_tokens=max_tokens,
+ inference_params=inference_params,
+ )
+
+ def __len__(self):
+ return len(self.requests)
+
+
+class Mamba(Model):
+ def __init__(
+ self,
+ model_id: str,
+ revision: Optional[str] = None,
+ quantize: Optional[str] = None,
+ speculator: Optional[str] = None,
+ dtype: Optional[torch.dtype] = None,
+ trust_remote_code: bool = False,
+ ):
+ self.quantize = quantize
+ self.process_group, _rank, world_size = initialize_torch_distributed()
+ if world_size > 1:
+ raise RuntimeError("Mamba does not support Tensor Parallelism (TP)")
+ self.cuda_graphs = {}
+ if torch.cuda.is_available():
+ device = torch.device("cuda")
+ # Bf16 is important. In f16 accumulations in the matmul are causing
+ # differences while the server is under load.
+ # This is detectable by the integration load test
+ dtype = torch.bfloat16 if dtype is None else dtype
+ else:
+ if quantize:
+ raise ValueError("quantization is not available on CPU")
+
+ device = torch.device("cpu")
+ dtype = torch.float32 if dtype is None else dtype
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ "EleutherAI/gpt-neox-20b",
+ revision=revision,
+ padding_side="left",
+ truncation_side="left",
+ trust_remote_code=trust_remote_code,
+ )
+ config = MambaConfig.from_pretrained(
+ model_id, revision=revision, trust_remote_code=trust_remote_code
+ )
+
+ tokenizer.bos_token_id = config.bos_token_id
+ tokenizer.eos_token_id = config.eos_token_id
+ tokenizer.pad_token = tokenizer.eos_token
+
+ config.quantize = quantize
+ config.speculator = speculator
+ torch.distributed.barrier(group=self.process_group)
+ weights_loader = get_loader(
+ quantize=quantize, model_id=model_id, revision=revision
+ )
+ filenames = weight_files(model_id, revision=revision, extension=".safetensors")
+ weights = Weights(
+ filenames,
+ device,
+ dtype,
+ process_group=self.process_group,
+ weights_loader=weights_loader,
+ )
+ model = MambaModel(config, weights)
+ torch.distributed.barrier(group=self.process_group)
+ super(Mamba, self).__init__(
+ model_id=model_id,
+ model=model,
+ tokenizer=tokenizer,
+ requires_padding=True,
+ dtype=dtype,
+ device=device,
+ )
+
+ @property
+ def batch_type(self) -> Type[MambaBatch]:
+ return MambaBatch
+
+ def warmup(self, batch) -> Optional[int]:
+ # TODO: implement warmup for Mamba if needed
+ if CUDA_GRAPHS:
+ if self.speculate is None or self.speculate == 0:
+ try:
+ logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
+ # Warmup cuda graphs
+ for bs in CUDA_GRAPHS:
+ self.cuda_graph_warmup(bs)
+ except Exception:
+ logger.exception("Decode cuda graph warmup failed")
+ else:
+ logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
+
+ return None
+
+ def cuda_graph_warmup(self, batch_size: int):
+ input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)
+ n_blocks = len(self.model.blocks)
+
+ d_state = self.model.config.d_state
+ d_conv = self.model.config.d_conv
+ # Inner takes the expand multiplication
+ d_inner = self.model.config.d_inner
+
+ # Important seqlen_offset to go through the update mecanism with the state
+ seqlen_offset = 1
+ inference_params = new_inference_params(
+ n_blocks=n_blocks,
+ batch_size=batch_size,
+ d_state=d_state,
+ d_conv=d_conv,
+ d_inner=d_inner,
+ seqlen_offset=seqlen_offset,
+ device=self.device,
+ dtype=self.dtype,
+ )
+
+ graph = torch.cuda.CUDAGraph()
+
+ torch.cuda.synchronize()
+ # Run once outside to warmup
+ self.model.forward(input_ids=input_ids, inference_params=inference_params)
+ torch.cuda.synchronize()
+
+ with torch.cuda.graph(graph, pool=MEM_POOL):
+ logits, speculative_logits = self.model.forward(
+ input_ids=input_ids, inference_params=inference_params
+ )
+ torch.cuda.synchronize()
+ graph_dict = {
+ "input_ids": input_ids,
+ "inference_params": inference_params,
+ "graph": graph,
+ "logits": logits,
+ "speculative_logits": speculative_logits,
+ }
+ self.cuda_graphs[batch_size] = graph_dict
+
+ def tunableop_warmup(self, batch_size: int, seqlen: int):
+ input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)
+ n_blocks = len(self.model.blocks)
+
+ d_state = self.model.config.d_state
+ d_conv = self.model.config.d_conv
+ # Inner takes the expand multiplication
+ d_inner = self.model.config.d_inner
+
+ # Important seqlen_offset to go through the update mecanism with the state
+ seqlen_offset = 1
+ inference_params = new_inference_params(
+ n_blocks=n_blocks,
+ batch_size=seqlen,
+ d_state=d_state,
+ d_conv=d_conv,
+ d_inner=d_inner,
+ seqlen_offset=seqlen_offset,
+ device=self.device,
+ dtype=self.dtype,
+ )
+
+ self.model.forward(input_ids=input_ids, inference_params=inference_params)
+
+ def forward(
+ self, input_ids: torch.Tensor, inference_params: Any
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ bs = input_ids.shape[0]
+ padded_bs = bs
+ if bs == 3:
+ padded_bs = 4
+ elif 3 < bs <= 8:
+ padded_bs = 8
+ elif bs > 8:
+ padded_bs = (bs + 7) // 8 * 8
+
+ # Try to find an associated cuda graph
+ cuda_graph = self.cuda_graphs.get(padded_bs, None)
+ is_prefill = inference_params is None or inference_params.seqlen_offset == 0
+
+ if is_prefill or cuda_graph is None:
+ return self.model(
+ input_ids,
+ inference_params=inference_params,
+ )
+
+ # Copy inputs to the static inputs of the cuda graph
+ # Static inputs are potentially padded
+ cuda_graph["input_ids"][:bs] = input_ids
+ cuda_graph["inference_params"].conv_states[
+ :, :bs
+ ] = inference_params.conv_states
+ cuda_graph["inference_params"].ssm_states[:, :bs] = inference_params.ssm_states
+
+ # Replay the graph
+ cuda_graph["graph"].replay()
+
+ inference_params.conv_states.copy_(
+ cuda_graph["inference_params"].conv_states[:, :bs]
+ )
+ inference_params.ssm_states.copy_(
+ cuda_graph["inference_params"].ssm_states[:, :bs]
+ )
+ # Slice output to the correct shape
+ speculative_logits = (
+ cuda_graph["speculative_logits"][:bs]
+ if cuda_graph["speculative_logits"] is not None
+ else None
+ )
+ logits = cuda_graph["logits"][:bs]
+ return logits, speculative_logits
+
+ def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
+ start = time.time_ns()
+ input_ids = (
+ batch.input_ids
+ ) # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
+
+ batch_size, max_seqlen = input_ids.shape
+ # Inference params
+
+ if batch.inference_params is None:
+ # 0 is important here
+ seqlen_offset = 0
+ n_blocks = len(self.model.blocks)
+ d_state = self.model.config.d_state
+ d_conv = self.model.config.d_conv
+ d_inner = self.model.config.d_inner
+ inference_params = new_inference_params(
+ n_blocks=n_blocks,
+ batch_size=batch_size,
+ d_state=d_state,
+ d_conv=d_conv,
+ d_inner=d_inner,
+ seqlen_offset=seqlen_offset,
+ device=self.device,
+ dtype=self.dtype,
+ )
+ batch.inference_params = inference_params
+
+ # Forward pass
+ logits, speculative_logits = self.forward(
+ input_ids, inference_params=batch.inference_params
+ )
+
+ # batch.inference_params = new_inference_params
+ # Results
+ generations: List[Generation] = []
+ stopped = True
+
+ # Speculation is not active for causal
+ accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
+ batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
+ batch.top_n_tokens,
+ batch.top_n_tokens_tensor,
+ torch.log_softmax(logits[:, -1], -1),
+ accepted_ids,
+ )
+
+ start_decode = time.time_ns()
+
+ # Zipped iterator
+ iterator = zip(
+ batch.requests,
+ batch.input_lengths,
+ batch.prefix_offsets,
+ batch.read_offsets,
+ logits,
+ batch.next_token_choosers,
+ batch.stopping_criterias,
+ batch.all_input_ids,
+ batch.top_n_tokens,
+ batch_top_token_ids,
+ batch_top_token_logprobs,
+ )
+
+ # For each member of the batch
+ for i, (
+ request,
+ input_length,
+ prefix_offset,
+ read_offset,
+ logits,
+ next_token_chooser,
+ stopping_criteria,
+ all_input_ids,
+ top_n_tokens,
+ top_token_ids,
+ top_token_logprobs,
+ ) in enumerate(iterator):
+ # Select next token
+ next_token_id, logprobs = next_token_chooser(
+ all_input_ids.view(1, -1), logits[-1:, :]
+ )
+
+ # Append next token to all tokens
+ all_input_ids = torch.cat([all_input_ids, next_token_id])
+ new_input_length = input_length + 1
+
+ # Generated token
+ next_token_logprob = logprobs[-1, next_token_id]
+ next_token_id_squeezed = next_token_id.squeeze()
+ next_token_text, prefix_offset, read_offset = self.decode_token(
+ all_input_ids[:, 0], prefix_offset, read_offset
+ )
+
+ # Evaluate stopping criteria
+ stop, reason = stopping_criteria(
+ next_token_id_squeezed,
+ next_token_text,
+ )
+
+ if not stop:
+ stopped = False
+
+ # Shard generations
+ # All generations will be appended in the rust sharded client
+ if i % self.world_size == self.rank:
+ if stop:
+ # Decode generated tokens
+ output_text, _, _ = self.decode_token(
+ all_input_ids[:, 0],
+ prefix_offset=len(all_input_ids)
+ - stopping_criteria.current_tokens
+ - 1,
+ read_offset=len(all_input_ids)
+ - stopping_criteria.current_tokens,
+ skip_special_tokens=True,
+ )
+ # Get seed
+ if isinstance(next_token_chooser.choice, Sampling):
+ seed = next_token_chooser.choice.seed
+ else:
+ seed = None
+
+ generated_text = GeneratedText(
+ output_text, stopping_criteria.current_tokens, reason, seed
+ )
+ else:
+ generated_text = None
+
+ if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
+ # Remove generated token to only have prefill and add nan for first prompt token
+ prefill_logprobs = [float("nan")] + torch.log_softmax(
+ logits, -1
+ ).gather(1, all_input_ids[1:]).squeeze(1)[
+ -new_input_length:-1
+ ].tolist()
+ prefill_token_ids = all_input_ids[-new_input_length:-1]
+ prefill_texts = self.tokenizer.batch_decode(
+ prefill_token_ids,
+ clean_up_tokenization_spaces=False,
+ skip_special_tokens=False,
+ )
+ prefill_tokens = Tokens(
+ prefill_token_ids,
+ prefill_logprobs,
+ prefill_texts,
+ is_special=[],
+ )
+ else:
+ prefill_tokens = None
+
+ if top_n_tokens > 0:
+ toptoken_texts = self.tokenizer.batch_decode(
+ top_token_ids,
+ clean_up_tokenization_spaces=False,
+ skip_special_tokens=False,
+ )
+ special_toptokens = [
+ token_id in self.all_special_ids for token_id in top_token_ids
+ ]
+ top_tokens = Tokens(
+ top_token_ids,
+ top_token_logprobs,
+ toptoken_texts,
+ special_toptokens,
+ )
+ else:
+ top_tokens = None
+
+ generation = Generation(
+ request.id,
+ prefill_tokens,
+ Tokens(
+ [next_token_id_squeezed],
+ [next_token_logprob],
+ [next_token_text],
+ [next_token_id_squeezed.item() in self.all_special_ids],
+ ),
+ generated_text,
+ top_tokens,
+ )
+
+ generations.append(generation)
+
+ # Update values
+ batch.next_token_choosers[i] = batch.next_token_choosers[
+ i
+ ].advance_grammar(next_token_id_squeezed.item())
+ batch.input_ids[i, 0] = next_token_id
+ batch.all_input_ids[i] = all_input_ids
+ batch.input_lengths[i] = new_input_length
+ batch.prefix_offsets[i] = prefix_offset
+ batch.read_offsets[i] = read_offset
+ batch.max_input_length = max(batch.max_input_length, new_input_length)
+
+ # We finished all generations in the batch; there is no next batch
+ if stopped:
+ forward_ns = start_decode - start
+ decode_ns = time.time_ns() - start_decode
+ return generations, None, (forward_ns, decode_ns)
+
+ # Slice unused values from prefill
+ batch.input_ids = batch.input_ids[:, :1]
+
+ forward_ns = start_decode - start
+ decode_ns = time.time_ns() - start_decode
+ return generations, batch, (forward_ns, decode_ns)
diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py
new file mode 100644
index 000000000..e034ed492
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py
@@ -0,0 +1,308 @@
+import torch
+
+import numpy as np
+
+from typing import Iterable, Optional, Tuple, List, Dict
+from text_generation_server.pb.generate_pb2 import Request
+from io import BytesIO
+from PIL import Image
+from dataclasses import dataclass
+from opentelemetry import trace
+from transformers import (
+ PreTrainedTokenizerBase,
+)
+
+from text_generation_server.models.flash_vlm_causal_lm import (
+ FlashVlmCausalLMBatch,
+ FlashVlmCausalLM,
+)
+from text_generation_server.pb import generate_pb2
+from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
+import habana_frameworks.torch as htorch
+
+tracer = trace.get_tracer(__name__)
+
+
+@dataclass
+class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
+ image_indices: List[int] = 42
+ aspect_ratio_ids: Optional[torch.Tensor] = None
+ aspect_ratio_mask: Optional[torch.Tensor] = None
+ cross_attention_states: Optional[torch.Tensor] = None
+
+ @classmethod
+ @tracer.start_as_current_span("concatenate")
+ def concatenate(cls, batches):
+ batch = super().concatenate(batches)
+ batch.pixel_values = None
+ batch.pixel_attention_mask = None
+
+ offset = 0
+ image_indices = []
+ attention_states = []
+ for b in batches:
+ if b.cross_attention_states is not None:
+ attention_states.append(b.cross_attention_states)
+ image_indices.extend([i + offset for i in b.image_indices])
+ offset += len(b.image_indices)
+ if len(attention_states) > 0:
+ assert len(image_indices) > 0
+ batch.cross_attention_states = torch.cat(attention_states, dim=0)
+ batch.image_indices = image_indices
+ else:
+ batch.cross_attention_states = None
+ batch.image_indices = []
+ return batch
+
+ @tracer.start_as_current_span("filter")
+ def filter(self, request_ids: List[int]):
+ assert self.image_indices is not None
+ batch = super().filter(request_ids)
+ assert self.image_indices is not None
+ indices = []
+ for i, request_id in enumerate(request_ids):
+ idx = self.requests_idx_mapping[request_id]
+ indices.append(idx)
+
+ offset = 0
+ new_image_indices = []
+ prev_i = None
+ for i in self.image_indices:
+ if i in indices:
+ new_image_indices.append(offset)
+ if i != prev_i:
+ offset += 1
+ prev_i = i
+
+ batch.image_indices = new_image_indices
+ if len(new_image_indices) > 0:
+ assert max(new_image_indices) < self.cross_attention_states.shape[0]
+ assert offset <= self.cross_attention_states.shape[0]
+ batch.cross_attention_states = self.cross_attention_states[
+ new_image_indices
+ ]
+ else:
+ batch.cross_attention_states = None
+ return batch
+
+ @classmethod
+ def batch_tokenized_inputs(
+ cls, requests: Iterable[Request], tokenizer, processor, config
+ ):
+ image_inputs = []
+ texts = []
+ image_indices = []
+ batch_tokenized_inputs = []
+
+ for i, r in enumerate(requests):
+ # Each input is encoded into a list, where each element of this input list is either a string or a URL
+ curr_text = ""
+ curr_image = None
+ curr_i = None
+ for chunk in r.input_chunks.chunks:
+ chunk_type = chunk.WhichOneof("chunk")
+ if chunk_type == "text":
+ curr_text += chunk.text
+ elif chunk_type == "image":
+ image = Image.open(BytesIO(chunk.image.data))
+ # TODO unsure about BOS
+ curr_text += "<|image|>"
+ image_input = processor.image_processor(image, return_tensors="pt")
+ curr_image = image_input
+ curr_i = i
+ # image_inputs.append(image_input)
+ # image_indices.append(i)
+ else:
+ raise RuntimeError(f"Invalid chunk type {chunk_type}")
+ texts.append(curr_text)
+ if curr_image is not None:
+ image_inputs.append(curr_image)
+ image_indices.append(curr_i)
+
+ input_ids = tokenizer(
+ curr_text,
+ truncation=True,
+ max_length=r.truncate,
+ add_special_tokens=r.add_special_tokens,
+ )["input_ids"]
+ batch_tokenized_inputs.append(input_ids)
+ if image_inputs:
+ image_input = image_inputs[0]
+ new_image_inputs = {
+ "pixel_values": torch.cat(
+ [img["pixel_values"] for img in image_inputs], dim=0
+ ),
+ }
+ if "aspect_ratio_ids" in image_input:
+ new_image_inputs["aspect_ratio_ids"] = torch.cat(
+ [img["aspect_ratio_ids"] for img in image_inputs], dim=0
+ )
+ if "aspect_ratio_mask" in image_input:
+ new_image_inputs["aspect_ratio_mask"] = torch.cat(
+ [img["aspect_ratio_mask"] for img in image_inputs], dim=0
+ )
+ image_inputs = new_image_inputs
+ image_inputs["image_indices"] = image_indices
+ else:
+ image_inputs = None
+
+ if image_inputs is not None:
+ assert len(image_indices) == image_inputs["pixel_values"].shape[0]
+
+ return batch_tokenized_inputs, image_inputs
+
+ @classmethod
+ def from_pb_processor(
+ cls,
+ pb: generate_pb2.Batch,
+ tokenizer: PreTrainedTokenizerBase,
+ processor,
+ config,
+ dtype: torch.dtype,
+ device: torch.device,
+ ) -> "FlashVlmCausalLMBatch":
+ batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
+ pb.requests, tokenizer, processor, config
+ )
+ batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
+ # XXX: <|image|> token is actually out of bounds and bugs out the logit processors.
+ batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
+ max=config.text_config.vocab_size - 1
+ )
+ if isinstance(batch.input_ids, list):
+ if len(batch) > 1:
+ input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
+ else:
+ input_ids = batch.input_ids[0]
+ batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
+
+ batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
+
+ if image_inputs is not None:
+ batch.pixel_values = image_inputs["pixel_values"].to(
+ device=device, dtype=dtype
+ )
+ batch.aspect_ratio_ids = image_inputs["aspect_ratio_ids"].to(device=device)
+ batch.aspect_ratio_mask = image_inputs["aspect_ratio_mask"].to(
+ device=device
+ )
+ batch.image_indices = image_inputs["image_indices"]
+ else:
+ batch.pixel_values = None
+ batch.aspect_ratio_ids = None
+ batch.aspect_ratio_mask = None
+ batch.image_indices = []
+ assert batch.image_indices is not None
+ return batch
+
+
+class FlashMllamaCausalLM(FlashVlmCausalLM):
+ def forward(
+ self,
+ batch: FlashMllamaCausalLMBatch,
+ adapter_data: Optional[Dict[str, torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ # Model Forward
+ if batch.speculative_ids is not None:
+ input_ids = batch.input_ids
+ position_ids = batch.position_ids
+ cu_seqlen_prefill = batch.cu_seqlen_prefill
+ kv_cache = self.kv_cache
+ block_tables = batch.block_tables_tensor
+ slots = batch.slots[batch.slot_indices]
+ input_lengths = batch.input_lengths_tensor
+ max_s = batch.max_current_length
+ lm_head_indices = batch.prefill_head_indices
+
+ speculative_ids = batch.speculative_ids
+
+ B, speculative_length = speculative_ids.shape
+ new_length = speculative_length + 1
+ new_input_ids = torch.cat(
+ [input_ids.unsqueeze(-1), speculative_ids], dim=1
+ ).reshape(-1)
+ arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
+ arange_int = arange.to(dtype=torch.int32)
+ new_position_ids = (
+ position_ids.unsqueeze(-1).expand(B, new_length) + arange
+ ).view(-1)
+ slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
+ input_lengths = (
+ input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
+ ).view(-1)
+ cache_lengths_tensor = (
+ batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
+ ).reshape(-1)
+
+ # Add Copy the block tables for all members
+ block_tables = (
+ block_tables.unsqueeze(1)
+ .expand(B, new_length, -1)
+ .reshape(B * new_length, -1)
+ .contiguous()
+ )
+ max_s = max_s + speculative_length
+
+ input_ids = new_input_ids
+ position_ids = new_position_ids
+ else:
+ input_ids = batch.input_ids
+ position_ids = batch.position_ids
+ cu_seqlen_prefill = batch.cu_seqlen_prefill
+ kv_cache = self.kv_cache
+ block_tables = batch.block_tables_tensor
+ slots = batch.slots[batch.slot_indices]
+ input_lengths = batch.input_lengths_tensor
+ cache_lengths_tensor = batch.cache_lengths_tensor
+ max_s = batch.max_current_length
+ lm_head_indices = batch.prefill_head_indices
+
+ if cu_seqlen_prefill is None and self.max_past() is not None:
+ # In decode, not prefill, we're actually overwriting the KV-cache
+ # in a circular buffer mode.
+ # This makes sure the max_s for the decode pass is correct.
+ max_s = min(self.max_past(), max_s)
+
+ seqlen = Seqlen(
+ input_lengths=input_lengths,
+ cache_lengths=cache_lengths_tensor,
+ cu_seqlen_q=cu_seqlen_prefill,
+ )
+
+ if batch.pixel_values is not None:
+ cross_attention_states = self.model.vision_forward(
+ pixel_values=batch.pixel_values,
+ aspect_ratio_ids=batch.aspect_ratio_ids,
+ aspect_ratio_mask=batch.aspect_ratio_mask,
+ )
+ batch.cross_attention_states = cross_attention_states
+
+ cross_attention_states = batch.cross_attention_states
+
+ kwargs = {}
+ if htorch.utils.internal.is_lazy():
+ kwargs["bypass_hpu_graphs"] = False
+ if batch.prefill_cache_indices is not None:
+ slots_pad = torch.zeros_like(input_ids)
+ slots_pad[batch.prefill_cache_indices] = slots
+ slots = slots_pad
+ logits, speculative_logits = self.model.forward(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ cu_seqlen_prefill=cu_seqlen_prefill,
+ kv_cache=kv_cache,
+ slots=slots,
+ seqlen=trim_seqlen_metadata(seqlen),
+ hpu_attention_meta=batch.hpu_attn_meta,
+ lm_head_indices=lm_head_indices,
+ cross_attention_states=cross_attention_states,
+ # TODO list
+ adapter_data=None,
+ image_indices=batch.image_indices[:],
+ **kwargs,
+ )
+ if batch.prefill_cache_indices is not None:
+ batch.prefill_cache_indices = None
+ if batch.pixel_values is not None:
+ batch.pixel_values = None
+ return logits, speculative_logits
diff --git a/backends/gaudi/server/text_generation_server/models/model.py b/backends/gaudi/server/text_generation_server/models/model.py
new file mode 100644
index 000000000..66c69bc1f
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/model.py
@@ -0,0 +1,142 @@
+import inspect
+import torch
+
+from abc import ABC, abstractmethod
+from typing import List, Tuple, Optional, TypeVar, Type, Dict
+from collections import defaultdict
+from transformers import PreTrainedTokenizerBase
+
+from text_generation_server.models.types import Batch, Generation
+from text_generation_server.models.globals import BLOCK_SIZE
+from text_generation_server.utils.speculate import get_speculate
+from text_generation_server.pb.generate_pb2 import InfoResponse
+from text_generation_server.adapters.weights import LayerAdapterWeights
+from text_generation_server.pb import generate_pb2
+
+BASE_MODEL_ADAPTER_ID = "__base_model__"
+
+
+B = TypeVar("B", bound=Batch)
+
+
+class Model(ABC):
+ def __init__(
+ self,
+ model_id: str,
+ model: torch.nn.Module,
+ tokenizer: PreTrainedTokenizerBase,
+ requires_padding: bool,
+ dtype: torch.dtype,
+ device: torch.device,
+ rank: int = 0,
+ world_size: int = 1,
+ sliding_window: Optional[int] = None,
+ speculate: Optional[int] = None,
+ adapter_id: str = BASE_MODEL_ADAPTER_ID,
+ support_chunking: bool = False,
+ ):
+ self.model_id = model_id
+ self.model = model.eval()
+ self.tokenizer = tokenizer
+
+ # all_special_ids is not set correctly if the rust tokenizer is unpacked
+ # TODO report this to transformers.
+ other_special_ids = {
+ id for id, token in tokenizer.added_tokens_decoder.items() if token.special
+ }
+ self.all_special_ids = set(tokenizer.all_special_ids)
+ self.all_special_ids.update(other_special_ids)
+ self.requires_padding = requires_padding
+ self.dtype = dtype
+ self.device = device
+ self.rank = rank
+ self.world_size = world_size
+ self.sliding_window = sliding_window if sliding_window != -1 else None
+
+ self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
+ LayerAdapterWeights
+ )
+ self.loaded_adapters = set()
+ self.static_adapter_id = adapter_id
+
+ if speculate is None:
+ speculate = get_speculate()
+ self.speculate = speculate
+
+ self.has_position_ids = (
+ inspect.signature(model.forward).parameters.get("position_ids", None)
+ is not None
+ )
+
+ self.check_initialized()
+
+ @property
+ def info(self) -> InfoResponse:
+ if self.requires_padding and self.sliding_window is not None:
+ raise NotImplementedError("sliding_window is not implemented with padding")
+
+ return InfoResponse(
+ requires_padding=self.requires_padding,
+ dtype=str(self.dtype),
+ device_type=self.device.type,
+ window_size=self.sliding_window,
+ speculate=self.speculate,
+ block_size=BLOCK_SIZE,
+ )
+
+ @property
+ @abstractmethod
+ def batch_type(self) -> Type[B]:
+ raise NotImplementedError
+
+ @abstractmethod
+ def generate_token(
+ self, batch: B
+ ) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
+ raise NotImplementedError
+
+ def warmup(
+ self, batch: generate_pb2.WarmupRequest
+ ) -> Tuple[Optional[int], Optional[int], Optional[int]]:
+ self.generate_token(batch)
+ return None, None, None
+
+ def decode_token(
+ self,
+ all_input_ids: List[int],
+ prefix_offset: int = 0,
+ read_offset: int = 0,
+ skip_special_tokens: bool = False,
+ ) -> Tuple[str, int, int]:
+ """Hack to hopefully support generate_stream for the maximum number of tokenizers"""
+
+ # The prefix text is necessary only to defeat cleanup algorithms in the decode
+ # which decide to add a space or not depending on the surrounding ids.
+ prefix_text = self.tokenizer.decode(
+ all_input_ids[prefix_offset:read_offset],
+ skip_special_tokens=skip_special_tokens,
+ )
+
+ new_text = self.tokenizer.decode(
+ all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens
+ )
+
+ if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
+ # utf-8 char at the end means it's a potential unfinished byte sequence
+ # from byte fallback tokenization.
+ # If it's in the middle, it's probably a real invalid id generated
+ # by the model
+ new_text = new_text[len(prefix_text) :]
+ return new_text, read_offset, len(all_input_ids)
+ else:
+ return "", prefix_offset, read_offset
+
+ def check_initialized(self):
+ uninitialized_parameters = []
+ for n, p in self.model.named_parameters():
+ if p.data.device == torch.device("meta"):
+ uninitialized_parameters.append(n)
+ if uninitialized_parameters:
+ raise RuntimeError(
+ f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
+ )
diff --git a/backends/gaudi/server/text_generation_server/models/pali_gemma.py b/backends/gaudi/server/text_generation_server/models/pali_gemma.py
new file mode 100644
index 000000000..e91aaed99
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/pali_gemma.py
@@ -0,0 +1,71 @@
+from io import BytesIO
+from PIL import Image
+import torch
+import torch.distributed
+from opentelemetry import trace
+from typing import Iterable
+from text_generation_server.models.flash_vlm_causal_lm import (
+ FlashVlmCausalLMBatch,
+ image_text_replacement,
+)
+
+from text_generation_server.pb.generate_pb2 import Request
+
+tracer = trace.get_tracer(__name__)
+
+
+class PaliGemmaBatch(FlashVlmCausalLMBatch):
+ @classmethod
+ def batch_tokenized_inputs(
+ cls, requests: Iterable[Request], tokenizer, processor, config
+ ):
+ batch_inputs = []
+ image_inputs = []
+ max_truncation = 0
+ for r in requests:
+ full_text = ""
+ image_id = 0
+ for chunk in r.input_chunks.chunks:
+ chunk_type = chunk.WhichOneof("chunk")
+ if chunk_type == "text":
+ full_text += "" + chunk.text + "\n"
+ elif chunk_type == "image":
+ image = Image.open(BytesIO(chunk.image.data))
+ # TODO do_convert_RGB should be on by default ?
+ image = image.convert("RGB")
+ image_input = processor.image_processor(image, return_tensors="pt")
+ full_text += image_text_replacement(
+ processor, image_input, config, image_id
+ )
+ image_inputs.append(image_input)
+ else:
+ raise RuntimeError(f"Invalid chunk type {chunk_type}")
+
+ batch_inputs.append(full_text)
+ max_truncation = max(max_truncation, r.truncate)
+
+ batch_tokenized_inputs = tokenizer(
+ batch_inputs,
+ truncation=True,
+ max_length=max_truncation,
+ add_special_tokens=False,
+ )["input_ids"]
+ if image_inputs:
+ image_input = image_inputs[0]
+ new_image_inputs = {
+ "pixel_values": torch.cat(
+ [img["pixel_values"] for img in image_inputs], dim=0
+ ),
+ }
+ if "pixel_attention_mask" in image_input:
+ new_image_inputs["pixel_attention_mask"] = torch.cat(
+ [img["pixel_attention_mask"] for img in image_inputs], dim=0
+ )
+ if "image_sizes" in image_input:
+ new_image_inputs["image_sizes"] = torch.cat(
+ [img["image_sizes"] for img in image_inputs], dim=0
+ )
+ image_inputs = new_image_inputs
+ else:
+ image_inputs = None
+ return batch_tokenized_inputs, image_inputs
diff --git a/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py b/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py
new file mode 100644
index 000000000..0ee6ed167
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/seq2seq_lm.py
@@ -0,0 +1,920 @@
+import torch
+import torch.distributed
+import time
+from dataclasses import dataclass
+from opentelemetry import trace
+from transformers import (
+ AutoTokenizer,
+ AutoModelForSeq2SeqLM,
+ PreTrainedTokenizerBase,
+ AutoConfig,
+)
+from typing import Optional, Tuple, List, Type, Dict
+from text_generation_server.utils import (
+ initialize_torch_distributed,
+ weight_files,
+ Weights,
+)
+from text_generation_server.utils.chunks import concat_text_chunks
+from text_generation_server.utils.quantization import get_loader
+from text_generation_server.utils.tokens import batch_top_tokens
+from text_generation_server.models import Model
+from text_generation_server.models.types import (
+ GeneratedText,
+ Batch,
+ Generation,
+ Tokens,
+)
+from text_generation_server.pb import generate_pb2
+from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
+
+tracer = trace.get_tracer(__name__)
+
+
+@dataclass
+class Seq2SeqLMBatch(Batch):
+ batch_id: int
+ requests: List[generate_pb2.Request]
+ requests_idx_mapping: Dict[int, int]
+
+ # Encoder values
+ input_ids: Optional[torch.Tensor]
+ attention_mask: torch.Tensor
+
+ # Decoder values
+ decoder_input_ids: torch.Tensor
+ decoder_attention_mask: Optional[torch.Tensor]
+ encoder_last_hidden_state: Optional[torch.Tensor]
+
+ # All tokens
+ all_decoder_input_ids: List[torch.Tensor]
+
+ # Seq2SeqLM keeps track of both encoder and decoder attention keys and values
+ past_key_values: Optional[List[Tuple]]
+
+ # Lengths of all generations present in the batch
+ input_lengths: List[int]
+ decoder_input_lengths: List[int]
+ prefix_offsets: List[int]
+ read_offsets: List[int]
+
+ # Generation helpers
+ next_token_choosers: List[NextTokenChooser]
+ stopping_criterias: List[StoppingCriteria]
+ top_n_tokens: List[int]
+ top_n_tokens_tensor: torch.Tensor
+
+ # Metadata used for padding
+ max_input_length: int
+ max_decoder_input_length: int
+ padding_right_offset: int
+
+ # Maximum number of tokens this batch will grow to
+ max_tokens: int
+
+ def to_pb(self) -> generate_pb2.CachedBatch:
+ """Convert a Seq2SeqLMBatch to a text_generation_server.v1.CachedBatch protobuf"""
+ return generate_pb2.CachedBatch(
+ id=self.batch_id,
+ request_ids=[r.id for r in self.requests],
+ size=len(self),
+ max_tokens=self.max_tokens,
+ )
+
+ @classmethod
+ def from_pb(
+ cls,
+ pb: generate_pb2.Batch,
+ tokenizer: PreTrainedTokenizerBase,
+ dtype: torch.dtype,
+ device: torch.device,
+ ) -> "Seq2SeqLMBatch":
+ """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
+ inputs = []
+ next_token_choosers = []
+ stopping_criterias = []
+ top_n_tokens = []
+ decoder_input_lengths = []
+ prefix_offsets = []
+ read_offsets = []
+ requests_idx_mapping = {}
+
+ # Parse batch
+ max_truncation = 0
+ padding_right_offset = 0
+ max_decode_tokens = 0
+ for i, r in enumerate(pb.requests):
+ inputs.append(concat_text_chunks(r.input_chunks.chunks))
+ requests_idx_mapping[r.id] = i
+ decoder_input_lengths.append(1)
+ next_token_choosers.append(
+ NextTokenChooser.from_pb(r.parameters, device, tokenizer)
+ )
+ stopping_criteria = StoppingCriteria.from_pb(
+ r.stopping_parameters, tokenizer
+ )
+ stopping_criterias.append(stopping_criteria)
+ top_n_tokens.append(r.top_n_tokens)
+ max_truncation = max(max_truncation, r.truncate)
+ max_decode_tokens += stopping_criteria.max_new_tokens
+ padding_right_offset = max(
+ padding_right_offset, stopping_criteria.max_new_tokens
+ )
+
+ # Tokenize batch
+ tokenized_inputs = tokenizer(
+ inputs,
+ return_tensors="pt",
+ padding=True,
+ return_token_type_ids=False,
+ truncation=True,
+ max_length=max_truncation,
+ ).to(device)
+
+ input_lengths = tokenized_inputs["attention_mask"].sum(1)
+ max_input_length = input_lengths.max()
+
+ # Decoder sequence only contains the bos_token
+ decoder_input_ids = (
+ torch.tensor(tokenizer.bos_token_id, device=device)
+ .repeat(len(pb.requests))
+ .view(-1, 1)
+ )
+ for _ in pb.requests:
+ prefix_offsets.append(0)
+ read_offsets.append(1)
+ all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
+ top_n_tokens_tensor = torch.tensor(
+ top_n_tokens, device=device, dtype=torch.int64
+ )
+
+ max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
+
+ return cls(
+ batch_id=pb.id,
+ requests=pb.requests,
+ requests_idx_mapping=requests_idx_mapping,
+ input_ids=tokenized_inputs["input_ids"],
+ attention_mask=tokenized_inputs["attention_mask"],
+ decoder_input_ids=decoder_input_ids,
+ all_decoder_input_ids=list(all_decoder_input_ids),
+ decoder_attention_mask=None,
+ encoder_last_hidden_state=None,
+ past_key_values=None,
+ input_lengths=input_lengths.tolist(),
+ decoder_input_lengths=decoder_input_lengths,
+ prefix_offsets=prefix_offsets,
+ read_offsets=read_offsets,
+ next_token_choosers=next_token_choosers,
+ stopping_criterias=stopping_criterias,
+ top_n_tokens=top_n_tokens,
+ top_n_tokens_tensor=top_n_tokens_tensor,
+ max_input_length=max_input_length.item(),
+ max_decoder_input_length=1,
+ padding_right_offset=padding_right_offset,
+ max_tokens=max_tokens,
+ )
+
+ @tracer.start_as_current_span("filter")
+ def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]:
+ if len(request_ids) == 0:
+ raise ValueError("Batch must have at least one request")
+ if len(request_ids) == len(self):
+ return self
+
+ keep_indices = []
+
+ # New values after filtering
+ requests_idx_mapping = {}
+ requests = []
+ input_lengths = []
+ decoder_input_lengths = []
+ prefix_offsets = []
+ read_offsets = []
+
+ all_decoder_input_ids = []
+
+ next_token_choosers = []
+ stopping_criterias = []
+ top_n_tokens = []
+
+ max_input_length = 0
+ max_decoder_input_length = 0
+ padding_right_offset = 0
+
+ total_remaining_decode_tokens = 0
+
+ for i, request_id in enumerate(request_ids):
+ idx = self.requests_idx_mapping[request_id]
+ requests_idx_mapping[request_id] = i
+ keep_indices.append(idx)
+
+ requests.append(self.requests[idx])
+ prefix_offsets.append(self.prefix_offsets[idx])
+ read_offsets.append(self.read_offsets[idx])
+
+ all_decoder_input_ids.append(self.all_decoder_input_ids[idx])
+
+ request_input_length = self.input_lengths[idx]
+ input_lengths.append(request_input_length)
+ max_input_length = max(max_input_length, request_input_length)
+
+ request_decoder_input_length = self.decoder_input_lengths[idx]
+ decoder_input_lengths.append(request_decoder_input_length)
+ max_decoder_input_length = max(
+ max_decoder_input_length, request_decoder_input_length
+ )
+
+ next_token_choosers.append(self.next_token_choosers[idx])
+ stopping_criteria = self.stopping_criterias[idx]
+ stopping_criterias.append(stopping_criteria)
+ top_n_tokens.append(self.top_n_tokens[idx])
+ remaining_decode_tokens = (
+ stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
+ )
+ total_remaining_decode_tokens += remaining_decode_tokens
+ padding_right_offset = max(padding_right_offset, remaining_decode_tokens)
+
+ # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
+ self.decoder_input_ids = self.decoder_input_ids[keep_indices]
+ self.attention_mask = self.attention_mask[keep_indices, -max_input_length:]
+ if self.decoder_attention_mask is not None:
+ self.decoder_attention_mask = self.decoder_attention_mask[
+ keep_indices,
+ -(self.padding_right_offset + max_decoder_input_length) : (
+ self.decoder_attention_mask.shape[1] - self.padding_right_offset
+ )
+ + padding_right_offset,
+ ]
+
+ self.encoder_last_hidden_state = self.encoder_last_hidden_state[
+ keep_indices, -max_input_length:
+ ]
+
+ # Ensure that past_key_values tensors can be updated in-place
+ if type(self.past_key_values[0]) is tuple:
+ self.past_key_values = [
+ [t for t in layer] for layer in self.past_key_values
+ ]
+
+ decoder_past_seq_len = max_decoder_input_length - 1
+ for layer in self.past_key_values:
+ layer[0] = layer[0][keep_indices, :, -decoder_past_seq_len:]
+ layer[1] = layer[1][keep_indices, :, -decoder_past_seq_len:]
+ layer[2] = layer[2][keep_indices, :, -max_input_length:]
+ layer[3] = layer[3][keep_indices, :, -max_input_length:]
+
+ top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
+ max_tokens = (
+ len(request_ids) * (max_input_length + max_decoder_input_length)
+ + remaining_decode_tokens
+ )
+
+ self.requests = requests
+ self.requests_idx_mapping = requests_idx_mapping
+ self.input_ids = None
+ self.all_decoder_input_ids = all_decoder_input_ids
+ self.input_lengths = input_lengths
+ self.decoder_input_lengths = decoder_input_lengths
+ self.prefix_offsets = prefix_offsets
+ self.read_offsets = read_offsets
+ self.next_token_choosers = next_token_choosers
+ self.stopping_criterias = stopping_criterias
+ self.top_n_tokens = top_n_tokens
+ self.top_n_tokens_tensor = top_n_tokens_tensor
+ self.max_input_length = max_input_length
+ self.max_decoder_input_length = max_decoder_input_length
+ self.padding_right_offset = padding_right_offset
+ self.max_tokens = max_tokens
+
+ return self
+
+ @classmethod
+ @tracer.start_as_current_span("concatenate")
+ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
+ """Concatenate multiple batches together by padding internal torch tensors"""
+
+ # Used for padding
+ total_batch_size = 0
+ max_input_length = 0
+ max_decoder_input_length = 0
+ padding_right_offset = 0
+ for batch in batches:
+ total_batch_size += len(batch)
+ max_input_length = max(max_input_length, batch.max_input_length)
+ max_decoder_input_length = max(
+ max_decoder_input_length, batch.max_decoder_input_length
+ )
+ padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
+
+ # Batch attributes
+ requests = []
+ requests_idx_mapping = {}
+ all_decoder_input_ids = []
+ input_lengths = []
+ decoder_input_lengths = []
+ prefix_offsets = []
+ read_offsets = []
+ next_token_choosers = []
+ stopping_criterias = []
+ top_n_tokens = []
+ max_tokens = 0
+
+ # Batch tensors
+ attention_mask = None
+ decoder_input_ids = None
+ decoder_attention_mask = None
+ encoder_last_hidden_state = None
+ top_n_tokens_tensor = None
+ past_key_values = []
+
+ # Used for slicing correctly inside the tensors
+ # Equivalent to a cumsum on batch sizes
+ start_index = 0
+
+ for i, batch in enumerate(batches):
+ # Extend all list attributes
+ requests.extend(batch.requests)
+ all_decoder_input_ids.extend(batch.all_decoder_input_ids)
+ input_lengths.extend(batch.input_lengths)
+ decoder_input_lengths.extend(batch.decoder_input_lengths)
+ prefix_offsets.extend(batch.prefix_offsets)
+ read_offsets.extend(batch.read_offsets)
+ next_token_choosers.extend(batch.next_token_choosers)
+ stopping_criterias.extend(batch.stopping_criterias)
+ top_n_tokens.extend(batch.top_n_tokens)
+
+ if i == 0:
+ requests_idx_mapping = batch.requests_idx_mapping
+ else:
+ # We need to offset the mapping for each batch by the cumulative batch size
+ for k, v in batch.requests_idx_mapping.items():
+ requests_idx_mapping[k] = v + start_index
+
+ # Slicing end index for this batch
+ end_index = start_index + len(batch)
+
+ # We only concatenate batches that did at least one step
+ if batch.encoder_last_hidden_state is None:
+ raise ValueError("Batch encoder_last_hidden_state cannot be None")
+
+ # Create padded tensor
+ if attention_mask is None:
+ attention_mask = batch.attention_mask.new_zeros(
+ (total_batch_size, max_input_length),
+ )
+ # Copy to correct indices
+ attention_mask[start_index:end_index, -batch.max_input_length :] = (
+ batch.attention_mask[:, -batch.max_input_length :]
+ )
+
+ # Create padded tensor
+ if decoder_input_ids is None:
+ decoder_input_ids = batch.decoder_input_ids.new_zeros(
+ (total_batch_size, 1),
+ )
+ # Copy to correct indices
+ decoder_input_ids[start_index:end_index] = batch.decoder_input_ids
+
+ # Create padded tensor
+ if decoder_attention_mask is None:
+ # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
+ decoder_attention_mask = batch.attention_mask.new_zeros(
+ (total_batch_size, max_decoder_input_length + padding_right_offset),
+ )
+ # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
+ # this batch. All generations are of length `batch.max_decoder_input_length`.
+ left_offset = max_decoder_input_length - batch.max_decoder_input_length
+ if batch.decoder_attention_mask is None:
+ decoder_attention_mask[
+ start_index:end_index,
+ left_offset:-padding_right_offset,
+ ] = 1
+ # If it exists, we need to index
+ else:
+ batch_left_offset = (
+ batch.decoder_attention_mask.shape[1]
+ - batch.max_decoder_input_length
+ - batch.padding_right_offset
+ )
+ decoder_attention_mask[
+ start_index:end_index,
+ left_offset:-padding_right_offset,
+ ] = batch.decoder_attention_mask[
+ :,
+ batch_left_offset : -batch.padding_right_offset,
+ ]
+
+ # Create padded tensor
+ if encoder_last_hidden_state is None:
+ encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
+ (
+ total_batch_size,
+ max_input_length,
+ batch.encoder_last_hidden_state.shape[-1],
+ ),
+ )
+
+ if top_n_tokens_tensor is None:
+ top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
+ total_batch_size,
+ )
+ top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
+
+ # Copy to correct indices
+ encoder_last_hidden_state[
+ start_index:end_index, -batch.max_input_length :, :
+ ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
+ batch.encoder_last_hidden_state = None
+
+ # Ensure that we can update tensors in-place
+ if isinstance(batch.past_key_values[0], tuple):
+ batch.past_key_values = [
+ [t for t in layer] for layer in batch.past_key_values
+ ]
+
+ # Add eventual padding tokens that were added while concatenating
+ max_tokens += batch.max_tokens + (
+ max_input_length
+ - batch.max_input_length
+ + max_decoder_input_length
+ - batch.max_decoder_input_length
+ ) * len(batch)
+
+ start_index = end_index
+
+ # Determine shapes for new past kv tensors
+ first_past_kvs = batches[0].past_key_values
+ _, num_heads, _, head_dim = first_past_kvs[0][0].shape
+
+ padded_dec_t_shape = (
+ total_batch_size,
+ num_heads,
+ (max_decoder_input_length - 1),
+ head_dim,
+ )
+
+ padded_enc_t_shape = (
+ total_batch_size,
+ num_heads,
+ max_input_length,
+ head_dim,
+ )
+
+ # Iterate over attention layers
+ for j in range(len(first_past_kvs)):
+ past_key_values.append([])
+
+ # Decoder past
+ for k in range(0, 2):
+ # Initialize tensors
+ padded_past_values = first_past_kvs[j][k].new_zeros(padded_dec_t_shape)
+ past_key_values[j].append(padded_past_values)
+
+ start_index = 0
+ for batch in batches:
+ t = batch.past_key_values[j][k]
+ # Clear reference to the original tensor
+ batch.past_key_values[j][k] = None
+ # Slicing end index for this batch
+ end_index = start_index + len(batch)
+ # We slice the past keys and values to remove the padding from previous batches
+ past_seq_len = batch.max_decoder_input_length - 1
+ padded_past_values[start_index:end_index, :, -past_seq_len:, :] = t[
+ :, :, -past_seq_len:, :
+ ]
+ del t
+
+ start_index = end_index
+
+ # Encoder past
+ for k in range(2, 4):
+ # Initialize tensors
+ padded_past_values = first_past_kvs[j][k].new_zeros(padded_enc_t_shape)
+ past_key_values[j].append(padded_past_values)
+
+ start_index = 0
+ for batch in batches:
+ t = batch.past_key_values[j][k]
+ # Clear reference to the original tensor
+ batch.past_key_values[j][k] = None
+ # Slicing end index for this batch
+ end_index = start_index + len(batch)
+ # We slice the past keys and values to remove the padding from previous batches
+ padded_past_values[
+ start_index:end_index, :, -batch.max_input_length :, :
+ ] = t[:, :, -batch.max_input_length :, :]
+ del t
+
+ start_index = end_index
+
+ return cls(
+ batch_id=batches[0].batch_id,
+ requests=requests,
+ requests_idx_mapping=requests_idx_mapping,
+ input_ids=None,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ all_decoder_input_ids=all_decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ encoder_last_hidden_state=encoder_last_hidden_state,
+ past_key_values=past_key_values,
+ input_lengths=input_lengths,
+ decoder_input_lengths=decoder_input_lengths,
+ prefix_offsets=prefix_offsets,
+ read_offsets=read_offsets,
+ next_token_choosers=next_token_choosers,
+ stopping_criterias=stopping_criterias,
+ top_n_tokens=top_n_tokens,
+ top_n_tokens_tensor=top_n_tokens_tensor,
+ max_input_length=max_input_length,
+ max_decoder_input_length=max_decoder_input_length,
+ padding_right_offset=padding_right_offset,
+ max_tokens=max_tokens,
+ )
+
+ def __len__(self):
+ return len(self.requests)
+
+
+class Seq2SeqLM(Model):
+ def __init__(
+ self,
+ model_id: str,
+ model_class,
+ revision: Optional[str] = None,
+ quantize: Optional[str] = None,
+ speculator: Optional[str] = None,
+ dtype: Optional[torch.dtype] = None,
+ default_dtype=torch.float16,
+ trust_remote_code: bool = False,
+ config_class=AutoConfig,
+ tokenizer_class=AutoTokenizer,
+ aliases=None,
+ ):
+ self.quantize = quantize
+ self.process_group, rank, world_size = initialize_torch_distributed()
+
+ device = torch.device("hpu")
+ dtype = torch.bfloat16 if dtype is None else dtype
+
+ config = config_class.from_pretrained(
+ model_id,
+ revision=revision,
+ trust_remote_code=trust_remote_code,
+ )
+ config.quantize = quantize
+ config.speculator = speculator
+
+ tokenizer = tokenizer_class.from_pretrained(
+ model_id,
+ revision=revision,
+ padding_side="left",
+ truncation_side="left",
+ trust_remote_code=trust_remote_code,
+ )
+ tokenizer.bos_token_id = config.decoder_start_token_id
+
+ weights_loader = get_loader(
+ quantize=quantize, model_id=model_id, revision=revision
+ )
+ torch.distributed.barrier(group=self.process_group)
+ filenames = weight_files(model_id, revision=revision, extension=".safetensors")
+ weights = Weights(
+ filenames,
+ device=device,
+ dtype=dtype,
+ process_group=self.process_group,
+ aliases=aliases,
+ weights_loader=weights_loader,
+ )
+ if config.quantize in ["awq", "gptq"]:
+ weights._set_gptq_params(model_id, revision)
+
+ model = model_class(config, weights)
+
+ torch.distributed.barrier(group=self.process_group)
+ super().__init__(
+ model_id=model_id,
+ model=model,
+ tokenizer=tokenizer,
+ requires_padding=True,
+ dtype=dtype,
+ device=device,
+ rank=rank,
+ world_size=world_size,
+ )
+
+ @classmethod
+ def fallback(
+ cls,
+ model_id: str,
+ revision: Optional[str] = None,
+ quantize: Optional[str] = None,
+ speculator: Optional[str] = None,
+ dtype: Optional[torch.dtype] = None,
+ trust_remote_code: bool = False,
+ ):
+ if speculator:
+ raise RuntimeError("Speculator decoding is not enabled for AutoModel")
+
+ if torch.cuda.is_available():
+ device = torch.device("cuda")
+ dtype = torch.float16 if dtype is None else dtype
+ else:
+ if quantize:
+ raise ValueError("quantization is not available on CPU")
+
+ device = torch.device("cpu")
+ dtype = torch.float32 if dtype is None else dtype
+
+ model = AutoModelForSeq2SeqLM.from_pretrained(
+ model_id,
+ revision=revision,
+ torch_dtype=dtype,
+ device_map=(
+ "auto"
+ if torch.cuda.is_available() and torch.cuda.device_count() > 1
+ else None
+ ),
+ load_in_8bit=quantize == "bitsandbytes",
+ trust_remote_code=trust_remote_code,
+ )
+ if torch.cuda.is_available() and torch.cuda.device_count() == 1:
+ model = model.cuda()
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id,
+ revision=revision,
+ padding_side="left",
+ truncation_side="left",
+ trust_remote_code=trust_remote_code,
+ )
+ tokenizer.bos_token_id = model.config.decoder_start_token_id
+
+ self = cls.__new__(
+ cls,
+ )
+ super().__init__(
+ self,
+ model_id=model_id,
+ model=model,
+ tokenizer=tokenizer,
+ requires_padding=True,
+ dtype=dtype,
+ device=device,
+ )
+ self.quantize = quantize
+ return self
+
+ @property
+ def batch_type(self) -> Type[Seq2SeqLMBatch]:
+ return Seq2SeqLMBatch
+
+ def forward(
+ self,
+ input_ids,
+ attention_mask,
+ decoder_input_ids,
+ decoder_attention_mask: Optional,
+ encoder_last_hidden_state: Optional,
+ past_key_values: Optional = None,
+ ) -> Tuple[
+ torch.Tensor,
+ Optional[torch.Tensor],
+ torch.Tensor,
+ List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
+ ]:
+ # Model Forward
+ outputs = self.model.forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ encoder_outputs=encoder_last_hidden_state,
+ past_key_values=past_key_values,
+ use_cache=True,
+ )
+ if isinstance(outputs, tuple):
+ # Our custom models
+ outputs, speculative_logits = outputs
+ else:
+ # Generic transformers models
+ speculative_logits = None
+ return (
+ outputs.logits,
+ speculative_logits,
+ outputs.encoder_last_hidden_state,
+ outputs.past_key_values,
+ )
+
+ @tracer.start_as_current_span("generate_token")
+ def generate_token(
+ self, batch: Seq2SeqLMBatch
+ ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch], Tuple[int, int]]:
+ start = time.time_ns()
+ if batch.decoder_attention_mask is not None:
+ # slice to the correct shape
+ decoder_attention_mask = batch.decoder_attention_mask[
+ :, : -batch.padding_right_offset
+ ]
+ else:
+ decoder_attention_mask = None
+
+ # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
+ # internally...
+ if batch.encoder_last_hidden_state is not None:
+ encoder_last_hidden_state = [batch.encoder_last_hidden_state]
+ else:
+ encoder_last_hidden_state = None
+
+ logits, speculative_logits, encoder_last_hidden_state, past = self.forward(
+ batch.input_ids,
+ batch.attention_mask,
+ batch.decoder_input_ids,
+ decoder_attention_mask,
+ encoder_last_hidden_state,
+ batch.past_key_values,
+ )
+
+ # Speculation is not active for seq2seq
+ accepted_ids = torch.ones_like(batch.decoder_input_ids)[:, 0]
+ batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
+ batch.top_n_tokens,
+ batch.top_n_tokens_tensor,
+ torch.log_softmax(logits[:, -1], -1),
+ accepted_ids,
+ )
+
+ start_decode = time.time_ns()
+
+ # Finished requests
+ generations: List[Generation] = []
+ stopped = True
+
+ # Zipped iterator
+ iterator = zip(
+ batch.requests,
+ batch.input_lengths,
+ batch.prefix_offsets,
+ batch.read_offsets,
+ batch.decoder_input_lengths,
+ logits,
+ batch.next_token_choosers,
+ batch.stopping_criterias,
+ batch.all_decoder_input_ids,
+ batch.top_n_tokens,
+ batch_top_token_ids,
+ batch_top_token_logprobs,
+ )
+
+ # For each member of the batch
+ for i, (
+ request,
+ input_length,
+ prefix_offset,
+ read_offset,
+ decoder_input_length,
+ logits,
+ next_token_chooser,
+ stopping_criteria,
+ all_decoder_input_ids,
+ top_n_tokens,
+ top_token_ids,
+ top_token_logprobs,
+ ) in enumerate(iterator):
+ # Select next token
+ next_token_id, logprobs = next_token_chooser(
+ all_decoder_input_ids.view(1, -1), logits[-1:, :]
+ )
+
+ # Append next token to decoder tokens
+ all_decoder_input_ids = torch.cat(
+ [all_decoder_input_ids, next_token_id.squeeze(1)]
+ )
+ new_decoder_input_length = decoder_input_length + 1
+
+ # Generated token
+ next_token_logprob = logprobs[-1, next_token_id]
+ next_token_id_squeezed = next_token_id.squeeze()
+ next_token_text, prefix_offset, read_offset = self.decode_token(
+ all_decoder_input_ids, prefix_offset, read_offset
+ )
+
+ # Evaluate stopping criteria
+ stop, reason = stopping_criteria(next_token_id, next_token_text)
+
+ if not stop:
+ stopped = False
+
+ # Shard generations
+ # All generations will be appended in the rust sharded client
+ if i % self.world_size == self.rank:
+ if stop:
+ # Slice with decoder_input_length to remove padding
+ # Decode all tokens
+ output_text, _, _ = self.decode_token(
+ all_decoder_input_ids,
+ prefix_offset=len(all_decoder_input_ids)
+ - decoder_input_length
+ - 1,
+ read_offset=len(all_decoder_input_ids) - decoder_input_length,
+ skip_special_tokens=True,
+ )
+
+ # Get seed
+ if isinstance(next_token_chooser.choice, Sampling):
+ seed = next_token_chooser.choice.seed
+ else:
+ seed = None
+
+ generated_text = GeneratedText(
+ output_text, stopping_criteria.current_tokens, reason, seed
+ )
+ else:
+ generated_text = None
+
+ # Prefill
+ if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
+ prefill_tokens = Tokens(
+ [self.tokenizer.bos_token_id],
+ [float("nan")],
+ [self.tokenizer.bos_token],
+ [False],
+ )
+ else:
+ prefill_tokens = None
+
+ if top_n_tokens > 0:
+ all_top_tokens = []
+ for top_token_ids, top_token_logprobs in zip(
+ top_token_ids, top_token_logprobs
+ ):
+ toptoken_texts = self.tokenizer.batch_decode(
+ top_token_ids,
+ clean_up_tokenization_spaces=False,
+ skip_special_tokens=False,
+ )
+ special_toptokens = [
+ token_id in self.all_special_ids
+ for token_id in top_token_ids
+ ]
+ top_tokens = Tokens(
+ top_token_ids,
+ top_token_logprobs,
+ toptoken_texts,
+ special_toptokens,
+ )
+ all_top_tokens.append(top_tokens)
+ top_tokens = all_top_tokens
+ else:
+ top_tokens = None
+
+ generation = Generation(
+ request.id,
+ prefill_tokens,
+ Tokens(
+ [next_token_id_squeezed],
+ [next_token_logprob],
+ [next_token_text],
+ [next_token_id_squeezed.item() in self.all_special_ids],
+ ),
+ generated_text,
+ top_tokens,
+ )
+
+ generations.append(generation)
+
+ # Update values
+ batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
+ next_token_id_squeezed.item()
+ )
+ batch.decoder_input_ids[i] = next_token_id
+ batch.all_decoder_input_ids[i] = all_decoder_input_ids
+ batch.input_lengths[i] = input_length
+ batch.decoder_input_lengths[i] = new_decoder_input_length
+ batch.prefix_offsets[i] = prefix_offset
+ batch.read_offsets[i] = read_offset
+ batch.max_input_length = max(batch.max_input_length, input_length)
+ batch.max_decoder_input_length = max(
+ batch.max_decoder_input_length, new_decoder_input_length
+ )
+
+ # We finished all generations in the batch; there is no next batch
+ if stopped:
+ forward_ns = start_decode - start
+ decode_ns = time.time_ns() - start_decode
+ return generations, None, (forward_ns, decode_ns)
+
+ # We don't need input_ids after the prefill forward
+ batch.input_ids = None
+ batch.encoder_last_hidden_state = encoder_last_hidden_state
+ batch.past_key_values = past
+ # Update decoder_attention_mask as we added a new token to input_ids
+ if batch.decoder_attention_mask is not None:
+ batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
+ batch.padding_right_offset -= 1
+
+ forward_ns = start_decode - start
+ decode_ns = time.time_ns() - start_decode
+ return generations, batch, (forward_ns, decode_ns)
diff --git a/backends/gaudi/server/text_generation_server/models/starcoder.py b/backends/gaudi/server/text_generation_server/models/starcoder.py
new file mode 100644
index 000000000..6c6ca2cf9
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/starcoder.py
@@ -0,0 +1,47 @@
+import torch
+from dataclasses import dataclass
+from typing import List, Optional, Type
+
+from text_generation_server.models import CausalLM
+from text_generation_server.models.causal_lm import CausalLMBatch
+
+
+@dataclass
+class StarCoderCausalLMBatch(CausalLMBatch):
+ past_key_values: Optional[List[torch.Tensor]]
+
+ def detach_kv_cache(self):
+ past_keys = []
+ past_values = []
+ last_dim = int(self.past_key_values[0].size(dim=-1) / 2)
+ for key_value in self.past_key_values:
+ past_keys.append(key_value.split((last_dim, last_dim), dim=-1)[0])
+ past_values.append(key_value.split((last_dim, last_dim), dim=-1)[1])
+ del self.past_key_values
+
+ return past_keys, past_values
+
+ def attach_kv_cache(self, past_keys, past_values):
+ self.past_key_values = [
+ torch.cat((key, value), dim=-1)
+ for key, value in zip(past_keys, past_values)
+ ]
+
+
+class StarCoder(CausalLM):
+ def __init__(
+ self,
+ model_id: str,
+ revision: Optional[str] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+
+ super(StarCoder, self).__init__(
+ model_id=model_id,
+ revision=revision,
+ dtype=dtype,
+ )
+
+ @property
+ def batch_type(self) -> Type[CausalLMBatch]:
+ return StarCoderCausalLMBatch
diff --git a/backends/gaudi/server/text_generation_server/models/types.py b/backends/gaudi/server/text_generation_server/models/types.py
new file mode 100644
index 000000000..d4e7cca75
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/types.py
@@ -0,0 +1,102 @@
+import torch
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import List, Optional
+
+from transformers import PreTrainedTokenizerBase
+
+from text_generation_server.pb import generate_pb2
+from text_generation_server.pb.generate_pb2 import FinishReason
+
+
+class Batch(ABC):
+ @abstractmethod
+ def to_pb(self) -> generate_pb2.CachedBatch:
+ raise NotImplementedError
+
+ @classmethod
+ @abstractmethod
+ def from_pb(
+ cls,
+ pb: generate_pb2.Batch,
+ tokenizer: PreTrainedTokenizerBase,
+ dtype: torch.dtype,
+ device: torch.device,
+ ) -> "Batch":
+ raise NotImplementedError
+
+ @abstractmethod
+ def filter(self, request_ids: List[int]) -> "Batch":
+ raise NotImplementedError
+
+ @classmethod
+ @abstractmethod
+ def concatenate(cls, batches: List["Batch"]) -> "Batch":
+ raise NotImplementedError
+
+ @abstractmethod
+ def __len__(self):
+ raise NotImplementedError
+
+
+@dataclass
+class GeneratedText:
+ text: str
+ generated_tokens: int
+ finish_reason: FinishReason
+ seed: Optional[int]
+
+ def to_pb(self) -> generate_pb2.GeneratedText:
+ return generate_pb2.GeneratedText(
+ text=self.text,
+ generated_tokens=self.generated_tokens,
+ finish_reason=self.finish_reason,
+ seed=self.seed,
+ )
+
+
+@dataclass
+class Tokens:
+ token_ids: List[int]
+ logprobs: List[float]
+ texts: List[str]
+ is_special: List[bool]
+
+ def to_pb(self) -> generate_pb2.Tokens:
+ return generate_pb2.Tokens(
+ ids=self.token_ids,
+ logprobs=self.logprobs,
+ texts=self.texts,
+ is_special=self.is_special,
+ )
+
+ def __len__(self):
+ return len(self.token_ids)
+
+
+@dataclass
+class Generation:
+ request_id: int
+ prefill_tokens: Optional[Tokens]
+ tokens: Tokens
+ generated_text: Optional[GeneratedText]
+ # Optional for now, since it's not yet supported for every model.
+ top_tokens: Optional[List[Tokens]]
+
+ def to_pb(self) -> generate_pb2.Generation:
+ return generate_pb2.Generation(
+ request_id=self.request_id,
+ prefill_tokens=(
+ self.prefill_tokens.to_pb() if self.prefill_tokens is not None else None
+ ),
+ tokens=self.tokens.to_pb(),
+ generated_text=(
+ self.generated_text.to_pb() if self.generated_text is not None else None
+ ),
+ top_tokens=(
+ [top_tokens.to_pb() for top_tokens in self.top_tokens]
+ if self.top_tokens is not None
+ else None
+ ),
+ )
diff --git a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py
new file mode 100644
index 000000000..709437d93
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py
@@ -0,0 +1,1603 @@
+import re
+import torch
+import os
+import time
+import math
+from PIL import Image
+from io import BytesIO
+from opentelemetry import trace
+from loguru import logger
+from typing import Iterable, Optional, Tuple, List, Type, Dict
+import tempfile
+import copy
+from text_generation_server.models import Model
+from transformers import PreTrainedTokenizerBase
+from text_generation_server.utils.tokens import batch_top_tokens
+from text_generation_server.pb import generate_pb2
+from text_generation_server.models.causal_lm import (
+ CausalLMBatch,
+ CausalLMRequest,
+ remove_kv_cache_from_output,
+)
+
+from transformers.models.llava_next.modeling_llava_next import (
+ get_anyres_image_grid_shape,
+)
+
+from transformers import AutoProcessor
+import text_generation_server.habana_quantization_env as hq_env
+from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
+from text_generation_server.utils import (
+ HeterogeneousNextTokenChooser,
+ make_tokenizer_optional,
+ is_tokenizer_transparent,
+ pad_next_token_chooser_parameters,
+)
+import habana_frameworks.torch as htorch
+from optimum.habana.utils import HabanaProfile
+from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
+from optimum.habana.utils import get_hpu_memory_stats
+from optimum.habana.checkpoint_utils import get_ds_injection_policy
+
+from transformers import (
+ AutoTokenizer,
+ AutoConfig,
+)
+from optimum.habana.checkpoint_utils import (
+ get_repo_root,
+ model_on_meta,
+ write_checkpoints_json,
+)
+
+from text_generation_server.utils.speculate import get_speculate
+from text_generation_server.models.types import (
+ Tokens,
+ Generation,
+ GeneratedText,
+)
+from text_generation_server.utils.debug import dbg_trace
+
+tracer = trace.get_tracer(__name__)
+
+IDEFICS2_FAKE_TOKEN = ""
+IDEFICS2_IMAGE_TOKEN = ""
+
+
+IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
+BASE_IMAGE_TOKENS = int(os.environ.get("BASE_IMAGE_TOKENS", 2048))
+MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 8192))
+PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128))
+CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
+LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
+
+
+PREFILL_WARMUP_BATCH_SIZE_LIST = []
+PREFILL_WARMUP_SEQLEN_LIST = []
+DECODE_WARMUP_BATCH_SIZE_LIST = []
+CROSS_ATTENTION_LAYERS = []
+
+
+def round_up(warmup_list: list, num):
+ i = 0
+ for i in warmup_list:
+ if num <= i:
+ break
+ return i if i > 0 else num
+
+
+def split(string) -> List[Dict[str, str]]:
+ parts = []
+ cursor = 0
+ for pattern in IMAGES.finditer(string):
+ start = pattern.start()
+ if start != cursor:
+ parts.append({"type": "text", "content": string[cursor:start]})
+
+ parts.append({"type": "image", "content": pattern.group(1)})
+ cursor = pattern.end()
+
+ if cursor != len(string):
+ parts.append({"type": "text", "content": string[cursor:]})
+
+ return parts
+
+
+def image_text_replacement(config) -> str:
+ if config.model_type == "idefics2":
+ image_seq_len = 64
+ image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
+ return image_str
+ elif config.model_type == "llava_next":
+ return ""
+ elif config.model_type == "paligemma":
+ return ""
+ elif config.model_type == "mllama":
+ return "<|image|>"
+ else:
+ raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
+
+
+def image_text_replacement_fixup(config, text: str) -> str:
+ if config.model_type == "idefics2":
+ return text.replace(
+ f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
+ )
+ return text
+
+
+def get_unpadded_features(
+ original_height: int,
+ original_width: int,
+ npatches: int,
+ num_patch_height: int,
+ num_patch_width: int,
+) -> Tuple[int, int]:
+ current_height = npatches * num_patch_height
+ current_width = npatches * num_patch_width
+
+ aspect_ratio: float = original_width / original_height
+ current_aspect_ratio: float = current_width / current_height
+
+ if aspect_ratio > current_aspect_ratio:
+ new_height = (original_height * current_width) // original_width
+ padding = (current_height - new_height) // 2
+ current_height = current_height - (2 * padding)
+ else:
+ new_width = (original_width * current_height) // original_height
+ padding = (current_width - new_width) // 2
+ current_width = current_width - (2 * padding)
+
+ unpadded_features = current_height * current_width
+ newline_features = current_height
+ return (unpadded_features, newline_features)
+
+
+def get_number_of_features(height: int, width: int, config) -> int:
+ # From config
+ # Hardcoded for CLIP for now
+ # image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
+ image_grid_pinpoints = config.image_grid_pinpoints
+ image_size = config.vision_config.image_size
+ patch_size = config.vision_config.patch_size
+
+ assert image_size % patch_size == 0
+
+ npatches = image_size // patch_size
+
+ # Dimensions are intentionally swapped to be bug-compatible with
+ # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(
+ [height, width],
+ image_grid_pinpoints,
+ image_size,
+ )
+
+ unpadded_features, newline_features = get_unpadded_features(
+ height, width, npatches, num_patch_height, num_patch_width
+ )
+ # The base patch covers the entire image
+ base_features = npatches**2
+ return unpadded_features + newline_features + base_features
+
+
+class VlmCausalLMBatch(CausalLMBatch):
+ pixel_values: Optional[List[torch.Tensor]]
+ pixel_attention_mask: Optional[List[torch.Tensor]]
+ image_sizes: Optional[List[Tuple[int, int]]]
+ aspect_ratio_ids: Optional[torch.Tensor] = None
+ aspect_ratio_mask: Optional[torch.Tensor] = None
+ cross_attention_mask: Optional[torch.Tensor] = None
+ prefilling: bool = True
+ token_idx: torch.Tensor = None
+
+ def __init__(
+ self,
+ batch_id,
+ requests,
+ input_ids,
+ attention_mask,
+ position_ids,
+ past_key_values,
+ merged_kv_cache,
+ next_token_chooser,
+ top_n_tokens,
+ top_n_tokens_tensor,
+ input_length,
+ pixel_values: Optional[List[torch.Tensor]] = None,
+ pixel_attention_mask: Optional[List[torch.Tensor]] = None,
+ image_sizes: Optional[List[Tuple[int, int]]] = None,
+ aspect_ratio_ids: Optional[torch.Tensor] = None,
+ aspect_ratio_mask: Optional[torch.Tensor] = None,
+ cross_attention_mask: Optional[torch.Tensor] = None,
+ prefilling: Optional[bool] = True,
+ ):
+ super().__init__(
+ batch_id=batch_id,
+ requests=requests,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ merged_kv_cache=merged_kv_cache,
+ next_token_chooser=next_token_chooser,
+ top_n_tokens=top_n_tokens,
+ top_n_tokens_tensor=top_n_tokens_tensor,
+ input_length=input_length,
+ )
+
+ self.pixel_values = pixel_values
+ self.pixel_attention_mask = pixel_attention_mask
+ self.image_sizes = image_sizes
+ self.aspect_ratio_ids = aspect_ratio_ids
+ self.aspect_ratio_mask = aspect_ratio_mask
+ self.cross_attention_mask = cross_attention_mask
+ self.prefilling = prefilling
+
+ @property
+ def token_idx(self):
+ if self.prefilling:
+ # no right padding for prefill
+ token_idx_scalar = self.attention_mask.shape[-1] - 1
+ return torch.tensor(token_idx_scalar).to(self.attention_mask.device)
+ else:
+ token_idx_scalar = self.attention_mask.shape[-1] - self.right_padding
+ return torch.tensor(token_idx_scalar).to(self.attention_mask.device)
+
+ def padding_process(self, pad_id: int):
+ # self.input_ids = torch.index_select(self.input_ids, 1, self.token_idx - 1)
+ right_padding = MAX_TOTAL_TOKENS - self.attention_mask.shape[1]
+ self.input_ids = torch.nn.functional.pad(
+ self.input_ids, (0, right_padding), value=pad_id
+ )
+ self.attention_mask = torch.nn.functional.pad(
+ self.attention_mask, (0, right_padding), value=0
+ )
+ # if self.position_ids is not None:
+ # self.position_ids = torch.index_select(self.position_ids, 1, self.token_idx - 1) + 1
+ if self.cross_attention_mask is not None:
+ self.cross_attention_mask = torch.nn.functional.pad(
+ self.cross_attention_mask, (0, 0, 0, 0, 0, right_padding), value=0
+ )
+ if self.past is not None:
+ past_key_values_list = list(self.past_key_values)
+ for layer_id in range(len(self.past)):
+ past_key_value_list = list(self.past_key_values[layer_id])
+ if layer_id not in CROSS_ATTENTION_LAYERS:
+ past_key_value_list[0] = torch.nn.functional.pad(
+ self.past_key_values[layer_id][0],
+ (0, 0, 0, right_padding),
+ value=0,
+ )
+ past_key_value_list[1] = torch.nn.functional.pad(
+ self.past_key_values[layer_id][1],
+ (0, 0, 0, right_padding),
+ value=0,
+ )
+ past_key_values_list[layer_id] = tuple(past_key_value_list)
+ self.past_key_values = tuple(past_key_values_list)
+
+ self.prefilling = False
+ self.input_length = self.input_length
+
+ @classmethod
+ def from_tokenized(
+ cls,
+ pb: generate_pb2.Batch,
+ tokenizer: PreTrainedTokenizerBase,
+ batch_tokenized_inputs,
+ dtype: torch.dtype,
+ device: torch.device,
+ is_warmup: bool = False,
+ ) -> "VlmCausalLMBatch":
+
+ dbg_trace("FROM_PB", f"num_reqs:{len(pb.requests)}")
+ requests = [
+ CausalLMRequest.from_pb(idx, req, tokenizer)
+ for idx, req in enumerate(pb.requests)
+ ]
+
+ max_input_length = max(r.data.truncate for r in requests)
+ max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests)
+ # TODO: Add support for sparse batches
+ top_n_tokens = [r.top_n_tokens for r in pb.requests]
+ top_n_tokens_tensor = torch.tensor(
+ top_n_tokens, device=device, dtype=torch.int64
+ )
+
+ # TODO: by tokenizing all inputs at once we loose information on actual input lengths
+ # this means that we cannot shift inputs to the left after a long input sequence
+ # was filtered out
+ new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests))
+ parameters = [r.parameters for r in pb.requests]
+ # append the dummy parameters for dummy request
+ parameters = pad_next_token_chooser_parameters(parameters, new_bs)
+
+ next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
+ pb=parameters,
+ dtype=dtype,
+ device=device,
+ tokenizer=tokenizer,
+ quantization_enabled=hq_env.is_quantization_enabled,
+ )
+ tokenized_inputs = batch_tokenized_inputs
+ input_len = tokenized_inputs["input_ids"].shape[1]
+
+ bucket_size = max_input_length
+ left_padding = max_input_length - input_len
+ if is_warmup is False:
+ rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1)
+ bucket_size = rounded_seq_len - 1
+ left_padding = bucket_size - input_len
+
+ input_ids = tokenized_inputs["input_ids"]
+ attention_mask = tokenized_inputs["attention_mask"]
+ cross_attention_mask = tokenized_inputs.get("cross_attention_mask", None)
+ # Allocate space for first token
+ input_ids = torch.nn.functional.pad(
+ input_ids, (left_padding, 1), value=tokenizer.pad_token_id
+ )
+ attention_mask = torch.nn.functional.pad(
+ attention_mask, (left_padding, 1), value=0
+ )
+ if cross_attention_mask is not None:
+ cross_attention_mask = torch.nn.functional.pad(
+ cross_attention_mask, (0, 0, 0, 0, left_padding, 1), value=0
+ )
+ all_input_ids = torch.nn.functional.pad(
+ input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id
+ ).T.split(1, dim=1)
+
+ # New input length after left padding
+ input_len = bucket_size
+ for r in requests:
+ r.input_length = input_len
+ r.prefix_offset = input_len - 5
+ r.read_offset = input_len
+ r.all_input_ids = all_input_ids[r.idx]
+ input_ids = input_ids.to(device)
+ attention_mask = attention_mask.to(device)
+ cross_attention_mask = (
+ cross_attention_mask.to(device)
+ if cross_attention_mask is not None
+ else None
+ )
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ htorch.core.mark_step()
+
+ return cls(
+ batch_id=pb.id,
+ requests=requests,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=None,
+ merged_kv_cache=False,
+ next_token_chooser=next_token_chooser,
+ top_n_tokens=top_n_tokens,
+ top_n_tokens_tensor=top_n_tokens_tensor,
+ input_length=input_len,
+ cross_attention_mask=cross_attention_mask,
+ )
+
+ @classmethod
+ def batch_tokenized_inputs(
+ cls,
+ requests: Iterable[generate_pb2.Request],
+ tokenizer,
+ processor,
+ config,
+ is_warmup,
+ ):
+ image_inputs = {}
+ texts = []
+ images = []
+ batch_tokenized_inputs = {}
+
+ for i, r in enumerate(requests):
+ # Each input is encoded into a list, where each element of this input list is either a string or a URL
+ curr_text = ""
+ curr_image = None
+ for chunk in r.input_chunks.chunks:
+ chunk_type = chunk.WhichOneof("chunk")
+ if chunk_type == "text":
+ curr_text += chunk.text
+ elif chunk_type == "image":
+ image = Image.open(BytesIO(chunk.image.data))
+ # TODO unsure about BOS
+ curr_image = image
+ else:
+ raise RuntimeError(f"Invalid chunk type {chunk_type}")
+
+ if image_text_replacement(config) not in curr_text:
+ if "" in curr_text:
+ curr_text = curr_text.replace(
+ "", image_text_replacement(config)
+ )
+ else:
+ curr_text = image_text_replacement(config) + curr_text
+
+ texts.append(curr_text)
+ if curr_image is not None:
+ if config.model_type == "mllama":
+ images.append([curr_image])
+ else:
+ images.append(curr_image)
+
+ if is_warmup is True:
+ images += [images[0]] * (len(texts) - len(images))
+
+ missing_inputs = 0
+ dummy_images = None
+ if is_warmup is False:
+ new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests))
+ missing_inputs = new_bs - len(requests)
+ if missing_inputs > 0:
+ dummy_inputs = []
+ if len(texts) > 0:
+ dummy_inputs = [texts[0]] * missing_inputs
+ dummy_images = [images[0]] * missing_inputs
+ texts += dummy_inputs
+ images += dummy_images
+
+ processor_output = processor(
+ images,
+ texts,
+ truncation=True,
+ max_length=r.truncate,
+ add_special_tokens=r.add_special_tokens,
+ return_tensors="pt",
+ padding_side="left",
+ padding="longest",
+ )
+ if "input_ids" in processor_output:
+ batch_tokenized_inputs.update({"input_ids": processor_output["input_ids"]})
+ if "attention_mask" in processor_output:
+ batch_tokenized_inputs.update(
+ {"attention_mask": processor_output["attention_mask"]}
+ )
+ if "cross_attention_mask" in processor_output:
+ batch_tokenized_inputs.update(
+ {"cross_attention_mask": processor_output["cross_attention_mask"]}
+ )
+ if "pixel_values" in processor_output:
+ image_inputs.update({"pixel_values": processor_output["pixel_values"]})
+ if "pixel_attention_mask" in processor_output:
+ image_inputs.update(
+ {"pixel_attention_mask": processor_output["pixel_attention_mask"]}
+ )
+ if "aspect_ratio_ids" in processor_output:
+ image_inputs.update(
+ {"aspect_ratio_ids": processor_output["aspect_ratio_ids"]}
+ )
+ if "aspect_ratio_mask" in processor_output:
+ image_inputs.update(
+ {"aspect_ratio_mask": processor_output["aspect_ratio_mask"]}
+ )
+ if "image_sizes" in processor_output:
+ image_inputs.update({"image_sizes": processor_output["image_sizes"]})
+
+ return batch_tokenized_inputs, image_inputs
+
+ @classmethod
+ def from_pb_processor(
+ cls,
+ pb: generate_pb2.Batch,
+ tokenizer: PreTrainedTokenizerBase,
+ processor,
+ config,
+ dtype: torch.dtype,
+ device: torch.device,
+ is_warmup: bool = False,
+ ) -> "VlmCausalLMBatch":
+ batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
+ pb.requests, tokenizer, processor, config, is_warmup
+ )
+ batch = cls.from_tokenized(
+ pb, tokenizer, batch_tokenized_inputs, dtype, device, is_warmup=is_warmup
+ )
+ if image_inputs is not None:
+ batch.pixel_values = image_inputs["pixel_values"].to(device=device)
+ if "pixel_attention_mask" in image_inputs:
+ batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
+ device=device
+ )
+ else:
+ batch.pixel_attention_mask = None
+ if "image_sizes" in image_inputs:
+ batch.image_sizes = image_inputs["image_sizes"].to(device=device)
+ else:
+ batch.image_sizes = None
+ if "aspect_ratio_ids" in image_inputs:
+ batch.aspect_ratio_ids = image_inputs["aspect_ratio_ids"].to(
+ device=device
+ )
+ else:
+ batch.aspect_ratio_ids = None
+ if "aspect_ratio_mask" in image_inputs:
+ batch.aspect_ratio_mask = image_inputs["aspect_ratio_mask"].to(
+ device=device
+ )
+ else:
+ batch.aspect_ratio_mask = None
+ else:
+ batch.pixel_values = None
+ batch.pixel_attention_mask = None
+ batch.image_sizes = None
+ batch.aspect_ratio_ids = None
+ batch.aspect_ratio_mask = None
+ batch.cross_attention_mask = None
+
+ return batch
+
+ @classmethod
+ @tracer.start_as_current_span("concatenate")
+ def concatenate(
+ cls,
+ batches: List["CausalLMBatch"],
+ pad_token_id: int = 0,
+ is_warmup: bool = False,
+ ) -> "CausalLMBatch":
+ return cls.recombine(batches, pad_token_id, is_warmup)
+
+ @classmethod
+ def recombine(
+ cls,
+ batches: List["VlmCausalLMBatch"],
+ pad_token_id: int,
+ is_warmup: bool = False,
+ ) -> "VlmCausalLMBatch":
+ if not all(b.past_key_values is not None for b in batches):
+ raise ValueError("KV cache not allocated! Cannot recombine before prefill!")
+ # Used for padding
+
+ total_requests = sum(len(b) for b in batches)
+ new_bs = total_requests
+ if not is_warmup:
+ new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, total_requests)
+
+ if len(batches) > 1:
+ scenario = "CONCAT"
+ elif batches[0].prefilling:
+ scenario = "SHIFT"
+ else:
+ return batches[0]
+
+ dbg_trace(
+ scenario,
+ f"bs:{[b.batch_size for b in batches]}->{new_bs}"
+ f" reqs:{[len(b) for b in batches]}",
+ )
+
+ if scenario == "SHIFT":
+ batch = batches[0]
+ batch.padding_process(pad_token_id)
+ return batch
+
+ total_batch_size = 0
+ max_input_length = 0
+ for i, batch in enumerate(batches):
+ total_batch_size += len(batch)
+ max_input_length = max(max_input_length, batch.input_length)
+ # Batch attributes
+ requests = []
+ input_lengths = []
+ top_n_tokens = []
+ parameters = []
+ fsm_grammar_states = []
+
+ # Batch tensors
+ input_ids = None
+ attention_mask = None
+ position_ids = None
+ past_key_values = []
+ top_n_tokens_tensor = None
+ cross_attention_mask = None
+ # Used for slicing correctly inside the tensors
+ # Equivalent to a cumsum on batch sizes
+ start_index = 0
+ for i, batch in enumerate(batches):
+ keep_indices = []
+ for req in batch.requests:
+ keep_indices.append(req.idx)
+
+ requests.extend(batch.requests)
+ parameters.extend([r.data.parameters for r in batch.requests])
+ fsm_grammar_states.extend(
+ [batch.next_token_chooser.fsm_grammar_states[i] for i in keep_indices]
+ )
+ input_lengths.extend([batch.input_length])
+ top_n_tokens.extend([batch.top_n_tokens[i] for i in keep_indices])
+
+ # Slicing end index for this batch
+ end_index = start_index + len(batch)
+
+ # We only concatenate batches that did at least one step
+ if batch.past_key_values is None:
+ raise ValueError("only concatenate prefilled batches")
+
+ # Create empty tensor
+ # input_ids is always of shape [batch_size, 1]
+ # We do not need to pad it
+ if input_ids is None:
+ input_ids = batch.input_ids.new_empty((new_bs, MAX_TOTAL_TOKENS))
+ # # Copy to correct indices
+
+ left_offset = max_input_length - batch.input_length
+ right_padding = MAX_TOTAL_TOKENS - max_input_length
+ input_ids[start_index:end_index, left_offset:-right_padding] = (
+ batch.input_ids[keep_indices, : batch.input_length]
+ )
+
+ # Create padded tensor
+ if top_n_tokens_tensor is None:
+ top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
+ new_bs,
+ )
+ top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor[
+ keep_indices
+ ]
+
+ if attention_mask is None:
+ attention_mask = batch.attention_mask.new_zeros(
+ (new_bs, MAX_TOTAL_TOKENS),
+ )
+
+ attention_mask[
+ start_index:end_index,
+ left_offset:-right_padding,
+ ] = batch.attention_mask[
+ keep_indices,
+ : batch.input_length,
+ ]
+
+ if batch.cross_attention_mask is not None:
+ cross_attention_mask_shape = list(batch.cross_attention_mask.shape)
+ cross_attention_mask_shape[1] = MAX_TOTAL_TOKENS
+ cross_attention_mask_shape[0] = new_bs
+ cross_attention_mask_shape = torch.Size(cross_attention_mask_shape)
+ if cross_attention_mask is None:
+ cross_attention_mask = batch.cross_attention_mask.new_zeros(
+ cross_attention_mask_shape,
+ )
+ cross_attention_mask[
+ start_index:end_index,
+ left_offset:-right_padding,
+ ] = batch.cross_attention_mask[
+ keep_indices,
+ : batch.input_length,
+ ]
+
+ # Create empty tensor
+ # position_ids is always of shape [batch_size, 1]
+ if position_ids is None:
+ position_ids = batch.position_ids.new_empty((new_bs, 1))
+ position_ids[start_index:end_index] = batch.position_ids[keep_indices, :]
+
+ # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
+ # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
+ # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
+ # And ensure that we can update tensors in-place
+ if isinstance(batch.past_key_values, tuple):
+ batch.past_key_values = [
+ [t.view(batch.batch_size, -1, *t.shape[-2:]) for t in layer]
+ for layer in batch.past_key_values
+ ]
+ elif len(batch.past_key_values[0][0].shape) == 3:
+ for layer in batch.past_key_values:
+ for k, t in enumerate(layer):
+ layer[k] = t.view(batch.batch_size, -1, *t.shape[-2:])
+
+ start_index = end_index
+
+ first_past_kvs = batches[0].past_key_values
+ _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape
+ past_key_values = []
+ for layer_id in range(len(batches[0].past_key_values)):
+ if layer_id in CROSS_ATTENTION_LAYERS:
+ padded_past_keys_shape = list(
+ batches[0].past_key_values[layer_id][0].shape
+ )
+ padded_past_keys_shape[0] = new_bs
+ padded_past_keys_shape = torch.Size(padded_past_keys_shape)
+ else:
+ padded_past_keys_shape = (
+ new_bs,
+ num_heads,
+ MAX_TOTAL_TOKENS,
+ head_dim,
+ )
+
+ padded_past_keys = first_past_kvs[layer_id][0].new_zeros(
+ padded_past_keys_shape
+ )
+ padded_past_values = first_past_kvs[layer_id][1].new_zeros(
+ padded_past_keys_shape
+ )
+ start_index = 0
+ for batch in batches:
+ keep_indices = []
+ for req in batch.requests:
+ keep_indices.append(req.idx)
+
+ left_offset = max_input_length - batch.input_length
+ right_padding = MAX_TOTAL_TOKENS - max_input_length
+ past_keys = batch.past_key_values[layer_id][0]
+ past_values = batch.past_key_values[layer_id][1]
+ # Clear reference to the original tensor
+ batch.past_key_values[layer_id] = None
+
+ # Slicing end index for this batch
+ end_index = start_index + len(batch)
+ # We slice the keys to remove the padding from previous batches
+ if layer_id in CROSS_ATTENTION_LAYERS:
+ padded_past_keys[start_index:end_index, :, :, :] = past_keys[
+ keep_indices, :, :, :
+ ]
+ padded_past_values[start_index:end_index, :, :, :] = past_values[
+ keep_indices, :, :, :
+ ]
+
+ else:
+ padded_past_keys[
+ start_index:end_index, :, left_offset:-right_padding, :
+ ] = past_keys[keep_indices, :, : batch.input_length, :]
+ padded_past_values[
+ start_index:end_index, :, left_offset:-right_padding, :
+ ] = past_values[keep_indices, :, : batch.input_length, :]
+
+ start_index = end_index
+
+ past_key_values.append(tuple([padded_past_keys, padded_past_values]))
+ past_key_values = tuple(past_key_values)
+
+ batch_id = batches[0].batch_id
+ top_n_tokens.extend([-1] * (new_bs - total_batch_size))
+ fsm_grammar_states.extend([-1] * (new_bs - total_batch_size))
+
+ for idx, req in enumerate(requests):
+ req.idx = idx
+
+ parameters = pad_next_token_chooser_parameters(parameters, new_bs)
+ next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
+ parameters,
+ batches[0].next_token_chooser.dtype,
+ batches[0].next_token_chooser.device,
+ batches[0].next_token_chooser.tokenizer,
+ fsm_grammar_states,
+ quantization_enabled=hq_env.is_quantization_enabled,
+ )
+ input_length = max_input_length
+
+ htorch.core.mark_step()
+
+ return cls(
+ batch_id=batch_id,
+ requests=requests,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ merged_kv_cache=False,
+ next_token_chooser=next_token_chooser,
+ top_n_tokens=top_n_tokens,
+ top_n_tokens_tensor=top_n_tokens_tensor,
+ input_length=input_length,
+ pixel_values=None,
+ pixel_attention_mask=None,
+ image_sizes=None,
+ aspect_ratio_ids=None,
+ aspect_ratio_mask=None,
+ cross_attention_mask=cross_attention_mask,
+ prefilling=False,
+ )
+
+
+class VlmCausalLM(Model):
+ def __init__(
+ self,
+ model_class,
+ model_id: str,
+ *,
+ processor_class=AutoProcessor,
+ processor_kwargs=None,
+ batch_class=VlmCausalLMBatch,
+ revision,
+ quantize: Optional[str] = None,
+ dtype,
+ trust_remote_code: bool,
+ **kwargs,
+ ):
+ adapt_transformers_to_gaudi()
+ if processor_kwargs is None:
+ processor_kwargs = {}
+ self.processor = processor_class.from_pretrained(
+ model_id,
+ revision=revision,
+ trust_remote_code=trust_remote_code,
+ **processor_kwargs,
+ )
+ self.batch_class = batch_class
+ self.prev_bs = 0
+ self.quantize = quantize
+
+ # Create tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_id,
+ revision=revision,
+ padding_side="left",
+ truncation_side="left",
+ trust_remote_code=trust_remote_code,
+ )
+ make_tokenizer_optional(tokenizer)
+
+ # Create model
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
+ rank = int(os.getenv("RANK", "0"))
+ dtype = torch.bfloat16 if dtype is None else dtype
+ device = torch.device("hpu")
+
+ if hq_env.is_quantization_enabled:
+ htorch.core.hpu_set_env()
+
+ if world_size > 1:
+ os.environ.setdefault(
+ "DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1"
+ )
+ model = self.get_deepspeed_model(model_class, model_id, dtype, revision)
+ model = hq_env.prepare_model_for_quantization(model)
+ else:
+ get_repo_root(model_id)
+
+ # Check support for rope scaling
+ model_kwargs = {}
+ config = AutoConfig.from_pretrained(model_id)
+ if hasattr(config, "rope_scaling"):
+ model_kwargs["rope_scaling"] = self.get_rope_scaling()
+
+ model = model_class.from_pretrained(
+ model_id,
+ revision=revision,
+ torch_dtype=dtype,
+ trust_remote_code=trust_remote_code,
+ **model_kwargs,
+ )
+ model = hq_env.prepare_model_for_quantization(model)
+ model = model.eval().to(device)
+
+ self.enable_hpu_graph = (
+ os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
+ )
+ self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "true").lower() == "true"
+ model = remove_kv_cache_from_output(model)
+ if self.enable_hpu_graph:
+ from habana_frameworks.torch.hpu import wrap_in_hpu_graph
+
+ model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
+ else:
+ if LAZY_MODE == 0:
+ # It is said that "keep_input_mutations" is safe for inference to be done
+ dbg_trace("TORCH COMPILE", "Torch compiling of model")
+ model.model = torch.compile(
+ model.model,
+ backend="hpu_backend",
+ options={"keep_input_mutations": True},
+ )
+
+ model = hq_env.setup_quantization(model)
+
+ if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
+ raise ValueError(f"Model type {model.config.model_type} is not supported!")
+
+ if tokenizer.pad_token_id is None:
+ if model.config.pad_token_id is not None:
+ tokenizer.pad_token_id = model.config.pad_token_id
+ elif model.config.eos_token_id is not None:
+ if isinstance(model.config.eos_token_id, int):
+ tokenizer.pad_token_id = model.config.eos_token_id
+ elif isinstance(model.config.eos_token_id, list):
+ tokenizer.pad_token_id = model.config.eos_token_id[0]
+ else:
+ raise ValueError(
+ f"{type(model.config.eos_token_id)} type of eos_token_id in the model's config is not supported for tokenizer.pad_token_id"
+ )
+ elif tokenizer.eos_token_id is not None:
+ tokenizer.pad_token_id = tokenizer.eos_token_id
+ else:
+ tokenizer.add_special_tokens({"pad_token": "[PAD]"})
+
+ self.kwargs = {
+ "use_cache": True,
+ "return_dict": True,
+ }
+
+ if model.config.model_type in ["llava_next"]:
+ self.kwargs["attn_softmax_bf16"] = True
+ self.kwargs["trim_logits"] = True
+
+ if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "true":
+ self.kwargs["use_flash_attention"] = True
+ if os.getenv("FLASH_ATTENTION_RECOMPUTE", "true").lower() == "true":
+ self.kwargs["flash_attention_recompute"] = True
+
+ self.speculate = get_speculate()
+ if model.config.model_type == "mllama":
+ global CROSS_ATTENTION_LAYERS, BASE_IMAGE_TOKENS
+ CROSS_ATTENTION_LAYERS = model.config.text_config.cross_attention_layers
+ BASE_IMAGE_TOKENS = 0
+
+ super(VlmCausalLM, self).__init__(
+ model_id=model_id,
+ model=model,
+ tokenizer=tokenizer,
+ requires_padding=True,
+ dtype=dtype,
+ device=device,
+ rank=rank,
+ )
+
+ # Create profiler
+ ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(",")]
+ record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true"
+ output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile")
+ self.profiling_warmup_steps = (
+ int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0
+ )
+ self.profiling_steps = (
+ int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0
+ )
+ self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0"))
+ if self.profiling_steps > 0:
+ self.hb_profiler = HabanaProfile(
+ wait=self.profiling_wait_steps,
+ warmup=self.profiling_warmup_steps,
+ active=self.profiling_steps,
+ output_dir=output_dir,
+ record_shapes=record_shapes,
+ )
+ self.hb_profiler.start()
+ else:
+ self.hb_profiler = None
+ self.step = 0
+
+ @property
+ def batch_type(self) -> Type[VlmCausalLMBatch]:
+ return self.batch_class
+
+ def max_past(self) -> Optional[int]:
+ return getattr(self.model.text_model, "max_past", None)
+
+ def get_deepspeed_model(
+ self,
+ model_class,
+ model_id: str,
+ dtype: torch.dtype,
+ revision: Optional[str] = None,
+ ) -> torch.nn.Module:
+ import deepspeed
+ from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
+
+ world_size, rank, local_rank = initialize_distributed_hpu()
+ model_kwargs = {"revision": revision}
+
+ # Initialize process(es) for DeepSpeed
+ deepspeed.init_distributed(dist_backend="hccl")
+ logger.info(
+ "DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(
+ world_size, rank, local_rank
+ )
+ )
+ config = AutoConfig.from_pretrained(model_id, **model_kwargs)
+ load_to_meta = model_on_meta(config)
+
+ # Check support for rope scaling
+ if hasattr(config, "rope_scaling"):
+ config.rope_scaling = self.get_rope_scaling()
+ model_kwargs["rope_scaling"] = self.get_rope_scaling()
+
+ if load_to_meta:
+ # Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load
+ with deepspeed.OnDevice(dtype=dtype, device="meta"):
+ model = model_class.from_config(config, torch_dtype=dtype)
+ else:
+ get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK"))
+ # TODO: revisit placement on CPU when auto-injection is possible
+ with deepspeed.OnDevice(dtype=dtype, device="cpu"):
+ model = model_class.from_pretrained(
+ model_id, torch_dtype=dtype, **model_kwargs
+ )
+ model = model.eval()
+
+ # Initialize the model
+ ds_inference_kwargs = {"dtype": dtype}
+ ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size}
+ ds_inference_kwargs["enable_cuda_graph"] = False
+ ds_inference_kwargs["injection_policy"] = get_ds_injection_policy(
+ model.language_model.config
+ )
+
+ if load_to_meta:
+ # model loaded to meta is managed differently
+ checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
+ write_checkpoints_json(model_id, local_rank, checkpoints_json)
+ ds_inference_kwargs["checkpoint"] = checkpoints_json.name
+ model = deepspeed.init_inference(model, **ds_inference_kwargs)
+
+ return model.module
+
+ def get_rope_scaling(self) -> Optional[Dict]:
+ rope_scaling = os.getenv("ROPE_SCALING", None)
+ if rope_scaling is None:
+ return None
+
+ rope_factor = float(os.getenv("ROPE_FACTOR", 1.0))
+ return {"type": rope_scaling, "factor": float(rope_factor)}
+
+ def decode(self, generated_ids: List[int]) -> str:
+ return self.tokenizer.decode(
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
+ )
+
+ def decode_token(
+ self,
+ all_input_ids: List[int],
+ prefix_offset: int = 0,
+ read_offset: int = 0,
+ ) -> Tuple[str, int, int]:
+ if is_tokenizer_transparent(self.tokenizer):
+ new_text = self.tokenizer.decode(
+ all_input_ids[read_offset:], skip_special_tokens=False
+ )
+ return new_text, read_offset, len(all_input_ids)
+ else:
+ return super().decode_token(all_input_ids, prefix_offset, read_offset)
+
+ def forward(
+ self,
+ batch: VlmCausalLMBatch,
+ bypass_hpu_graph: Optional[bool] = None,
+ ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
+ # Model Forward
+ kwargs = {
+ "input_ids": batch.input_ids,
+ "attention_mask": batch.attention_mask,
+ "past_key_values": batch.past_key_values,
+ "token_idx": batch.token_idx,
+ "pixel_values": batch.pixel_values,
+ }
+
+ if self.model.config.model_type == "mllama":
+ kwargs["aspect_ratio_ids"] = batch.aspect_ratio_ids
+ kwargs["aspect_ratio_mask"] = batch.aspect_ratio_mask
+ kwargs["cross_attention_mask"] = batch.cross_attention_mask
+ else:
+ kwargs["image_sizes"] = batch.image_sizes
+
+ hpu_kwargs = {}
+ # Optimum Habana got "lazy_mode" key-val only supported for llama type of models
+ if self.model.config.model_type == "llama":
+ hpu_kwargs["lazy_mode"] = LAZY_MODE == 1
+
+ if self.has_position_ids:
+ kwargs["position_ids"] = batch.position_ids
+ if bypass_hpu_graph is not None:
+ hpu_kwargs["bypass_hpu_graphs"] = bypass_hpu_graph
+
+ kwargs.update(self.kwargs)
+ model_inputs = self.model.prepare_inputs_for_generation(**kwargs)
+
+ if batch.past_key_values is not None:
+ return self.model.forward(**model_inputs, **hpu_kwargs)
+ else:
+ outputs = self.model.forward(**model_inputs, **hpu_kwargs)
+ return outputs.logits, outputs.past_key_values
+
+ @tracer.start_as_current_span("generate_token")
+ def generate_token(
+ self, batches: list[VlmCausalLMBatch], is_warmup: bool = False
+ ) -> Tuple[List[Generation], Optional[VlmCausalLMBatch], Tuple[int, int]]:
+
+ start = time.time_ns()
+ # Results
+ generations: List[Generation] = []
+ prev_batches = []
+ requests_to_generate = []
+ # In order to pipeline any actions on CPU we perform the operation in 3 main stages:
+ # Stage 1. Collect next token ids of any previously started generations
+ for batch_id, batch in enumerate(batches):
+ if batch.logits is not None:
+ logits = batch.logits
+ past = batch.past
+ prefill = batch.past_key_values is None
+ if prefill:
+ # no right padding for prefill
+ token_idx_scalar = batch.attention_mask.shape[-1] - 1
+ token_idx = torch.tensor(token_idx_scalar).to(self.device)
+ else:
+ token_idx_scalar = (
+ batch.attention_mask.shape[-1] - batch.right_padding
+ )
+ token_idx = torch.tensor(token_idx_scalar).to(self.device)
+
+ # Select next token
+ input_length = batch.input_length
+ if logits.shape[-2] > 1:
+ next_token_ids, next_token_logprobs, logprobs, _, _ = (
+ batch.next_token_chooser(
+ batch.input_ids,
+ logits[:, input_length - 1 : input_length, :].squeeze(-2),
+ self.speculate,
+ )
+ )
+ else:
+ next_token_ids, next_token_logprobs, logprobs, _, _ = (
+ batch.next_token_chooser(
+ batch.input_ids, logits.squeeze(-2), self.speculate
+ )
+ )
+ # Speculation is not active for causal
+ accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
+ batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
+ batch.top_n_tokens,
+ batch.top_n_tokens_tensor,
+ logprobs,
+ accepted_ids,
+ )
+
+ prev_batches.append(
+ {
+ "next_token_ids": next_token_ids,
+ "next_token_logprobs": next_token_logprobs,
+ }
+ )
+
+ for req_idx, req in enumerate(batch.requests):
+ requests_to_generate.append(
+ {
+ "req": req,
+ "prev_req_idx": req.idx,
+ "batch_id": batch_id,
+ "seed": batch.next_token_chooser.seeds[req_idx],
+ "do_sample": batch.next_token_chooser.do_sample[req_idx],
+ "top_n_tokens": batch.top_n_tokens[req_idx],
+ "top_token_ids": batch_top_token_ids[req_idx],
+ "top_token_logprobs": batch_top_token_logprobs[req_idx],
+ "grammar_state": batch.next_token_chooser.fsm_grammar_states[
+ req.idx
+ ],
+ }
+ )
+
+ htorch.core.mark_step()
+
+ # Add new token into input_ids
+ batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1))
+
+ # Update attention_mask as we added a new token to input_ids
+ batch.attention_mask.index_fill_(1, token_idx, 1)
+
+ # add cross-attn mask for new token
+ if batch.cross_attention_mask is not None:
+ cross_attention_mask_prev = batch.cross_attention_mask
+ if token_idx is not None:
+ mask = cross_attention_mask_prev[
+ :, token_idx - 2 : token_idx - 1, ...
+ ]
+ cross_attention_mask_prev.index_copy_(1, token_idx - 1, mask)
+ batch.cross_attention_mask = cross_attention_mask_prev
+
+ # Adjust lengths
+ batch.input_length += 1
+ # Update position_ids
+ if prefill:
+ batch.position_ids = (
+ torch.index_select(batch.position_ids, 1, token_idx - 1) + 1
+ )
+ else:
+ batch.position_ids += 1
+ # Update past key values
+ if prefill:
+ batch.past_key_values = past
+
+ htorch.core.mark_step()
+
+ # Stage 2. Prepare new batch for speculative scheduling
+ if len(batches) > 1:
+ batch = self.batch_type.concatenate(
+ batches, self.tokenizer.pad_token_id, is_warmup
+ )
+ else:
+ batch = batches[0]
+
+ prefill = batch.past_key_values is None
+
+ # Check if we need to do any bookkeeping first
+ if not prefill:
+ batch = self.batch_type.recombine(
+ [batch], self.tokenizer.pad_token_id, is_warmup
+ )
+
+ scenario = "PREFILL" if prefill else "GENERATE"
+ if (
+ self.enable_hpu_graph
+ and self.limit_hpu_graph
+ and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size)
+ != self.prev_bs
+ ):
+ self.model.clear_cache()
+ self.prev_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size)
+ dbg_trace(
+ scenario,
+ f"bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}",
+ )
+ # assert batch.right_padding > 0, 'No more room for next token!'
+
+ # Execute batch
+ if prefill:
+ # no right padding for prefill
+ # token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device)
+ batch.logits, batch.past = self.forward(
+ batch,
+ bypass_hpu_graph=(
+ prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
+ ),
+ )
+
+ elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]):
+ # Don't schedule next forward if max_new_tokens for all requests equals 1
+ # - we've already generated the first and only needed token in the prefill phase
+ pass
+ else:
+ # token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
+ batch.logits = self.forward(
+ batch,
+ bypass_hpu_graph=(
+ prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
+ ),
+ )
+
+ if batch.pixel_values is not None:
+ batch.pixel_values = None
+ if batch.aspect_ratio_ids is not None:
+ batch.aspect_ratio_ids = None
+ if batch.aspect_ratio_mask is not None:
+ batch.aspect_ratio_mask = None
+
+ htorch.core.mark_step()
+
+ start_decode = time.time_ns()
+
+ # Stage 3. Finish and return previous generations
+ stopped = len(requests_to_generate) > 0
+ for prev_batch in prev_batches:
+ prev_batch["next_token_logprobs"] = prev_batch[
+ "next_token_logprobs"
+ ].tolist()
+ prev_batch["next_token_ids_cpu"] = prev_batch["next_token_ids"].cpu()
+ htorch.core.mark_step()
+
+ for req_data in requests_to_generate:
+ req = req_data["req"]
+ i = req_data["prev_req_idx"]
+ prev_batch_id = req_data["batch_id"]
+ assert len(prev_batches) > prev_batch_id
+ next_token_ids_cpu = prev_batches[prev_batch_id]["next_token_ids_cpu"]
+ next_token_logprobs = prev_batches[prev_batch_id]["next_token_logprobs"]
+
+ request = req.data
+ input_length = req.input_length
+ prefix_offset = req.prefix_offset
+ read_offset = req.read_offset
+ do_sample = req_data["do_sample"]
+ seed = req_data["seed"]
+ stopping_criteria = req.stopping_criteria
+ all_input_ids = req.all_input_ids
+ next_token_id = next_token_ids_cpu[i]
+ next_token_logprob = next_token_logprobs[i]
+ top_n_tokens = req_data["top_n_tokens"]
+ top_token_ids = req_data["top_token_ids"]
+ top_token_logprobs = req_data["top_token_logprobs"]
+ grammar_state = req_data["grammar_state"]
+
+ # Append next token to all tokens
+ all_input_ids[input_length] = next_token_id
+ new_input_length = input_length + 1
+
+ # Generated token
+ if (
+ is_tokenizer_transparent(self.tokenizer)
+ and len(stopping_criteria.stop_sequence_criterias) == 0
+ ):
+ next_token_text = ""
+ else:
+ next_token_text, prefix_offset, read_offset = self.decode_token(
+ all_input_ids[0:new_input_length, 0], prefix_offset, read_offset
+ )
+
+ # Evaluate stopping criteria
+ stop, reason = stopping_criteria(
+ next_token_id,
+ next_token_text,
+ )
+
+ if not stop:
+ stopped = False
+
+ # Shard generations
+ # All generations will be appended in the rust sharded client
+ if i % self.world_size == self.rank:
+ if stop:
+ # Decode generated tokens
+ if is_tokenizer_transparent(self.tokenizer):
+ output_text = None
+ else:
+ output_text = self.decode(
+ all_input_ids[
+ new_input_length
+ - stopping_criteria.current_tokens : new_input_length,
+ 0,
+ ]
+ )
+ generated_text = GeneratedText(
+ output_text,
+ stopping_criteria.current_tokens,
+ reason,
+ seed if do_sample else None,
+ )
+ else:
+ generated_text = None
+
+ # Prefill
+ if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
+ # Remove generated token to only have prefill and add nan for first prompt token
+ prefill_logprobs = [float("nan")] + next_token_logprobs
+ prefill_token_ids = all_input_ids[0 : new_input_length - 1]
+ prefill_texts = self.tokenizer.batch_decode(
+ prefill_token_ids,
+ clean_up_tokenization_spaces=False,
+ skip_special_tokens=False,
+ )
+ prefill_tokens = Tokens(
+ prefill_token_ids,
+ prefill_logprobs,
+ prefill_texts,
+ is_special=[],
+ )
+ else:
+ prefill_tokens = None
+
+ if top_n_tokens > 0:
+ all_top_tokens = []
+ for top_token_ids, top_token_logprobs in zip(
+ top_token_ids, top_token_logprobs
+ ):
+ toptoken_texts = self.tokenizer.batch_decode(
+ top_token_ids,
+ clean_up_tokenization_spaces=False,
+ skip_special_tokens=False,
+ )
+ special_toptokens = [
+ token_id in self.all_special_ids
+ for token_id in top_token_ids
+ ]
+ top_tokens = Tokens(
+ top_token_ids,
+ top_token_logprobs,
+ toptoken_texts,
+ special_toptokens,
+ )
+ all_top_tokens.append(top_tokens)
+ top_tokens = all_top_tokens
+ else:
+ top_tokens = None
+
+ generation = Generation(
+ request.id,
+ prefill_tokens,
+ Tokens(
+ [next_token_id],
+ [next_token_logprob],
+ [next_token_text],
+ [next_token_id in self.all_special_ids],
+ ),
+ generated_text,
+ top_tokens,
+ )
+
+ generations.append(generation)
+
+ batch.next_token_chooser = (
+ batch.next_token_chooser.advance_grammar_single_with_past_state(
+ req.idx, next_token_id, grammar_state
+ )
+ )
+
+ req.all_input_ids = all_input_ids
+ req.input_length = new_input_length
+ req.prefix_offset = prefix_offset
+ req.read_offset = read_offset
+
+ htorch.core.mark_step()
+ self.step = self.step + 1
+ if self.hb_profiler is not None:
+ if (
+ self.step
+ > self.profiling_wait_steps
+ + self.profiling_warmup_steps
+ + self.profiling_steps
+ ):
+ self.hb_profiler.stop()
+ else:
+ self.hb_profiler.step()
+
+ forward_ns = start_decode - start
+ decode_ns = time.time_ns() - start_decode
+ return generations, batch if not stopped else None, (forward_ns, decode_ns)
+
+ def batch_from_pb(self, batch, is_warmup):
+ return self.batch_type.from_pb_processor(
+ batch,
+ self.tokenizer,
+ self.processor,
+ self.model.config,
+ self.dtype,
+ self.device,
+ is_warmup,
+ )
+
+ def generate_warmup_batch(self, request, seq_len, batch_size, is_warmup):
+ batch = copy.deepcopy(request.batch)
+ for req in batch.requests:
+ req.truncate = seq_len
+
+ for i in range(len(batch.requests) - batch_size):
+ batch.requests.pop()
+
+ return self.batch_from_pb(batch, is_warmup)
+
+ def warmup(
+ self, request: generate_pb2.WarmupRequest
+ ) -> Tuple[Optional[int], Optional[int], Optional[int]]:
+ global MAX_TOTAL_TOKENS
+ MAX_TOTAL_TOKENS = request.max_total_tokens
+ batch = self.batch_from_pb(request.batch, is_warmup=True)
+ max_input_tokens = request.max_input_tokens
+ max_prefill_batch_size = batch.input_ids.shape[0]
+ max_batch_size_str = os.environ.get("MAX_BATCH_SIZE")
+ if max_batch_size_str is not None:
+ MAX_BATCH_SIZE = int(max_batch_size_str)
+ else:
+ raise ValueError("MAX_BATCH_SIZE is not set")
+
+ try:
+ # max prefill batch size warmup
+ _, prefill_batch, _ = self.generate_token([batch], is_warmup=True)
+ except Exception:
+ raise RuntimeError(
+ f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
+ f"You need to decrease `--max-batch-prefill-tokens`"
+ )
+
+ global BASE_IMAGE_TOKENS, PREFILL_WARMUP_BATCH_SIZE_LIST, PREFILL_WARMUP_SEQLEN_LIST, DECODE_WARMUP_BATCH_SIZE_LIST
+ PREFILL_WARMUP_BATCH_SIZE_LIST = []
+ batch_size = 1
+ while batch_size <= max_prefill_batch_size:
+ PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size)
+ batch_size = batch_size * 2
+ if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size:
+ PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size)
+
+ if self.model.config.model_type == "mllama":
+ seq_len = PAD_SEQUENCE_TO_MULTIPLE_OF
+ else:
+ seq_len = BASE_IMAGE_TOKENS
+
+ PREFILL_WARMUP_SEQLEN_LIST = []
+ i = 0
+ while seq_len <= max_input_tokens:
+ PREFILL_WARMUP_SEQLEN_LIST.append(seq_len)
+ seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF * (2**i)
+ i += 1
+ if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_tokens:
+ PREFILL_WARMUP_SEQLEN_LIST.append(max_input_tokens)
+
+ # Prefill and decode warmup
+ DECODE_WARMUP_BATCH_SIZE_LIST = []
+ prefill_batch = None
+ decode_batch = None
+ try:
+ for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST:
+ for seq_len in PREFILL_WARMUP_SEQLEN_LIST:
+ batch = self.generate_warmup_batch(
+ request, seq_len, batch_size, is_warmup=True
+ )
+ _, prefill_batch, _ = self.generate_token([batch], is_warmup=True)
+ assert prefill_batch is not None
+ _, decode_batch, _ = self.generate_token(
+ [prefill_batch], is_warmup=True
+ )
+
+ DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
+
+ except Exception:
+ raise RuntimeError(
+ f"Not enough memory to handle following prefill and decode warmup."
+ f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}"
+ f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}"
+ f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}"
+ f"You need to decrease `--max-batch-prefill-tokens`"
+ )
+
+ mem_stats = get_hpu_memory_stats(self.device)
+ logger.info(
+ f"\nFollowing prefill and decode warmup successfully.\n"
+ f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n"
+ f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n"
+ f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n"
+ f"Memory stats: {mem_stats} "
+ )
+
+ max_decode_batch_size = MAX_BATCH_SIZE
+ batch_size = max_prefill_batch_size * 2
+ # Decode warmup with bigger batch_size
+ try:
+ if (
+ DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size
+ and batch_size <= max_decode_batch_size
+ ):
+ batches = []
+ while batch_size <= max_decode_batch_size:
+ for i in range(int(batch_size / max_prefill_batch_size)):
+ batch = self.generate_warmup_batch(
+ request,
+ PREFILL_WARMUP_SEQLEN_LIST[0] - 1,
+ max_prefill_batch_size,
+ is_warmup=True,
+ )
+ _, prefill_batch, _ = self.generate_token(
+ [batch], is_warmup=True
+ )
+ batches.append(prefill_batch)
+
+ _, decode_batch, _ = self.generate_token(batches, is_warmup=True)
+ DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
+ batch_size = batch_size * 2
+ batches.clear()
+
+ if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size:
+ max_decode_batch_size = math.floor(max_decode_batch_size / 2) * 2
+ batch_size = max_decode_batch_size
+ for i in range(int(max_decode_batch_size / 2)):
+ batch = self.generate_warmup_batch(
+ request,
+ PREFILL_WARMUP_SEQLEN_LIST[0] - 1,
+ 2,
+ is_warmup=True,
+ )
+ _, prefill_batch, _ = self.generate_token(
+ [batch], is_warmup=True
+ )
+ batches.append(prefill_batch)
+ _, decode_batch, _ = self.generate_token(batches, is_warmup=True)
+ DECODE_WARMUP_BATCH_SIZE_LIST.append(max_decode_batch_size)
+
+ except Exception:
+ raise RuntimeError(
+ f"Not enough memory to handle batch_size({batch_size}) decode warmup."
+ f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}"
+ f"max_decode_batch_size is {max_decode_batch_size}"
+ f"You need to decrease env `MAX_BATCH_SIZE` or '--max_batch_size'"
+ )
+
+ mem_stats = get_hpu_memory_stats(self.device)
+ logger.info(
+ f"\nFollowing decode warmup successfully.\n"
+ f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n"
+ f"Memory stats: {mem_stats}"
+ )
+
+ max_supported_total_tokens = MAX_BATCH_SIZE * MAX_TOTAL_TOKENS
+ max_input_tokens = max_input_tokens
+ max_total_tokens = MAX_TOTAL_TOKENS
+
+ return max_supported_total_tokens, max_input_tokens, max_total_tokens
diff --git a/backends/gaudi/server/text_generation_server/pb/.gitignore b/backends/gaudi/server/text_generation_server/pb/.gitignore
new file mode 100644
index 000000000..5a68d6313
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/pb/.gitignore
@@ -0,0 +1,3 @@
+*.py
+*.pyi
+*.py-e
diff --git a/backends/gaudi/server/text_generation_server/server.py b/backends/gaudi/server/text_generation_server/server.py
new file mode 100644
index 000000000..5a7d21175
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/server.py
@@ -0,0 +1,331 @@
+# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
+
+import asyncio
+import os
+import torch
+import time
+import signal
+
+from grpc import aio
+from loguru import logger
+
+from grpc_reflection.v1alpha import reflection
+from pathlib import Path
+from typing import List, Optional
+
+from text_generation_server.cache import Cache
+from text_generation_server.interceptor import ExceptionInterceptor
+from text_generation_server.models import Model, get_model_with_lora_adapters
+from text_generation_server.pb import generate_pb2_grpc, generate_pb2
+from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
+from text_generation_server.models.globals import set_model_id, ATTENTION
+from text_generation_server.models.globals import set_adapter_to_index
+from text_generation_server.utils.adapter import AdapterInfo
+from text_generation_server.utils.tokens import make_tokenizer_optional
+from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
+
+try:
+ from text_generation_server.models.pali_gemma import PaliGemmaBatch
+ from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
+ from text_generation_server.models.vlm_causal_lm import (
+ VlmCausalLMBatch,
+ )
+ from text_generation_server.models.flash_vlm_causal_lm import (
+ FlashVlmCausalLMBatch,
+ )
+
+ VLM_BATCH_TYPES = {
+ PaliGemmaBatch,
+ VlmCausalLMBatch,
+ FlashVlmCausalLMBatch,
+ FlashMllamaCausalLMBatch,
+ }
+except (ImportError, NotImplementedError):
+ # These imports can fail on CPU/Non flash.
+ VLM_BATCH_TYPES = set()
+from text_generation_server.utils.version import (
+ is_driver_compatible,
+ MIN_TGI_GAUDI_SYNAPSE_VERSION,
+)
+
+
+class SignalHandler:
+ KEEP_PROCESSING = True
+
+ def __init__(self):
+ signal.signal(signal.SIGINT, self.exit_gracefully)
+ signal.signal(signal.SIGTERM, self.exit_gracefully)
+
+ def exit_gracefully(self, signum, frame):
+ print(f"Exiting gracefully: Signal {signum}")
+ self.KEEP_PROCESSING = False
+
+
+class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
+ def __init__(
+ self,
+ model: Model,
+ cache: Cache,
+ server_urls: List[str],
+ ):
+ self.cache = cache
+ self.model = model
+ # Quantize is resolved during model loading
+ self.quantize = model.quantize
+ self.server_urls = server_urls
+ # For some reason, inference_mode does not work well with GLOO which we use on CPU
+ # TODO: The inferecemode set messes up the autograd op dispatch. And results in aten::matmul
+ # op not optimized issue. Will investigate further.
+ # if model.device.type == "hpu":
+ # Force inference mode for the lifetime of TextGenerationService
+ # self._inference_mode_raii_guard = torch._C._InferenceMode(True)
+
+ async def Info(self, request, context):
+ return self.model.info
+
+ async def Health(self, request, context):
+ if self.model.device.type == "hpu":
+ torch.zeros((2, 2)).to("hpu")
+ return generate_pb2.HealthResponse()
+
+ async def ServiceDiscovery(self, request, context):
+ return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
+
+ async def ClearCache(self, request, context):
+ if request.HasField("id"):
+ self.cache.delete(request.id)
+ else:
+ self.cache.clear()
+ return generate_pb2.ClearCacheResponse()
+
+ async def FilterBatch(self, request, context):
+ batch = self.cache.pop(request.batch_id)
+ if batch is None:
+ raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
+ filtered_batch = batch.filter(request.request_ids)
+ self.cache.set(filtered_batch)
+
+ return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
+
+ async def Warmup(self, request, context):
+ if ATTENTION == "paged":
+ set_max_prefill_tokens(request.max_prefill_tokens)
+ if (
+ self.model.batch_type in VLM_BATCH_TYPES
+ ): # Hack, i would rather use kwargs in the `from_pb` call
+ batch = self.model.batch_type.from_pb_processor(
+ request.batch,
+ self.model.tokenizer,
+ self.model.processor,
+ self.model.model.config,
+ self.model.dtype,
+ self.model.device,
+ )
+ else:
+ batch = self.model.batch_type.from_pb(
+ request.batch,
+ self.model.tokenizer,
+ self.model.dtype,
+ self.model.device,
+ )
+
+ # Override default values with None for clearer semantics.
+ max_input_tokens = (
+ request.max_input_tokens
+ if request.HasField("max_input_tokens")
+ else None
+ )
+ max_total_tokens = (
+ request.max_total_tokens
+ if request.HasField("max_total_tokens")
+ else None
+ )
+ max_supported_total_tokens, max_input_tokens, max_total_tokens = (
+ self.model.warmup(batch, max_input_tokens, max_total_tokens)
+ )
+ else:
+ max_supported_total_tokens, max_input_tokens, max_total_tokens = (
+ self.model.warmup(request)
+ )
+
+ # W/A for the skip tokenizer path
+ # We need to call make_tokenizer_optional after the warmup,
+ # because router is not aware of that feature
+ make_tokenizer_optional(self.model.tokenizer)
+
+ return generate_pb2.WarmupResponse(
+ max_supported_total_tokens=max_supported_total_tokens,
+ max_input_tokens=max_input_tokens,
+ max_total_tokens=max_total_tokens,
+ )
+
+ async def Prefill(self, request, context):
+ start = time.time_ns()
+ if (
+ self.model.batch_type in VLM_BATCH_TYPES
+ ): # Hack, i would rather use kwargs in the `from_pb` call
+ batch = self.model.batch_type.from_pb_processor(
+ request.batch,
+ self.model.tokenizer,
+ self.model.processor,
+ self.model.model.config,
+ self.model.dtype,
+ self.model.device,
+ )
+ else:
+ batch = self.model.batch_type.from_pb(
+ request.batch, self.model.tokenizer, self.model.dtype, self.model.device
+ )
+
+ generations, next_batch, timings = self.model.generate_token([batch])
+ self.cache.set(next_batch)
+
+ return generate_pb2.PrefillResponse(
+ generations=[generation.to_pb() for generation in generations],
+ batch=next_batch.to_pb() if next_batch else None,
+ forward_ns=timings[0],
+ decode_ns=timings[1],
+ total_ns=time.time_ns() - start,
+ )
+
+ async def Decode(self, request, context):
+ start = time.time_ns()
+ if len(request.batches) == 0:
+ raise ValueError("Must provide at least one batch")
+
+ batches = []
+ for batch_pb in request.batches:
+ batch = self.cache.pop(batch_pb.id)
+ if batch is None:
+ raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
+ batches.append(batch)
+
+ if len(batches) == 0:
+ raise ValueError("All batches are empty")
+
+ generations, next_batch, timings = self.model.generate_token(batches)
+ self.cache.set(next_batch)
+
+ return generate_pb2.DecodeResponse(
+ generations=[generation.to_pb() for generation in generations],
+ batch=next_batch.to_pb() if next_batch else None,
+ concat_ns=None,
+ forward_ns=timings[0],
+ decode_ns=timings[1],
+ total_ns=time.time_ns() - start,
+ )
+
+
+def serve(
+ model_id: str,
+ lora_adapters: Optional[List[AdapterInfo]],
+ revision: Optional[str],
+ sharded: bool,
+ quantize: Optional[str],
+ speculate: Optional[int],
+ dtype: Optional[str],
+ trust_remote_code: bool,
+ uds_path: Path,
+ max_input_tokens: int,
+):
+ async def serve_inner(
+ model_id: str,
+ lora_adapters: Optional[List[AdapterInfo]],
+ revision: Optional[str],
+ sharded: bool = False,
+ quantize: Optional[str] = None,
+ speculate: Optional[int] = None,
+ dtype: Optional[str] = None,
+ trust_remote_code: bool = False,
+ ):
+ if not is_driver_compatible():
+ logger.warning(
+ f"Current Synapse version is lower than the minimum version supported: {MIN_TGI_GAUDI_SYNAPSE_VERSION}, this could result in failures"
+ )
+
+ unix_socket_template = "unix://{}-{}"
+ adapter_to_index = {}
+ logger.info("Server:server_inner: sharded ={}".format(sharded))
+
+ if sharded:
+ rank = int(os.environ["RANK"])
+ logger.info("Server:server_inner: rank ={}".format(rank))
+ server_urls = [
+ unix_socket_template.format(uds_path, rank)
+ for rank in range(int(os.environ["WORLD_SIZE"]))
+ ]
+ local_url = server_urls[int(os.environ["RANK"])]
+ else:
+ local_url = unix_socket_template.format(uds_path, 0)
+ server_urls = [local_url]
+
+ logger.info(
+ "Server:server_inner: data type = {}, local_url = {}".format(
+ dtype, local_url
+ )
+ )
+ if dtype == "bfloat16" or None:
+ data_type = torch.bfloat16
+ else:
+ data_type = torch.float
+ if revision == "None":
+ revision = None
+ try:
+ model = get_model_with_lora_adapters(
+ model_id,
+ lora_adapters,
+ revision,
+ sharded,
+ quantize,
+ speculate,
+ data_type,
+ trust_remote_code,
+ max_input_tokens,
+ adapter_to_index,
+ )
+
+ except Exception:
+ logger.exception("Error when initializing model")
+ raise
+
+ set_adapter_to_index(adapter_to_index)
+ server = aio.server(
+ interceptors=[
+ ExceptionInterceptor(),
+ UDSOpenTelemetryAioServerInterceptor(),
+ ],
+ options=[
+ # Set the maximum possible message length: i32::MAX
+ ("grpc.max_receive_message_length", (1 << 31) - 1)
+ ],
+ )
+ generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
+ TextGenerationService(model, Cache(), server_urls), server
+ )
+ SERVICE_NAMES = (
+ generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
+ reflection.SERVICE_NAME,
+ )
+ reflection.enable_server_reflection(SERVICE_NAMES, server)
+ server.add_insecure_port(local_url)
+
+ await server.start()
+
+ logger.info("Server started at {}".format(local_url))
+ signal_handler = SignalHandler()
+ while signal_handler.KEEP_PROCESSING:
+ await asyncio.sleep(0.5)
+
+ set_model_id(model_id)
+ asyncio.run(
+ serve_inner(
+ model_id,
+ lora_adapters,
+ revision,
+ sharded,
+ quantize,
+ speculate,
+ dtype,
+ trust_remote_code,
+ )
+ )
diff --git a/backends/gaudi/server/text_generation_server/tgi_service.py b/backends/gaudi/server/text_generation_server/tgi_service.py
new file mode 100644
index 000000000..18e88a7eb
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/tgi_service.py
@@ -0,0 +1,49 @@
+import os
+from pathlib import Path
+from loguru import logger
+from text_generation_server import server
+import argparse
+from text_generation_server.utils.adapter import parse_lora_adapters
+
+
+def main(args):
+ logger.info("TGIService: starting tgi service .... ")
+ logger.info(
+ "TGIService: --model_id {}, --revision {}, --sharded {}, --speculate {}, --dtype {}, --trust_remote_code {}, --uds_path {} ".format(
+ args.model_id,
+ args.revision,
+ args.sharded,
+ args.speculate,
+ args.dtype,
+ args.trust_remote_code,
+ args.uds_path,
+ )
+ )
+ lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS"))
+ server.serve(
+ model_id=args.model_id,
+ lora_adapters=lora_adapters,
+ revision=args.revision,
+ sharded=args.sharded,
+ quantize=args.quantize,
+ speculate=args.speculate,
+ dtype=args.dtype,
+ trust_remote_code=args.trust_remote_code,
+ uds_path=args.uds_path,
+ max_input_tokens=args.max_input_tokens,
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_id", type=str)
+ parser.add_argument("--revision", type=str)
+ parser.add_argument("--sharded", type=bool)
+ parser.add_argument("--speculate", type=int, default=None)
+ parser.add_argument("--dtype", type=str)
+ parser.add_argument("--trust_remote_code", type=bool)
+ parser.add_argument("--uds_path", type=Path)
+ parser.add_argument("--quantize", type=str)
+ parser.add_argument("--max_input_tokens", type=int)
+ args = parser.parse_args()
+ main(args)
diff --git a/backends/gaudi/server/text_generation_server/tracing.py b/backends/gaudi/server/text_generation_server/tracing.py
new file mode 100644
index 000000000..bc7a04ee7
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/tracing.py
@@ -0,0 +1,63 @@
+import grpc
+
+from opentelemetry import trace
+from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
+from opentelemetry.instrumentation.grpc._aio_server import (
+ OpenTelemetryAioServerInterceptor,
+)
+from opentelemetry.semconv.trace import SpanAttributes
+from opentelemetry.sdk.resources import Resource
+from opentelemetry.sdk.trace import TracerProvider
+from opentelemetry.sdk.trace.export import (
+ BatchSpanProcessor,
+)
+
+
+class UDSOpenTelemetryAioServerInterceptor(OpenTelemetryAioServerInterceptor):
+ def __init__(self):
+ super().__init__(trace.get_tracer(__name__))
+
+ def _start_span(self, handler_call_details, context, set_status_on_exception=False):
+ """
+ Rewrite _start_span method to support Unix Domain Socket gRPC contexts
+ """
+
+ # standard attributes
+ attributes = {
+ SpanAttributes.RPC_SYSTEM: "grpc",
+ SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[0],
+ }
+
+ # if we have details about the call, split into service and method
+ if handler_call_details.method:
+ service, method = handler_call_details.method.lstrip("/").split("/", 1)
+ attributes.update(
+ {
+ SpanAttributes.RPC_METHOD: method,
+ SpanAttributes.RPC_SERVICE: service,
+ }
+ )
+
+ # add some attributes from the metadata
+ metadata = dict(context.invocation_metadata())
+ if "user-agent" in metadata:
+ attributes["rpc.user_agent"] = metadata["user-agent"]
+
+ # We use gRPC over a UNIX socket
+ attributes.update({SpanAttributes.NET_TRANSPORT: "unix"})
+
+ return self._tracer.start_as_current_span(
+ name=handler_call_details.method,
+ kind=trace.SpanKind.SERVER,
+ attributes=attributes,
+ set_status_on_exception=set_status_on_exception,
+ )
+
+
+def setup_tracing(otlp_service_name: str, otlp_endpoint: str):
+ resource = Resource.create(attributes={"service.name": otlp_service_name})
+ span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)
+ span_processor = BatchSpanProcessor(span_exporter)
+
+ trace.set_tracer_provider(TracerProvider(resource=resource))
+ trace.get_tracer_provider().add_span_processor(span_processor)
diff --git a/backends/gaudi/server/text_generation_server/utils/__init__.py b/backends/gaudi/server/text_generation_server/utils/__init__.py
new file mode 100644
index 000000000..cda3a4da1
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/__init__.py
@@ -0,0 +1,50 @@
+# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
+
+from text_generation_server.utils.convert import convert_file, convert_files
+from text_generation_server.utils.dist import initialize_torch_distributed
+from text_generation_server.utils.weights import Weights
+from text_generation_server.utils.peft import download_and_unload_peft
+from text_generation_server.utils.hub import (
+ weight_files,
+ weight_hub_files,
+ download_weights,
+ EntryNotFoundError,
+ LocalEntryNotFoundError,
+ RevisionNotFoundError,
+)
+from text_generation_server.utils.tokens import (
+ NextTokenChooser,
+ HeterogeneousNextTokenChooser,
+ StoppingCriteria,
+ StopSequenceCriteria,
+ FinishReason,
+ Sampling,
+ Greedy,
+ make_tokenizer_optional,
+ is_tokenizer_transparent,
+ pad_next_token_chooser_parameters,
+)
+
+__all__ = [
+ "convert_file",
+ "convert_files",
+ "initialize_torch_distributed",
+ "weight_files",
+ "weight_hub_files",
+ "download_weights",
+ "download_and_unload_peft",
+ "EntryNotFoundError",
+ "HeterogeneousNextTokenChooser",
+ "LocalEntryNotFoundError",
+ "RevisionNotFoundError",
+ "Greedy",
+ "NextTokenChooser",
+ "Sampling",
+ "StoppingCriteria",
+ "StopSequenceCriteria",
+ "FinishReason",
+ "Weights",
+ "make_tokenizer_optional",
+ "is_tokenizer_transparent",
+ "pad_next_token_chooser_parameters",
+]
diff --git a/backends/gaudi/server/text_generation_server/utils/adapter.py b/backends/gaudi/server/text_generation_server/utils/adapter.py
new file mode 100644
index 000000000..2b61f9bb4
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/adapter.py
@@ -0,0 +1,320 @@
+# Origin: https://github.com/predibase/lorax
+# Path: lorax/server/lorax_server/utils/adapter.py
+# License: Apache License Version 2.0, January 2004
+
+import warnings
+import re
+from dataclasses import dataclass
+from functools import lru_cache
+from typing import TYPE_CHECKING, Set, Tuple, Optional, List
+
+from safetensors.torch import load_file
+from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
+
+from text_generation_server.utils.merges.strategies import merge_adapters
+
+from text_generation_server.utils import hub
+from text_generation_server.adapters.lora import LoraConfig
+
+
+if TYPE_CHECKING:
+ from text_generation_server.adapters.config import AdapterConfig, ModuleMap
+
+
+BASE_MODEL_ADAPTER_ID = "__base_model__"
+
+
+@dataclass
+class AdapterInfo:
+ id: str
+ path: Optional[str]
+ revision: Optional[str] = None
+
+
+@dataclass
+class AdapterParameters:
+ adapter_info: Tuple[AdapterInfo]
+ weights: Tuple[float]
+ merge_strategy: NotImplemented
+ density: float
+ majority_sign_method: NotImplemented
+
+
+@dataclass
+class AdapterSource:
+ adapter_id: str
+ model_id: str
+ revision: str
+
+
+def parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]:
+ if not lora_adapters:
+ return []
+
+ adapter_list = []
+ for adapter in lora_adapters.split(","):
+ adapter = adapter.strip()
+ if adapter.count("=") > 1 or adapter.count("@") > 1:
+ raise ValueError(f"Invalid LoRA adapter format: {adapter}")
+ match = re.match(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$", adapter)
+
+ if match:
+ adapter_id, path, revision = match.groups()
+ adapter_list.append(
+ AdapterInfo(id=adapter_id, path=path, revision=revision)
+ )
+ else:
+ raise ValueError(f"Invalid LoRA adapter format: {adapter}")
+ return adapter_list
+
+
+def load_and_merge_adapters(
+ model_id: str,
+ adapter_parameters: AdapterParameters,
+ adapter_index: int,
+ weight_names: Tuple[str],
+ trust_remote_code: bool = False,
+) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
+ if len(adapter_parameters.adapter_info) == 1:
+ adapter = next(iter(adapter_parameters.adapter_info))
+ return load_module_map(
+ model_id,
+ adapter.revision,
+ adapter.id,
+ adapter.path,
+ weight_names,
+ trust_remote_code,
+ )
+
+ adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)
+ return _load_and_merge(
+ model_id,
+ adapter_params,
+ weight_names,
+ trust_remote_code,
+ )
+
+
+@dataclass
+class AdapterParametersContainer:
+ adapter_parameters: AdapterParameters
+ adapter_index: int
+
+ def __hash__(self) -> int:
+ return self.adapter_index
+
+
+@lru_cache(maxsize=32)
+def _load_and_merge(
+ model_id: str,
+ adapter_params: AdapterParametersContainer,
+ weight_names: Tuple[str],
+ trust_remote_code: bool = False,
+) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
+ params = adapter_params.adapter_parameters
+
+ adapters_to_merge = []
+ merged_weight_names = set()
+ tokenizer = None
+ for adapter in params.adapter_info:
+ if adapter.id == BASE_MODEL_ADAPTER_ID:
+ raise ValueError("Base model adapter cannot be merged.")
+
+ module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
+ load_module_map(
+ model_id,
+ adapter.revision,
+ adapter.id,
+ adapter.path,
+ weight_names,
+ trust_remote_code,
+ )
+ )
+
+ adapters_to_merge.append((module_map, adapter_config))
+ merged_weight_names = merged_weight_names.union(adapter_weight_names)
+ if tokenizer is None:
+ tokenizer = adapter_tokenizer
+
+ if len(adapters_to_merge) == 0:
+ raise ValueError("No adapters to merge.")
+
+ module_map, adapter_config = merge_adapters(adapters_to_merge, params)
+ return module_map, adapter_config, merged_weight_names, tokenizer
+
+
+def check_architectures(
+ model_id: str,
+ adapter_id: str,
+ adapter_config: "AdapterConfig",
+ trust_remote_code: bool = False,
+):
+ try:
+ if not adapter_config.base_model_name_or_path:
+ # Avoid execution latency caused by the network connection retrying for AutoConfig.from_pretrained(None)
+ return
+
+ expected_config = AutoConfig.from_pretrained(
+ model_id, trust_remote_code=trust_remote_code
+ )
+ model_config = AutoConfig.from_pretrained(
+ adapter_config.base_model_name_or_path, trust_remote_code=trust_remote_code
+ )
+ except Exception as e:
+ warnings.warn(
+ f"Unable to check architecture compatibility for adapter '{adapter_id}' "
+ f"against model '{model_id}'. Assuming they are compatible. Error: {e}"
+ )
+ return
+
+ if model_config.architectures == expected_config.architectures:
+ warnings.warn(
+ f"Adapter '{adapter_id}' was not trained on base model '{model_id}'. "
+ f"If you encounter issues, use --model-id '{adapter_config.base_model_name_or_path}' instead."
+ )
+ else:
+ # TODO(travis): revisit this when we support clasification heads which will not use CausalLM
+ raise ValueError(
+ f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. "
+ f"Architectures differ: {model_config.architectures} != {expected_config.architectures}. "
+ f"Use --model-id '{adapter_config.base_model_name_or_path}' instead."
+ )
+
+
+@lru_cache(maxsize=128)
+def load_module_map(
+ model_id: str,
+ revision: str,
+ adapter_id: str,
+ adapter_path: Optional[str],
+ weight_names: Tuple[str],
+ trust_remote_code: bool = False,
+) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
+ adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
+
+ if not adapter_path and adapter_config.base_model_name_or_path != model_id:
+ check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
+
+ adapter_filenames = (
+ hub._weight_files_from_dir(adapter_path, extension=".safetensors")
+ if adapter_path
+ else hub._cached_weight_files(
+ adapter_id, revision=revision, extension=".safetensors"
+ )
+ )
+
+ # throw an error if no adapter weights are found
+ if not adapter_filenames:
+ raise FileNotFoundError(
+ f"No adapter weights found for adapter '{adapter_id}' and revision '{revision}'."
+ )
+
+ try:
+ adapter_tokenizer = AutoTokenizer.from_pretrained(
+ adapter_config.config_path,
+ trust_remote_code=trust_remote_code,
+ )
+ except Exception:
+ # Adapter does not have a tokenizer, so fallback to base model tokenizer
+ adapter_tokenizer = None
+
+ # load adapter weights from all shards (should have relatively small memory footprint)
+ adapter_weights = {}
+ for filename in adapter_filenames:
+ adapter_weights.update(load_file(filename))
+
+ # map the model weights to the relevant adapter weights (LoRA A and B matrices)
+ module_map, adapter_weight_names = adapter_config.map_weights_for_model(
+ adapter_weights, weight_names
+ )
+ return module_map, adapter_config, adapter_weight_names, adapter_tokenizer
+
+
+def get_attn_weights(i, layer):
+ qkv = layer.self_attn.query_key_value
+ weights = {}
+
+ for k in ["q", "k", "v"]:
+ key = (i, f"{k}_proj")
+ value = (f"model.layers.{i}.self_attn.{k}_proj", qkv)
+ weights[key] = value
+
+ # also add the qkv_proj weight for the adapter
+ weights[(i, "qkv_proj")] = (
+ f"model.layers.{i}.self_attn.qkv_proj",
+ qkv,
+ )
+
+ weights[(i, "o_proj")] = (
+ f"model.layers.{i}.self_attn.o_proj",
+ layer.self_attn.o_proj,
+ )
+
+ return weights
+
+
+def get_mlp_weights(i, layer):
+ weights = {}
+ if hasattr(layer, "mlp"):
+ mlp = layer.mlp
+ if hasattr(mlp, "gate_up_proj"):
+ # handle combined gate_up_proj (e.g., for some LLaMA variants)
+ weights.update(
+ {
+ (i, "gate_proj"): (
+ f"model.layers.{i}.mlp.gate_proj",
+ mlp.gate_up_proj,
+ ),
+ (i, "up_proj"): (f"model.layers.{i}.mlp.up_proj", mlp.gate_up_proj),
+ }
+ )
+ else:
+ # handle separate gate_proj, up_proj, and down_proj (e.g., for Gemma)
+ if hasattr(mlp, "gate_proj"):
+ weights[(i, "gate_proj")] = (
+ f"model.layers.{i}.mlp.gate_proj",
+ mlp.gate_proj,
+ )
+ if hasattr(mlp, "up_proj"):
+ weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj)
+
+ if hasattr(mlp, "down_proj"):
+ weights[(i, "down_proj")] = (
+ f"model.layers.{i}.mlp.down_proj",
+ mlp.down_proj,
+ )
+
+ return weights
+
+
+# build_layer_weight_lookup creates a mapping of model layers to their corresponding
+# weight tensors and paths. It builds a dictionary that maps layer identifiers to tuples
+# containing the weight tensor path and the actual layer object. This mapping is needed
+# for the lora adapter to know which weights to update when applying the adapter.
+def build_layer_weight_lookup(model):
+ if hasattr(model, "language_model"):
+ m = model.language_model.model
+ elif hasattr(model, "text_model"):
+ m = model.text_model.model
+ else:
+ m = model.model
+
+ layer_weights = {}
+
+ for i, layer in enumerate(m.layers):
+ attn_weights = get_attn_weights(i, layer)
+ mlp_weights = get_mlp_weights(i, layer)
+
+ layer_weights.update(attn_weights)
+ layer_weights.update(mlp_weights)
+
+ lm_head = None
+ if hasattr(m, "lm_head"):
+ lm_head = m.lm_head
+ elif hasattr(model, "lm_head"):
+ lm_head = model.lm_head
+
+ if lm_head:
+ layer_weights[(0, "lm_head")] = ("lm_head", lm_head)
+
+ return layer_weights
diff --git a/backends/gaudi/server/text_generation_server/utils/chunks.py b/backends/gaudi/server/text_generation_server/utils/chunks.py
new file mode 100644
index 000000000..73962ea39
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/chunks.py
@@ -0,0 +1,27 @@
+from typing import Iterable
+
+from loguru import logger
+
+from text_generation_server.pb import generate_pb2
+
+
+def concat_text_chunks(chunks: Iterable[generate_pb2.InputChunk]) -> str:
+ """
+ Concatenate text in text chunks. Non-text chunks are dropped.
+ """
+ text = None
+ for chunk in chunks:
+ chunk_type = chunk.WhichOneof("chunk")
+ if chunk_type == "text":
+ if text is None:
+ text = chunk.text
+ else:
+ raise NotImplementedError("Request contained more than one text chunk")
+ else:
+ # We cannot reject this, e.g. warmup sends an image chunk.
+ logger.debug(f"Encountered non-text chunk type {chunk_type}")
+
+ if text is None:
+ raise NotImplementedError("Request without a text chunk")
+
+ return text
diff --git a/backends/gaudi/server/text_generation_server/utils/convert.py b/backends/gaudi/server/text_generation_server/utils/convert.py
new file mode 100644
index 000000000..d9c3276bc
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/convert.py
@@ -0,0 +1,114 @@
+import datetime
+import torch
+import os
+
+from loguru import logger
+from pathlib import Path
+from safetensors.torch import save_file, load_file, _find_shared_tensors, _is_complete
+from typing import List, Dict
+from collections import defaultdict
+
+
+def _remove_duplicate_names(
+ state_dict: Dict[str, torch.Tensor],
+ *,
+ preferred_names: List[str] = None,
+ discard_names: List[str] = None,
+) -> Dict[str, List[str]]:
+ if preferred_names is None:
+ preferred_names = []
+ preferred_names = set(preferred_names)
+ if discard_names is None:
+ discard_names = []
+ discard_names = set(discard_names)
+
+ shareds = _find_shared_tensors(state_dict)
+ to_remove = defaultdict(list)
+ for shared in shareds:
+ complete_names = set(
+ [name for name in shared if _is_complete(state_dict[name])]
+ )
+ if not complete_names:
+ if len(shared) == 1:
+ # Force contiguous
+ name = list(shared)[0]
+ state_dict[name] = state_dict[name].clone()
+ complete_names = {name}
+ else:
+ raise RuntimeError(
+ f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue."
+ )
+
+ keep_name = sorted(list(complete_names))[0]
+
+ # Mecanism to preferentially select keys to keep
+ # coming from the on-disk file to allow
+ # loading models saved with a different choice
+ # of keep_name
+ preferred = complete_names.difference(discard_names)
+ if preferred:
+ keep_name = sorted(list(preferred))[0]
+
+ if preferred_names:
+ preferred = preferred_names.intersection(complete_names)
+ if preferred:
+ keep_name = sorted(list(preferred))[0]
+ for name in sorted(shared):
+ if name != keep_name:
+ to_remove[keep_name].append(name)
+ return to_remove
+
+
+def convert_file(pt_file: Path, sf_file: Path, discard_names: List[str]):
+ """
+ Convert a pytorch file to a safetensors file
+ This will remove duplicate tensors from the file.
+
+ Unfortunately, this might not respect *transformers* convention.
+ Forcing us to check for potentially different keys during load when looking
+ for specific tensors (making tensor sharing explicit).
+ """
+ loaded = torch.load(pt_file, map_location="cpu", weights_only=True)
+ if "state_dict" in loaded:
+ loaded = loaded["state_dict"]
+ to_removes = _remove_duplicate_names(loaded, discard_names=discard_names)
+
+ metadata = {"format": "pt"}
+ for kept_name, to_remove_group in to_removes.items():
+ for to_remove in to_remove_group:
+ if to_remove not in metadata:
+ metadata[to_remove] = kept_name
+ del loaded[to_remove]
+ # Force tensors to be contiguous
+ loaded = {k: v.contiguous() for k, v in loaded.items()}
+
+ dirname = os.path.dirname(sf_file)
+ os.makedirs(dirname, exist_ok=True)
+ save_file(loaded, sf_file, metadata=metadata)
+ reloaded = load_file(sf_file)
+ for k in loaded:
+ pt_tensor = loaded[k]
+ sf_tensor = reloaded[k]
+ if not torch.equal(pt_tensor, sf_tensor):
+ raise RuntimeError(f"The output tensors do not match for key {k}")
+
+
+def convert_files(pt_files: List[Path], sf_files: List[Path], discard_names: List[str]):
+ assert len(pt_files) == len(sf_files)
+
+ N = len(pt_files)
+ # We do this instead of using tqdm because we want to parse the logs with the launcher
+
+ for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)):
+ # Skip blacklisted files
+ if (
+ "arguments" in pt_file.name
+ or "args" in pt_file.name
+ or "training" in pt_file.name
+ ):
+ continue
+
+ start = datetime.datetime.now()
+ convert_file(pt_file, sf_file, discard_names)
+ elapsed = datetime.datetime.now() - start
+ logger.info(f"Convert: [{i + 1}/{N}] -- Took: {elapsed}")
diff --git a/backends/gaudi/server/text_generation_server/utils/debug.py b/backends/gaudi/server/text_generation_server/utils/debug.py
new file mode 100644
index 000000000..8bbcad6a3
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/debug.py
@@ -0,0 +1,35 @@
+# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
+
+import os
+import glob
+import time
+
+from optimum.habana.utils import to_gb_rounded
+import habana_frameworks.torch as htorch
+
+START_TS = None
+DBG_TRACE_FILENAME = os.environ.get("DBG_TRACE_FILENAME")
+if "GRAPH_VISUALIZATION" in os.environ:
+ for f in glob.glob(".graph_dumps/*"):
+ os.remove(f)
+
+
+def count_hpu_graphs():
+ return len(glob.glob(".graph_dumps/*PreGraph*"))
+
+
+def dbg_trace(tag, txt):
+ global START_TS
+ if DBG_TRACE_FILENAME is not None and int(os.getenv("RANK", 0)) == 0:
+ if START_TS is None:
+ START_TS = time.perf_counter()
+ time_offset = time.perf_counter() - START_TS
+ mem_stats = htorch.hpu.memory.memory_stats()
+ mem_used = to_gb_rounded(mem_stats["InUse"])
+ max_mem_used = to_gb_rounded(mem_stats["MaxInUse"])
+ print(
+ f"ts:{time_offset:.3f}s g:{count_hpu_graphs()} mu:{mem_used:.1f}GB "
+ f"mmu:{max_mem_used:.1f}GB | {tag} | {txt}",
+ flush=True,
+ file=open(DBG_TRACE_FILENAME, "a"),
+ )
diff --git a/backends/gaudi/server/text_generation_server/utils/dist.py b/backends/gaudi/server/text_generation_server/utils/dist.py
new file mode 100644
index 000000000..1c45713e8
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/dist.py
@@ -0,0 +1,66 @@
+import os
+import torch
+from torch.distributed import ProcessGroup
+from datetime import timedelta
+from loguru import logger
+
+# Tensor Parallelism settings
+RANK = int(os.getenv("RANK", "0"))
+WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
+MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.8"))
+
+
+class FakeBarrier:
+ def wait(self):
+ pass
+
+
+class FakeGroup(ProcessGroup):
+ def __init__(self, rank, size):
+ self._rank = rank
+ self._size = size
+ super().__init__(rank, size)
+
+ def allreduce(self, *args, **kwargs):
+ return FakeBarrier()
+
+ def allgather(self, inputs, local_tensor, **kwargs):
+ assert (
+ len(inputs[0]) == len(local_tensor) == 1
+ ), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors"
+ for input_ in inputs:
+ input_[0].data = local_tensor[0].data
+ return FakeBarrier()
+
+ def barrier(self, *args, **kwargs):
+ return FakeBarrier()
+
+ def size(self):
+ return self._size
+
+ def rank(self):
+ return self._rank
+
+ def _get_backend_name(self):
+ return "fake"
+
+
+def initialize_torch_distributed():
+ if WORLD_SIZE == 1:
+ return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
+ else:
+ if os.getenv("DEBUG", None) == "1":
+ return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
+
+ if not torch.distributed.is_initialized():
+ # Call the init process.
+ torch.distributed.init_process_group(
+ backend="hccl",
+ world_size=WORLD_SIZE,
+ rank=RANK,
+ timeout=timedelta(seconds=120),
+ )
+ else:
+ logger.warning("torch.distributed is already initialized.")
+
+ return torch.distributed.group.WORLD, RANK, WORLD_SIZE
diff --git a/backends/gaudi/server/text_generation_server/utils/hub.py b/backends/gaudi/server/text_generation_server/utils/hub.py
new file mode 100644
index 000000000..f9c476ac3
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/hub.py
@@ -0,0 +1,234 @@
+import time
+import os
+
+from datetime import timedelta
+from loguru import logger
+from pathlib import Path
+from typing import Optional, List
+
+from huggingface_hub import file_download, hf_api, HfApi, hf_hub_download
+from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
+from huggingface_hub.utils import (
+ LocalEntryNotFoundError,
+ EntryNotFoundError,
+ RevisionNotFoundError, # noqa # Import here to ease try/except in other part of the lib
+)
+
+WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
+HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
+
+
+def _cached_weight_files(
+ model_id: str, revision: Optional[str], extension: str
+) -> List[str]:
+ """Guess weight files from the cached revision snapshot directory"""
+ d = _get_cached_revision_directory(model_id, revision)
+ if not d:
+ return []
+ filenames = _weight_files_from_dir(d, extension)
+ return filenames
+
+
+def _weight_hub_files_from_model_info(
+ info: hf_api.ModelInfo, extension: str
+) -> List[str]:
+ return [
+ s.rfilename
+ for s in info.siblings
+ if s.rfilename.endswith(extension)
+ and len(s.rfilename.split("/")) == 1
+ and "arguments" not in s.rfilename
+ and "args" not in s.rfilename
+ and "training" not in s.rfilename
+ ]
+
+
+def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
+ # os.walk: do not iterate, just scan for depth 1, not recursively
+ # see _weight_hub_files_from_model_info, that's also what is
+ # done there with the len(s.rfilename.split("/")) == 1 condition
+ root, _, files = next(os.walk(str(d)))
+ filenames = [
+ os.path.join(root, f)
+ for f in files
+ if f.endswith(extension)
+ and "arguments" not in f
+ and "args" not in f
+ and "training" not in f
+ ]
+ return filenames
+
+
+def _get_cached_revision_directory(
+ model_id: str, revision: Optional[str]
+) -> Optional[Path]:
+ if revision is None:
+ revision = "main"
+
+ repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path(
+ file_download.repo_folder_name(repo_id=model_id, repo_type="model")
+ )
+
+ if not repo_cache.is_dir():
+ # No cache for this model
+ return None
+
+ refs_dir = repo_cache / "refs"
+ snapshots_dir = repo_cache / "snapshots"
+
+ # Resolve refs (for instance to convert main to the associated commit sha)
+ if refs_dir.is_dir():
+ revision_file = refs_dir / revision
+ if revision_file.exists():
+ with revision_file.open() as f:
+ revision = f.read()
+
+ # Check if revision folder exists
+ if not snapshots_dir.exists():
+ return None
+ cached_shas = os.listdir(snapshots_dir)
+ if revision not in cached_shas:
+ # No cache for this revision and we won't try to return a random revision
+ return None
+
+ return snapshots_dir / revision
+
+
+def weight_hub_files(
+ model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
+) -> List[str]:
+ """Get the weights filenames on the hub"""
+ api = HfApi()
+
+ if HF_HUB_OFFLINE:
+ filenames = _cached_weight_files(model_id, revision, extension)
+ else:
+ # Online case, fetch model info from the Hub
+ info = api.model_info(model_id, revision=revision)
+ filenames = _weight_hub_files_from_model_info(info, extension)
+
+ if not filenames:
+ raise EntryNotFoundError(
+ f"No {extension} weights found for model {model_id} and revision {revision}.",
+ None,
+ )
+
+ return filenames
+
+
+def try_to_load_from_cache(
+ model_id: str, revision: Optional[str], filename: str
+) -> Optional[Path]:
+ """Try to load a file from the Hugging Face cache"""
+
+ d = _get_cached_revision_directory(model_id, revision)
+ if not d:
+ return None
+
+ # Check if file exists in cache
+ cached_file = d / filename
+ return cached_file if cached_file.is_file() else None
+
+
+def weight_files(
+ model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
+) -> List[Path]:
+ """Get the local files"""
+ # Local model
+ d = Path(model_id)
+ if d.exists() and d.is_dir():
+ local_files = _weight_files_from_dir(d, extension)
+ if not local_files:
+ raise FileNotFoundError(
+ f"No local weights found in {model_id} with extension {extension}"
+ )
+ return [Path(f) for f in local_files]
+
+ try:
+ filenames = weight_hub_files(model_id, revision, extension)
+ except EntryNotFoundError as e:
+ if extension != ".safetensors":
+ raise e
+ # Try to see if there are pytorch weights
+ pt_filenames = weight_hub_files(model_id, revision, extension=".bin")
+ # Change pytorch extension to safetensors extension
+ # It is possible that we have safetensors weights locally even though they are not on the
+ # hub if we converted weights locally without pushing them
+ filenames = [
+ f"{Path(f).stem.lstrip('pytorch_')}.safetensors" for f in pt_filenames
+ ]
+
+ if WEIGHTS_CACHE_OVERRIDE is not None:
+ files = []
+ for filename in filenames:
+ p = Path(WEIGHTS_CACHE_OVERRIDE) / filename
+ if not p.exists():
+ raise FileNotFoundError(
+ f"File {p} not found in {WEIGHTS_CACHE_OVERRIDE}."
+ )
+ files.append(p)
+ return files
+
+ files = []
+ for filename in filenames:
+ cache_file = try_to_load_from_cache(
+ model_id, revision=revision, filename=filename
+ )
+ if cache_file is None:
+ raise LocalEntryNotFoundError(
+ f"File {filename} of model {model_id} not found in "
+ f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
+ f"Please run `text-generation-server download-weights {model_id}` first."
+ )
+ files.append(cache_file)
+
+ return files
+
+
+def download_weights(
+ filenames: List[str], model_id: str, revision: Optional[str] = None
+) -> List[Path]:
+ """Download the safetensors files from the hub"""
+
+ def download_file(fname, tries=5, backoff: int = 5):
+ local_file = try_to_load_from_cache(model_id, revision, fname)
+ if local_file is not None:
+ logger.info(f"File {fname} already present in cache.")
+ return Path(local_file)
+
+ for idx in range(tries):
+ try:
+ logger.info(f"Download file: {fname}")
+ stime = time.time()
+ local_file = hf_hub_download(
+ filename=fname,
+ repo_id=model_id,
+ revision=revision,
+ local_files_only=HF_HUB_OFFLINE,
+ )
+ logger.info(
+ f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - stime))}."
+ )
+ return Path(local_file)
+ except Exception as e:
+ if idx + 1 == tries:
+ raise e
+ logger.error(e)
+ logger.info(f"Retrying in {backoff} seconds")
+ time.sleep(backoff)
+ logger.info(f"Retry {idx + 1}/{tries - 1}")
+
+ # We do this instead of using tqdm because we want to parse the logs with the launcher
+ start_time = time.time()
+ files = []
+ for i, filename in enumerate(filenames):
+ file = download_file(filename)
+
+ elapsed = timedelta(seconds=int(time.time() - start_time))
+ remaining = len(filenames) - (i + 1)
+ eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0
+
+ logger.info(f"Download: [{i + 1}/{len(filenames)}] -- ETA: {eta}")
+ files.append(file)
+
+ return files
diff --git a/backends/gaudi/server/text_generation_server/utils/import_utils.py b/backends/gaudi/server/text_generation_server/utils/import_utils.py
new file mode 100644
index 000000000..22560dd7a
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/import_utils.py
@@ -0,0 +1,28 @@
+import torch
+from loguru import logger
+
+
+def get_hpu_free_memory(device, memory_fraction):
+ from habana_frameworks.torch.hpu import memory_stats
+
+ device_id = device.index
+ mem_stats = memory_stats(device_id)
+ logger.info(f"mem_stats: {mem_stats}")
+ total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"]
+ free_memory = max(
+ 0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"])
+ )
+ return free_memory
+
+
+def synchronize_hpu(device):
+ torch.hpu.synchronize()
+
+
+def noop(*args, **kwargs):
+ pass
+
+
+empty_cache = noop
+synchronize = synchronize_hpu
+get_free_memory = get_hpu_free_memory
diff --git a/backends/gaudi/server/text_generation_server/utils/kernels.py b/backends/gaudi/server/text_generation_server/utils/kernels.py
new file mode 100644
index 000000000..42745c716
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/kernels.py
@@ -0,0 +1,22 @@
+import importlib
+
+from loguru import logger
+from hf_kernels import load_kernel as hf_load_kernel
+
+from text_generation_server.utils.log import log_once
+
+
+def load_kernel(*, module: str, repo_id: str):
+ """
+ Load a kernel. First try to load it as the given module (e.g. for
+ local development), falling back to a locked Hub kernel.
+ """
+ try:
+ m = importlib.import_module(module)
+ log_once(logger.info, f"Using local module for `{module}`")
+ return m
+ except ModuleNotFoundError:
+ return hf_load_kernel(repo_id=repo_id)
+
+
+__all__ = ["load_kernel"]
diff --git a/backends/gaudi/server/text_generation_server/utils/log.py b/backends/gaudi/server/text_generation_server/utils/log.py
new file mode 100644
index 000000000..4385c71ee
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/log.py
@@ -0,0 +1,15 @@
+from functools import lru_cache
+from text_generation_server.utils.dist import RANK
+
+
+@lru_cache(10)
+def log_once(log, msg: str, master=True):
+ if master:
+ log_master(log, msg)
+ else:
+ log(msg)
+
+
+def log_master(log, msg: str):
+ if RANK == 0:
+ log(msg)
diff --git a/backends/gaudi/server/text_generation_server/utils/logits_process.py b/backends/gaudi/server/text_generation_server/utils/logits_process.py
new file mode 100644
index 000000000..c0fd6cbae
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/logits_process.py
@@ -0,0 +1,610 @@
+import math
+import torch
+import habana_frameworks.torch.core as htcore
+
+from loguru import logger
+from typing import Dict
+from text_generation_server.pb.generate_pb2 import GrammarType
+
+from outlines.fsm.fsm import RegexFSM
+from outlines.fsm.json_schema import build_regex_from_schema
+from functools import lru_cache
+from typing import List, Optional, DefaultDict
+import time
+
+from transformers import (
+ LogitsProcessor,
+ TemperatureLogitsWarper,
+ TopKLogitsWarper,
+ TopPLogitsWarper,
+ TypicalLogitsWarper,
+)
+
+mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
+
+
+class StaticWarper:
+ def __init__(
+ self,
+ temperature=1.0,
+ top_k=None,
+ top_p=None,
+ typical_p=None,
+ ):
+ self.warpers = []
+
+ if temperature is not None and temperature != 1.0:
+ temperature = float(temperature)
+ self.warpers.append(TemperatureLogitsWarper(temperature))
+ if top_k is not None and top_k != 0:
+ self.warpers.append(TopKLogitsWarper(top_k=top_k))
+ if top_p is not None and top_p < 1.0:
+ self.warpers.append(TopPLogitsWarper(top_p=top_p))
+ if typical_p is not None and typical_p < 1.0:
+ self.warpers.append(TypicalLogitsWarper(mass=typical_p))
+
+ self.hpu_graph = None
+ self.static_scores = None
+ self.static_warped_scores = None
+ self.static_next_logprob = None
+
+ def __call__(self, scores):
+ if self.hpu_graph is None:
+ self.static_scores = scores.clone().contiguous()
+ self.static_warped_scores = scores.clone().contiguous()
+ self.static_next_logprob = scores.clone().contiguous()
+ self.hpu_graph = htcore.hpu.HPUGraph()
+
+ with htcore.hpu.graph(self.hpu_graph):
+ local_scores = self.static_scores
+ for warper in self.warpers:
+ local_scores = warper(None, local_scores)
+
+ self.static_warped_scores.copy_(local_scores)
+ # Compute logprobs
+ self.static_next_logprob.copy_(
+ torch.log_softmax(self.static_warped_scores, -1)
+ )
+
+ self.static_scores.copy_(scores)
+ self.hpu_graph.replay()
+
+ return self.static_warped_scores, self.static_next_logprob
+
+
+@lru_cache(10)
+def static_warper(
+ temperature: Optional[float],
+ top_k: Optional[int],
+ top_p: Optional[float],
+ typical_p: Optional[float],
+) -> StaticWarper:
+ return StaticWarper(
+ temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
+ )
+
+
+class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
+ This version allows for a separate value for each sample and runs inplace when possible.
+ It doesn't validate inputs.
+
+ Args:
+ repetition_penalty (`List[float]`):
+ The parameter for repetition penalty. 1.0 means no penalty. See [this
+ paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
+ """
+
+ def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
+ self.penalty = penalty
+ self.penalty_tensor = torch.tensor(
+ penalty, dtype=dtype, device=device
+ ).unsqueeze(1)
+
+ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
+ score = torch.gather(scores, 1, input_ids)
+
+ # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
+ score = torch.where(
+ score < 0, score * self.penalty_tensor, score / self.penalty_tensor
+ )
+
+ scores.scatter_(1, input_ids, score)
+ return scores
+
+ def filter(self, indices):
+ self.penalty = [self.penalty[i] for i in indices]
+ if any([x != 1.0 for x in self.penalty]):
+ self.penalty_tensor = self.penalty_tensor[indices]
+ return self
+ return None
+
+
+class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
+ r"""
+ Frequency penalty as defined by OpenAI
+
+ Args:
+ penalty (`float`):
+ The parameter for frequency penalty. 0.0 means no penalty.
+ """
+
+ def __init__(self, penalty: float):
+ self.penalty = penalty
+
+ def __call__(
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
+ ) -> torch.FloatTensor:
+ score = torch.gather(scores, 1, input_ids)
+ # if score < 0 then penalty has to be multiplied to reduce the previous token probability
+ score = -torch.where(score < 0, score * self.penalty, score / self.penalty)
+ # set score to 0 where input_ids is a padding token
+ score *= input_ids.ne(0)
+
+ return scores.scatter_add_(1, input_ids, score)
+
+
+class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
+ r"""
+ Frequency penalty as defined by OpenAI in
+ https://platform.openai.com/docs/guides/text-generation/parameter-details
+
+ Args:
+ frequency_penalty (`List[float]`):
+ The parameter for frequency penalty. 0.0 means no penalty.
+ """
+
+ def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
+ self.penalty = penalty
+ self.penalty_tensor = torch.tensor(
+ penalty, dtype=dtype, device=device
+ ).unsqueeze(1)
+
+ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
+ batch_size, input_size = input_ids.size()
+ vocab_size = scores.size(1)
+
+ # Calculate the frequency for each token so far
+ token_freq = torch.zeros(
+ batch_size, vocab_size, dtype=scores.dtype, device=scores.device
+ )
+ token_freq.scatter_add_(
+ 1,
+ input_ids,
+ torch.ones_like(input_ids, dtype=scores.dtype, device=scores.device),
+ )
+ token_freq /= input_size
+
+ # Apply the frequency penalty to logits
+ scores -= token_freq * self.penalty_tensor
+ return scores
+
+ def filter(self, indices):
+ self.penalty = [self.penalty[i] for i in indices]
+ if any([x != 0.0 for x in self.penalty]):
+ self.penalty_tensor = self.penalty_tensor[indices]
+ return self
+ return None
+
+
+class HeterogeneousTemperatureLogitsWarper:
+ r"""
+ [`LogitsProcessor`] for temperature (exponential scaling output probability distribution).
+ This version allows for a separate value for each sample and runs inplace when possible.
+ It doesn't validate inputs.
+
+ Args:
+ temperature (`float`):
+ The value used to module the logits distribution.
+ """
+
+ def __init__(
+ self, temperature: List[float], dtype: torch.dtype, device: torch.device
+ ):
+ self.temperature = temperature
+ self.temperature_tensor = torch.tensor(
+ temperature, dtype=dtype, device=device
+ ).unsqueeze(1)
+
+ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
+ scores.div_(self.temperature_tensor)
+ return scores
+
+ def filter(self, indices):
+ self.temperature = [self.temperature[i] for i in indices]
+ if any([x != 1.0 for x in self.temperature]):
+ self.temperature_tensor = self.temperature_tensor[indices]
+ return self
+ return None
+
+
+class HeterogeneousTopPLogitsWarper(LogitsProcessor):
+ """
+ [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
+ This version allows for a separate value for each sample and runs inplace when possible.
+ It doesn't validate inputs.
+
+ Args:
+ top_p (`float`):
+ If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
+ higher are kept for generation.
+ filter_value (`float`, *optional*, defaults to `-float("Inf")`):
+ All filtered values will be set to this float value.
+ min_tokens_to_keep (`int`, *optional*, defaults to 1):
+ Minimum number of tokens that cannot be filtered.
+ """
+
+ def __init__(
+ self,
+ top_p: List[float],
+ dtype: torch.dtype,
+ device: torch.device,
+ filter_value: float = -math.inf,
+ min_tokens_to_keep: int = 1,
+ ):
+ self.top_p = top_p
+ self.top_p_opposite = 1 - torch.tensor(
+ top_p, dtype=dtype, device=device
+ ).unsqueeze(1)
+ self.filter_value = filter_value
+ self.min_tokens_to_keep = min_tokens_to_keep
+
+ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
+ sorted_logits, sorted_indices = torch.sort(scores, descending=False)
+ probs = sorted_logits.softmax(dim=-1)
+ # This is way faster for some reason
+ for i in range(probs.shape[0]):
+ probs[i] = probs[i].cumsum(dim=-1)
+
+ # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
+ sorted_indices_to_remove = probs <= self.top_p_opposite
+ # Keep at least min_tokens_to_keep
+ sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
+
+ # scatter sorted tensors to original indexing
+ indices_to_remove = sorted_indices_to_remove.scatter(
+ 1, sorted_indices, sorted_indices_to_remove
+ )
+ warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
+
+ return warped_scores
+
+ def filter(self, indices):
+ self.top_p = [self.top_p[i] for i in indices]
+ if any([x < 1.0 for x in self.top_p]):
+ self.top_p_opposite = self.top_p_opposite[indices]
+ return self
+ return None
+
+
+class HeterogeneousTopKLogitsWarper(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] that performs top-k, i.e. restricting to the k highest probability elements.
+ This version allows for a separate value for each sample and runs inplace when possible.
+ It doesn't validate inputs.
+
+ Args:
+ top_k (`int`):
+ The number of highest probability vocabulary tokens to keep for top-k-filtering.
+ filter_value (`float`, *optional*, defaults to `-float("Inf")`):
+ All filtered values will be set to this float value.
+ min_tokens_to_keep (`int`, *optional*, defaults to 1):
+ Minimum number of tokens that cannot be filtered.
+ """
+
+ def __init__(
+ self,
+ top_k: List[int],
+ device: torch.device,
+ filter_value: float = -math.inf,
+ min_tokens_to_keep: int = 1,
+ ):
+ self.top_k = top_k
+ self.max_top_k = max(top_k)
+ # value - 1 as we will use top_k to index and python uses 0 based numbering
+ self.top_k_tensor = torch.tensor(
+ [max(x - 1, min_tokens_to_keep - 1) for x in top_k],
+ dtype=torch.int64,
+ device=device,
+ ).unsqueeze(1)
+
+ # 0 is a special value that disables top_k warping for this member of the batch
+ disabled = [x == 0 for x in top_k]
+
+ if any(disabled):
+ self.top_k_disabled_mask = torch.tensor(
+ disabled, dtype=torch.bool, device=device
+ ).view(-1, 1)
+ else:
+ self.top_k_disabled_mask = None
+
+ self.filter_value = filter_value
+
+ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
+ # If max_top_k is superior to the vocab, we need to clamp or the warper will fail
+ if scores.size(-1) < self.max_top_k:
+ max_top_k = scores.size(-1)
+ top_k = torch.clamp_max(self.top_k_tensor, max_top_k)
+ else:
+ max_top_k = self.max_top_k
+ top_k = self.top_k_tensor
+
+ # Get the kth score for each member of the batch
+ kth_scores = torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k)
+
+ # Mask member of kth_scores that do not want to use top_k warping
+ if self.top_k_disabled_mask is not None:
+ kth_scores.masked_fill_(self.top_k_disabled_mask, self.filter_value)
+
+ # Remove all tokens with a probability less than the last token of the top-k
+ indices_to_remove = scores < kth_scores
+ scores.masked_fill_(indices_to_remove, self.filter_value)
+ return scores
+
+ def filter(self, indices):
+ self.top_k = [self.top_k[i] for i in indices]
+ disabled = [x == 0 for x in self.top_k]
+
+ if not all(disabled):
+ self.top_k_tensor = self.top_k_tensor[indices]
+ self.max_top_k = max(self.top_k)
+
+ if self.top_k_disabled_mask is not None:
+ self.top_k_disabled_mask = (
+ self.top_k_disabled_mask[indices] if any(disabled) else None
+ )
+
+ return self
+ return None
+
+
+class HeterogeneousTypicalLogitsWarper(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] that performs typical decoding. See [Typical Decoding for Natural Language
+ Generation](https://arxiv.org/abs/2202.00666) for more information.
+ This version allows for a separate value for each sample and runs inplace when possible.
+ It doesn't validate inputs.
+
+ Args:
+ mass (`float`):
+ Value of typical_p between 0 and 1 inclusive, defaults to 0.9.
+ filter_value (`float`, *optional*, defaults to `-float("Inf")`):
+ All filtered values will be set to this float value.
+ min_tokens_to_keep (`int`, *optional*, defaults to 1):
+ Minimum number of tokens that cannot be filtered.
+ """
+
+ def __init__(
+ self,
+ mass: List[float],
+ dtype: torch.dtype,
+ device: torch.device,
+ filter_value: float = -math.inf,
+ min_tokens_to_keep: int = 1,
+ ):
+ self.mass = mass
+ self.mass_tensor = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1)
+
+ # 1 is a special value that disables typical_p warping for this member of the batch
+ disabled = [x == 1.0 for x in mass]
+
+ if any(disabled):
+ self.disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device)
+ else:
+ self.disabled_mask = None
+
+ self.filter_value = filter_value
+ self.min_tokens_to_keep = min_tokens_to_keep
+
+ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
+ # calculate entropy
+ normalized = torch.nn.functional.log_softmax(scores, dim=-1)
+ p = torch.exp(normalized)
+ ent = -(normalized * p).nansum(-1, keepdim=True)
+
+ # shift and sort
+ shifted_scores = torch.abs((-normalized) - ent)
+ sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
+ sorted_logits = scores.gather(-1, sorted_indices)
+ probs = sorted_logits.softmax(dim=-1)
+ # This is way faster for some reason
+ for i in range(probs.shape[0]):
+ probs[i] = probs[i].cumsum(dim=-1)
+
+ # Remove tokens with cumulative mass above the threshold
+ last_ind = (probs < self.mass_tensor).sum(dim=1)
+ last_ind[last_ind < 0] = 0
+
+ if self.disabled_mask is not None:
+ last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1)
+
+ sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
+ 1, last_ind.view(-1, 1)
+ )
+ if self.min_tokens_to_keep > 1:
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
+ sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
+ indices_to_remove = sorted_indices_to_remove.scatter(
+ 1, sorted_indices, sorted_indices_to_remove
+ )
+
+ warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
+
+ return warped_scores
+
+ def filter(self, indices):
+ self.mass = [self.mass[i] for i in indices]
+ disabled = [x == 1.0 for x in self.mass]
+
+ if not all(disabled):
+ self.mass_tensor = self.mass_tensor[indices]
+
+ if self.disabled_mask is not None:
+ self.disabled_mask = (
+ self.disabled_mask[indices] if any(disabled) else None
+ )
+
+ return self
+ return None
+
+
+class HeterogeneousProcessorWrapper(LogitsProcessor):
+ r"""
+ A wrapper for logit warpers or processors without heterogeneous parameter support.
+ Args:
+ processors (`Dict[int, LogitsProcessor]`):
+ A mapping of sample indices to logit warpers or processors, to be run sequentially.
+ """
+
+ def __init__(
+ self,
+ processors: Dict[int, LogitsProcessor],
+ ):
+ self.processors = processors
+
+ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
+ for i, processor in self.processors.items():
+ scores[i : i + 1] = processor(input_ids[i : i + 1], scores[i : i + 1])
+ return scores
+
+ def filter(self, indices):
+ new_processors = {}
+ for i, idx in enumerate(indices):
+ if idx in self.processors:
+ new_processors[i] = self.processors[idx]
+
+ if new_processors:
+ self.processors = new_processors
+ return self
+ return None
+
+
+class GrammarLogitProcessor(LogitsProcessor):
+ fsm_state: DefaultDict[int, int]
+ fsm: RegexFSM
+
+ def __init__(self, tokenizer, device, grammar, grammar_type):
+ self.device = device
+ self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
+ self.fsm = GrammarLogitProcessor._cached_compile_fsm(
+ grammar_type, grammar, self.tokenizer
+ )
+
+ def __call__(
+ self,
+ logits: torch.Tensor,
+ fsm_grammar_state: int,
+ ):
+ if fsm_grammar_state == -1 or self.fsm is None:
+ return logits
+ allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
+ mask = torch.full_like(logits, -math.inf)
+ mask[:, allowed_tokens] = 0
+ biased_scores = logits + mask
+ return biased_scores
+
+ def advance(self, next_token_id, fsm_grammar_state):
+ return GrammarLogitProcessor._advance(
+ next_token_id, fsm_grammar_state, self.fsm
+ )
+
+ @staticmethod
+ def _advance(next_token_id, fsm_grammar_state, fsm):
+ if fsm_grammar_state == -1:
+ return fsm_grammar_state
+ return fsm.next_state(fsm_grammar_state, next_token_id)
+
+ # TODO: move grammar compilation into the router
+ @staticmethod
+ @lru_cache(maxsize=32, typed=True)
+ def _cached_compile_fsm(grammar_type, schema, tokenizer):
+ start_time = time.time()
+ if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
+ schema = build_regex_from_schema(schema)
+ elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
+ pass # schema is already a regex just here for clarity
+ fsm = RegexFSM(schema, tokenizer)
+ logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
+ return fsm
+
+ @staticmethod
+ @lru_cache(maxsize=32, typed=True)
+ def _cached_adapt_tokenizer(tokenizer):
+ """Adapt tokenizer to work with the FSM.
+
+ The API of Outlines tokenizers is slightly different to that of
+ `transformers`. In addition we need to handle the missing spaces to
+ Llama's tokenizer to be able to compile FSMs for this model.
+
+ """
+ start_time = time.time()
+ tokenizer.vocabulary = tokenizer.get_vocab()
+ tokenizer.special_tokens = set(tokenizer.all_special_tokens)
+
+ def convert_token_to_string(token: str) -> str:
+ from transformers.file_utils import SPIECE_UNDERLINE
+
+ string = tokenizer.convert_tokens_to_string([token])
+
+ # A hack to handle missing spaces to HF's Llama tokenizers
+ if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
+ return " " + string
+
+ return string
+
+ tokenizer.convert_token_to_string = convert_token_to_string
+ logger.debug(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
+ return tokenizer
+
+
+class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
+ def __init__(self, tokenizer, device, grammars, grammar_types):
+ self.device = device
+ self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
+ self.fsms = []
+ for grammar, grammar_type in zip(grammars, grammar_types):
+ if len(grammar) == 0:
+ self.fsms.append(None)
+ continue
+ fsm = GrammarLogitProcessor._cached_compile_fsm(
+ grammar_type, grammar, self.tokenizer
+ )
+ self.fsms.append(fsm)
+
+ def __call__(
+ self,
+ logits: torch.Tensor,
+ fsm_grammar_states: List[int],
+ ):
+ mask = torch.full_like(logits, -math.inf)
+ for i in range(logits.shape[0]):
+ fsm = self.fsms[i]
+ if fsm is None:
+ continue
+ allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
+ mask[i, allowed_tokens] = 0
+ logits[i] += mask[i]
+ return logits
+
+ def advance_batch(self, next_token_ids, fsm_grammar_states):
+ return [
+ GrammarLogitProcessor._advance(
+ next_token_ids[i], fsm_grammar_states[i], self.fsms[i]
+ )
+ for i in range(len(next_token_ids))
+ ]
+
+ def advance_at_index(self, next_token_id, fsm_grammar_state, index):
+ if self.fsms[index] is None:
+ return fsm_grammar_state
+ return GrammarLogitProcessor._advance(
+ next_token_id, fsm_grammar_state, self.fsms[index]
+ )
+
+ def filter(self, indices):
+ new_fsms = []
+ for i in indices:
+ new_fsms.append(self.fsms[i])
+ self.fsms = new_fsms
+ return self
diff --git a/backends/gaudi/server/text_generation_server/utils/merges/strategies.py b/backends/gaudi/server/text_generation_server/utils/merges/strategies.py
new file mode 100644
index 000000000..cb39cde1f
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/merges/strategies.py
@@ -0,0 +1,220 @@
+import copy
+from abc import ABC
+from collections import defaultdict
+from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union
+from text_generation_server.utils.merges.utils import (
+ calculate_majority_sign_mask,
+ disjoint_merge,
+ prune,
+)
+import torch
+
+if TYPE_CHECKING:
+ from text_generation_server.adapters.lora import LoraConfig
+ from text_generation_server.utils.adapter import ModuleMap
+
+
+class AdapterParameters:
+ def __init__(
+ self, adapter_ids, weights, merge_strategy, density, majority_sign_method
+ ):
+ self.adapter_ids = adapter_ids
+ self.weights = weights
+ self.merge_strategy = merge_strategy
+ self.density = density
+ self.majority_sign_method = majority_sign_method
+
+
+def _apply_weights(
+ tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor
+) -> torch.Tensor:
+ if isinstance(tensors, torch.Tensor):
+ t = tensors
+ else:
+ t = torch.stack(tensors, dim=0)
+
+ # element-wise weighting of each task tensor
+ # need to unsqueeze weights to match task tensor dimensions
+ # for multiplication to apply element-wise
+ while len(t.shape) > len(w.shape):
+ w = w.unsqueeze(-1)
+ return t * w
+
+
+class MergeStrategy(ABC):
+ def merge(
+ self, task_tensors: List[torch.Tensor], weights: torch.Tensor
+ ) -> torch.Tensor:
+ raise NotImplementedError()
+
+
+class LinearMerge(MergeStrategy):
+ def __init__(self, **kwargs):
+ pass
+
+ def merge(
+ self, task_tensors: List[torch.Tensor], weights: torch.Tensor
+ ) -> torch.Tensor:
+ weighted_task_tensors = _apply_weights(task_tensors, weights)
+ return weighted_task_tensors.sum(dim=0)
+
+
+class TiesMerge(MergeStrategy):
+ def __init__(self, density: float, majority_sign_method: str = "total", **kwargs):
+ self.density = density
+ self.majority_sign_method = majority_sign_method
+
+ def merge(
+ self, task_tensors: List[torch.Tensor], weights: torch.Tensor
+ ) -> torch.Tensor:
+ # sparsify
+ task_tensors = [
+ prune(tensor, self.density, method="magnitude") for tensor in task_tensors
+ ]
+ task_tensors = torch.stack(task_tensors, dim=0)
+
+ # elect sign before applying weights
+ majority_sign_mask = calculate_majority_sign_mask(
+ task_tensors, method=self.majority_sign_method
+ )
+ weighted_task_tensors = _apply_weights(task_tensors, weights)
+
+ # disjoint merge
+ return disjoint_merge(weighted_task_tensors, majority_sign_mask)
+
+
+class DareLinearMerge(MergeStrategy):
+ def __init__(self, density: float, **kwargs):
+ self.density = density
+
+ def merge(
+ self, task_tensors: List[torch.Tensor], weights: torch.Tensor
+ ) -> torch.Tensor:
+ # sparsify
+ task_tensors = [
+ prune(tensor, self.density, method="random", rescale=True)
+ for tensor in task_tensors
+ ]
+ weighted_task_tensors = _apply_weights(task_tensors, weights)
+ return weighted_task_tensors.sum(dim=0)
+
+
+class DareTiesMerge(MergeStrategy):
+ def __init__(self, density: float, majority_sign_method: str = "total", **kwargs):
+ self.density = density
+ self.majority_sign_method = majority_sign_method
+
+ def merge(
+ self, task_tensors: List[torch.Tensor], weights: torch.Tensor
+ ) -> torch.Tensor:
+ # sparsify
+ task_tensors = [
+ prune(tensor, self.density, method="random", rescale=True)
+ for tensor in task_tensors
+ ]
+ task_tensors = torch.stack(task_tensors, dim=0)
+
+ # elect sign before applying weights
+ majority_sign_mask = calculate_majority_sign_mask(
+ task_tensors, method=self.majority_sign_method
+ )
+ weighted_task_tensors = _apply_weights(task_tensors, weights)
+
+ # disjoint merge
+ mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask)
+ return mixed_task_tensors
+
+
+strategy_registry: Dict[str, Type[MergeStrategy]] = {
+ "linear": LinearMerge,
+ "ties": TiesMerge,
+ "dare_linear": DareLinearMerge,
+ "dare_ties": DareTiesMerge,
+}
+
+
+def merge_adapters(
+ adapters: List[Tuple["ModuleMap", "LoraConfig"]],
+ merge_params: AdapterParameters,
+) -> Tuple["ModuleMap", "LoraConfig"]:
+ # strategy_name = MergeStrategyEnum.Name(merge_params.merge_strategy).lower()
+ strategy_name = "linear"
+
+ weights = merge_params.weights
+ if not weights:
+ weights = torch.ones(len(adapters))
+ else:
+ weights = torch.tensor(weights)
+
+ merge_config = {
+ "density": merge_params.density,
+ # "majority_sign_method": MajoritySignMethodEnum.Name(
+ # merge_params.majority_sign_method
+ # ).lower(),
+ "majority_sign_method": "total",
+ }
+ merge_strategy = strategy_registry[strategy_name](**merge_config)
+
+ module_maps: Dict[str, Dict[str, Dict[str, List[torch.Tensor]]]] = defaultdict(
+ lambda: defaultdict(lambda: defaultdict(list))
+ )
+ lora_configs = []
+ weight_name_to_adapter_idx = defaultdict(list)
+
+ # input is list of (module_map, lora_config) tuples
+ # convert into dict[k][param_name] -> list of tensors
+ for idx, (module_map, lora_config) in enumerate(adapters):
+ for weight_name, data in module_map.items():
+ weight_name_to_adapter_idx[weight_name].append(idx)
+ for k, (param_data, param_name) in data.items():
+ module_maps[weight_name][k][param_name].append(param_data)
+ lora_configs.append(lora_config)
+
+ # validate lora configs are compatible
+ _validate_lora_configs(lora_configs)
+
+ # merge tensors for each module such that we have a single ModuleMap:
+ # dict[k] -> merged tensor
+ merged_module_map: "ModuleMap" = defaultdict(dict)
+ for weight_name, data in module_maps.items():
+ indices = weight_name_to_adapter_idx[weight_name]
+ param_weights = weights[indices]
+ for k, param_data in data.items():
+ for param_name, tensors in param_data.items():
+ merged_tensor = merge_strategy.merge(tensors, param_weights)
+ merged_module_map[weight_name][k] = (merged_tensor, param_name)
+
+ # merge lora configs
+ merged_lora_config = _merge_lora_configs(lora_configs)
+
+ return merged_module_map, merged_lora_config
+
+
+def _validate_lora_configs(lora_configs: List["LoraConfig"]):
+ # check that all configs have the same rank
+ ranks = set(lora_config.r for lora_config in lora_configs)
+ if len(ranks) > 1:
+ raise ValueError(
+ f"unable to merge adapters, lora configs have different ranks: {ranks}"
+ )
+
+ if all(len(lora_config.target_modules) == 0 for lora_config in lora_configs):
+ raise ValueError(
+ "unable to merge adapters, lora configs have no target modules"
+ )
+
+
+def _merge_lora_configs(lora_configs: List["LoraConfig"]) -> "LoraConfig":
+ merged_lora_config = copy.copy(lora_configs[0])
+
+ # merge target modules as a union operation
+ merged_target_modules = sorted(
+ set(
+ module
+ for lora_config in lora_configs
+ for module in lora_config.target_modules
+ )
+ )
+ merged_lora_config.target_modules = merged_target_modules
+
+ return merged_lora_config
diff --git a/backends/gaudi/server/text_generation_server/utils/merges/utils.py b/backends/gaudi/server/text_generation_server/utils/merges/utils.py
new file mode 100644
index 000000000..d9ad3278a
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/merges/utils.py
@@ -0,0 +1,108 @@
+# coding=utf-8
+# From: https://github.com/huggingface/peft/pull/1364
+# Copyright 2024-present the HuggingFace Inc. team.
+# Modifications by Predibase, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Literal
+
+import torch
+
+
+def magnitude_based_pruning(tensor: torch.Tensor, density: float) -> torch.Tensor:
+ """
+ Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction
+ `density`.
+
+ Args:
+ tensor (`torch.Tensor`):The tensor to prune.
+ density (`float`):The fraction of values to preserve. Should be in [0,1].
+ """
+ mask = torch.zeros_like(tensor).reshape(-1)
+ k = int(density * tensor.reshape(-1).shape[0])
+ top_k = torch.topk(tensor.abs().reshape(-1), k=k, largest=True)
+ mask[top_k[1]] = 1
+ return tensor * mask.reshape(tensor.shape)
+
+
+def random_pruning(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor:
+ """
+ Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction
+ `density`.
+
+ Args:
+ tensor (`torch.Tensor`):The tensor to prune.
+ density (`float`):The fraction of values to preserve. Should be in [0,1].
+ rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.
+ """
+ mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density))
+ pruned_tensor = tensor * mask
+ if rescale:
+ torch.div(input=pruned_tensor, other=density)
+ return pruned_tensor
+
+
+def prune(
+ tensor: torch.Tensor,
+ density: float,
+ method: Literal["magnitude", "random"],
+ rescale: bool = False,
+) -> torch.Tensor:
+ """
+ Prune the values of task tensors based on the `method`.
+
+ Args:
+ tensor (`torch.Tensor`):The tensor to prune.
+ density (`float`):The fraction of values to preserve. Should be in [0,1].
+ method (`str`):The method to use to prune. Should be one of ["magnitude", "random"].
+ rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.
+ """
+ if density >= 1:
+ return tensor
+ elif density < 0:
+ raise ValueError("Density should be >= 0, got {density}")
+ if method == "magnitude":
+ return magnitude_based_pruning(tensor, density)
+ elif method == "random":
+ return random_pruning(tensor, density, rescale=rescale)
+ else:
+ raise ValueError(f"Unknown method {method}")
+
+
+def calculate_majority_sign_mask(
+ tensor: torch.Tensor, method: Literal["total", "frequency"] = "total"
+):
+ """
+ Get the mask of the majority sign across the task tensors. Task tensors are stacked on dimension 0.
+
+ Args:
+ tensor (`torch.Tensor`):The tensor to get the mask from.
+ method (`str`):The method to use to get the mask. Should be one of ["total", "frequency"].
+ """
+
+ sign = tensor.sign()
+ if method == "total":
+ sign_magnitude = (sign * tensor.abs()).sum(dim=0)
+ elif method == "frequency":
+ sign_magnitude = sign.sum(dim=0)
+ else:
+ raise RuntimeError(f'Unimplemented mask method "{method}"')
+ majority_sign = torch.where(sign_magnitude >= 0, 1, -1)
+ return sign == majority_sign
+
+
+def disjoint_merge(task_tensors, majority_sign_mask):
+ mixed_task_tensors = (task_tensors * majority_sign_mask).sum(dim=0)
+ num_params_preserved = majority_sign_mask.sum(dim=0)
+ return mixed_task_tensors / torch.clamp(num_params_preserved, min=1.0)
diff --git a/backends/gaudi/server/text_generation_server/utils/peft.py b/backends/gaudi/server/text_generation_server/utils/peft.py
new file mode 100644
index 000000000..d49e73f00
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/peft.py
@@ -0,0 +1,68 @@
+import os
+from typing import Union
+from loguru import logger
+import torch
+
+from transformers import AutoTokenizer
+from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
+
+
+def download_and_unload_peft(model_id, revision, trust_remote_code):
+ torch_dtype = torch.float16
+
+ logger.info("Trying to load a Peft model. It might take a while without feedback")
+ try:
+ model = AutoPeftModelForCausalLM.from_pretrained(
+ model_id,
+ revision=revision,
+ torch_dtype=torch_dtype,
+ trust_remote_code=trust_remote_code,
+ low_cpu_mem_usage=True,
+ )
+ except Exception:
+ model = AutoPeftModelForSeq2SeqLM.from_pretrained(
+ model_id,
+ revision=revision,
+ torch_dtype=torch_dtype,
+ trust_remote_code=trust_remote_code,
+ low_cpu_mem_usage=True,
+ )
+ logger.info("Peft model detected.")
+ logger.info("Merging the lora weights.")
+
+ base_model_id = model.peft_config["default"].base_model_name_or_path
+
+ model = model.merge_and_unload()
+
+ os.makedirs(model_id, exist_ok=True)
+ cache_dir = model_id
+ logger.info(f"Saving the newly created merged model to {cache_dir}")
+ tokenizer = AutoTokenizer.from_pretrained(
+ base_model_id, trust_remote_code=trust_remote_code
+ )
+ model.save_pretrained(cache_dir, safe_serialization=True)
+ model.config.save_pretrained(cache_dir)
+ tokenizer.save_pretrained(cache_dir)
+
+
+def download_peft(
+ model_id: Union[str, os.PathLike], revision: str, trust_remote_code: bool
+):
+ torch_dtype = torch.float16
+ try:
+ _model = AutoPeftModelForCausalLM.from_pretrained(
+ model_id,
+ revision=revision,
+ torch_dtype=torch_dtype,
+ trust_remote_code=trust_remote_code,
+ low_cpu_mem_usage=True,
+ )
+ except Exception:
+ _model = AutoPeftModelForSeq2SeqLM.from_pretrained(
+ model_id,
+ revision=revision,
+ torch_dtype=torch_dtype,
+ trust_remote_code=trust_remote_code,
+ low_cpu_mem_usage=True,
+ )
+ logger.info("Peft model downloaded.")
diff --git a/backends/gaudi/server/text_generation_server/utils/prefill_chunking.py b/backends/gaudi/server/text_generation_server/utils/prefill_chunking.py
new file mode 100644
index 000000000..c227d30f5
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/prefill_chunking.py
@@ -0,0 +1,24 @@
+from typing import Optional
+
+SUPPORT_CHUNKING: Optional[bool] = None
+MAX_PREFILL_TOKENS: Optional[int] = None
+
+
+def set_support_chunking(support_chunking: bool):
+ global SUPPORT_CHUNKING
+ SUPPORT_CHUNKING = support_chunking
+
+
+def get_support_chunking() -> bool:
+ global SUPPORT_CHUNKING
+ return SUPPORT_CHUNKING
+
+
+def set_max_prefill_tokens(max_prefill_tokens: int):
+ global MAX_PREFILL_TOKENS
+ MAX_PREFILL_TOKENS = max_prefill_tokens
+
+
+def get_max_prefill_tokens() -> int:
+ global MAX_PREFILL_TOKENS
+ return MAX_PREFILL_TOKENS
diff --git a/backends/gaudi/server/text_generation_server/utils/quantization.py b/backends/gaudi/server/text_generation_server/utils/quantization.py
new file mode 100644
index 000000000..a8faf4a59
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/quantization.py
@@ -0,0 +1,149 @@
+import json
+import os
+from dataclasses import dataclass
+from typing import Optional
+
+from huggingface_hub import hf_hub_download
+from text_generation_server.utils.weights import (
+ WeightsLoader,
+)
+
+
+# TODO: Split this config to have a single config type per quant method
+@dataclass
+class _QuantizerConfig:
+ bits: int
+ checkpoint_format: Optional[str]
+ desc_act: bool
+ groupsize: int
+ quant_method: str
+ sym: bool
+
+
+@dataclass
+class _FP8QuantizerConfig:
+ activation_scale_ub: float
+
+
+# We should probably do this with Pytantic JSON deserialization,
+# but for now we'll stay close to the old _set_gptq_params.
+def _get_quantizer_config(model_id, revision):
+ bits = 4
+ groupsize = -1
+ quant_method = "gptq"
+ checkpoint_format = None
+ sym = False
+ desc_act = False
+
+ filename = "config.json"
+ try:
+ if os.path.exists(os.path.join(model_id, filename)):
+ filename = os.path.join(model_id, filename)
+ else:
+ filename = hf_hub_download(model_id, filename=filename, revision=revision)
+ with open(filename, "r") as f:
+ data = json.load(f)
+
+ # FP8 config
+ if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
+ return _FP8QuantizerConfig(
+ activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
+ )
+
+ if "zero_point" in data["quantization_config"]:
+ sym = not data["quantization_config"]["zero_point"]
+ quant_method = "awq"
+ elif "sym" in data["quantization_config"]:
+ sym = data["quantization_config"]["sym"]
+
+ bits = data["quantization_config"]["bits"]
+ groupsize = data["quantization_config"]["group_size"]
+ # Order is important here, desc_act is missing on some real models
+ quant_method = data["quantization_config"]["quant_method"]
+ checkpoint_format = data["quantization_config"].get("checkpoint_format")
+ desc_act = data["quantization_config"]["desc_act"]
+ except Exception:
+ filename = "quantize_config.json"
+ try:
+ if os.path.exists(os.path.join(model_id, filename)):
+ filename = os.path.join(model_id, filename)
+ else:
+ filename = hf_hub_download(
+ model_id, filename=filename, revision=revision
+ )
+ with open(filename, "r") as f:
+ data = json.load(f)
+ bits = data["bits"]
+ groupsize = data["group_size"]
+
+ if "zero_point" in data:
+ sym = not data["zero_point"]
+ quant_method = "awq"
+ elif "sym" in data:
+ sym = data["sym"]
+
+ desc_act = data["desc_act"]
+ if "version" in data and data["version"] == "GEMM":
+ quant_method = "awq"
+ except Exception:
+ filename = "quant_config.json"
+ try:
+ if os.path.exists(os.path.join(model_id, filename)):
+ filename = os.path.join(model_id, filename)
+ else:
+ filename = hf_hub_download(
+ model_id, filename=filename, revision=revision
+ )
+ with open(filename, "r") as f:
+ data = json.load(f)
+ bits = data["w_bit"]
+ groupsize = data["q_group_size"]
+ desc_act = data["desc_act"]
+ if "version" in data and data["version"] == "GEMM":
+ quant_method = "awq"
+ except Exception:
+ pass
+
+ return _QuantizerConfig(
+ bits=bits,
+ groupsize=groupsize,
+ quant_method=quant_method,
+ checkpoint_format=checkpoint_format,
+ sym=sym,
+ desc_act=desc_act,
+ )
+
+
+def get_loader(
+ quantize: Optional[str], model_id: str, revision: Optional[str]
+) -> WeightsLoader:
+ quantizer_config = _get_quantizer_config(model_id, revision)
+ if quantize in {"awq", "gptq"}:
+ from text_generation_server.layers.gptq import GPTQWeightsLoader
+
+ # TODO: improve check once we have one config type per quantize value
+ if not isinstance(quantizer_config, _QuantizerConfig):
+ raise ValueError(
+ f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
+ )
+
+ return GPTQWeightsLoader(
+ bits=quantizer_config.bits,
+ desc_act=quantizer_config.desc_act,
+ groupsize=quantizer_config.groupsize,
+ quant_method=quantizer_config.quant_method,
+ quantize=quantize,
+ sym=quantizer_config.sym,
+ )
+ elif quantize == "fp8" or quantize is None:
+ from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
+
+ # Since the default for the quantize config is _QuantizerConfig,
+ # we need to add this check to not get an attribute error
+ activation_scale_ub = None
+ if isinstance(quantizer_config, _FP8QuantizerConfig):
+ activation_scale_ub = quantizer_config.activation_scale_ub
+
+ return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8")
+ else:
+ raise ValueError(f"Unknown quantization method: {quantize}")
diff --git a/backends/gaudi/server/text_generation_server/utils/segments.py b/backends/gaudi/server/text_generation_server/utils/segments.py
new file mode 100644
index 000000000..f59611021
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/segments.py
@@ -0,0 +1,66 @@
+# Origin: https://github.com/predibase/lorax
+# Path: lorax/server/lorax_server/utils/segments.py
+# License: Apache License Version 2.0, January 2004
+
+from typing import List, Tuple, Union
+
+import torch
+
+
+def find_segments(
+ adapter_indices: Union[torch.Tensor, List[int]]
+) -> Tuple[List[int], List[int]]:
+ segments = [0]
+ segment_indices = []
+
+ if isinstance(adapter_indices, torch.Tensor):
+ # Calling .item() repeatedly on CUDA tensor is very slow, so we move it to CPU first
+ adapter_indices = adapter_indices.cpu().tolist()
+
+ start_index = 0
+ for i in range(1, len(adapter_indices)):
+ if adapter_indices[i] != adapter_indices[i - 1]:
+ segments.append(i)
+ segment_indices.append(adapter_indices[i - 1])
+ start_index = i
+
+ # Handle the last segment
+ if start_index < len(adapter_indices):
+ segments.append(len(adapter_indices))
+ segment_indices.append(adapter_indices[-1])
+
+ return segments, segment_indices
+
+
+class SegmentConcatBuilder:
+ def __init__(self):
+ self.adapter_segment_indices = []
+ self.adapter_segment_tensors = []
+
+ def concat(self, adapter_segments: torch.Tensor, segment_indices: List[int]):
+ # Update adapter segments
+ if self.adapter_segment_tensors:
+ # Because we have already processed at least one batch, remove the 0 start index
+ # from this batch denoting the beginning of the segment, then offset all segment
+ # positions by the value of the last segment in the previous batch to account for
+ # the concatenation.
+ adapter_segments = (
+ adapter_segments[1:] + self.adapter_segment_tensors[-1][-1]
+ )
+
+ if (
+ self.adapter_segment_indices
+ and self.adapter_segment_indices[-1] == segment_indices[0]
+ ):
+ # If the last segment in the previous batch is the same as the first segment in this batch,
+ # then we merge them together into a single segment. In effect, this means removing it from
+ # the segment indices of this batch, and extending the segment span by removing the segment
+ # end index from the previous batch.
+ segment_indices = segment_indices[1:]
+ self.adapter_segment_tensors[-1] = self.adapter_segment_tensors[-1][:-1]
+
+ self.adapter_segment_indices.extend(segment_indices)
+ self.adapter_segment_tensors.append(adapter_segments)
+
+ def build(self) -> Tuple[torch.Tensor, List[int]]:
+ return torch.concat(self.adapter_segment_tensors), self.adapter_segment_indices
diff --git a/backends/gaudi/server/text_generation_server/utils/sgmv.py b/backends/gaudi/server/text_generation_server/utils/sgmv.py
new file mode 100644
index 000000000..2d0a73a54
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/sgmv.py
@@ -0,0 +1,252 @@
+# Origin: https://github.com/predibase/lorax
+# Path: lorax/server/lorax_server/utils/sgmv.py
+# License: Apache License Version 2.0, January 2004
+
+import os
+import warnings
+from functools import lru_cache
+from typing import List, Tuple
+
+import torch
+import torch.nn.functional as F
+
+try:
+ import punica_kernels as _kernels
+
+ HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", ""))
+except ImportError:
+ warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.")
+ _kernels = None
+ HAS_SGMV = False
+
+
+MIN_SGMV_RANK = 8
+MIN_RANK_CUSTOM = 16
+MAX_RANK_CUSTOM = 128
+SGMV_BLOCK_SIZE = 16
+BGMV_MAX_RANK = 64
+
+
+def has_sgmv() -> bool:
+ return HAS_SGMV
+
+
+def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:
+ """Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size."""
+ if not has_sgmv():
+ return t
+
+ # tensor parallelism will result in effective rank being divided by world_size,
+ # so we need to scale the min rank to offset that effect
+ min_rank = MIN_SGMV_RANK * world_size
+
+ # if we're at or below the min rank, pad up to the min rank
+ # otherwise, pad to the nearest multiple of the block size
+ current_rank = t.size(dim)
+ target_rank = (
+ min_rank
+ if current_rank <= min_rank
+ else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE
+ )
+ if current_rank == target_rank:
+ return t
+
+ pad_size = target_rank - current_rank
+
+ # see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
+ pad = [0, 0] * t.dim()
+ pad[(t.dim() - dim - 1) * 2 + 1] = pad_size
+ pad = tuple(pad)
+
+ return F.pad(t, pad, mode="constant", value=0.0)
+
+
+def use_cutlass_shrink(lora_rank: int) -> bool:
+ return lora_rank < MIN_RANK_CUSTOM
+
+
+def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor:
+ if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM:
+ return t.transpose(0, 1)
+ return t
+
+
+# Source: https://github.com/punica-ai/punica/blob/master/src/punica/ops/__init__.py
+def add_lora_sgmv_cutlass(
+ y: torch.Tensor,
+ x: torch.Tensor,
+ wa_ptr: torch.Tensor,
+ wb_ptr: torch.Tensor,
+ s_start: torch.Tensor,
+ s_end: torch.Tensor,
+ layer_idx: int,
+ lora_rank: int,
+):
+ """
+ Semantics:
+ y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])
+
+ Args:
+ y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
+ x: Shape: `[B, H1]`. Input vectors.
+ wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
+ Weight matrix shape: `[num_layers, R, H1]`.
+ wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
+ Weight matrix shape: `[num_layers, R, H2]`.
+ s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices.
+ s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices.
+ layer_idx: Layer index of the weight matrices.
+ """
+ if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM:
+ # Custom SGMV shrink only supports rank 16, 32, 64, 128
+ _add_lora_sgmv_cutlass_legacy(
+ y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank
+ )
+ return
+
+ tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
+ tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
+ tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
+ v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
+ _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
+ _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx)
+
+
+def _add_lora_sgmv_cutlass_legacy(
+ y: torch.Tensor,
+ x: torch.Tensor,
+ wa_ptr: torch.Tensor,
+ wb_ptr: torch.Tensor,
+ s_start: torch.IntTensor,
+ s_end: torch.IntTensor,
+ layer_idx: int,
+ lora_rank: int,
+):
+ tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
+ tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
+ v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
+ _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
+ _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
+
+
+@lru_cache(maxsize=1)
+def get_tmp_tensor(device: torch.device) -> torch.Tensor:
+ return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)
+
+
+@lru_cache(maxsize=32)
+def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:
+ tmp_size = _kernels.sgmv_cutlass_tmp_size(size)
+ return torch.empty((tmp_size,), dtype=torch.uint8, device=device)
+
+
+def get_tmp_tensor_for_size_no_kernels(size: int, device: torch.device) -> torch.Tensor:
+ return torch.empty((size,), dtype=torch.uint8, device=device)
+
+
+def get_tmp_expand_size(size: int) -> int:
+ return _kernels.sgmv_cutlass_tmp_size(size)
+
+
+def get_tmp_tensors(
+ nsegments: int, lora_rank: int, device: torch.device
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ use_cutlass = use_cutlass_shrink(lora_rank) and has_sgmv()
+ has_sgmv_available = has_sgmv()
+
+ if use_cutlass:
+ tmp = get_tmp_tensor_for_size(nsegments, device)
+ return tmp, tmp
+ elif has_sgmv_available:
+ return get_tmp_tensor(device), get_tmp_tensor_for_size(nsegments, device)
+ else:
+ tmp = get_tmp_tensor_for_size(nsegments, device)
+ return tmp, tmp
+
+
+def lora_a_sgmv_cutlass(
+ x: torch.Tensor,
+ tmp: torch.Tensor,
+ wa_ptr: torch.Tensor,
+ s_start: torch.IntTensor,
+ s_end: torch.IntTensor,
+ layer_idx: int,
+ lora_rank: int,
+) -> torch.Tensor:
+ v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
+ if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:
+ _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
+ else:
+ _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
+ return v
+
+
+def lora_b_sgmv_cutlass(
+ y: torch.Tensor,
+ v: torch.Tensor,
+ tmp: torch.Tensor,
+ wb_ptr: torch.Tensor,
+ s_start: torch.IntTensor,
+ s_end: torch.IntTensor,
+ layer_idx: int,
+):
+ _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
+
+
+"""
+Semantics:
+ y[i] += (
+ x[i].unsqueeze(0)
+ @ wa_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)
+ @ wb_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)
+ * scale
+ ).squeeze(0)
+
+Args:
+ y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
+ v: Shape: `[B, R]`. Temporary vector.
+ x: Shape: `[B, H1]`. Input vectors.
+ wa_T_all: Shape: `[None, L, R, H1]`. All of the transposed LoRA A matrices.
+ wb_T_all: Shape: `[None, L, H2, R]`. All of the transposed LoRA B matrices.
+ indicies: Shape: `[B]`. Indices of the LoRA weights.
+ layer_idx: Layer index of LoRA weights.
+ scale: Scaling factor.
+"""
+
+
+def add_lora_a_bgmv(
+ v: torch.Tensor,
+ x: torch.Tensor,
+ wa_T_all: torch.Tensor,
+ indicies: torch.LongTensor,
+ layer_idx: int,
+):
+ _kernels.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0)
+
+
+def add_lora_b_bgmv(
+ y: torch.Tensor,
+ v: torch.Tensor,
+ wb_T_all: torch.Tensor,
+ indicies: torch.LongTensor,
+ layer_idx: int,
+):
+ _kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0)
+
+
+def segmented_matmul(
+ y: torch.Tensor,
+ x: torch.Tensor,
+ w: List[torch.Tensor],
+ b: List[torch.Tensor],
+ s_start: torch.IntTensor,
+ s_end: torch.IntTensor,
+):
+ for i in range(len(w)):
+ if s_end[i] - s_start[i] <= 0:
+ continue
+
+ xi = x[s_start[i] : s_end[i]]
+ wi = w[i]
+ bi = b[i]
+ y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi)
diff --git a/backends/gaudi/server/text_generation_server/utils/speculate.py b/backends/gaudi/server/text_generation_server/utils/speculate.py
new file mode 100644
index 000000000..a1b37a344
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/speculate.py
@@ -0,0 +1,11 @@
+SPECULATE = None
+
+
+def get_speculate() -> int:
+ global SPECULATE
+ return SPECULATE
+
+
+def set_speculate(speculate: int):
+ global SPECULATE
+ SPECULATE = speculate
diff --git a/backends/gaudi/server/text_generation_server/utils/tokens.py b/backends/gaudi/server/text_generation_server/utils/tokens.py
new file mode 100644
index 000000000..9c44ba15c
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/tokens.py
@@ -0,0 +1,762 @@
+# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
+
+import re
+from typing import List, Optional, Tuple, Set, Union
+
+import torch
+from text_generation_server.pb import generate_pb2
+from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType
+from text_generation_server.utils.logits_process import (
+ FrequencyPenaltyLogitsProcessor,
+ GrammarLogitProcessor,
+ HeterogeneousProcessorWrapper,
+ HeterogeneousRepetitionPenaltyLogitsProcessor,
+ HeterogeneousFrequencyPenaltyLogitsProcessor,
+ HeterogeneousTemperatureLogitsWarper,
+ HeterogeneousTopKLogitsWarper,
+ HeterogeneousTopPLogitsWarper,
+ HeterogeneousTypicalLogitsWarper,
+ HeterogeneousGrammarLogitProcessor,
+ static_warper,
+)
+from text_generation_server.utils.watermark import WatermarkLogitsProcessor
+from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
+import os
+
+
+class NextTokenChooser:
+ def __init__(
+ self,
+ watermark: bool = False,
+ temperature: float = 1.0,
+ repetition_penalty: float = 1.0,
+ frequency_penalty: float = 0.0,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ typical_p: Optional[float] = None,
+ do_sample: bool = False,
+ seed: int = 0,
+ device: str = "cpu",
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
+ grammar: str = "",
+ grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
+ fsm_grammar_state: int = 0,
+ ):
+ self.watermark_processor = (
+ WatermarkLogitsProcessor(device=device) if watermark else None
+ )
+ self.repetition_processor = (
+ RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
+ if repetition_penalty and repetition_penalty != 1.0
+ else None
+ )
+ self.frequency_processor = (
+ FrequencyPenaltyLogitsProcessor(penalty=frequency_penalty)
+ if frequency_penalty and frequency_penalty != 0.0
+ else None
+ )
+ self.grammar_processor = (
+ GrammarLogitProcessor(tokenizer, device, grammar, grammar_type)
+ if grammar != ""
+ else None
+ )
+ self.tokenizer = tokenizer
+
+ has_warpers = (
+ (temperature is not None and temperature != 1.0)
+ or (top_k is not None and top_k != 0)
+ or (top_p is not None and top_p < 1.0)
+ or (typical_p is not None and typical_p < 1.0)
+ )
+ if has_warpers:
+ self.static_warper = static_warper(
+ temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
+ )
+ else:
+ self.static_warper = None
+
+ sampling = do_sample or has_warpers
+
+ self.choice = Sampling(seed, device) if sampling else Greedy()
+ self.fsm_grammar_state = fsm_grammar_state
+ self.grammar = grammar
+
+ def __call__(self, input_ids, scores):
+ if self.watermark_processor is not None:
+ scores = self.watermark_processor(input_ids, scores)
+ if self.repetition_processor is not None:
+ scores = self.repetition_processor(input_ids, scores)
+ if self.frequency_processor is not None:
+ scores = self.frequency_processor(input_ids, scores)
+ if self.grammar_processor is not None:
+ scores = self.grammar_processor(scores, self.fsm_grammar_state)
+
+ if self.static_warper is None:
+ next_logprob = torch.log_softmax(scores, -1)
+ else:
+ scores, next_logprob = self.static_warper(scores)
+
+ next_id = self.choice(scores[-1]).view(1, 1)
+
+ return next_id, next_logprob
+
+ def advance_grammar(self, next_id: int):
+ if self.grammar_processor is not None:
+ self.fsm_grammar_state = self.grammar_processor.advance(
+ next_id, self.fsm_grammar_state
+ )
+ return self
+
+ @classmethod
+ def from_pb(
+ cls,
+ pb: generate_pb2.NextTokenChooserParameters,
+ device: torch.device,
+ tokenizer: PreTrainedTokenizerBase,
+ ) -> "NextTokenChooser":
+ return NextTokenChooser(
+ watermark=pb.watermark,
+ temperature=pb.temperature,
+ repetition_penalty=pb.repetition_penalty,
+ frequency_penalty=pb.frequency_penalty,
+ top_k=pb.top_k,
+ top_p=pb.top_p,
+ typical_p=pb.typical_p,
+ do_sample=pb.do_sample,
+ seed=pb.seed,
+ device=device,
+ tokenizer=tokenizer,
+ grammar=pb.grammar,
+ grammar_type=pb.grammar_type,
+ )
+
+
+class StopSequenceCriteria:
+ def __init__(self, stop_sequence: str):
+ stop_sequence = re.escape(stop_sequence)
+ self.regex = re.compile(f"{stop_sequence}$")
+
+ def __call__(self, output: str) -> bool:
+ if self.regex.findall(output):
+ return True
+ return False
+
+
+class StoppingCriteria:
+ def __init__(
+ self,
+ eos_token_ids: Optional[Union[Set[int], int]],
+ stop_sequence_criterias: List[StopSequenceCriteria],
+ max_new_tokens: int = 20,
+ ignore_eos_token: bool = False,
+ ):
+ if eos_token_ids is None:
+ eos_token_ids = set()
+ elif isinstance(eos_token_ids, int):
+ eos_token_ids = set([eos_token_ids])
+ elif isinstance(eos_token_ids, set):
+ eos_token_ids = eos_token_ids
+ else:
+ raise RuntimeError(
+ f"eos_token_ids is of invalid type {type(eos_token_ids)}, expected int, None or set[int]"
+ )
+ self.eos_token_ids = eos_token_ids
+ self.stop_sequence_criterias = stop_sequence_criterias
+ self.max_new_tokens = max_new_tokens
+ self.current_tokens = 0
+ self.current_output = ""
+
+ if os.getenv("TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN", "false") == "true":
+ self.ignore_eos_token = True
+ else:
+ self.ignore_eos_token = ignore_eos_token
+
+ def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
+ self.current_tokens += 1
+ if self.current_tokens >= self.max_new_tokens:
+ return True, FinishReason.FINISH_REASON_LENGTH
+
+ if isinstance(last_token, torch.Tensor):
+ last_token = last_token.item()
+
+ if not self.ignore_eos_token and last_token in self.eos_token_ids:
+ return True, FinishReason.FINISH_REASON_EOS_TOKEN
+
+ if self.stop_sequence_criterias:
+ self.current_output += last_output
+ # There is no need to keep an output that is too long
+ if len(self.current_output) > 300:
+ # Slice to -200 to avoid doing it all the time
+ self.current_output = self.current_output[-200:]
+ for stop_sequence_criteria in self.stop_sequence_criterias:
+ if stop_sequence_criteria(self.current_output):
+ return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
+
+ return False, None
+
+ @classmethod
+ def from_pb(
+ cls,
+ pb: generate_pb2.StoppingCriteriaParameters,
+ tokenizer: PreTrainedTokenizerBase,
+ ) -> "StoppingCriteria":
+ stop_sequence_criterias = [
+ StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
+ ]
+ # TODO Hack because eos_token_id cannot be what we want.
+ eos_token_id = getattr(tokenizer, "_eos_token_ids", tokenizer.eos_token_id)
+ return StoppingCriteria(
+ eos_token_id,
+ stop_sequence_criterias,
+ pb.max_new_tokens,
+ pb.ignore_eos_token,
+ )
+
+
+def create_n_gram_speculation(
+ input_ids: torch.Tensor,
+ next_ids: torch.Tensor,
+ accepted_ids: torch.Tensor,
+ speculate: int,
+ verbose: bool,
+):
+ # Very trivial approach, find first match in the string.
+ # This is much less refined than actual n-gram but seems to work
+ # relatively OK in grounded mode and is by far much faster with
+ # much less worst case complexity as everything happens on device.
+ B = accepted_ids.shape[0]
+ device = input_ids.device
+ seeds = next_ids[accepted_ids.cumsum(dim=-1) - 1]
+ indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
+ all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(
+ speculate, device=device
+ )
+ all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)
+
+ speculative_ids = input_ids.gather(dim=-1, index=all_indices)
+ return speculative_ids
+
+
+class HeterogeneousNextTokenChooser:
+ def __init__(
+ self,
+ dtype: torch.dtype,
+ device: torch.device,
+ watermark: List[bool],
+ temperature: List[float],
+ repetition_penalty: List[float],
+ frequency_penalty: List[float],
+ top_k: List[int],
+ top_p: List[float],
+ typical_p: List[float],
+ do_sample: List[bool],
+ seeds: List[int],
+ tokenizer: PreTrainedTokenizerBase,
+ grammars: List[str],
+ grammar_types: List[int],
+ fsm_grammar_states: List[int],
+ quantization_enabled: bool,
+ ):
+ warpers = []
+
+ # TODO: enable watermark with FP8 quantization
+ self.watermark_processor = (
+ HeterogeneousProcessorWrapper(
+ {
+ i: WatermarkLogitsProcessor(device=device)
+ for i, do_watermark in enumerate(watermark)
+ if do_watermark
+ }
+ )
+ if any(watermark) and not quantization_enabled
+ else None
+ )
+
+ self.repetition_processor = (
+ HeterogeneousRepetitionPenaltyLogitsProcessor(
+ repetition_penalty, dtype, device
+ )
+ if any([x != 1.0 for x in repetition_penalty])
+ else None
+ )
+
+ self.frequency_processor = (
+ HeterogeneousFrequencyPenaltyLogitsProcessor(
+ frequency_penalty, dtype, device
+ )
+ if any([x != 0.0 for x in frequency_penalty])
+ else None
+ )
+
+ self.grammar_processor = (
+ HeterogeneousGrammarLogitProcessor(
+ tokenizer, device, grammars, grammar_types
+ )
+ if any([grammar != "" for grammar in grammars])
+ else None
+ )
+
+ if any(x != 1.0 for x in temperature):
+ do_sample = [
+ sample or x != 1.0 for x, sample in zip(temperature, do_sample)
+ ]
+ warpers.append(
+ HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
+ )
+
+ if any(x != 0 for x in top_k):
+ do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
+ warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
+
+ if any(x < 1.0 for x in top_p):
+ do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
+ warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
+
+ if any(x < 1.0 for x in typical_p):
+ do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
+ warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))
+
+ self.warpers = warpers
+
+ if any(do_sample):
+ self.choice = HeterogeneousSampling(do_sample, seeds, device)
+ else:
+ self.choice = Greedy()
+
+ self.seeds = seeds
+ self.do_sample = do_sample
+ self.dtype = dtype
+ self.device = device
+ self.tokenizer = tokenizer
+ self.fsm_grammar_states = fsm_grammar_states
+ self.grammars = grammars
+ self.grammar_types = grammar_types
+
+ def __call__(
+ self,
+ input_ids: torch.Tensor,
+ scores: torch.Tensor,
+ speculate: int,
+ speculated_ids: Optional[torch.Tensor] = None,
+ speculative_scores: Optional[torch.Tensor] = None,
+ verbose=False,
+ ):
+ if speculated_ids is not None:
+ B = scores.shape[0] // (speculated_ids.shape[1] + 1)
+ S = speculated_ids.shape[1] + 1
+ scores = scores.view(B, S, -1)
+ else:
+ B = scores.shape[0]
+ S = 1
+ scores = scores.view(B, S, -1)
+
+ next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
+
+ for j in range(S):
+ _scores = scores[:, j]
+ if self.watermark_processor is not None:
+ _scores = self.watermark_processor(input_ids, _scores)
+ if self.repetition_processor is not None:
+ _scores = self.repetition_processor(input_ids, _scores)
+ if self.frequency_processor is not None:
+ _scores = self.frequency_processor(input_ids, _scores)
+ if self.grammar_processor is not None:
+ _scores = self.grammar_processor(_scores, self.fsm_grammar_states)
+ for warper in self.warpers:
+ _scores = warper(input_ids, _scores)
+ _next_ids = self.choice(_scores)
+ scores[:, j] = _scores
+ next_ids[:, j] = _next_ids
+ next_ids = next_ids.view(B * S)
+ allscores = scores.view(B * S, -1)
+ alllogprobs = torch.log_softmax(allscores, -1)
+
+ if speculated_ids is not None:
+ accepted_ids = []
+ B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
+ S = speculated_ids.shape[1] + 1
+ indices = []
+ for i in range(B):
+ _next_ids = next_ids[i * S : (i + 1) * S]
+ _speculated_ids = speculated_ids[i]
+ validate_speculative = _next_ids[:-1] == _speculated_ids
+ index = i * S
+ accepted = 1
+ # First is always valid
+ indices.append(index)
+ for valid in validate_speculative.tolist():
+ if valid:
+ index += 1
+ accepted += 1
+ indices.append(index)
+ else:
+ break
+ accepted_ids.append(accepted)
+
+ accepted_ids = torch.tensor(
+ accepted_ids, device=input_ids.device, dtype=input_ids.dtype
+ )
+ next_ids = next_ids[indices]
+ logprobs = alllogprobs[indices]
+ indices = torch.arange(B, device=input_ids.device) * S
+ if speculative_scores is not None:
+ speculative_scores = speculative_scores[indices + accepted_ids - 1]
+ else:
+ accepted_ids = torch.ones_like(next_ids)
+ logprobs = alllogprobs
+
+ next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
+
+ if speculate > 0:
+ if speculative_scores is not None:
+ # Medusa provided some scores
+ speculative_ids = Greedy()(speculative_scores)
+ else:
+ # n-gram
+ speculative_ids = create_n_gram_speculation(
+ input_ids, next_ids, accepted_ids, speculate, verbose
+ )
+ else:
+ speculative_ids = None
+
+ return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
+
+ def advance_grammar(self, next_ids: List[int]):
+ if self.grammar_processor is not None:
+ other_new_states = self.grammar_processor.advance_batch(
+ next_ids, self.fsm_grammar_states
+ )
+ self.fsm_grammar_states = other_new_states
+ return self
+
+ def advance_grammar_single(self, grammar_state_index: int, next_id: int):
+ if self.grammar_processor is not None:
+ self.fsm_grammar_states[grammar_state_index] = (
+ self.grammar_processor.advance_at_index(
+ next_id,
+ self.fsm_grammar_states[grammar_state_index],
+ grammar_state_index,
+ )
+ )
+ return self
+
+ def advance_grammar_single_with_past_state(
+ self, grammar_state_index: int, next_id: torch.Tensor, past_state: int
+ ):
+ if self.grammar_processor is not None:
+ next_id = next_id.item()
+ self.fsm_grammar_states[grammar_state_index] = (
+ self.grammar_processor.advance_at_index(
+ next_id,
+ past_state,
+ grammar_state_index,
+ )
+ )
+ return self
+
+ def filter(self, indices):
+ if self.watermark_processor is not None:
+ self.watermark_processor = self.watermark_processor.filter(indices)
+
+ if self.repetition_processor is not None:
+ self.repetition_processor = self.repetition_processor.filter(indices)
+
+ if self.frequency_processor is not None:
+ self.frequency_processor = self.frequency_processor.filter(indices)
+
+ if self.grammar_processor is not None:
+ self.grammar_processor = self.grammar_processor.filter(indices)
+
+ filtered_warpers = []
+ for warper in self.warpers:
+ filtered_warper = warper.filter(indices)
+ if filtered_warper is not None:
+ filtered_warpers.append(filtered_warper)
+ self.warpers = filtered_warpers
+
+ self.seeds = [self.seeds[i] for i in indices]
+ self.do_sample = [self.do_sample[i] for i in indices]
+
+ new_grammars = []
+ new_fsm_grammar_states = []
+ new_grammar_types = []
+ for i in indices:
+ new_grammars.append(self.grammars[i])
+ new_fsm_grammar_states.append(self.fsm_grammar_states[i])
+ new_grammar_types.append(self.grammar_types[i])
+
+ self.grammars = new_grammars
+ self.fsm_grammar_states = new_fsm_grammar_states
+ self.grammar_types = new_grammar_types
+
+ if any(self.do_sample):
+ self.choice.filter(indices)
+ else:
+ self.choice = Greedy()
+
+ return self
+
+ @classmethod
+ def from_pb(
+ cls,
+ pb: List[generate_pb2.NextTokenChooserParameters],
+ dtype: torch.dtype,
+ device: torch.device,
+ tokenizer: PreTrainedTokenizerBase,
+ fsm_grammar_states: Optional[List[int]] = None,
+ quantization_enabled: bool = False,
+ ) -> "HeterogeneousNextTokenChooser":
+ return HeterogeneousNextTokenChooser(
+ watermark=[pb_.watermark for pb_ in pb],
+ temperature=[pb_.temperature for pb_ in pb],
+ repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
+ frequency_penalty=[pb_.frequency_penalty for pb_ in pb],
+ top_k=[pb_.top_k for pb_ in pb],
+ top_p=[pb_.top_p for pb_ in pb],
+ typical_p=[pb_.typical_p for pb_ in pb],
+ do_sample=[pb_.do_sample for pb_ in pb],
+ seeds=[pb_.seed for pb_ in pb],
+ device=device,
+ dtype=dtype,
+ tokenizer=tokenizer,
+ grammars=[pb_.grammar for pb_ in pb],
+ grammar_types=[pb_.grammar_type for pb_ in pb],
+ fsm_grammar_states=(
+ fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
+ ),
+ quantization_enabled=quantization_enabled,
+ )
+
+
+def pad_next_token_chooser_parameters(
+ parameters: List[generate_pb2.NextTokenChooserParameters],
+ expected_size: int,
+) -> List[generate_pb2.NextTokenChooserParameters]:
+ # disable all logits processors to minimize padding overhead
+ empty_parameters = generate_pb2.NextTokenChooserParameters(
+ temperature=1.0,
+ top_k=0,
+ top_p=1.0,
+ typical_p=1.0,
+ do_sample=False,
+ seed=0,
+ repetition_penalty=1.0,
+ frequency_penalty=0.0,
+ watermark=False,
+ grammar="",
+ grammar_type=0,
+ )
+ parameters.extend([empty_parameters] * (expected_size - len(parameters)))
+ return parameters
+
+
+class Sampling:
+ def __init__(self, seed: int, device: str = "cpu"):
+ self.generator = torch.Generator("cpu")
+ self.generator.manual_seed(seed)
+ self.seed = seed
+
+ def __call__(self, logits):
+ probs = torch.nn.functional.softmax(logits, -1)
+ # Avoid GPU<->CPU sync done by torch multinomial
+ # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
+ q = torch.empty_like(probs).exponential_(1, generator=self.generator)
+ return probs.div_(q).argmax()
+
+
+class Greedy:
+ def __call__(self, logits):
+ return logits.argmax(dim=-1)
+
+
+class HeterogeneousSampling:
+ r"""
+ Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
+ """
+
+ def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device):
+ self.seeds = seeds
+
+ self.greedy_indices = []
+ self.sampling_mapping = {}
+ for i, (sample, seed) in enumerate(zip(do_sample, seeds)):
+ if sample:
+ self.sampling_mapping[i] = Sampling(seed, device)
+ else:
+ self.greedy_indices.append(i)
+
+ self.greedy = Greedy()
+
+ def __call__(self, logits):
+ out = torch.zeros(logits.shape[0], dtype=torch.int64, device=logits.device)
+ if self.greedy_indices:
+ # Computing for all indices is faster than slicing
+ torch.argmax(logits, -1, out=out)
+
+ for i, sampling in self.sampling_mapping.items():
+ out[i] = sampling(logits[i])
+ return out
+
+ def filter(self, indices):
+ new_greedy_indices = []
+ new_sampling_mapping = {}
+ for i, idx in enumerate(indices):
+ if idx in self.sampling_mapping:
+ new_sampling_mapping[i] = self.sampling_mapping[idx]
+ else:
+ new_greedy_indices.append(i)
+
+ self.greedy_indices = new_greedy_indices
+ self.sampling_mapping = new_sampling_mapping
+ return self
+
+
+def batch_top_tokens(
+ top_n_tokens: List[int],
+ top_n_tokens_tensor: torch.Tensor,
+ logprobs: torch.Tensor,
+ accepted_ids: torch.Tensor,
+) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
+ """Find the top n most likely tokens for a batch of generations.
+
+ When multiple tokens have equal probabilities and they don't all fit, the
+ remaining tokens are also returned.
+ """
+ max_top_n = max(top_n_tokens)
+ # Early exit when top_n_tokens is not used
+ if max_top_n == 0:
+ return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens)
+
+ batch_size = accepted_ids.shape[0]
+ speculate_size = logprobs.shape[0] // batch_size
+ top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size)
+ # Ensure top_n doesn't exceed vocab size
+ top_n_tokens = [
+ min(tok, logprobs.size(-1))
+ for tok in top_n_tokens
+ for _ in range(speculate_size)
+ ]
+
+ # Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
+ # Sorted topk is faster than torch.sort() since we only need a small subset
+ sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=-1, sorted=True).values
+
+ nth_highest = torch.gather(
+ sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
+ )
+ nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min
+
+ # Find the new "fuzzy" top n values
+ top_n_indices = (logprobs >= nth_highest).nonzero()
+ _, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
+
+ k = 1 if top_n_ishes.numel() == 0 else top_n_ishes.max()
+ # Take a new topk for these new max n values
+ top_k = torch.topk(logprobs, k=k, dim=1, sorted=True)
+
+ top_n_ishes = top_n_ishes.tolist()
+ top_indices = top_k.indices.tolist()
+ top_values = top_k.values.tolist()
+
+ batch_top_token_ids = []
+ batch_top_token_logprobs = []
+ accepted_ids_list = accepted_ids.tolist()
+ for i, n_accepted_ids in enumerate(accepted_ids_list):
+ start = speculate_size * i
+ stop = speculate_size * (i + 1)
+ _top_indices = top_indices[start:stop]
+ _top_values = top_values[start:stop]
+ _top_n_ishes = top_n_ishes[start:stop]
+ _top_n_tokens = top_n_tokens[start:stop]
+
+ _top_indices = _top_indices[:n_accepted_ids]
+ _top_values = _top_values[:n_accepted_ids]
+ _top_n_ishes = _top_n_ishes[:n_accepted_ids]
+ _top_n_tokens = _top_n_tokens[:n_accepted_ids]
+
+ row_top_token_ids = []
+ row_top_token_logprobs = []
+
+ for idxs, vals, n, req_n in zip(
+ _top_indices, _top_values, _top_n_ishes, _top_n_tokens
+ ):
+ indices = idxs[:n] if req_n > 0 else []
+ values = vals[:n] if req_n > 0 else []
+
+ row_top_token_ids.append(indices)
+ row_top_token_logprobs.append(values)
+
+ batch_top_token_ids.append(row_top_token_ids)
+ batch_top_token_logprobs.append(row_top_token_logprobs)
+
+ return batch_top_token_ids, batch_top_token_logprobs
+
+
+def make_tokenizer_optional(tokenizer):
+ class _(type(tokenizer)):
+ def __call__(
+ self,
+ text,
+ return_tensors,
+ padding,
+ return_token_type_ids,
+ truncation,
+ max_length,
+ ):
+ assert (
+ return_tensors == "pt"
+ ), "inccorrect input arguments when calling TransparentTokenizer"
+ assert (
+ padding == "max_length" or padding == "longest"
+ ), "inccorrect input arguments when calling TransparentTokenizer"
+ assert (
+ not return_token_type_ids
+ ), "inccorrect input arguments when calling TransparentTokenizer"
+ assert (
+ truncation
+ ), "inccorrect input arguments when calling TransparentTokenizer"
+
+ def str_token_to_int(i):
+ if i == "?":
+ return tokenizer.pad_token_id
+ else:
+ return int(i)
+
+ all_tokens = [
+ [str_token_to_int(i.strip()) for i in inner_text.split(",")]
+ for inner_text in text
+ ]
+ if padding == "longest":
+ max_length = max(len(tokens) for tokens in all_tokens)
+ return {
+ "input_ids": torch.tensor(
+ [
+ [tokenizer.pad_token_id] * (max_length - len(tokens)) + tokens
+ for tokens in all_tokens
+ ]
+ ),
+ "attention_mask": torch.tensor(
+ [
+ [0] * (max_length - len(tokens)) + [1] * len(tokens)
+ for tokens in all_tokens
+ ]
+ ),
+ }
+
+ def decode(
+ self,
+ token_ids,
+ skip_special_tokens: bool = False,
+ clean_up_tokenization_spaces: bool = None,
+ **kwargs,
+ ) -> str:
+ # I don't think this method is used anywhere and should be removed when doing refactoring
+ return ",".join(str(i) for i in to_py_obj(token_ids)) # noqa: F821
+
+ if os.getenv("SKIP_TOKENIZER_IN_TGI", "false").lower() == "true":
+ tokenizer.__class__ = _
+ tokenizer.is_transparent = True
+
+
+def is_tokenizer_transparent(tokenizer):
+ return hasattr(tokenizer, "is_transparent") and tokenizer.is_transparent is True
diff --git a/backends/gaudi/server/text_generation_server/utils/version.py b/backends/gaudi/server/text_generation_server/utils/version.py
new file mode 100644
index 000000000..f54b6ae8f
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/version.py
@@ -0,0 +1,12 @@
+from optimum.habana.utils import get_driver_version
+from packaging.version import Version
+
+MIN_TGI_GAUDI_SYNAPSE_VERSION = Version("1.19.0")
+
+
+def is_driver_compatible():
+ driver_version = get_driver_version()
+ if driver_version is not None:
+ if driver_version < MIN_TGI_GAUDI_SYNAPSE_VERSION:
+ return False
+ return True
diff --git a/backends/gaudi/server/text_generation_server/utils/watermark.py b/backends/gaudi/server/text_generation_server/utils/watermark.py
new file mode 100644
index 000000000..5092b076c
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/watermark.py
@@ -0,0 +1,98 @@
+# coding=utf-8
+# Copyright 2023 Authors of "A Watermark for Large Language Models"
+# available at https://arxiv.org/abs/2301.10226
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+import torch
+from transformers import LogitsProcessor
+from typing import List, Union
+
+GAMMA = float(os.getenv("WATERMARK_GAMMA", 0.5))
+DELTA = float(os.getenv("WATERMARK_DELTA", 2.0))
+
+
+class WatermarkLogitsProcessor(LogitsProcessor):
+ def __init__(
+ self,
+ gamma: float = GAMMA,
+ delta: float = DELTA,
+ hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
+ device: str = "cpu",
+ ):
+ # watermarking parameters
+ self.gamma = gamma
+ self.delta = delta
+ self.rng = torch.Generator(device="cpu")
+ self.hash_key = hash_key
+
+ def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]):
+ if isinstance(input_ids, list):
+ assert (
+ len(input_ids) >= 1
+ ), "requires at least a 1 token prefix sequence to seed rng"
+ prev_token = input_ids[-1]
+ else:
+ assert len(input_ids) == 1
+ input_ids = input_ids[0]
+ assert (
+ input_ids.shape[-1] >= 1
+ ), "requires at least a 1 token prefix sequence to seed rng"
+ prev_token = input_ids[-1].item()
+ self.rng.manual_seed(self.hash_key * prev_token)
+
+ def _get_greenlist_ids(
+ self,
+ input_ids: Union[List[int], torch.LongTensor],
+ max_value: int,
+ device: torch.device,
+ ) -> List[int]:
+ # seed the rng using the previous tokens/prefix
+ self._seed_rng(input_ids)
+
+ greenlist_size = int(max_value * self.gamma)
+ vocab_permutation = torch.randperm(max_value, device=device, generator=self.rng)
+ greenlist_ids = vocab_permutation[:greenlist_size]
+ return greenlist_ids
+
+ @staticmethod
+ def _calc_greenlist_mask(
+ scores: torch.FloatTensor, greenlist_token_ids
+ ) -> torch.BoolTensor:
+ green_tokens_mask = torch.zeros_like(scores)
+ green_tokens_mask[-1, greenlist_token_ids] = 1
+ final_mask = green_tokens_mask.bool()
+ return final_mask
+
+ @staticmethod
+ def _bias_greenlist_logits(
+ scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float
+ ) -> torch.Tensor:
+ scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
+ return scores
+
+ def __call__(
+ self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor
+ ) -> torch.FloatTensor:
+ greenlist_ids = self._get_greenlist_ids(
+ input_ids, scores.shape[-1], scores.device
+ )
+ green_tokens_mask = self._calc_greenlist_mask(
+ scores=scores, greenlist_token_ids=greenlist_ids
+ )
+
+ scores = self._bias_greenlist_logits(
+ scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta
+ )
+ return scores
diff --git a/backends/gaudi/server/text_generation_server/utils/weights.py b/backends/gaudi/server/text_generation_server/utils/weights.py
new file mode 100644
index 000000000..acd598d7a
--- /dev/null
+++ b/backends/gaudi/server/text_generation_server/utils/weights.py
@@ -0,0 +1,437 @@
+import torch
+
+from abc import ABC, abstractmethod
+from contextlib import contextmanager
+from pathlib import Path
+from typing import Dict, List, Optional, Union, Type
+from safetensors import safe_open
+from dataclasses import dataclass
+
+
+class WeightsLoader(ABC):
+ """
+ Instances of this type implement higher-level weight loading.
+
+ At a low-level, every weight is stored in the Safetensors format.
+ The interpretation of weights may be different however, for instance
+ could be packed, quantized weights. Loaders are responsible for
+ interpreting the raw tensors, sharding tensors in a manner compatible
+ with the format, etc.
+ """
+
+ @abstractmethod
+ def get_weights(self, weights: "Weights", prefix: str):
+ """
+ Get weights at the given prefix and apply without tensor paralllism.
+ """
+ ...
+
+ @abstractmethod
+ def get_weights_col_packed(
+ self,
+ weights: "Weights",
+ prefix: str,
+ block_sizes: Union[int, List[int]],
+ ):
+ """
+ Get the packed weights at the given prefix with column-splitting for
+ tensor parallelism. This method should be used when multiple different
+ weights are packed into a tensor, for instance, query/key/value
+ weights or a gate/up projection.
+
+ The `block_sizes` determines the proportions of the packed tensors.
+ The columns are split in equally sized blocks when `block_sizes` is an
+ `int`, or in blocks proportional given to the sizes. For instance
+ `[2, 1, 1]` will divide an input with dimensionality `1024` in
+ `[512, 256, 256]`.
+ """
+ ...
+
+ def get_weights_col(self, weights: "Weights", prefix: str):
+ """
+ Get weights at the given prefix and apply column-splitting for tensor
+ paralllism.
+ """
+ return weights.get_multi_weights_col([prefix], 0)
+
+ @abstractmethod
+ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
+ """
+ Get the weights at the given prefixes, column-split them for tensor
+ parallelim, and then concatenate the weights along the given dimension.
+ """
+ ...
+
+ @abstractmethod
+ def get_weights_row(self, weights: "Weights", prefix: str):
+ """
+ Get the weights at the given prefix and apply row-splitting for tensor
+ parallism.
+ """
+ ...
+
+
+class Weight(ABC):
+ """Instances of this type implement unquantized/quantized/to-be
+ quantized weights."""
+
+ @abstractmethod
+ def get_linear(self, bias: torch.Tensor):
+ """Create a linear layer from this weight."""
+ ...
+
+
+@dataclass
+class UnquantizedWeight(Weight):
+ weight: torch.Tensor
+
+ def get_linear(self, bias: torch.Tensor):
+ from text_generation_server.layers.linear import FastLinear
+
+ return FastLinear(self.weight, bias)
+
+
+class DefaultWeightsLoader(WeightsLoader):
+ """Weight loader that loads (unquantized) Torch tensors."""
+
+ def __init__(self, weight_class: Type[UnquantizedWeight]):
+ """Create a loader. Weights will be wrapped using the given `weights_class`,
+ normally this will be `UnquantizedWeight`, but a quantizer-specific class
+ such as `Fp8Weight` can be used to quantize the weights during loading.
+ """
+ self.weight_class = weight_class
+
+ """
+ Loader that uses tensors as-is with the exception of applying sharding
+ and/or concatenation.
+ """
+
+ def get_weights(self, weights: "Weights", prefix: str):
+ return weights.get_tensor(f"{prefix}.weight")
+
+ def get_weights_col_packed(
+ self,
+ weights: "Weights",
+ prefix: str,
+ block_sizes: Union[int, List[int]],
+ ):
+ return self.weight_class(
+ weights.get_packed_sharded(
+ f"{prefix}.weight", dim=0, block_sizes=block_sizes
+ ),
+ )
+
+ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
+ w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
+ return self.weight_class(torch.cat(w, dim=dim))
+
+ def get_weights_row(self, weights: "Weights", prefix: str):
+ return self.weight_class(
+ weights.get_sharded(f"{prefix}.weight", dim=1),
+ )
+
+
+class Weights:
+ def __init__(
+ self,
+ filenames: List[Path],
+ device,
+ dtype,
+ process_group,
+ weights_loader: WeightsLoader,
+ aliases: Optional[Dict[str, List[str]]] = None,
+ prefix: Optional[str] = None,
+ ):
+ routing = {}
+ for filename in filenames:
+ with safe_open(filename, framework="pytorch") as f:
+ for k in f.keys():
+ if k in routing:
+ raise RuntimeError(
+ f"Key {k} was found in multiple files: {filename} and {routing[k]}"
+ )
+ routing[k] = filename
+ if aliases is None:
+ aliases = {}
+ self.aliases = aliases
+ self.routing = routing
+ self.device = device
+ self.dtype = dtype
+ self.process_group = process_group
+ self.prefix = prefix
+ self.weights_loader = weights_loader
+ self._handles = {}
+
+ def _get_handle(self, filename):
+ if filename not in self._handles:
+ f = safe_open(filename, framework="pytorch")
+ self._handles[filename] = f
+
+ return self._handles[filename]
+
+ def get_filename(self, tensor_name: str) -> (str, str):
+ names = [tensor_name]
+ if self.prefix is not None:
+ prefixed = f"{self.prefix}.{tensor_name}"
+ names.append(prefixed)
+ for name in names:
+ filename = self.routing.get(name, None)
+ if filename is not None:
+ return str(filename), name
+
+ aliases = self.aliases.get(name, [])
+ for alias in aliases:
+ filename = self.routing.get(alias, None)
+ if filename is not None:
+ return str(filename), alias
+ raise RuntimeError(f"weight {tensor_name} does not exist")
+
+ def _get_slice(self, tensor_name: str):
+ filename, tensor_name = self.get_filename(tensor_name)
+ f = self._get_handle(filename)
+ slice_ = f.get_slice(tensor_name)
+ return slice_
+
+ def has_tensor(self, tensor_name: str):
+ try:
+ self.get_filename(tensor_name)
+ except Exception:
+ return False
+ return True
+
+ def get_shape(self, tensor_name: str):
+ return self._get_slice(tensor_name).get_shape()
+
+ def get_tensor(
+ self, tensor_name: str, to_device: bool = True, to_dtype: bool = True
+ ) -> torch.Tensor:
+ filename, tensor_name = self.get_filename(tensor_name)
+ f = self._get_handle(filename)
+ tensor = f.get_tensor(tensor_name)
+ # Special case for gptq which shouldn't convert
+ # u4 which are disguised as int32. Exl2 uses int16
+ # as well. FP8 uses torch.float8_e4m3fn
+ if (
+ tensor.dtype
+ not in [
+ torch.float8_e4m3fn,
+ torch.int8,
+ torch.int16,
+ torch.int32,
+ torch.int64,
+ ]
+ and to_dtype
+ ):
+ tensor = tensor.to(dtype=self.dtype)
+ if to_device:
+ tensor = tensor.to(device=self.device)
+ return tensor
+
+ def get_partial_sharded(
+ self, tensor_name: str, dim: int, to_device=True, to_dtype=True
+ ):
+ filename, tensor_name = self.get_filename(tensor_name)
+ f = self._get_handle(filename)
+ slice_ = f.get_slice(tensor_name)
+ world_size = self.process_group.size()
+ rank = self.process_group.rank()
+
+ size = slice_.get_shape()[dim]
+ block_size = (size + world_size - 1) // world_size
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+
+ if dim == 0:
+ tensor = slice_[start:stop]
+ elif dim == 1:
+ tensor = slice_[:, start:stop]
+ else:
+ raise NotImplementedError("Let's make that generic when needed")
+ # Special case for gptq which shouldn't convert
+ # u4 which are disguised as int32. exl2 uses int16.
+ # FP8 uses torch.float8_e4m3fn.
+ if (
+ tensor.dtype
+ not in (torch.float8_e4m3fn, torch.int8, torch.int16, torch.int32)
+ and to_dtype
+ ):
+ tensor = tensor.to(dtype=self.dtype)
+ if to_device:
+ tensor = tensor.to(device=self.device)
+ return tensor
+
+ def get_sharded(self, tensor_name: str, dim: int, to_device=True, to_dtype=True):
+ filename, tensor_name = self.get_filename(tensor_name)
+ f = self._get_handle(filename)
+ slice_ = f.get_slice(tensor_name)
+ world_size = self.process_group.size()
+ size = slice_.get_shape()[dim]
+ assert (
+ size % world_size == 0
+ ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
+ return self.get_partial_sharded(
+ tensor_name, dim, to_device=to_device, to_dtype=to_dtype
+ )
+
+ def get_packed_sharded(
+ self,
+ tensor_name: str,
+ dim: int,
+ block_sizes: Union[int, List[int]],
+ to_dtype=True,
+ ) -> torch.Tensor:
+ """
+ Get a shard from a tensor that packs multiple tensors.
+
+ When a tensor packs multiple tensors (such as QKV or an up
+ projection + gate projection), sharding with `get_sharded` is not
+ safe since it would not split the packed tensors across shards.
+
+ This method shards a tensor, such that the packed tensors are
+ split across shards.
+
+ The columns are split in equally sized blocks when blocks is an `int`, or
+ in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
+ divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
+ convenient for e.g. splitting QKV without knowing the storage details of
+ quantized weights.
+ """
+ slice_ = self._get_slice(tensor_name)
+ total_size = slice_.get_shape()[dim]
+ block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)
+
+ world_size = self.process_group.size()
+ rank = self.process_group.rank()
+
+ tensors = []
+ block_offset = 0
+ for block_size in block_sizes:
+ assert (
+ block_size % world_size == 0
+ ), f"Prepacked tensor cannot be sharded across {world_size} shards"
+ shard_block_size = block_size // world_size
+ start = rank * shard_block_size
+ stop = (rank + 1) * shard_block_size
+ if dim == 0:
+ tensor = slice_[block_offset + start : block_offset + stop]
+ elif dim == 1:
+ tensor = slice_[:, block_offset + start : block_offset + stop]
+ else:
+ raise NotImplementedError("Currently only dim=0 or dim=1 is supported")
+ tensors.append(tensor)
+ block_offset += block_size
+ tensor = torch.cat(tensors, dim=dim)
+ tensor = tensor.to(device=self.device)
+
+ # Avoid casting quantizer dtypes.
+ if (
+ tensor.dtype
+ not in [
+ torch.float8_e4m3fn,
+ torch.int8,
+ torch.int16,
+ torch.int32,
+ torch.int64,
+ ]
+ and to_dtype
+ ):
+ tensor = tensor.to(dtype=self.dtype)
+
+ return tensor
+
+ def get_weights(self, prefix: str):
+ return self.weights_loader.get_weights(self, prefix)
+
+ def get_weights_col_packed_qkv(
+ self,
+ prefix: str,
+ num_heads: int,
+ num_key_value_heads: int,
+ ):
+ return self.get_weights_col_packed(
+ prefix, [num_heads, num_key_value_heads, num_key_value_heads]
+ )
+
+ def get_weights_col_packed_gate_up(self, prefix: str):
+ return self.get_weights_col_packed(prefix, 2)
+
+ def get_weights_col_packed(self, prefix: str, block_sizes: Union[int, List[int]]):
+ """
+ The columns are split in equally sized blocks when blocks is an `int`, or
+ in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
+ divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
+ convenient for e.g. splitting QKV without knowing the storage details of
+ quantized weights.
+ """
+ return self.weights_loader.get_weights_col_packed(self, prefix, block_sizes)
+
+ def get_weights_col(self, prefix: str):
+ return self.weights_loader.get_weights_col(self, prefix)
+
+ def get_multi_weights_col(self, prefixes: List[str], dim: int):
+ return self.weights_loader.get_multi_weights_col(self, prefixes, dim)
+
+ def get_tensor_shard(self, var, dim):
+ world_size = self.process_group.size()
+ rank = self.process_group.rank()
+ block_size = var.size()[dim] // world_size
+ start = rank * block_size
+ stop = (rank + 1) * block_size
+ if dim == 0:
+ tensor = var[start:stop]
+ elif dim == 1:
+ tensor = var[:, start:stop]
+ else:
+ raise NotImplementedError("Let's make that generic when needed")
+ tensor = tensor.to(dtype=self.dtype)
+ tensor = tensor.to(device=self.device)
+ return tensor
+
+ def get_weights_row(self, prefix: str):
+ return self.weights_loader.get_weights_row(self, prefix)
+
+ @contextmanager
+ def use_loader(self, weights_loader: WeightsLoader):
+ """
+ This method is a context manager that can be used to use `Weights` with
+ a different loader for the duration of the context.
+ """
+
+ old_loader = self.weights_loader
+ self.weights_loader = weights_loader
+ try:
+ yield
+ finally:
+ self.weights_loader = old_loader
+
+ @property
+ def loader(self):
+ return self.weights_loader
+
+
+def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
+ """
+ Convert block count or proportions to block sizes.
+
+ This function accepts
+
+ - The number of blocks (int), in which case the block size is
+ total_size//blocks; or
+ - A list of block sizes (List[int]).
+
+ In the latter case, if sum(blocks) < total_size, the ratios between
+ the block sizes will be preserved. For instance, if blocks is
+ [2, 1, 1] and total_size is 1024, the returned block sizes are
+ [512, 256, 256].
+ """
+ if isinstance(blocks, list):
+ total_blocks = sum(blocks)
+ assert (
+ total_size % total_blocks == 0
+ ), f"Cannot split {total_size} in proportional blocks: {blocks}"
+ part_size = total_size // total_blocks
+ return [part_size * block for block in blocks]
+ else:
+ assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
+ single_size = total_size // blocks
+ return [single_size] * blocks
diff --git a/backends/gaudi/tgi-entrypoint.sh b/backends/gaudi/tgi-entrypoint.sh
new file mode 100644
index 000000000..a5c3f5e1d
--- /dev/null
+++ b/backends/gaudi/tgi-entrypoint.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+
+ldconfig 2>/dev/null || echo 'unable to refresh ld cache, not a big deal in most cases'
+
+# Check if --sharded argument is present in the command line arguments
+if [[ "$*" == *"--sharded true"* ]]; then
+ echo 'setting PT_HPU_ENABLE_LAZY_COLLECTIVES=1 for sharding'
+ export PT_HPU_ENABLE_LAZY_COLLECTIVES=1
+fi
+
+text-generation-launcher $@
diff --git a/backends/grpc-metadata/src/lib.rs b/backends/grpc-metadata/src/lib.rs
index 3068a61c3..822b03072 100644
--- a/backends/grpc-metadata/src/lib.rs
+++ b/backends/grpc-metadata/src/lib.rs
@@ -8,7 +8,7 @@ use tracing_opentelemetry::OpenTelemetrySpanExt;
/// Inject context in the metadata of a gRPC request.
struct MetadataInjector<'a>(pub &'a mut tonic::metadata::MetadataMap);
-impl<'a> Injector for MetadataInjector<'a> {
+impl Injector for MetadataInjector<'_> {
/// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs
fn set(&mut self, key: &str, value: String) {
if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) {
diff --git a/backends/llamacpp/Cargo.toml b/backends/llamacpp/Cargo.toml
new file mode 100644
index 000000000..685a313f1
--- /dev/null
+++ b/backends/llamacpp/Cargo.toml
@@ -0,0 +1,22 @@
+[package]
+name = "text-generation-router-llamacpp"
+version.workspace = true
+edition.workspace = true
+authors.workspace = true
+homepage.workspace = true
+
+[build-dependencies]
+bindgen = "0.71.1"
+pkg-config = "0.3.31"
+
+[dependencies]
+async-trait = "0.1.85"
+clap = "4.5.27"
+hf-hub.workspace = true
+num_cpus = "1.16.0"
+text-generation-router = { path = "../../router" }
+thiserror = "2.0.11"
+tokenizers.workspace = true
+tokio = { version = "1.43.0", features = ["process"] }
+tokio-stream = "0.1.17"
+tracing = "0.1.41"
diff --git a/backends/llamacpp/README.md b/backends/llamacpp/README.md
new file mode 100644
index 000000000..0971efc5a
--- /dev/null
+++ b/backends/llamacpp/README.md
@@ -0,0 +1,24 @@
+# Llamacpp backend
+
+If all your dependencies are installed at the system level, running
+cargo build should be sufficient. However, if you want to experiment
+with different versions of llama.cpp, some additional setup is required.
+
+## Install llama.cpp
+
+ LLAMACPP_PREFIX=$(pwd)/llama.cpp.out
+
+ git clone https://github.com/ggerganov/llama.cpp
+ cd llama.cpp
+ cmake -B build \
+ -DCMAKE_INSTALL_PREFIX="$LLAMACPP_PREFIX" \
+ -DLLAMA_BUILD_COMMON=OFF \
+ -DLLAMA_BUILD_TESTS=OFF \
+ -DLLAMA_BUILD_EXAMPLES=OFF \
+ -DLLAMA_BUILD_SERVER=OFF
+ cmake --build build --config Release -j
+ cmake --install build
+
+## Build TGI
+
+ PKG_CONFIG_PATH="$LLAMACPP_PREFIX/lib/pkgconfig" cargo build
diff --git a/backends/llamacpp/build.rs b/backends/llamacpp/build.rs
new file mode 100644
index 000000000..8f00f3b5b
--- /dev/null
+++ b/backends/llamacpp/build.rs
@@ -0,0 +1,49 @@
+use bindgen::callbacks::{ItemInfo, ParseCallbacks};
+use std::env;
+use std::path::PathBuf;
+
+#[derive(Debug)]
+struct PrefixStripper;
+
+impl ParseCallbacks for PrefixStripper {
+ fn generated_name_override(&self, item_info: ItemInfo<'_>) -> Option {
+ item_info.name.strip_prefix("llama_").map(str::to_string)
+ }
+}
+
+fn main() {
+ if let Some(cuda_version) = option_env!("CUDA_VERSION") {
+ let mut version: Vec<&str> = cuda_version.split('.').collect();
+ if version.len() > 2 {
+ version.pop();
+ }
+ let cuda_version = format!("cuda-{}", version.join("."));
+ pkg_config::Config::new().probe(&cuda_version).unwrap();
+ }
+ let llama = pkg_config::Config::new().probe("llama").unwrap();
+
+ for path in &llama.link_paths {
+ println!("cargo:rustc-link-arg=-Wl,-rpath,{}", path.display());
+ }
+ if cfg!(target_os = "linux") {
+ println!("cargo:rustc-link-arg=-Wl,--disable-new-dtags");
+ }
+ let bindings = bindgen::Builder::default()
+ .clang_args(
+ llama
+ .include_paths
+ .iter()
+ .map(|p| format!("-I{}", p.display())),
+ )
+ .header_contents("llama_bindings.h", "#include ")
+ .prepend_enum_name(false)
+ .parse_callbacks(Box::new(PrefixStripper))
+ .parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
+ .generate()
+ .expect("Unable to generate bindings");
+
+ let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
+ bindings
+ .write_to_file(out_path.join("llamacpp.rs"))
+ .expect("Couldn't write bindings!");
+}
diff --git a/backends/llamacpp/requirements.txt b/backends/llamacpp/requirements.txt
new file mode 100644
index 000000000..293cd2055
--- /dev/null
+++ b/backends/llamacpp/requirements.txt
@@ -0,0 +1,4 @@
+transformers==4.49
+huggingface-hub==0.28.1
+hf-transfer==0.1.9
+torch==2.6.0
diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs
new file mode 100644
index 000000000..3405cfadd
--- /dev/null
+++ b/backends/llamacpp/src/backend.rs
@@ -0,0 +1,674 @@
+use crate::llamacpp;
+
+use async_trait::async_trait;
+use std::ffi::CString;
+use std::mem::replace;
+use std::str::FromStr;
+use std::sync::{mpsc, Once};
+use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
+use text_generation_router::validation::ValidGenerateRequest;
+use text_generation_router::{FinishReason, Token};
+use thiserror::Error;
+use tokenizers::Tokenizer;
+use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
+use tokio::sync::{oneshot, watch};
+use tokio::task::{spawn, spawn_blocking};
+use tokio::time::{timeout, Duration, Instant};
+use tokio_stream::wrappers::UnboundedReceiverStream;
+use tracing::instrument;
+use tracing::{debug, error, info, trace, warn};
+
+#[derive(Debug, Clone, Copy)]
+pub enum LlamacppSplitMode {
+ GPU(usize),
+ Layer,
+ Row,
+}
+
+impl FromStr for LlamacppSplitMode {
+ type Err = String;
+ fn from_str(s: &str) -> Result {
+ match s.to_lowercase().as_str() {
+ "layer" => Ok(LlamacppSplitMode::Layer),
+ "row" => Ok(LlamacppSplitMode::Row),
+ _ => match s.parse::() {
+ Ok(n) => Ok(LlamacppSplitMode::GPU(n)),
+ Err(_) => Err("Choose a GPU number or `layer` or `row`".to_string()),
+ },
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy, clap::ValueEnum)]
+pub enum LlamacppNuma {
+ Disabled,
+ Distribute,
+ Isolate,
+ Numactl,
+ Mirror,
+}
+
+#[allow(non_camel_case_types)]
+#[derive(Debug, Clone, Copy, clap::ValueEnum)]
+pub enum LlamacppGGMLType {
+ F32,
+ F16,
+ Q4_0,
+ Q4_1,
+ Q5_0,
+ Q5_1,
+ Q8_0,
+ Q8_1,
+ Q2_K,
+ Q3_K,
+ Q4_K,
+ Q5_K,
+ Q6_K,
+ Q8_K,
+ IQ2_XXS,
+ IQ2_XS,
+ IQ3_XXS,
+ IQ1_S,
+ IQ4_NL,
+ IQ3_S,
+ IQ2_S,
+ IQ4_XS,
+ I8,
+ I16,
+ I32,
+ I64,
+ F64,
+ IQ1_M,
+ BF16,
+ TQ1_0,
+ TQ2_0,
+}
+
+// TODO: macro
+impl LlamacppGGMLType {
+ fn to_ggml_type(self) -> llamacpp::ggml_type {
+ match self {
+ LlamacppGGMLType::F32 => llamacpp::GGML_TYPE_F32,
+ LlamacppGGMLType::F16 => llamacpp::GGML_TYPE_F16,
+ LlamacppGGMLType::Q4_0 => llamacpp::GGML_TYPE_Q4_0,
+ LlamacppGGMLType::Q4_1 => llamacpp::GGML_TYPE_Q4_1,
+ LlamacppGGMLType::Q5_0 => llamacpp::GGML_TYPE_Q5_0,
+ LlamacppGGMLType::Q5_1 => llamacpp::GGML_TYPE_Q5_1,
+ LlamacppGGMLType::Q8_0 => llamacpp::GGML_TYPE_Q8_0,
+ LlamacppGGMLType::Q8_1 => llamacpp::GGML_TYPE_Q8_1,
+ LlamacppGGMLType::Q2_K => llamacpp::GGML_TYPE_Q2_K,
+ LlamacppGGMLType::Q3_K => llamacpp::GGML_TYPE_Q3_K,
+ LlamacppGGMLType::Q4_K => llamacpp::GGML_TYPE_Q4_K,
+ LlamacppGGMLType::Q5_K => llamacpp::GGML_TYPE_Q5_K,
+ LlamacppGGMLType::Q6_K => llamacpp::GGML_TYPE_Q6_K,
+ LlamacppGGMLType::Q8_K => llamacpp::GGML_TYPE_Q8_K,
+ LlamacppGGMLType::IQ2_XXS => llamacpp::GGML_TYPE_IQ2_XXS,
+ LlamacppGGMLType::IQ2_XS => llamacpp::GGML_TYPE_IQ2_XS,
+ LlamacppGGMLType::IQ3_XXS => llamacpp::GGML_TYPE_IQ3_XXS,
+ LlamacppGGMLType::IQ1_S => llamacpp::GGML_TYPE_IQ1_S,
+ LlamacppGGMLType::IQ4_NL => llamacpp::GGML_TYPE_IQ4_NL,
+ LlamacppGGMLType::IQ3_S => llamacpp::GGML_TYPE_IQ3_S,
+ LlamacppGGMLType::IQ2_S => llamacpp::GGML_TYPE_IQ2_S,
+ LlamacppGGMLType::IQ4_XS => llamacpp::GGML_TYPE_IQ4_XS,
+ LlamacppGGMLType::I8 => llamacpp::GGML_TYPE_I8,
+ LlamacppGGMLType::I16 => llamacpp::GGML_TYPE_I16,
+ LlamacppGGMLType::I32 => llamacpp::GGML_TYPE_I32,
+ LlamacppGGMLType::I64 => llamacpp::GGML_TYPE_I64,
+ LlamacppGGMLType::F64 => llamacpp::GGML_TYPE_F64,
+ LlamacppGGMLType::IQ1_M => llamacpp::GGML_TYPE_IQ1_M,
+ LlamacppGGMLType::BF16 => llamacpp::GGML_TYPE_BF16,
+ LlamacppGGMLType::TQ1_0 => llamacpp::GGML_TYPE_TQ1_0,
+ LlamacppGGMLType::TQ2_0 => llamacpp::GGML_TYPE_TQ2_0,
+ }
+ }
+}
+
+pub struct LlamacppConfig {
+ pub model_gguf: String,
+ pub max_batch_total_tokens: usize,
+ pub max_physical_batch_total_tokens: usize,
+ pub max_batch_size: usize,
+ pub batch_timeout: Duration,
+ pub n_threads: usize,
+ pub n_threads_batch: usize,
+ pub n_gpu_layers: usize,
+ pub split_mode: LlamacppSplitMode,
+ pub numa: LlamacppNuma,
+ pub defrag_threshold: f32,
+ pub use_mmap: bool,
+ pub use_mlock: bool,
+ pub offload_kqv: bool,
+ pub flash_attention: bool,
+ pub type_k: LlamacppGGMLType,
+ pub type_v: LlamacppGGMLType,
+}
+
+#[derive(Debug)]
+struct LlamacppRequest {
+ input_ids: Vec,
+ top_k: i32,
+ top_p: f32,
+ typical_p: f32,
+ min_keep: usize,
+ temp: f32,
+ seed: u32,
+ penalty_last_n: i32,
+ penalty_repeat: f32,
+ penalty_freq: f32,
+ penalty_present: f32,
+ max_new_tokens: usize,
+ tx: UnboundedSender>,
+ time: Instant,
+}
+
+pub struct LlamacppBackend {
+ tx: UnboundedSender,
+ status: watch::Receiver,
+}
+
+impl LlamacppRequest {
+ fn new(
+ from: &ValidGenerateRequest,
+ tx: UnboundedSender>,
+ ) -> Option {
+ from.input_ids.as_ref().map(|input_ids| LlamacppRequest {
+ input_ids: input_ids.iter().map(|&x| x as i32).collect(),
+ top_k: from.parameters.top_k as _,
+ top_p: from.parameters.top_p as _,
+ typical_p: from.parameters.typical_p as _,
+ min_keep: 0, // disabled
+ temp: from.parameters.temperature as _,
+ seed: from.parameters.seed as _,
+ penalty_last_n: 64, // 0 = disabled, -1 = context size
+ penalty_repeat: from.parameters.repetition_penalty as _,
+ penalty_freq: from.parameters.frequency_penalty as _,
+ penalty_present: 0.0, // disabled
+ max_new_tokens: from.stopping_parameters.max_new_tokens as _,
+ tx,
+ time: Instant::now(),
+ })
+ }
+}
+
+struct Llamacpp {
+ model: *mut llamacpp::llama_model,
+ ctx: *mut llamacpp::llama_context,
+ vocab: *const llamacpp::llama_vocab,
+ logprobs: Vec,
+ batch: llamacpp::llama_batch,
+}
+
+extern "C" fn llamacpp_log_callback(
+ level: llamacpp::ggml_log_level,
+ msg: *const std::os::raw::c_char,
+ _user_data: *mut std::os::raw::c_void,
+) {
+ let cmsg = unsafe { std::ffi::CStr::from_ptr(msg) };
+ let rmsg = cmsg.to_string_lossy().trim_end_matches('\n').to_string();
+
+ match level {
+ llamacpp::GGML_LOG_LEVEL_DEBUG => debug!(target: "llamacpp", "{}", rmsg),
+ llamacpp::GGML_LOG_LEVEL_INFO => info!(target: "llamacpp", "{}", rmsg),
+ llamacpp::GGML_LOG_LEVEL_WARN => warn!(target: "llamacpp", "{}", rmsg),
+ llamacpp::GGML_LOG_LEVEL_ERROR => error!(target: "llamacpp", "{}", rmsg),
+ _ => trace!(target: "llamacpp", "{}", rmsg),
+ }
+}
+
+impl Llamacpp {
+ fn new(conf: LlamacppConfig) -> Result {
+ let gguf = CString::new(conf.model_gguf)?;
+
+ let model = unsafe {
+ let mut params = llamacpp::model_default_params();
+ params.n_gpu_layers = conf.n_gpu_layers as _;
+ params.split_mode = match conf.split_mode {
+ LlamacppSplitMode::GPU(_) => llamacpp::LLAMA_SPLIT_MODE_NONE,
+ LlamacppSplitMode::Layer => llamacpp::LLAMA_SPLIT_MODE_LAYER,
+ LlamacppSplitMode::Row => llamacpp::LLAMA_SPLIT_MODE_ROW,
+ };
+ params.main_gpu = match conf.split_mode {
+ LlamacppSplitMode::GPU(n) => n as _,
+ _ => 0,
+ };
+ params.use_mmap = conf.use_mmap;
+ params.use_mlock = conf.use_mlock;
+ llamacpp::model_load_from_file(gguf.as_ptr(), params)
+ };
+ if model.is_null() {
+ return Err(BackendError::Llamacpp("Failed to load model".to_string()));
+ }
+ let ctx = unsafe {
+ let mut params = llamacpp::context_default_params();
+ params.n_ctx = conf.max_batch_total_tokens as _;
+ params.n_batch = conf.max_batch_total_tokens as _;
+ params.n_ubatch = conf.max_physical_batch_total_tokens as _;
+ params.n_seq_max = conf.max_batch_size as _;
+ params.n_threads = conf.n_threads as _;
+ params.n_threads_batch = conf.n_threads_batch as _;
+ params.defrag_thold = conf.defrag_threshold;
+ params.offload_kqv = conf.offload_kqv;
+ params.flash_attn = conf.flash_attention;
+ params.type_k = conf.type_k.to_ggml_type();
+ params.type_v = conf.type_v.to_ggml_type();
+ params.no_perf = true;
+ llamacpp::init_from_model(model, params)
+ };
+ if ctx.is_null() {
+ return Err(BackendError::Llamacpp("Failed to init context".to_string()));
+ }
+ let vocab = unsafe { llamacpp::model_get_vocab(model) };
+ if vocab.is_null() {
+ return Err(BackendError::Llamacpp("Failed to get vocab".to_string()));
+ }
+ let n_tokens = unsafe { llamacpp::vocab_n_tokens(vocab) };
+ let mut logprobs = Vec::with_capacity(n_tokens as usize);
+
+ for token in 0..n_tokens {
+ logprobs.push(llamacpp::llama_token_data {
+ id: token,
+ logit: 0.0,
+ p: 0.0,
+ });
+ }
+ let batch = unsafe { llamacpp::batch_init(conf.max_batch_total_tokens as _, 0, 1) };
+ Ok(Llamacpp {
+ model,
+ ctx,
+ vocab,
+ logprobs,
+ batch,
+ })
+ }
+
+ fn decode(&mut self) -> i32 {
+ unsafe { llamacpp::decode(self.ctx, self.batch) }
+ }
+
+ fn clear_kv_cache(&mut self, seq_id: llamacpp::llama_seq_id) {
+ unsafe {
+ llamacpp::kv_cache_seq_rm(self.ctx, seq_id, -1, -1);
+ }
+ }
+
+ fn batch_push(
+ &mut self,
+ token: llamacpp::llama_token,
+ pos: llamacpp::llama_pos,
+ seq_id: llamacpp::llama_seq_id,
+ logits: bool,
+ ) -> usize {
+ let n = self.batch.n_tokens as usize;
+ unsafe {
+ *self.batch.token.add(n) = token;
+ *self.batch.pos.add(n) = pos;
+ *self.batch.n_seq_id.add(n) = 1;
+ *(*self.batch.seq_id.add(n)).add(0) = seq_id;
+ *self.batch.logits.add(n) = logits as i8;
+ }
+ self.batch.n_tokens += 1;
+ n
+ }
+}
+
+impl Drop for Llamacpp {
+ fn drop(&mut self) {
+ if !self.ctx.is_null() {
+ unsafe { llamacpp::free(self.ctx) };
+ }
+ if !self.model.is_null() {
+ unsafe { llamacpp::model_free(self.model) };
+ }
+ unsafe { llamacpp::batch_free(self.batch) };
+ }
+}
+
+struct LlamacppSampler {
+ chain: *mut llamacpp::llama_sampler,
+}
+
+impl LlamacppSampler {
+ fn new(req: &LlamacppRequest) -> Option {
+ let chain = unsafe {
+ let params = llamacpp::sampler_chain_default_params();
+ llamacpp::sampler_chain_init(params)
+ };
+ if chain.is_null() {
+ error!("Failed to init sampler");
+ return None;
+ }
+ let (top_k, top_p, typical_p, temp, penalties, dist) = unsafe {
+ (
+ llamacpp::sampler_init_top_k(req.top_k),
+ llamacpp::sampler_init_top_p(req.top_p, req.min_keep),
+ llamacpp::sampler_init_typical(req.typical_p, req.min_keep),
+ llamacpp::sampler_init_temp(req.temp),
+ llamacpp::sampler_init_penalties(
+ req.penalty_last_n,
+ req.penalty_repeat,
+ req.penalty_freq,
+ req.penalty_present,
+ ),
+ llamacpp::sampler_init_dist(req.seed),
+ )
+ };
+ let all = &[
+ ("top_k", top_k),
+ ("top_p", top_p),
+ ("typical_p", typical_p),
+ ("temp", temp),
+ ("penalties", penalties),
+ ("dist", dist),
+ ];
+ let mut failed = false;
+
+ for (k, v) in all {
+ if v.is_null() {
+ error!("Failed to init {k} sampler");
+ failed = true;
+ } else {
+ unsafe { llamacpp::sampler_chain_add(chain, *v) };
+ }
+ }
+ if failed {
+ unsafe { llamacpp::sampler_free(chain) };
+ None
+ } else {
+ Some(LlamacppSampler { chain })
+ }
+ }
+
+ fn sample(&self, llamacpp: &mut Llamacpp, idx: usize) -> (llamacpp::llama_token, f32) {
+ let logits = unsafe { llamacpp::get_logits_ith(llamacpp.ctx, idx as _) };
+ for (token, logprob) in llamacpp.logprobs.iter_mut().enumerate() {
+ *logprob = llamacpp::llama_token_data {
+ id: token as _,
+ logit: unsafe { *logits.add(token) },
+ p: 0.0,
+ };
+ }
+ let mut view = llamacpp::llama_token_data_array {
+ data: llamacpp.logprobs.as_mut_ptr(),
+ size: llamacpp.logprobs.len(),
+ selected: -1,
+ sorted: false,
+ };
+ unsafe {
+ llamacpp::sampler_apply(self.chain, &mut view);
+ let logprob = *view.data.offset(view.selected as _);
+ llamacpp::sampler_accept(self.chain, logprob.id);
+ (logprob.id, logprob.p.ln())
+ }
+ }
+}
+
+impl Drop for LlamacppSampler {
+ fn drop(&mut self) {
+ if !self.chain.is_null() {
+ unsafe { llamacpp::sampler_free(self.chain) };
+ }
+ }
+}
+
+struct LlamacppSeq {
+ id: usize,
+ batch_pos: usize,
+ token: llamacpp::llama_token,
+ pos: llamacpp::llama_pos,
+ sampler: LlamacppSampler,
+ text: String,
+ n_new_tokens: usize,
+ running: bool,
+}
+
+static INIT: Once = Once::new();
+
+impl LlamacppBackend {
+ pub fn new(
+ conf: LlamacppConfig,
+ tokenizer: Tokenizer,
+ ) -> (
+ Self,
+ oneshot::Receiver>,
+ watch::Sender,
+ ) {
+ // Setup llama & export logs, once and for all
+ INIT.call_once(|| unsafe {
+ llamacpp::log_set(Some(llamacpp_log_callback), std::ptr::null_mut());
+ llamacpp::backend_init();
+ llamacpp::numa_init(match conf.numa {
+ LlamacppNuma::Disabled => llamacpp::GGML_NUMA_STRATEGY_DISABLED,
+ LlamacppNuma::Distribute => llamacpp::GGML_NUMA_STRATEGY_DISTRIBUTE,
+ LlamacppNuma::Isolate => llamacpp::GGML_NUMA_STRATEGY_ISOLATE,
+ LlamacppNuma::Numactl => llamacpp::GGML_NUMA_STRATEGY_NUMACTL,
+ LlamacppNuma::Mirror => llamacpp::GGML_NUMA_STRATEGY_MIRROR,
+ });
+ });
+
+ let (status_tx, status_rx) = watch::channel(false);
+ let (shutdown_tx, shutdown_rx) = watch::channel(false);
+ let (ok_tx, ok_rx) = oneshot::channel();
+ let (tx, mut rx) = unbounded_channel::();
+ let (sync_tx, sync_rx) = mpsc::channel();
+
+ spawn(async move {
+ let mut n_tokens = 0;
+ let mut requests = Vec::with_capacity(conf.max_batch_size);
+
+ let flush = |requests: &mut Vec<_>, n_tokens: &mut usize| {
+ if !requests.is_empty() {
+ let _ =
+ sync_tx.send(replace(requests, Vec::with_capacity(conf.max_batch_size)));
+ *n_tokens = 0;
+ }
+ };
+ loop {
+ match timeout(conf.batch_timeout, rx.recv()).await {
+ Ok(Some(request)) => {
+ let n_tokens_to_add = request.input_ids.len();
+
+ if n_tokens + n_tokens_to_add > conf.max_batch_total_tokens {
+ flush(&mut requests, &mut n_tokens);
+ }
+ n_tokens += n_tokens_to_add;
+ requests.push(request);
+
+ if requests.len() == conf.max_batch_size {
+ flush(&mut requests, &mut n_tokens);
+ }
+ }
+ Ok(None) => break, // closed
+ Err(_) => flush(&mut requests, &mut n_tokens), // timeout
+ }
+ }
+ });
+
+ spawn_blocking(move || {
+ let mut llamacpp = match Llamacpp::new(conf) {
+ Ok(v) => {
+ let _ = ok_tx.send(Ok(()));
+ v
+ }
+ Err(e) => {
+ let _ = ok_tx.send(Err(e));
+ return;
+ }
+ };
+ let vocab = tokenizer.get_added_vocabulary();
+
+ // health() returns true
+ let _ = status_tx.send(true);
+
+ while let Ok(requests) = sync_rx.recv() {
+ if *shutdown_rx.borrow() {
+ break;
+ }
+ let start_time = Instant::now();
+ let mut seqs: Vec = Vec::with_capacity(requests.len());
+ llamacpp.batch.n_tokens = 0;
+
+ for (seq_id, request) in requests.iter().enumerate() {
+ debug!("Request: {:?}", request);
+ // TODO remove this
+ let sampler = match LlamacppSampler::new(request) {
+ Some(sampler) => sampler,
+ _ => {
+ let _ = request.tx.send(Err(InferError::IncompleteGeneration));
+ continue;
+ }
+ };
+ let last_pos = request.input_ids.len() - 1;
+
+ for (pos, &token_id) in request.input_ids.iter().enumerate() {
+ llamacpp.batch_push(
+ token_id as llamacpp::llama_token,
+ pos as llamacpp::llama_pos,
+ seq_id as llamacpp::llama_seq_id,
+ pos == last_pos, // check samplers
+ );
+ }
+ seqs.push(LlamacppSeq {
+ id: seq_id,
+ batch_pos: llamacpp.batch.n_tokens as usize - 1,
+ token: llamacpp::LLAMA_TOKEN_NULL,
+ pos: last_pos as llamacpp::llama_pos + 1,
+ sampler,
+ text: String::with_capacity(1024),
+ n_new_tokens: 0,
+ running: true,
+ });
+ }
+ while llamacpp.batch.n_tokens > 0 {
+ if llamacpp.decode() != 0 {
+ warn!("llama_decode failed, clearing kv cache");
+ llamacpp.clear_kv_cache(-1);
+ for seq in seqs.iter_mut() {
+ let _ = requests[seq.id]
+ .tx
+ .send(Err(InferError::IncompleteGeneration));
+ seq.running = false;
+ }
+ break;
+ }
+ for seq in seqs.iter_mut() {
+ if !seq.running {
+ continue;
+ }
+ let (next, logprob) = seq.sampler.sample(&mut llamacpp, seq.batch_pos);
+ seq.n_new_tokens += 1;
+ seq.token = next;
+
+ let piece = match tokenizer.decode(&[next as u32], false) {
+ Ok(piece) => piece,
+ Err(e) => {
+ error!("Failed to decode token: {e}");
+ let _ = requests[seq.id]
+ .tx
+ .send(Err(InferError::IncompleteGeneration));
+ seq.running = false;
+ continue;
+ }
+ };
+ let special = vocab.is_special_token(&piece);
+
+ if !special {
+ seq.text.push_str(&piece);
+ }
+ let token = Token {
+ id: next as _,
+ text: piece,
+ logprob,
+ special,
+ };
+ let finish: Option = {
+ if unsafe { llamacpp::vocab_is_eog(llamacpp.vocab, next) } {
+ Some(FinishReason::EndOfSequenceToken)
+ } else if seq.n_new_tokens == requests[seq.id].max_new_tokens {
+ Some(FinishReason::Length)
+ } else {
+ None
+ }
+ };
+ if let Some(reason) = finish {
+ let _ = requests[seq.id].tx.send(Ok(InferStreamResponse::End {
+ token,
+ top_tokens: vec![],
+ generated_text: GeneratedText {
+ text: seq.text.clone(),
+ generated_tokens: seq.n_new_tokens as _,
+ finish_reason: reason,
+ seed: Some(requests[seq.id].seed as _),
+ },
+ start: start_time,
+ queued: requests[seq.id].time,
+ }));
+ seq.running = false;
+ continue;
+ }
+ let _ = requests[seq.id]
+ .tx
+ .send(Ok(InferStreamResponse::Intermediate {
+ token,
+ top_tokens: vec![],
+ }));
+ }
+ // generate a new batch
+ llamacpp.batch.n_tokens = 0;
+
+ for seq in seqs.iter_mut() {
+ if seq.running {
+ seq.batch_pos =
+ llamacpp.batch_push(seq.token, seq.pos, seq.id as _, true);
+ seq.pos += 1;
+ } else {
+ llamacpp.clear_kv_cache(seq.id as _);
+ }
+ }
+ }
+ }
+ });
+ (
+ Self {
+ tx,
+ status: status_rx,
+ },
+ ok_rx,
+ shutdown_tx,
+ )
+ }
+}
+
+#[async_trait]
+impl Backend for LlamacppBackend {
+ #[instrument(skip_all)]
+ fn schedule(
+ &self,
+ request: ValidGenerateRequest,
+ ) -> Result>, InferError> {
+ debug!(?request);
+ let (tx, rx) = unbounded_channel::>();
+ match LlamacppRequest::new(&request, tx) {
+ Some(v) => match self.tx.send(v) {
+ Err(e) => Err(InferError::GenerationError(e.to_string())),
+ _ => Ok(UnboundedReceiverStream::new(rx)),
+ },
+ _ => Err(InferError::GenerationError("Bad request".to_string())),
+ }
+ }
+
+ async fn health(&self, _: bool) -> bool {
+ *self.status.borrow()
+ }
+
+ fn name(&self) -> &'static str {
+ "llamacpp"
+ }
+}
+
+#[derive(Debug, Error)]
+pub enum BackendError {
+ #[error("CString error: {0}")]
+ CStringError(#[from] std::ffi::NulError),
+ #[error("Llamacpp error: {0}")]
+ Llamacpp(String),
+}
diff --git a/backends/llamacpp/src/llamacpp.rs b/backends/llamacpp/src/llamacpp.rs
new file mode 100644
index 000000000..fb206df27
--- /dev/null
+++ b/backends/llamacpp/src/llamacpp.rs
@@ -0,0 +1,5 @@
+#![allow(non_upper_case_globals)]
+#![allow(non_camel_case_types)]
+#![allow(non_snake_case)]
+#![allow(dead_code)]
+include!(concat!(env!("OUT_DIR"), "/llamacpp.rs"));
diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs
new file mode 100644
index 000000000..b99e9591e
--- /dev/null
+++ b/backends/llamacpp/src/main.rs
@@ -0,0 +1,347 @@
+mod backend;
+mod llamacpp;
+mod quantize;
+
+use quantize::QuantizeType;
+
+use backend::{
+ BackendError, LlamacppBackend, LlamacppConfig, LlamacppGGMLType, LlamacppNuma,
+ LlamacppSplitMode,
+};
+use clap::Parser;
+use hf_hub::api::tokio::ApiBuilder;
+use hf_hub::{Repo, RepoType};
+use std::path::Path;
+use text_generation_router::{logging, server, usage_stats};
+use thiserror::Error;
+use tokenizers::Tokenizer;
+use tokio::process::Command;
+use tokio::sync::oneshot::error::RecvError;
+use tracing::{error, warn};
+
+/// Backend Configuration
+#[derive(Parser, Debug)]
+#[clap(author, version, about, long_about = None)]
+struct Args {
+ /// Name of the model to load.
+ #[clap(long, env)]
+ model_id: String,
+
+ /// Revision of the model.
+ #[clap(default_value = "main", long, env)]
+ revision: String,
+
+ /// Path to the GGUF model file for inference.
+ #[clap(long, env)]
+ model_gguf: Option,
+
+ /// Number of threads to use for generation.
+ #[clap(long, env)]
+ n_threads: Option,
+
+ /// Number of threads to use for batch processing.
+ #[clap(long, env)]
+ n_threads_batch: Option,
+
+ /// Number of layers to store in VRAM.
+ #[clap(default_value = "0", long, env)]
+ n_gpu_layers: usize,
+
+ /// Split the model across multiple GPUs.
+ #[clap(default_value = "layer", long, env)]
+ split_mode: LlamacppSplitMode,
+
+ /// Defragment the KV cache if holes/size > threshold.
+ #[clap(default_value = "-1.0", long, env)]
+ defrag_threshold: f32,
+
+ /// Enable NUMA optimizations.
+ #[clap(default_value = "disabled", value_enum, long, env)]
+ numa: LlamacppNuma,
+
+ /// Use memory mapping for the model.
+ #[clap(long, env)]
+ disable_mmap: bool,
+
+ /// Use memory locking to prevent swapping.
+ #[clap(long, env)]
+ use_mlock: bool,
+
+ /// Enable offloading of KQV operations to the GPU.
+ #[clap(long, env)]
+ disable_offload_kqv: bool,
+
+ /// Enable flash attention for faster inference. (EXPERIMENTAL)
+ #[clap(long, env)]
+ disable_flash_attention: bool,
+
+ /// Data type used for K cache.
+ #[clap(default_value = "f16", value_enum, long, env)]
+ type_k: LlamacppGGMLType,
+
+ /// Data type used for V cache.
+ #[clap(default_value = "f16", value_enum, long, env)]
+ type_v: LlamacppGGMLType,
+
+ /// Number of tokenizer workers used for payload validation and truncation.
+ #[clap(default_value = "2", long, env)]
+ validation_workers: usize,
+
+ /// Maximum number of concurrent requests.
+ #[clap(long, env)]
+ max_concurrent_requests: Option,
+
+ /// Maximum number of input tokens per request.
+ #[clap(default_value = "1024", long, env)]
+ max_input_tokens: usize,
+
+ /// Maximum number of total tokens (input + output) per request.
+ #[clap(default_value = "2048", long, env)]
+ max_total_tokens: usize,
+
+ /// Maximum number of tokens in a batch.
+ #[clap(long, env)]
+ max_batch_total_tokens: Option,
+
+ /// Maximum number of tokens in a physical batch.
+ #[clap(long, env)]
+ max_physical_batch_total_tokens: Option,
+
+ /// Maximum number of requests per batch.
+ #[clap(long, env)]
+ max_batch_size: Option,
+
+ /// IP address to listen on.
+ #[clap(default_value = "0.0.0.0", long)]
+ hostname: String,
+
+ /// Port to listen on.
+ #[clap(default_value = "3000", long, short, env)]
+ port: u16,
+
+ /// Enable JSON output format.
+ #[clap(long, env)]
+ json_output: bool,
+
+ /// OTLP endpoint for telemetry data.
+ #[clap(long, env)]
+ otlp_endpoint: Option,
+
+ /// Service name for OTLP telemetry.
+ #[clap(default_value = "text-generation-inference.router", long, env)]
+ otlp_service_name: String,
+
+ /// Allowed origins for CORS.
+ #[clap(long, env)]
+ cors_allow_origin: Option>,
+
+ /// Path to the tokenizer configuration file.
+ #[clap(long, env)]
+ tokenizer_config_path: Option,
+
+ /// Disable grammar support.
+ #[clap(long, env)]
+ disable_grammar_support: bool,
+
+ /// Maximum number of inputs per request.
+ #[clap(default_value = "4", long, env)]
+ max_client_batch_size: usize,
+
+ /// Level of usage statistics collection.
+ #[clap(default_value = "on", long, env)]
+ usage_stats: usage_stats::UsageStatsLevel,
+
+ /// Maximum payload size in bytes.
+ #[clap(default_value = "2000000", long, env)]
+ payload_limit: usize,
+}
+
+#[tokio::main]
+async fn main() -> Result<(), RouterError> {
+ let args = Args::parse();
+
+ logging::init_logging(args.otlp_endpoint, args.otlp_service_name, args.json_output);
+
+ let n_threads = match args.n_threads {
+ Some(0) | None => num_cpus::get(),
+ Some(threads) => threads,
+ };
+ let n_threads_batch = match args.n_threads_batch {
+ Some(0) | None => n_threads,
+ Some(threads) => threads,
+ };
+ let max_batch_size = match args.max_batch_size {
+ Some(0) | None => n_threads_batch,
+ Some(threads) => threads,
+ };
+ let max_batch_total_tokens = match args.max_batch_total_tokens {
+ None => max_batch_size * args.max_total_tokens,
+ Some(size) => size,
+ };
+ let max_physical_batch_total_tokens = match args.max_physical_batch_total_tokens {
+ None => max_batch_total_tokens,
+ Some(size) => size,
+ };
+ let max_concurrent_requests = match args.max_concurrent_requests {
+ None => max_batch_size * 2,
+ Some(size) => size,
+ };
+ if args.max_input_tokens >= args.max_total_tokens {
+ return Err(RouterError::ArgumentValidation(
+ "`max_input_tokens` must be < `max_total_tokens`".to_string(),
+ ));
+ }
+ if args.max_total_tokens > max_batch_total_tokens {
+ return Err(RouterError::ArgumentValidation(
+ "`max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
+ ));
+ }
+ if max_batch_size * args.max_total_tokens > max_batch_total_tokens {
+ return Err(RouterError::ArgumentValidation(
+ "`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
+ ));
+ }
+
+ let api_builder = || {
+ let mut builder = ApiBuilder::new().with_progress(true);
+
+ if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
+ builder = builder.with_cache_dir(cache_dir.into());
+ }
+ if let Ok(token) = std::env::var("HF_TOKEN") {
+ builder = builder.with_token(token.into());
+ }
+ if let Ok(origin) = std::env::var("HF_HUB_USER_AGENT_ORIGIN") {
+ builder = builder.with_user_agent("origin", origin.as_str());
+ }
+ builder
+ };
+ let api_repo = api_builder().build()?.repo(Repo::with_revision(
+ args.model_id.clone(),
+ RepoType::Model,
+ args.revision.clone(),
+ ));
+
+ let tokenizer_path = api_repo.get("tokenizer.json").await?;
+ let tokenizer = Tokenizer::from_file(&tokenizer_path)?;
+
+ let model_gguf = if let Some(model_gguf) = args.model_gguf {
+ model_gguf
+ } else {
+ let model_gguf = format!("models/{}/model.gguf", args.model_id);
+ let model_gguf_path = Path::new(&model_gguf);
+
+ if !model_gguf_path.exists() {
+ let tmp_gguf = "models/tmp.gguf";
+
+ if let Some(parent) = Path::new(model_gguf_path).parent() {
+ std::fs::create_dir_all(parent)?;
+ }
+ let cache_path = tokenizer_path.parent().unwrap();
+
+ for sibling in api_repo.info().await?.siblings {
+ let _ = api_repo.get(&sibling.rfilename).await?;
+ }
+ let status = Command::new("convert_hf_to_gguf.py")
+ .arg("--outfile")
+ .arg(tmp_gguf)
+ .arg(cache_path)
+ .spawn()?
+ .wait()
+ .await?;
+
+ if !status.success() {
+ let exit_code = status.code().unwrap_or(-1);
+ error!("Failed to generate GGUF, exit code: {}", exit_code);
+ return Err(RouterError::CommandError(exit_code));
+ }
+ quantize::model(tmp_gguf, &model_gguf, QuantizeType::MostlyQ4_0, n_threads)
+ .map_err(RouterError::QuantizeError)?;
+ }
+ model_gguf
+ };
+
+ let (backend, ok, shutdown) = LlamacppBackend::new(
+ LlamacppConfig {
+ model_gguf,
+ n_threads,
+ n_threads_batch,
+ n_gpu_layers: args.n_gpu_layers,
+ split_mode: args.split_mode,
+ defrag_threshold: args.defrag_threshold,
+ numa: args.numa,
+ use_mmap: !args.disable_mmap,
+ use_mlock: args.use_mlock,
+ flash_attention: !args.disable_flash_attention,
+ type_k: args.type_k,
+ type_v: args.type_v,
+ offload_kqv: !args.disable_offload_kqv,
+ max_batch_total_tokens,
+ max_physical_batch_total_tokens,
+ max_batch_size,
+ batch_timeout: tokio::time::Duration::from_millis(5),
+ },
+ tokenizer,
+ );
+ ok.await??;
+
+ if cfg!(debug_assertions) {
+ warn!("Graceful shutdown disabled!");
+ let _ = tokio::task::spawn(async move {
+ let _ = tokio::signal::ctrl_c().await;
+ let _ = shutdown.send(true);
+ });
+ }
+
+ server::run(
+ backend,
+ max_concurrent_requests,
+ 0, // max_best_of
+ 0, // max_stop_sequences
+ 0, // max_top_n_tokens
+ args.max_input_tokens,
+ args.max_total_tokens,
+ args.validation_workers,
+ None, // api_key
+ args.model_id, // tokenizer_name
+ args.tokenizer_config_path,
+ Some(args.revision),
+ false, // trust_remote_code
+ args.hostname,
+ args.port,
+ args.cors_allow_origin,
+ false, // ngrok,
+ None, // ngrok_authtoken,
+ None, // ngrok_edge,
+ args.disable_grammar_support,
+ args.max_client_batch_size,
+ args.usage_stats,
+ args.payload_limit,
+ )
+ .await?;
+ Ok(())
+}
+
+#[derive(Debug, Error)]
+enum RouterError {
+ #[error("Argument validation error: {0}")]
+ ArgumentValidation(String),
+ #[error("Tokenizer error: {0}")]
+ Tokenizer(#[from] tokenizers::Error),
+ #[error("Backend error: {0}")]
+ Backend(#[from] BackendError),
+ #[error("WebServer error: {0}")]
+ WebServer(#[from] server::WebServerError),
+ #[error("Recv error: {0}")]
+ RecvError(#[from] RecvError),
+ #[error("Io error: {0}")]
+ IoError(#[from] std::io::Error),
+ #[error("Var error: {0}")]
+ VarError(#[from] std::env::VarError),
+ #[error("Quantize error: {0}")]
+ QuantizeError(String),
+ #[error("Command error: {0}")]
+ CommandError(i32),
+ #[error("HF hub error: {0}")]
+ HubError(#[from] hf_hub::api::tokio::ApiError),
+}
diff --git a/backends/llamacpp/src/quantize.rs b/backends/llamacpp/src/quantize.rs
new file mode 100644
index 000000000..31307becf
--- /dev/null
+++ b/backends/llamacpp/src/quantize.rs
@@ -0,0 +1,35 @@
+use crate::llamacpp;
+
+use std::ffi::CString;
+
+#[repr(u32)]
+#[derive(Debug, Clone, Copy)]
+pub enum QuantizeType {
+ MostlyQ4_0 = 2,
+}
+
+pub fn model(
+ input_path: &str,
+ output_path: &str,
+ ftype: QuantizeType,
+ n_threads: usize,
+) -> Result<(), String> {
+ let c_input_path =
+ CString::new(input_path).map_err(|e| format!("Failed to convert input path: {}", e))?;
+
+ let c_output_path =
+ CString::new(output_path).map_err(|e| format!("Failed to convert output path: {}", e))?;
+
+ let result = unsafe {
+ let mut params = llamacpp::model_quantize_default_params();
+ params.nthread = n_threads as _;
+ params.ftype = ftype as _;
+ params.quantize_output_tensor = true;
+ llamacpp::model_quantize(c_input_path.as_ptr(), c_output_path.as_ptr(), ¶ms)
+ };
+ if result == 0 {
+ Ok(())
+ } else {
+ Err(format!("Quantization failed, error code: {}", result))
+ }
+}
diff --git a/backends/neuron/Cargo.toml b/backends/neuron/Cargo.toml
new file mode 100644
index 000000000..72f92e69c
--- /dev/null
+++ b/backends/neuron/Cargo.toml
@@ -0,0 +1,47 @@
+[workspace]
+members = [
+ "backends/v2",
+ "backends/grpc-metadata",
+ "launcher",
+ "router"
+]
+default-members = [
+ "backends/v2",
+ "backends/grpc-metadata",
+ "launcher",
+ "router"
+]
+resolver = "2"
+
+[workspace.package]
+version = "3.0.0"
+edition = "2021"
+authors = ["Olivier Dehaene"]
+homepage = "https://github.com/huggingface/text-generation-inference"
+
+[workspace.dependencies]
+base64 = "0.22.0"
+tokenizers = { version = "0.20.0", features = ["http"] }
+hf-hub = { version = "0.4.2", features = ["tokio"] }
+metrics = { version = "0.23.0" }
+metrics-exporter-prometheus = { version = "0.15.1", features = [] }
+minijinja = { version = "2.2.0", features = ["json"] }
+minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
+pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
+
+[profile.release]
+incremental = true
+
+[profile.release-binary]
+inherits = "release"
+debug = 1
+incremental = true
+panic = "abort"
+
+[profile.release-opt]
+inherits = "release"
+debug = 0
+incremental = false
+lto = "fat"
+opt-level = 3
+codegen-units = 1
diff --git a/backends/neuron/Makefile b/backends/neuron/Makefile
new file mode 100644
index 000000000..066749713
--- /dev/null
+++ b/backends/neuron/Makefile
@@ -0,0 +1,35 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
+mkfile_dir := $(dir $(mkfile_path))
+root_dir := "${mkfile_dir}/../.."
+
+.PHONY: image install_server test_server test_integration
+
+VERSION := $(shell gawk 'match($$0, /^version = "(.*)"/, a) {print a[1]}' ${root_dir}/Cargo.toml)
+
+image:
+ docker build --rm -f ${root_dir}/Dockerfile.neuron \
+ --ulimit nofile=100000:100000 \
+ --build-arg VERSION=$(VERSION) \
+ -t text-generation-inference:$(VERSION)-neuron ${root_dir}
+ docker tag text-generation-inference:$(VERSION)-neuron text-generation-inference:latest-neuron
+
+install_server:
+ make -C ${mkfile_dir}/server install VERSION:=${VERSION}
+
+test_server: install_server
+ python -m pip install -r ${mkfile_dir}/tests/requirements.txt
+ python -m pytest -sv ${mkfile_dir}/tests/server
diff --git a/backends/neuron/README.md b/backends/neuron/README.md
new file mode 100644
index 000000000..55722c3bd
--- /dev/null
+++ b/backends/neuron/README.md
@@ -0,0 +1,25 @@
+# Text-generation-inference - Neuron backend for AWS Trainium and inferentia2
+
+## Description
+
+This is the TGI backend for AWS Neuron Trainium and Inferentia family of chips.
+
+This backend is composed of:
+- the AWS Neuron SDK,
+- the legacy v2 TGI launcher and router,
+- a neuron specific inference server for text-generation.
+
+## Usage
+
+Please refer to the official [documentation](https://huggingface.co/docs/text-generation-inference/backends/neuron).
+
+## Build your own image
+
+The simplest way to build TGI with the neuron backend is to use the provided `Makefile`:
+
+```shell
+$ make -C backends/neuron image
+```
+
+Alternatively, you can build the image directly from the top directory using a command similar to the one defined
+in the `Makefile` under the `image` target.
diff --git a/backends/neuron/server/.gitignore b/backends/neuron/server/.gitignore
new file mode 100644
index 000000000..378eac25d
--- /dev/null
+++ b/backends/neuron/server/.gitignore
@@ -0,0 +1 @@
+build
diff --git a/backends/neuron/server/Makefile b/backends/neuron/server/Makefile
new file mode 100644
index 000000000..efe34bd0d
--- /dev/null
+++ b/backends/neuron/server/Makefile
@@ -0,0 +1,74 @@
+# Initialize base variables
+SHELL := /bin/bash
+pkg_name := text_generation_server
+BUILDDIR ?= $(CURDIR)/build
+VERSION ?= 0.0.1
+mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
+mkfile_dir := $(dir $(mkfile_path))
+pkg_dir := $(BUILDDIR)/$(pkg_name)
+py_version := $(subst -,.,${VERSION})
+pkg_dist := ${BUILDDIR}/dist/${pkg_name}-$(py_version).tar.gz
+
+clean:
+ rm -rf $(BUILDDIR)/*
+
+${BUILDDIR}:
+ install -d $@
+
+# List static sources to be deployed in the package
+src_dir := $(mkfile_dir)/$(pkg_name)
+sources := $(wildcard $(src_dir)/*.py)
+deployed_sources := $(subst $(src_dir), $(pkg_dir), $(sources))
+
+# Static files are just copied
+
+define COPY
+ cp -f $< $@
+endef
+
+# We use a PHONY target to represent the VERSION
+.PHONY: VERSION
+
+VERSION: ${BUILDDIR}
+ # The trick is to compare the value of the variable with the content of a file in the build directory
+ @if [[ `cat ${BUILDDIR}/VERSION 2>&1` != '$(VERSION)' ]]; then echo -n $(VERSION) >${BUILDDIR}/VERSION; fi
+
+# Depending on the PHONY VERSION target makes sure the pyproject.toml is regenerated if the version changes
+$(BUILDDIR)/pyproject.toml: $(mkfile_dir)/pyproject.toml VERSION
+ mkdir -p $(BUILDDIR)
+ $(COPY)
+ sed -i -e 's/version = "VERSION"/version = \"${VERSION}\"/' $@
+
+$(pkg_dir)/%.py: $(src_dir)/%.py
+ mkdir -p $(pkg_dir)
+ $(COPY)
+
+# Generated files are produced by grpcio tools
+
+# If not provided, get local proto files
+ifndef PROTODIR
+PROTODIR := $(mkfile_dir)/../../../proto
+endif
+
+# Three python files are generated for each protobuf
+protobufs := $(PROTODIR)/generate.proto
+pkg_pb_dir := $(pkg_dir)/pb
+generated_sources_base := $(foreach proto, $(protobufs), $(proto:.proto=_pb2.py))
+generated_sources := $(subst $(PROTODIR), $(pkg_pb_dir), $(generated_sources_base))
+generated_sources += $(subst $(PROTODIR), $(pkg_pb_dir), $(generated_sources_base:.py=.pyi))
+generated_sources += $(subst $(PROTODIR), $(pkg_pb_dir), $(generated_sources_base:.py=_grpc.py))
+
+$(pkg_pb_dir)/%_pb2.py $(pkg_pb_dir)/%_pb2.pyi $(pkg_pb_dir)/%_pb2_grpc.py: $(PROTODIR)/%.proto
+ mkdir -p $(pkg_pb_dir)
+ python -m grpc_tools.protoc -I$(PROTODIR) --python_out=$(pkg_pb_dir) \
+ --grpc_python_out=$(pkg_pb_dir) --mypy_out=$(pkg_pb_dir) $^
+ sed -i -e 's/^\(import.*pb2\)/from . \1/g' $(pkg_pb_dir)/$*_pb2_grpc.py
+
+${pkg_dist}: $(BUILDDIR)/pyproject.toml $(deployed_sources) $(generated_sources)
+ python -m build $(BUILDDIR)
+
+package: ${pkg_dist}
+
+install: ${pkg_dist}
+ python3 -m pip uninstall -y ${pkg_name}
+ python3 -m pip install ${pkg_dist}
diff --git a/backends/neuron/server/build-requirements.txt b/backends/neuron/server/build-requirements.txt
new file mode 100644
index 000000000..2083bd73f
--- /dev/null
+++ b/backends/neuron/server/build-requirements.txt
@@ -0,0 +1,3 @@
+build
+grpcio-tools==1.53.0
+mypy-protobuf
diff --git a/backends/neuron/server/pyproject.toml b/backends/neuron/server/pyproject.toml
new file mode 100644
index 000000000..6bf4e5eee
--- /dev/null
+++ b/backends/neuron/server/pyproject.toml
@@ -0,0 +1,26 @@
+[build-system]
+requires = ["setuptools>=78.1"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "text-generation-server"
+version = "VERSION"
+authors = [{name="David Corvoysier", email="david@huggingface.co" }]
+description = "TGI compatible inference server for AWS Neuronx platforms"
+dependencies = [
+ 'protobuf > 3.20.1, < 4',
+ 'grpcio == 1.57.0',
+ 'grpcio-status == 1.48.2',
+ 'grpcio-reflection == 1.48.2',
+ 'grpc-interceptor == 0.15.2',
+ 'typer == 0.6.1',
+ 'safetensors',
+ 'loguru == 0.6.0',
+ 'optimum-neuron[neuronx] >= 0.0.28',
+]
+
+[tool.setuptools]
+packages = ["text_generation_server", "text_generation_server.pb"]
+
+[project.scripts]
+text-generation-server = 'text_generation_server.cli:app'
diff --git a/backends/neuron/server/text_generation_server/cli.py b/backends/neuron/server/text_generation_server/cli.py
new file mode 100644
index 000000000..4a9c47345
--- /dev/null
+++ b/backends/neuron/server/text_generation_server/cli.py
@@ -0,0 +1,115 @@
+import sys
+from typing import Optional
+
+import typer
+from loguru import logger
+
+
+app = typer.Typer()
+
+
+@app.command()
+def serve(
+ model_id: str,
+ revision: Optional[str] = None,
+ sharded: bool = False,
+ trust_remote_code: bool = None,
+ uds_path: str = "/tmp/text-generation-server",
+ logger_level: str = "INFO",
+ json_output: bool = False,
+ otlp_endpoint: Optional[str] = None,
+ otlp_service_name: str = "text-generation-inference.server",
+ max_input_tokens: Optional[int] = None,
+):
+ """This is the main entry-point for the server CLI.
+
+ Args:
+ model_id (`str`):
+ The *model_id* of a model on the HuggingFace hub or the path to a local model.
+ revision (`Optional[str]`, defaults to `None`):
+ The revision of the model on the HuggingFace hub.
+ sharded (`bool`):
+ Whether the model must be sharded or not. Kept for compatibility with the
+ text-generation-launcher, but must be set to False.
+ trust-remote-code (`bool`):
+ Kept for compatibility with text-generation-launcher. Ignored.
+ uds_path (`Union[Path, str]`):
+ The local path on which the server will expose its google RPC services.
+ logger_level (`str`):
+ The server logger level. Defaults to *INFO*.
+ json_output (`bool`):
+ Use JSON format for log serialization.
+ otlp_endpoint (`Optional[str]`, defaults to `None`):
+ The Open Telemetry endpoint to use.
+ otlp_service_name (`Optional[str]`, defaults to `None`):
+ The name to use when pushing data to the Open Telemetry endpoint.
+ max_input_tokens (`Optional[int]`, defaults to `None`):
+ The maximum number of input tokens each request should contain.
+ """
+ if sharded:
+ raise ValueError("Sharding is not supported.")
+ # Remove default handler
+ logger.remove()
+ logger.add(
+ sys.stdout,
+ format="{message}",
+ filter="text_generation_server",
+ level=logger_level,
+ serialize=json_output,
+ backtrace=True,
+ diagnose=False,
+ )
+
+ if trust_remote_code is not None:
+ logger.warning(
+ "'trust_remote_code' argument is not supported and will be ignored."
+ )
+
+ # Import here after the logger is added to log potential import exceptions
+ from .server import serve
+
+ serve(model_id, revision, uds_path)
+
+
+@app.command()
+def download_weights(
+ model_id: str,
+ revision: Optional[str] = None,
+ logger_level: str = "INFO",
+ json_output: bool = False,
+ auto_convert: Optional[bool] = None,
+ extension: Optional[str] = None,
+ trust_remote_code: Optional[bool] = None,
+ merge_lora: Optional[bool] = None,
+):
+ """Download the model weights.
+
+ This command will be called by text-generation-launcher before serving the model.
+ """
+ # Remove default handler
+ logger.remove()
+ logger.add(
+ sys.stdout,
+ format="{message}",
+ filter="text_generation_server",
+ level=logger_level,
+ serialize=json_output,
+ backtrace=True,
+ diagnose=False,
+ )
+
+ if extension is not None:
+ logger.warning("'extension' argument is not supported and will be ignored.")
+ if trust_remote_code is not None:
+ logger.warning(
+ "'trust_remote_code' argument is not supported and will be ignored."
+ )
+ if auto_convert is not None:
+ logger.warning("'auto_convert' argument is not supported and will be ignored.")
+ if merge_lora is not None:
+ logger.warning("'merge_lora' argument is not supported and will be ignored.")
+
+ # Import here after the logger is added to log potential import exceptions
+ from .model import fetch_model
+
+ fetch_model(model_id, revision)
diff --git a/backends/neuron/server/text_generation_server/generator.py b/backends/neuron/server/text_generation_server/generator.py
new file mode 100644
index 000000000..b3887e14c
--- /dev/null
+++ b/backends/neuron/server/text_generation_server/generator.py
@@ -0,0 +1,697 @@
+import copy
+import logging
+import time
+from abc import ABC
+from enum import Enum
+from typing import List, Optional, Tuple
+
+import torch
+from loguru import logger
+from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
+from transformers.generation import GenerationConfig
+
+from optimum.neuron import NeuronModelForCausalLM
+from optimum.neuron.generation import TokenSelector
+
+from .model import get_export_kwargs_from_env
+from .pb.generate_pb2 import (
+ Batch,
+ CachedBatch,
+ FinishReason,
+ GeneratedText,
+ Generation,
+ InfoResponse,
+ Request,
+ Tokens,
+)
+
+
+# Disable optimum-neuron warnings as it seems to block the server after a while
+optimum_logger = logging.getLogger("optimum.neuron")
+optimum_logger.setLevel("CRITICAL")
+
+
+class Generator(ABC):
+ """An abstract class to represent the workhorse behind TextGenerationService.
+
+ Ideally, it should not rely on protobuf constructs, but in a first step it does.
+ Implementations would typically need a model and a tokenizer to implement the Generator methods.
+ """
+
+ @property
+ def info(self) -> InfoResponse:
+ """This should simply return the expected InfoResponse"""
+ raise NotImplementedError
+
+ def warmup(self, batch: Batch) -> int:
+ """Verify if the hardware can support the target load.
+
+ Args:
+ batch (`Batch`):
+ A batch corresponding to the maximum number of concurrent requests.
+
+ Return:
+ The maximum number of tokens the model supports.
+ """
+ raise NotImplementedError
+
+ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
+ """Prefill is called whenever new requests need to be added.
+
+ When this method returns successfully, a decode method will follow
+ with both the current and newly prefilled batch(es).
+
+ Args:
+ batch (`Batch`):
+ A batch containing the new requests.
+
+ Return:
+ A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
+ """
+ raise NotImplementedError
+
+ def decode(self, batches: List[Batch]) -> Tuple[List[Generation], CachedBatch]:
+ """Decode after a prefill or another decode."""
+ raise NotImplementedError
+
+ def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch:
+ """Remove requests that are not listed from the specified batch"""
+ raise NotImplementedError
+
+ def clear(self):
+ """Remove all requests from the generator"""
+ raise NotImplementedError
+
+ @classmethod
+ def from_pretrained(cls, model_id: str, revision: Optional[str]):
+ """Factory method "a la transformers" """
+ raise NotImplementedError
+
+
+class Slot:
+ """Represents a slot in a static batch"""
+
+ class State(Enum):
+ EMPTY = 0
+ PAUSE = 1
+ READY = 2
+
+ def __init__(self, id: int, tokenizer: PreTrainedTokenizerBase):
+ self._id = id
+ self._tokenizer = tokenizer
+ self.clear()
+
+ def clear(self):
+ """Clear the slot and mark it as available."""
+ self._state = Slot.State.EMPTY
+ self._batch_id = None
+ self._request_id = None
+ self._inputs = ""
+ self._truncate = 0
+ self._generation_config = None
+ self._tokens = []
+ self._mask = torch.tensor([])
+ self._selector = None
+ self._generated_tokens = 0
+ self._next_text_token_start = 0
+ self._next_text_token_end = 0
+ self._generated_text = ""
+ self._next_text = ""
+
+ @property
+ def id(self) -> int:
+ return self._id
+
+ @property
+ def state(self) -> "Slot.State":
+ return self._state
+
+ @property
+ def batch_id(self) -> int:
+ return self._batch_id
+
+ @property
+ def request_id(self) -> int:
+ return self._request_id
+
+ @property
+ def cached_text(self) -> str:
+ return self._inputs + self._generated_text
+
+ @property
+ def generation_config(self) -> GenerationConfig:
+ return self._generation_config
+
+ @property
+ def generated_tokens(self) -> int:
+ return self._generated_tokens
+
+ def assign(
+ self, batch_id: int, request: Request, generation_config: GenerationConfig
+ ):
+ """Assign a request to a slot.
+
+ Args:
+ request (`Request`):
+ The request to be assigned. Contains the inputs and tokens selection parameters.
+ generation_config (`transformers.GenerationConfig`):
+ The base generation config (might be modified by the request generation parameters).
+ """
+ self._state = Slot.State.READY
+ self._batch_id = batch_id
+ self._request_id = request.id
+ self._inputs = request.inputs
+ if request.truncate:
+ self._truncate = request.truncate
+ self._generation_config = copy.deepcopy(generation_config)
+ # Update generation config with request parameters
+ self._generation_config.do_sample = request.parameters.do_sample
+ if self._generation_config.do_sample:
+ if request.parameters.temperature != 0:
+ self._generation_config.temperature = request.parameters.temperature
+ if request.parameters.top_k != 0:
+ self._generation_config.top_k = request.parameters.top_k
+ if request.parameters.top_p != 0:
+ self._generation_config.top_p = request.parameters.top_p
+ if request.parameters.typical_p != 0:
+ self._generation_config.typical_p = request.parameters.typical_p
+ if request.parameters.repetition_penalty != 0:
+ self._generation_config.repetition_penalty = (
+ request.parameters.repetition_penalty
+ )
+ self.seed = request.parameters.seed
+ self._generation_config.max_new_tokens = (
+ request.stopping_parameters.max_new_tokens
+ )
+ self._max_new_tokens = self._generation_config.max_new_tokens
+ stop_strings = request.stopping_parameters.stop_sequences
+ if stop_strings:
+ self._generation_config.stop_strings = stop_strings
+
+ def reset(
+ self,
+ input_ids: torch.LongTensor,
+ attention_mask: torch.LongTensor,
+ selector: TokenSelector,
+ ):
+ """Reset the slot for the next generation.
+
+ Args:
+ input_ids: (`torch.LongTensor`):
+ The new input_ids to use to generate the next token.
+ attention_mask: (`torch.LongTensor`):
+ The new attention_mask to use to generate the next token.
+ selector: (`optimum.neuron.generation.TokenSelector`):
+ An object implementing the updated token selection logic.
+ """
+ self._tokens = input_ids.clone()
+ self._next_text_token_start = 0
+ self._next_text_token_end = torch.numel(self._tokens)
+ self._next_text = ""
+ self._mask = attention_mask.clone()
+ self._selector = selector
+
+ def pause(self, reset_on_pause: bool):
+ """Mark the current slot as paused for generation.
+
+ Note that the KV cache for this slot will still be filled.
+ """
+ if reset_on_pause:
+ # Drop the last token as it will be added back when resuming the slot
+ self._generated_tokens -= 1
+ # Since generated tokens are now part of the prefill, we need to reevaluate
+ # max_new_tokens for the next generation
+ self._generation_config.max_new_tokens = (
+ self._max_new_tokens - self._generated_tokens
+ )
+ self._state = Slot.State.PAUSE
+
+ def resume(self):
+ """Mark the slot as ready for generation."""
+ self._state = Slot.State.READY
+
+ def _decode_next_tokens(
+ self,
+ ) -> str:
+ """Hack to hopefully support generate_stream for the maximum number of tokenizers"""
+ # We need to include the tokens that produced the last text to defeat cleanup algorithms in the decode
+ # which decide to add a space or not depending on the surrounding ids.
+ new_text = self._tokenizer.decode(
+ self._tokens[self._next_text_token_start :], skip_special_tokens=False
+ )
+ if new_text.endswith("�"):
+ # utf-8 char at the end means it's a potential unfinished byte sequence
+ # from byte fallback tokenization.
+ return ""
+
+ # Compare the generated text with the one using only the tokens producing the last one
+ last_text = self._tokenizer.decode(
+ self._tokens[self._next_text_token_start : self._next_text_token_end],
+ skip_special_tokens=False,
+ )
+ if len(new_text) == len(last_text):
+ # Nothing new was actually generated
+ return ""
+ # Return the decoded text and store its token offsets
+ self._next_text_token_start = self._next_text_token_end
+ self._next_text_token_end = torch.numel(self._tokens)
+ return new_text[len(last_text) :]
+
+ def append(self, next_token: int) -> str:
+ """Append a new generated token to this slot
+
+ The new token is added to the list of generated tokens, which impacts
+ directly the generated_text and stopped property.
+
+ The new token is however not added immediately to the slot inputs: it will
+ be added later on when it has effectively been used to produce the next token.
+
+ Args:
+ next_token (`int`):
+ The newly generated token.
+
+ Return:
+ The corresponding decoded text (if any).
+ """
+ self._tokens = torch.cat([self._tokens, torch.LongTensor([next_token])])
+ self._mask = torch.cat([self._mask, torch.LongTensor([1])])
+ self._generated_tokens += 1
+ next_text = self._decode_next_tokens()
+ # Now that a new token has been generated, we can append the previous one to the generated text
+ self._generated_text += self._next_text
+ self._next_text = next_text
+ return next_text
+
+ def select(
+ self, input_ids: torch.LongTensor, logits: torch.Tensor
+ ) -> torch.LongTensor:
+ """Select the next token from the candidate logits.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ The sequence used as a prompt for the generation (not used in all generation modes).
+ logits (`torch.Tensor` of shape `(batch_size, sequence_length)`):
+ The logits corresponding to the generated tokens.
+
+ Return:
+ `torch.LongTensor`: A scalar torch.LongTensor` containing the selected token.
+ """
+ return self._selector.select(input_ids, logits)[0]
+
+ @property
+ def stopped(self) -> bool:
+ # Transformers stopping criteria expects a batch of input ids
+ input_ids = torch.unsqueeze(self._tokens, dim=0)
+ return self._selector.stopping_criteria(input_ids, None)
+
+ @property
+ def generated_text(self) -> str:
+ return self._generated_text + self._next_text
+
+ @property
+ def next_token(self) -> int:
+ return None if len(self._tokens) == 0 else self._tokens[-1]
+
+ @property
+ def attention_mask(self) -> torch.LongTensor:
+ return self._mask
+
+ @property
+ def max_token(self) -> int:
+ return self._generation_config.max_length
+
+ @property
+ def max_new_tokens(self) -> int:
+ # The current value of max_new_tokens: might be different of the target max_new_tokens
+ # if the slot has been paused and resumed.
+ return self._generation_config.max_new_tokens
+
+ @property
+ def truncate(self) -> int:
+ return self._truncate
+
+
+class NeuronGenerator(Generator):
+ """A Generator for Neuron models."""
+
+ def __init__(
+ self,
+ model: NeuronModelForCausalLM,
+ tokenizer: PreTrainedTokenizerBase,
+ ):
+ self.model = model
+ self.rebuild_cache_on_prefill = not self.model.continuous_batching
+ # Specify padding and truncation options for decoder-only architecture
+ tokenizer.pad_token_id = tokenizer.eos_token_id
+ tokenizer.padding_side = "left"
+ tokenizer.truncation_side = "left"
+ self.tokenizer = tokenizer
+ self.special_tokens = self.tokenizer.all_special_ids
+ self.slots = [Slot(i, tokenizer) for i in range(self.model.batch_size)]
+ self.batch_id = 0
+
+ @property
+ def info(self) -> InfoResponse:
+ """Returns the expected InfoResponse."""
+ dtype = getattr(self.model.config, "torch_dtype", "float32")
+ return InfoResponse(
+ requires_padding=True,
+ dtype=str(dtype),
+ device_type="xla",
+ )
+
+ def warmup(self, batch: Batch) -> int:
+ """Verify if the hardware can support the target load.
+
+ Args:
+ batch (`Batch`):
+ A batch corresponding to the maximum number of concurrent requests.
+
+ Return:
+ The maximum number of tokens the model supports.
+ """
+ # Just check that the warmup request parameters match the model capacity
+ batch_size = self.model.batch_size
+ if len(batch.requests) > batch_size:
+ raise ValueError(
+ f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE."
+ )
+ self.prefill(batch)
+ self.clear()
+ return self.model.batch_size * self.model.max_length
+
+ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
+ """Prefill new requests.
+
+ Args:
+ batch (`Batch`):
+ A batch containing the new requests.
+
+ Return:
+ A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
+ """
+ slots = {state: [] for state in Slot.State}
+ for slot in self.slots:
+ slots[slot.state].append(slot)
+ active_slots = slots[Slot.State.READY]
+ empty_slots = slots[Slot.State.EMPTY]
+ if len(empty_slots) < len(batch.requests):
+ raise ValueError(
+ f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots."
+ f" Please align max_batch_size with the static batch size: {self.model.batch_size}."
+ )
+ # Assign each request to an empty slot
+ logger.debug(
+ f"Prefilling {len(batch.requests)} new request(s) with {len(empty_slots)} empty slot(s)"
+ )
+ new_slots = []
+ for request in batch.requests:
+ slot = empty_slots.pop()
+ slot.assign(self.batch_id, request, self.model.generation_config)
+ new_slots.append(slot)
+ logger.debug(
+ f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
+ )
+ if self.rebuild_cache_on_prefill:
+ # We will clear pending slots and prefill all slots
+ prefill_slots = self.slots
+ seq_ids = None
+ else:
+ # We only need to pass inputs for the new requests
+ prefill_slots = new_slots
+ seq_ids = torch.tensor([slot.id for slot in prefill_slots])
+ # Reconstruct the full inputs (without padding) as seen by the model.
+ # This comprises:
+ # - the inputs for new requests,
+ # - only when rebuilding the cache, the inputs and the generated text that has already
+ # been cached (i.e. excluding the last generated token) for unfinished requests.
+ inputs = []
+ max_length = 0
+ for slot in prefill_slots:
+ inputs.append(slot.cached_text)
+ # Apply truncation, making sure we fit into static dimensions
+ if slot.truncate == 0:
+ max_length = self.model.max_length
+ elif slot.truncate > max_length and slot.truncate < self.model.max_length:
+ max_length = slot.truncate
+ # Tokenize with padding and truncation
+ padded_inputs = self.tokenizer(
+ inputs,
+ return_tensors="pt",
+ padding=True,
+ truncation=True,
+ max_length=max_length,
+ )
+ input_ids = padded_inputs.input_ids
+ attention_mask = padded_inputs.attention_mask
+ # Pause previously active slots during generation
+ next_tokens = []
+ for slot in active_slots:
+ slot.pause(reset_on_pause=self.rebuild_cache_on_prefill)
+ if self.rebuild_cache_on_prefill:
+ # The slot will be reset, so we need to store its next token
+ next_tokens.append(slot.next_token)
+ # Each slot must be reset with the padded inputs and masks
+ for i, slot in enumerate(prefill_slots):
+ if slot.state != slot.state.EMPTY:
+ if slot.truncate > 0 and slot.truncate < input_ids.shape[-1]:
+ # Apply per-request truncation
+ input_ids[i, : -slot.truncate] = self.tokenizer.pad_token_id
+ attention_mask[i, : -slot.truncate] = 0
+ slot_input_ids = input_ids[i : i + 1, :]
+ # Padded input ids are also required to set logits processors and stopping criterias
+ selector = TokenSelector.create(
+ slot_input_ids,
+ slot.generation_config,
+ self.model,
+ self.model.max_length,
+ tokenizer=self.tokenizer,
+ seed=slot.seed,
+ )
+ slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
+ slot_attention_mask = attention_mask[i]
+ slot.reset(slot_input_ids, slot_attention_mask, selector)
+ # Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored,
+ # as they have already been generated and sent back in the last decode.
+ model_inputs = self.model.prepare_inputs_for_prefill(
+ input_ids, attention_mask, seq_ids
+ )
+ logits = self.model(**model_inputs)[0]
+ generation, next_batch = self._generate_token(
+ prefill_slots, self.batch_id, logits, input_ids
+ )
+ self.batch_id += 1
+ # Reactivate previously active slots for the next decode
+ for i, slot in enumerate(active_slots):
+ slot.resume()
+ if self.rebuild_cache_on_prefill:
+ # Append back the next token
+ slot.append(next_tokens[i])
+ logger.debug("Model ready for decoding")
+ if next_batch is not None:
+ logger.debug(
+ f"Next batch is {next_batch.id} with requests: {next_batch.request_ids}"
+ )
+ return generation, next_batch
+
+ def decode(
+ self, batches: List[CachedBatch]
+ ) -> Tuple[List[Generation], CachedBatch]:
+ """Decode the specified prefilled requests.
+
+ Args:
+ batches (`List[CachedBatch]`):
+ A list of previous batches containing the prefilled requests.
+
+ Return:
+ A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
+ """
+ # batches contains a list composed of:
+ # - the batch id returned by the last decode,
+ # - the batch id(s) returned by the last prefill(s)
+ # Batches are always concatenated during prefill, so we can
+ # just carry on with decoding. We adopt the id of the first
+ # batch in the list as our next batch id.
+ next_batch_id = batches[0].id
+ request_ids = []
+ for batch in batches:
+ request_ids += batch.request_ids
+ cleared_request_ids = []
+ for slot in self.slots:
+ if slot.state == slot.State.READY and slot.request_id not in request_ids:
+ cleared_request_ids.append(slot.request_id)
+ slot.clear()
+ if len(cleared_request_ids) > 0:
+ logger.info(
+ f"Clearing slot for requests {cleared_request_ids} as they are not requested."
+ )
+ active_slots = [slot for slot in self.slots if slot.state == slot.State.READY]
+ if len(active_slots) < len(request_ids):
+ raise ValueError(
+ "Unable to decode tokens for non-prefilled batches (probably due to a previous failure)"
+ )
+ if self.model.continuous_batching:
+ decode_slots = active_slots
+ seq_ids = torch.tensor([slot.id for slot in decode_slots])
+ else:
+ decode_slots = self.slots
+ seq_ids = None
+ # Reconstruct input_ids and attention_mask from decode slots
+ n_slots = len(decode_slots)
+ input_ids = torch.full(
+ [n_slots, 1], fill_value=self.tokenizer.eos_token_id, dtype=torch.int64
+ )
+ max_length = 0
+ for slot in decode_slots:
+ max_length = max(max_length, slot.attention_mask.size(-1))
+ attention_mask = torch.zeros([n_slots, max_length], dtype=torch.int64)
+ for i, slot in enumerate(decode_slots):
+ if slot.state != Slot.State.EMPTY:
+ # input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
+ input_ids[i, 0] = slot.next_token
+ attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask
+ model_inputs = self.model.prepare_inputs_for_decode(
+ input_ids, attention_mask, seq_ids
+ )
+ logits = self.model(**model_inputs)[0]
+ return self._generate_token(decode_slots, next_batch_id, logits, input_ids)
+
+ def _generate_token(
+ self,
+ slots: List[Slot],
+ next_batch_id: int,
+ logits: torch.Tensor,
+ input_ids: torch.LongTensor,
+ ) -> Tuple[List[Generation], CachedBatch]:
+ generations = []
+ active_slots = False
+ for i, slot in enumerate(slots):
+ if slot.state != Slot.State.READY:
+ continue
+ request_id = slot.request_id
+ next_token_logits = logits[i : i + 1, -1, :]
+ slot_input_ids = input_ids[i : i + 1, :]
+ next_token = slot.select(slot_input_ids, next_token_logits)
+ next_token_text = slot.append(next_token)
+ generated_text = None
+ finish_reason = None
+ if next_token == self.tokenizer.eos_token_id:
+ finish_reason = FinishReason.FINISH_REASON_EOS_TOKEN
+ elif slot.stopped:
+ if slot.generated_tokens == slot.max_new_tokens:
+ finish_reason = FinishReason.FINISH_REASON_LENGTH
+ else:
+ finish_reason = FinishReason.FINISH_REASON_STOP_SEQUENCE
+ if finish_reason is not None:
+ # We must include the generated text for each finished sequence in the response
+ generated_text = GeneratedText(
+ text=slot.generated_text,
+ generated_tokens=slot.generated_tokens,
+ finish_reason=finish_reason,
+ )
+ logger.debug(
+ f"Decode complete for request {request_id} with {slot.generated_tokens} tokens"
+ )
+ # mark the slot as available
+ slot.clear()
+ else:
+ active_slots = True
+ generations.append(
+ Generation(
+ request_id=request_id,
+ prefill_tokens=None,
+ tokens=Tokens(
+ ids=[next_token],
+ logprobs=[0],
+ texts=[next_token_text],
+ is_special=[next_token in self.special_tokens],
+ ),
+ generated_text=generated_text,
+ )
+ )
+ batch = None
+ if active_slots:
+ # Whatever initial batch these requests came from, we always return all pending requests in a single batch
+ request_ids = [
+ slot.request_id for slot in self.slots if slot.state == Slot.State.READY
+ ]
+ batch = self._cached_batch(next_batch_id, request_ids)
+ else:
+ logger.debug("No more pending requests")
+ return generations, batch
+
+ def _cached_batch(self, batch_id: int, request_ids: List):
+ size = len(request_ids)
+ max_tokens = size * self.model.max_length
+ return CachedBatch(
+ id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens
+ )
+
+ def filter(self, batch_id: int, keep_request_ids: List[int]) -> CachedBatch:
+ """Remove requests that are not listed from the specified batch
+
+ Args:
+ batch_id (`int`):
+ The id of a cached batch.
+ keep_ids(`List[int]`):
+ The list of requests that must be kept.
+
+ Return:
+ A `CachedBatch` containing the pending requests.
+ """
+ keep_slot_ids = [
+ slot.id for slot in self.slots if slot.request_id in keep_request_ids
+ ]
+ self._clear(keep_slot_ids)
+ return self._cached_batch(batch_id, keep_request_ids)
+
+ def clear(self, batch_id: Optional[int] = None):
+ """Remove a subset or all requests from the generator"""
+ keep_ids = []
+ if batch_id is not None:
+ keep_ids = [slot.id for slot in self.slots if slot.batch_id != batch_id]
+ return self._clear(keep_ids)
+
+ def _clear(self, keep_slot_ids: List):
+ for slot in self.slots:
+ if slot.state != Slot.State.EMPTY and slot.id not in keep_slot_ids:
+ logger.debug(f"Removing slot {slot.id} with request {slot.request_id}")
+ slot.clear()
+
+ @classmethod
+ def from_pretrained(cls, model_id: str, revision: str = None):
+ """Instantiate a NeuronGenerator.
+
+ Args:
+ model_id (`str`):
+ A hub model id or the path to a local model. This path must also contain a Tokenizer.
+ revision (`Optional[str]`, defaults to `None`):
+ The revision of the model on the HuggingFace hub.
+
+ Returns:
+ A NeuronGenerator.
+ """
+ config = AutoConfig.from_pretrained(model_id)
+ neuron_config = getattr(config, "neuron", None)
+ start = time.time()
+ if neuron_config is None:
+ export_kwargs = get_export_kwargs_from_env()
+ logger.info(f"Exporting model to neuron with config: {export_kwargs}.")
+ model = NeuronModelForCausalLM.from_pretrained(
+ model_id,
+ revision=revision,
+ low_cpu_mem_usage=True,
+ export=True,
+ **export_kwargs,
+ )
+ else:
+ logger.info(
+ "Loading model on neuron devices (this can take a few minutes)."
+ )
+ model = NeuronModelForCausalLM.from_pretrained(
+ model_id, low_cpu_mem_usage=True, revision=revision
+ )
+ end = time.time()
+ logger.info(f"Model successfully loaded in {end - start:.2f} s.")
+ tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
+ return cls(model, tokenizer)
diff --git a/backends/neuron/server/text_generation_server/interceptor.py b/backends/neuron/server/text_generation_server/interceptor.py
new file mode 100644
index 000000000..301cafd87
--- /dev/null
+++ b/backends/neuron/server/text_generation_server/interceptor.py
@@ -0,0 +1,29 @@
+from typing import Any, Callable
+
+import grpc
+from google.rpc import code_pb2, status_pb2
+from grpc_interceptor.server import AsyncServerInterceptor
+from grpc_status import rpc_status
+from loguru import logger
+
+
+class ExceptionInterceptor(AsyncServerInterceptor):
+ async def intercept(
+ self,
+ method: Callable,
+ request_or_iterator: Any,
+ context: grpc.ServicerContext,
+ method_name: str,
+ ) -> Any:
+ try:
+ response = method(request_or_iterator, context)
+ return await response
+ except Exception as err:
+ method_name = method_name.split("/")[-1]
+ logger.exception(f"Method {method_name} encountered an error.")
+
+ await context.abort_with_status(
+ rpc_status.to_status(
+ status_pb2.Status(code=code_pb2.INTERNAL, message=str(err))
+ )
+ )
diff --git a/backends/neuron/server/text_generation_server/model.py b/backends/neuron/server/text_generation_server/model.py
new file mode 100644
index 000000000..2151a9218
--- /dev/null
+++ b/backends/neuron/server/text_generation_server/model.py
@@ -0,0 +1,128 @@
+import os
+import shutil
+import time
+from typing import Optional
+
+from huggingface_hub import snapshot_download
+from huggingface_hub.constants import HF_HUB_CACHE
+from loguru import logger
+from transformers import AutoConfig
+
+from optimum.neuron import NeuronModelForCausalLM
+from optimum.neuron.utils import get_hub_cached_entries
+
+
+def get_export_kwargs_from_env():
+ batch_size = os.environ.get("MAX_BATCH_SIZE", None)
+ if batch_size is not None:
+ batch_size = int(batch_size)
+ sequence_length = os.environ.get("MAX_TOTAL_TOKENS", None)
+ if sequence_length is not None:
+ sequence_length = int(sequence_length)
+ num_cores = os.environ.get("HF_NUM_CORES", None)
+ if num_cores is not None:
+ num_cores = int(num_cores)
+ auto_cast_type = os.environ.get("HF_AUTO_CAST_TYPE", None)
+ return {
+ "task": "text-generation",
+ "batch_size": batch_size,
+ "sequence_length": sequence_length,
+ "num_cores": num_cores,
+ "auto_cast_type": auto_cast_type,
+ }
+
+
+def is_cached(model_id, neuron_config):
+ # Look for cached entries for the specified model
+ in_cache = False
+ entries = get_hub_cached_entries(model_id, "inference")
+ # Look for compatible entries
+ for entry in entries:
+ compatible = True
+ for key, value in neuron_config.items():
+ # Only weights can be different
+ if key in ["checkpoint_id", "checkpoint_revision"]:
+ continue
+ if entry[key] != value:
+ compatible = False
+ if compatible:
+ in_cache = True
+ break
+ return in_cache
+
+
+def log_cache_size():
+ path = HF_HUB_CACHE
+ if os.path.exists(path):
+ usage = shutil.disk_usage(path)
+ gb = 2**30
+ logger.info(
+ f"Cache disk [{path}]: total = {usage.total / gb:.2f} G, free = {usage.free / gb:.2f} G"
+ )
+ else:
+ raise ValueError(f"The cache directory ({path}) does not exist.")
+
+
+def fetch_model(
+ model_id: str,
+ revision: Optional[str] = None,
+) -> str:
+ """Fetch a neuron model.
+
+ Args:
+ model_id (`str`):
+ The *model_id* of a model on the HuggingFace hub or the path to a local model.
+ revision (`Optional[str]`, defaults to `None`):
+ The revision of the model on the HuggingFace hub.
+
+ Returns:
+ A string corresponding to the model_id or path.
+ """
+ if not os.path.isdir("/sys/class/neuron_device/"):
+ raise SystemError("No neuron cores detected on the host.")
+ if os.path.isdir(model_id) and revision is not None:
+ logger.warning(
+ "Revision {} ignored for local model at {}".format(revision, model_id)
+ )
+ revision = None
+ # Download the model from the Hub (HUGGING_FACE_HUB_TOKEN must be set for a private or gated model)
+ # Note that the model may already be present in the cache.
+ config = AutoConfig.from_pretrained(model_id, revision=revision)
+ neuron_config = getattr(config, "neuron", None)
+ if neuron_config is not None:
+ if os.path.isdir(model_id):
+ return model_id
+ # Prefetch the neuron model from the Hub
+ logger.info(
+ f"Fetching revision [{revision}] for neuron model {model_id} under {HF_HUB_CACHE}"
+ )
+ log_cache_size()
+ return snapshot_download(model_id, revision=revision, ignore_patterns="*.bin")
+ # Model needs to be exported: look for compatible cached entries on the hub
+ export_kwargs = get_export_kwargs_from_env()
+ export_config = NeuronModelForCausalLM.get_export_config(
+ model_id, config, revision=revision, **export_kwargs
+ )
+ neuron_config = export_config.neuron
+ if not is_cached(model_id, neuron_config):
+ hub_cache_url = "https://huggingface.co/aws-neuron/optimum-neuron-cache"
+ neuron_export_url = "https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-neuronx-tgi"
+ error_msg = (
+ f"No cached version found for {model_id} with {neuron_config}."
+ f"You can start a discussion to request it on {hub_cache_url}"
+ f"Alternatively, you can export your own neuron model as explained in {neuron_export_url}"
+ )
+ raise ValueError(error_msg)
+ logger.warning(
+ f"{model_id} is not a neuron model: it will be exported using cached artifacts."
+ )
+ if os.path.isdir(model_id):
+ return model_id
+ # Prefetch weights, tokenizer and generation config so that they are in cache
+ log_cache_size()
+ start = time.time()
+ snapshot_download(model_id, revision=revision, ignore_patterns="*.bin")
+ end = time.time()
+ logger.info(f"Model weights fetched in {end - start:.2f} s.")
+ log_cache_size()
+ return model_id
diff --git a/backends/neuron/server/text_generation_server/server.py b/backends/neuron/server/text_generation_server/server.py
new file mode 100644
index 000000000..8eb2592d6
--- /dev/null
+++ b/backends/neuron/server/text_generation_server/server.py
@@ -0,0 +1,89 @@
+import asyncio
+from pathlib import Path
+from typing import List
+
+from grpc import aio
+from grpc_reflection.v1alpha import reflection
+from loguru import logger
+
+from .generator import Generator, NeuronGenerator
+from .interceptor import ExceptionInterceptor
+from .pb import generate_pb2, generate_pb2_grpc
+
+
+class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
+ def __init__(self, generator: Generator, server_urls: List[str]):
+ self.generator = generator
+ self.server_urls = server_urls
+
+ async def Info(self, request, context):
+ return self.generator.info
+
+ async def Health(self, request, context):
+ return generate_pb2.HealthResponse()
+
+ async def ServiceDiscovery(self, request, context):
+ return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
+
+ async def ClearCache(self, request, context):
+ if request.HasField("id"):
+ self.generator.clear(request.id)
+ else:
+ self.generator.clear()
+ return generate_pb2.ClearCacheResponse()
+
+ async def FilterBatch(self, request, context):
+ filtered_batch = self.generator.filter(request.batch_id, request.request_ids)
+ return generate_pb2.FilterBatchResponse(batch=filtered_batch)
+
+ async def Warmup(self, request, context):
+ max_tokens = self.generator.warmup(request.batch)
+ return generate_pb2.WarmupResponse(max_supported_total_tokens=max_tokens)
+
+ async def Prefill(self, request, context):
+ generations, batch = self.generator.prefill(request.batch)
+ return generate_pb2.PrefillResponse(generations=generations, batch=batch)
+
+ async def Decode(self, request, context):
+ generations, batch = self.generator.decode(request.batches)
+ return generate_pb2.DecodeResponse(generations=generations, batch=batch)
+
+
+def serve(
+ model_id: str,
+ revision: str,
+ uds_path: Path,
+):
+ async def serve_inner(model_id: str, revision: str):
+ unix_socket_template = "unix://{}-{}"
+ local_url = unix_socket_template.format(uds_path, 0)
+ server_urls = [local_url]
+
+ try:
+ generator = NeuronGenerator.from_pretrained(model_id, revision)
+ except Exception:
+ logger.exception("Error when initializing model")
+ raise
+
+ server = aio.server(interceptors=[ExceptionInterceptor()])
+ generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
+ TextGenerationService(generator, server_urls), server
+ )
+ SERVICE_NAMES = (
+ generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
+ reflection.SERVICE_NAME,
+ )
+ reflection.enable_server_reflection(SERVICE_NAMES, server)
+ server.add_insecure_port(local_url)
+
+ await server.start()
+
+ logger.info("Server started at {}".format(local_url))
+
+ try:
+ await server.wait_for_termination()
+ except KeyboardInterrupt:
+ logger.info("Signal received. Shutting down")
+ await server.stop(0)
+
+ asyncio.run(serve_inner(model_id, revision))
diff --git a/backends/neuron/tests/conftest.py b/backends/neuron/tests/conftest.py
new file mode 100644
index 000000000..1dd20c8c6
--- /dev/null
+++ b/backends/neuron/tests/conftest.py
@@ -0,0 +1 @@
+pytest_plugins = ["fixtures.model"]
diff --git a/backends/neuron/tests/fixtures/model.py b/backends/neuron/tests/fixtures/model.py
new file mode 100644
index 000000000..4b6a1375d
--- /dev/null
+++ b/backends/neuron/tests/fixtures/model.py
@@ -0,0 +1,164 @@
+import copy
+import logging
+import subprocess
+import sys
+from tempfile import TemporaryDirectory
+
+import huggingface_hub
+import pytest
+from transformers import AutoTokenizer
+
+from optimum.neuron import NeuronModelForCausalLM
+from optimum.neuron.utils import synchronize_hub_cache
+from optimum.neuron.version import __sdk_version__ as sdk_version
+from optimum.neuron.version import __version__ as version
+
+
+logging.basicConfig(
+ level=logging.INFO,
+ format="[%(asctime)s] %(levelname)s [%(filename)s.%(funcName)s:%(lineno)d] %(message)s",
+ stream=sys.stdout,
+)
+logger = logging.getLogger(__file__)
+
+OPTIMUM_CACHE_REPO_ID = "optimum-internal-testing/neuron-testing-cache"
+
+# All model configurations below will be added to the neuron_model_config fixture
+MODEL_CONFIGURATIONS = {
+ "gpt2": {
+ "model_id": "gpt2",
+ "export_kwargs": {
+ "batch_size": 4,
+ "sequence_length": 1024,
+ "num_cores": 2,
+ "auto_cast_type": "fp16",
+ },
+ },
+ "llama": {
+ "model_id": "NousResearch/Hermes-2-Theta-Llama-3-8B",
+ "export_kwargs": {
+ "batch_size": 4,
+ "sequence_length": 2048,
+ "num_cores": 2,
+ "auto_cast_type": "fp16",
+ },
+ },
+ "mistral": {
+ "model_id": "optimum/mistral-1.1b-testing",
+ "export_kwargs": {
+ "batch_size": 4,
+ "sequence_length": 4096,
+ "num_cores": 2,
+ "auto_cast_type": "bf16",
+ },
+ },
+ "qwen2": {
+ "model_id": "Qwen/Qwen2.5-0.5B",
+ "export_kwargs": {
+ "batch_size": 4,
+ "sequence_length": 4096,
+ "num_cores": 2,
+ "auto_cast_type": "fp16",
+ },
+ },
+ "granite": {
+ "model_id": "ibm-granite/granite-3.1-2b-instruct",
+ "export_kwargs": {
+ "batch_size": 4,
+ "sequence_length": 4096,
+ "num_cores": 2,
+ "auto_cast_type": "bf16",
+ },
+ },
+}
+
+
+def get_hub_neuron_model_id(config_name: str):
+ return (
+ f"optimum-internal-testing/neuron-testing-{version}-{sdk_version}-{config_name}"
+ )
+
+
+def export_model(model_id, export_kwargs, neuron_model_path):
+ export_command = [
+ "optimum-cli",
+ "export",
+ "neuron",
+ "-m",
+ model_id,
+ "--task",
+ "text-generation",
+ ]
+ for kwarg, value in export_kwargs.items():
+ export_command.append(f"--{kwarg}")
+ export_command.append(str(value))
+ export_command.append(neuron_model_path)
+ logger.info(f"Exporting {model_id} with {export_kwargs}")
+ try:
+ subprocess.run(export_command, check=True)
+ except subprocess.CalledProcessError as e:
+ raise ValueError(f"Failed to export model: {e}")
+
+
+@pytest.fixture(scope="session", params=MODEL_CONFIGURATIONS.keys())
+def neuron_model_config(request):
+ """Expose a pre-trained neuron model
+
+ The fixture first makes sure the following model artifacts are present on the hub:
+ - exported neuron model under optimum-internal-testing/neuron-testing--,
+ - cached artifacts under optimum-internal-testing/neuron-testing-cache.
+ If not, it will export the model and push it to the hub.
+
+ It then fetches the model locally and return a dictionary containing:
+ - a configuration name,
+ - the original model id,
+ - the export parameters,
+ - the neuron model id,
+ - the neuron model local path.
+
+ For each exposed model, the local directory is maintained for the duration of the
+ test session and cleaned up afterwards.
+ The hub model artifacts are never cleaned up and persist accross sessions.
+ They must be cleaned up manually when the optimum-neuron version changes.
+
+ """
+ config_name = request.param
+ model_config = copy.deepcopy(MODEL_CONFIGURATIONS[request.param])
+ model_id = model_config["model_id"]
+ export_kwargs = model_config["export_kwargs"]
+ neuron_model_id = get_hub_neuron_model_id(config_name)
+ with TemporaryDirectory() as neuron_model_path:
+ hub = huggingface_hub.HfApi()
+ if hub.repo_exists(neuron_model_id):
+ logger.info(f"Fetching {neuron_model_id} from the HuggingFace hub")
+ hub.snapshot_download(neuron_model_id, local_dir=neuron_model_path)
+ else:
+ export_model(model_id, export_kwargs, neuron_model_path)
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+ tokenizer.save_pretrained(neuron_model_path)
+ del tokenizer
+ # Create the test model on the hub
+ hub.create_repo(neuron_model_id, private=True)
+ hub.upload_folder(
+ folder_path=neuron_model_path,
+ repo_id=neuron_model_id,
+ ignore_patterns=[NeuronModelForCausalLM.CHECKPOINT_DIR + "/*"],
+ )
+ # Make sure it is cached
+ synchronize_hub_cache(cache_repo_id=OPTIMUM_CACHE_REPO_ID)
+ # Add dynamic parameters to the model configuration
+ model_config["neuron_model_path"] = neuron_model_path
+ model_config["neuron_model_id"] = neuron_model_id
+ # Also add model configuration name to allow tests to adapt their expectations
+ model_config["name"] = config_name
+ # Yield instead of returning to keep a reference to the temporary directory.
+ # It will go out of scope and be released only once all tests needing the fixture
+ # have been completed.
+ logger.info(f"{config_name} ready for testing ...")
+ yield model_config
+ logger.info(f"Done with {config_name}")
+
+
+@pytest.fixture(scope="module")
+def neuron_model_path(neuron_model_config):
+ yield neuron_model_config["neuron_model_path"]
diff --git a/backends/neuron/tests/prune_test_models.py b/backends/neuron/tests/prune_test_models.py
new file mode 100644
index 000000000..448962fb6
--- /dev/null
+++ b/backends/neuron/tests/prune_test_models.py
@@ -0,0 +1,23 @@
+from argparse import ArgumentParser
+from huggingface_hub import HfApi
+
+
+def main():
+ parser = ArgumentParser()
+ parser.add_argument("--yes", action="store_true", default=False)
+ args = parser.parse_args()
+ api = HfApi()
+ models = api.list_models(search="optimum-internal-testing/neuron-tgi-testing")
+ for model in models:
+ if args.yes:
+ delete = True
+ else:
+ answer = input(f"Do you want to delete {model.id} [y/N] ?")
+ delete = answer == "y"
+ if delete:
+ api.delete_repo(model.id)
+ print(f"Deleted {model.id}.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/backends/neuron/tests/pytest.ini b/backends/neuron/tests/pytest.ini
new file mode 100644
index 000000000..2f4c80e30
--- /dev/null
+++ b/backends/neuron/tests/pytest.ini
@@ -0,0 +1,2 @@
+[pytest]
+asyncio_mode = auto
diff --git a/backends/neuron/tests/requirements.txt b/backends/neuron/tests/requirements.txt
new file mode 100644
index 000000000..ef3c8543e
--- /dev/null
+++ b/backends/neuron/tests/requirements.txt
@@ -0,0 +1,19 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+text-generation >= 0.6.0
+pytest >= 7.4.0
+pytest-asyncio >= 0.21.1
+requests < 2.32.0
+docker >= 6.1.3
+Levenshtein
diff --git a/backends/neuron/tests/server/helpers.py b/backends/neuron/tests/server/helpers.py
new file mode 100644
index 000000000..f0f81d06d
--- /dev/null
+++ b/backends/neuron/tests/server/helpers.py
@@ -0,0 +1,173 @@
+from text_generation_server.generator import NeuronGenerator
+from text_generation_server.pb.generate_pb2 import (
+ Batch,
+ NextTokenChooserParameters,
+ Request,
+ StoppingCriteriaParameters,
+)
+
+
+def create_request(
+ id: int,
+ inputs: str,
+ truncate: int = 0,
+ max_new_tokens: int = 20,
+ do_sample: bool = False,
+ top_k: int = 50,
+ top_p: float = 0.9,
+ temperature: float = 1.0,
+ seed: int = 42,
+ repetition_penalty: float = 1.0,
+):
+ parameters = NextTokenChooserParameters(
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ do_sample=do_sample,
+ seed=seed,
+ repetition_penalty=repetition_penalty,
+ )
+ stopping_parameters = StoppingCriteriaParameters(max_new_tokens=max_new_tokens)
+ return Request(
+ id=id,
+ inputs=inputs,
+ truncate=truncate,
+ parameters=parameters,
+ stopping_parameters=stopping_parameters,
+ )
+
+
+def check_prefill(
+ input_text,
+ expected_token_id,
+ expected_token_text,
+ do_sample,
+ batch_size,
+ model_path,
+):
+ """Verify that a prefill for a single request generates the expected output."""
+ generator = NeuronGenerator.from_pretrained(model_path)
+ assert generator.model.batch_size >= batch_size
+ requests = []
+ max_new_tokens = 20
+ for i in range(batch_size):
+ requests.append(
+ create_request(
+ id=0,
+ inputs=input_text,
+ do_sample=do_sample,
+ max_new_tokens=max_new_tokens,
+ )
+ )
+ # Let's be pessimistic when estimating max_tokens
+ batch_size * (len(input_text) + max_new_tokens)
+ max_length = generator.model.max_length
+ batch = Batch(
+ id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
+ )
+ generations, next_batch = generator.prefill(batch)
+ assert next_batch.size == batch_size
+ # Whatever was passed as max_tokens, the server will correct it
+ # because of static batching
+ assert next_batch.max_tokens == batch_size * max_length
+ assert len(generations) == batch_size
+ for g in generations:
+ tokens = g.tokens
+ assert tokens.ids == [expected_token_id]
+ assert tokens.texts == [expected_token_text]
+
+
+def check_decode_single(
+ input_text, max_new_tokens, generated_text, do_sample, model_path
+):
+ """Verify that a decoding for a single request generates the expected output."""
+ generator = NeuronGenerator.from_pretrained(model_path)
+ request = create_request(
+ id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample
+ )
+ max_length = generator.model.max_length
+ batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
+ generations, next_batch = generator.prefill(batch)
+ # We already generated one token: call decode max_new_tokens - 1 times
+ for _ in range(max_new_tokens - 1):
+ assert next_batch.size == 1
+ assert next_batch.max_tokens == max_length
+ assert len(generations) == 1
+ assert len(generations[0].tokens.ids) == 1
+ generations, next_batch = generator.decode([next_batch])
+ assert next_batch is None
+ assert len(generations) == 1
+ output = generations[0].generated_text
+ assert output.generated_tokens == max_new_tokens
+ assert output.finish_reason == 0
+ assert output.text == generated_text
+
+
+def check_decode_multiple(model_path):
+ """Verify that two requests added to the batch at different generation steps
+ generate the same outputs (continuous batching).
+ """
+ generator = NeuronGenerator.from_pretrained(model_path)
+ assert generator.model.batch_size > 1
+ input_text = "Once upon a time"
+ max_new_tokens = 20
+ # Prefill a single request, remembering the generated token
+ tokens = {0: [], 1: []}
+ request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens)
+ max_length = generator.model.max_length
+ batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
+ generations, next_batch = generator.prefill(batch)
+ assert next_batch.size == 1
+ assert len(generations) == 1
+ g = generations[0]
+ tokens[g.request_id].append(g.tokens.ids[0])
+ assert len(tokens[0]) == 1
+ # Decode a few tokens
+ gen_tokens = 4
+ for _ in range(gen_tokens - 1):
+ generations, next_batch = generator.decode([next_batch])
+ assert len(generations) == 1
+ g = generations[0]
+ tokens[g.request_id].append(g.tokens.ids[0])
+ assert len(tokens[0]) == gen_tokens
+ assert next_batch.size == 1
+ # Add a second request
+ request = create_request(id=1, inputs=input_text, max_new_tokens=max_new_tokens)
+ batch = Batch(id=1, requests=[request], size=1, max_tokens=max_length)
+ generations, next_batch_1 = generator.prefill(batch)
+ assert next_batch_1.size == 1
+ # We should have generated only a single token
+ assert len(generations) == 1
+ g = generations[0]
+ tokens[g.request_id].append(g.tokens.ids[0])
+ assert len(tokens[0]) == gen_tokens
+ assert len(tokens[1]) == 1
+ # Decode more tokens until we reach the maximum for the first request
+ batches = [next_batch, next_batch_1]
+ for _ in range(max_new_tokens - gen_tokens):
+ generations, next_batch = generator.decode(batches)
+ for g in generations:
+ tokens[g.request_id].append(g.tokens.ids[0])
+ batches = [next_batch]
+ # Verify we now only have one pending request
+ assert next_batch.size == 1
+ assert len(tokens[0]) == max_new_tokens
+ assert len(tokens[1]) == max_new_tokens - gen_tokens + 1
+ # Verify we have the output for the first request
+ for g in generations:
+ if g.request_id == 0:
+ output = g.generated_text
+ assert output.text != ""
+ assert output.generated_tokens == max_new_tokens
+ generated_text = output.text
+ # Continue decoding until the end of the second request
+ for _ in range(gen_tokens - 1):
+ generations, next_batch = generator.decode([next_batch])
+ assert len(generations) == 1
+ g = generations[0]
+ tokens[g.request_id].append(g.tokens.ids[0])
+ assert next_batch is None
+ output = generations[0].generated_text
+ assert output.generated_tokens == max_new_tokens
+ assert tokens[0] == tokens[1]
+ assert output.text == generated_text
diff --git a/backends/neuron/tests/server/test_continuous_batching.py b/backends/neuron/tests/server/test_continuous_batching.py
new file mode 100644
index 000000000..48bb70cc8
--- /dev/null
+++ b/backends/neuron/tests/server/test_continuous_batching.py
@@ -0,0 +1,74 @@
+from helpers import create_request
+from text_generation_server.generator import NeuronGenerator
+from text_generation_server.pb.generate_pb2 import Batch
+
+
+def test_continuous_batching_two_requests(neuron_model_config):
+ """Verify that two requests added to the batch at different generation steps
+ generate the same outputs (continuous batching).
+ """
+ neuron_model_path = neuron_model_config["neuron_model_path"]
+ generator = NeuronGenerator.from_pretrained(neuron_model_path)
+ assert generator.model.batch_size > 1
+ input_text = "Once upon a time"
+ max_new_tokens = 20
+ # Prefill a single request, remembering the generated token
+ tokens = {0: [], 1: []}
+ request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens)
+ max_length = generator.model.max_length
+ batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
+ generations, next_batch = generator.prefill(batch)
+ assert next_batch.size == 1
+ assert len(generations) == 1
+ g = generations[0]
+ tokens[g.request_id].append(g.tokens.ids[0])
+ assert len(tokens[0]) == 1
+ # Decode a few tokens
+ gen_tokens = 4
+ for _ in range(gen_tokens - 1):
+ generations, next_batch = generator.decode([next_batch])
+ assert len(generations) == 1
+ g = generations[0]
+ tokens[g.request_id].append(g.tokens.ids[0])
+ assert len(tokens[0]) == gen_tokens
+ assert next_batch.size == 1
+ # Add a second request
+ request = create_request(id=1, inputs=input_text, max_new_tokens=max_new_tokens)
+ batch = Batch(id=1, requests=[request], size=1, max_tokens=max_length)
+ generations, next_batch_1 = generator.prefill(batch)
+ assert next_batch_1.size == 1
+ # We should have generated only a single token
+ assert len(generations) == 1
+ g = generations[0]
+ tokens[g.request_id].append(g.tokens.ids[0])
+ assert len(tokens[0]) == gen_tokens
+ assert len(tokens[1]) == 1
+ # Decode more tokens until we reach the maximum for the first request
+ batches = [next_batch, next_batch_1]
+ for _ in range(max_new_tokens - gen_tokens):
+ generations, next_batch = generator.decode(batches)
+ for g in generations:
+ tokens[g.request_id].append(g.tokens.ids[0])
+ batches = [next_batch]
+ # Verify we now only have one pending request
+ assert next_batch.size == 1
+ assert len(tokens[0]) == max_new_tokens
+ assert len(tokens[1]) == max_new_tokens - gen_tokens + 1
+ # Verify we have the output for the first request
+ for g in generations:
+ if g.request_id == 0:
+ output = g.generated_text
+ assert output.text != ""
+ assert output.generated_tokens == max_new_tokens
+ generated_text = output.text
+ # Continue decoding until the end of the second request
+ for _ in range(gen_tokens - 1):
+ generations, next_batch = generator.decode([next_batch])
+ assert len(generations) == 1
+ g = generations[0]
+ tokens[g.request_id].append(g.tokens.ids[0])
+ assert next_batch is None
+ output = generations[0].generated_text
+ assert output.generated_tokens == max_new_tokens
+ assert tokens[0] == tokens[1]
+ assert output.text == generated_text
diff --git a/backends/neuron/tests/server/test_decode.py b/backends/neuron/tests/server/test_decode.py
new file mode 100644
index 000000000..9db5e3abb
--- /dev/null
+++ b/backends/neuron/tests/server/test_decode.py
@@ -0,0 +1,59 @@
+from helpers import create_request
+from text_generation_server.generator import NeuronGenerator
+from text_generation_server.pb.generate_pb2 import Batch
+
+
+def test_decode(neuron_model_config):
+ """Verify that a decoding for a single request generates the expected output."""
+ config_name = neuron_model_config["name"]
+ neuron_model_path = neuron_model_config["neuron_model_path"]
+ generator = NeuronGenerator.from_pretrained(neuron_model_path)
+ for do_sample in [True, False]:
+ mode = "sample" if do_sample else "greedy"
+ print(f"{config_name}[{mode}]")
+ _test_decode(config_name, generator, do_sample)
+ generator.clear()
+
+
+def _test_decode(config_name, generator, do_sample):
+ input_text = (
+ "It was a bright cold day in April, and the clocks were striking thirteen."
+ )
+ max_new_tokens = 20
+ request = create_request(
+ id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample
+ )
+ max_length = generator.model.max_length
+ batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
+ generations, next_batch = generator.prefill(batch)
+ # We already generated one token: call decode max_new_tokens - 1 times
+ for _ in range(max_new_tokens - 1):
+ assert next_batch.size == 1
+ assert next_batch.max_tokens == max_length
+ assert len(generations) == 1
+ assert len(generations[0].tokens.ids) == 1
+ generations, next_batch = generator.decode([next_batch])
+ assert next_batch is None
+ assert len(generations) == 1
+ output = generations[0].generated_text
+ assert output.generated_tokens == max_new_tokens
+ assert output.finish_reason == 0
+ if do_sample:
+ expected_text = {
+ "gpt2": " The sun was set",
+ "llama": "George Orwell, 1984",
+ "mistral": "The sky was",
+ "qwen2": " A young woman with",
+ "granite": "1984, George Orwell",
+ }[config_name]
+ assert expected_text in output.text
+ else:
+ print(output.text)
+ expected_text = {
+ "gpt2": '\n\n"I\'m going to go to bed," I said.\n\n"I\'m going',
+ "llama": " George Orwell’s classic dystopian novel, 1984, begins with this ominous sentence. The story",
+ "mistral": "\nThe clocks were striking thirteen.\nThe clocks were striking thirteen.",
+ "qwen2": " I was sitting in my room, staring at the ceiling, when the door opened and in came a",
+ "granite": "\n\nThis opening line from George Orwell's dystopian novel \"198",
+ }[config_name]
+ assert output.text == expected_text
diff --git a/backends/neuron/tests/server/test_generator_slot.py b/backends/neuron/tests/server/test_generator_slot.py
new file mode 100644
index 000000000..0c03e9d1e
--- /dev/null
+++ b/backends/neuron/tests/server/test_generator_slot.py
@@ -0,0 +1,66 @@
+import pytest
+import torch
+from text_generation_server.generator import Slot
+from text_generation_server.pb.generate_pb2 import Request
+from transformers import AutoTokenizer, GenerationConfig
+
+
+TOKENIZERS = ["NousResearch/Llama-2-7b-hf", "gpt2"]
+
+
+@pytest.fixture(params=TOKENIZERS)
+def tokenizer(request):
+ t = AutoTokenizer.from_pretrained(request.param)
+ t.padding_side = "left"
+ t.pad_token_id = t.eos_token_id
+ return t
+
+
+@pytest.mark.parametrize(
+ "input_text, generated_text",
+ [
+ [
+ "It was a bright cold day in April, and the clocks were striking thirteen.",
+ " Winston Smith, his chin nuzzled into his breast in an effort to escape the vile wind,"
+ " slipped quickly through the glass doors of Victory Mansions, though not quickly enough"
+ " to prevent a swirl of gritty dust from entering along with him.",
+ ],
+ ["This sentence is written in chinese:", "我很感谢你的热情"],
+ ["Some text might contain a lot of emojis like 😃", "😍💪 👉 👀"],
+ ],
+ ids=["spaces", "chinese-utf8", "emojis"],
+)
+def test_decode_streaming(tokenizer, input_text, generated_text):
+ slot = Slot(0, tokenizer)
+ request = Request(id=0, inputs=input_text)
+ slot.assign(0, request, GenerationConfig())
+ assert slot.cached_text == input_text
+
+ inputs = tokenizer(
+ input_text,
+ padding="max_length",
+ max_length=len(input_text) + 1,
+ return_tensors="pt",
+ )
+ input_ids = inputs["input_ids"][0]
+ attention_mask = inputs["attention_mask"][0]
+ generated_tokens = tokenizer(generated_text, add_special_tokens=False)["input_ids"]
+
+ # We need to regenerate the full text as the tokenizer might change it (extra spaces might be added)
+ all_input_ids = torch.cat([input_ids, torch.tensor(generated_tokens)])
+ full_text = tokenizer.decode(all_input_ids, skip_special_tokens=True)
+ regenerated_text = full_text[len(input_text) :]
+
+ # Initialize the slot with the inputs
+ slot.reset(input_ids, attention_mask, selector=None)
+
+ assert slot.generated_tokens == 0
+
+ # Simulate an iterative generation (i.e. don't call select and use known tokens instead)
+ decoded_text = ""
+ for i in range(len(generated_tokens)):
+ text = slot.append(generated_tokens[i])
+ assert slot.generated_tokens == i + 1
+ decoded_text += text
+
+ assert decoded_text == regenerated_text
diff --git a/backends/neuron/tests/server/test_info.py b/backends/neuron/tests/server/test_info.py
new file mode 100644
index 000000000..5913acec4
--- /dev/null
+++ b/backends/neuron/tests/server/test_info.py
@@ -0,0 +1,10 @@
+from text_generation_server.generator import NeuronGenerator
+
+
+def test_info(neuron_model_path):
+ generator = NeuronGenerator.from_pretrained(neuron_model_path)
+ info = generator.info
+ assert info.requires_padding is True
+ assert info.device_type == "xla"
+ assert info.window_size == 0
+ assert info.speculate == 0
diff --git a/backends/neuron/tests/server/test_prefill.py b/backends/neuron/tests/server/test_prefill.py
new file mode 100644
index 000000000..c0155b1a1
--- /dev/null
+++ b/backends/neuron/tests/server/test_prefill.py
@@ -0,0 +1,102 @@
+from helpers import create_request
+from text_generation_server.generator import NeuronGenerator
+from text_generation_server.pb.generate_pb2 import Batch
+
+
+def test_prefill(neuron_model_config):
+ """Verify that a prefill for a single request generates the expected output."""
+ config_name = neuron_model_config["name"]
+ neuron_model_path = neuron_model_config["neuron_model_path"]
+ generator = NeuronGenerator.from_pretrained(neuron_model_path)
+ max_batch_size = 4
+ assert generator.model.batch_size >= max_batch_size
+ for num_requests in [1, max_batch_size]:
+ for do_sample in [True, False]:
+ mode = "sample" if do_sample else "greedy"
+ print(f"[{mode}]: {num_requests} requests")
+ _test_prefill(config_name, generator, num_requests, do_sample)
+ generator.clear()
+
+
+def _test_prefill(config_name, generator, batch_size, do_sample):
+ requests = []
+ max_new_tokens = 20
+ input_text = (
+ "It was a bright cold day in April, and the clocks were striking thirteen."
+ )
+ for i in range(batch_size):
+ requests.append(
+ create_request(
+ id=i,
+ inputs=input_text,
+ do_sample=do_sample,
+ max_new_tokens=max_new_tokens,
+ )
+ )
+ # Let's be pessimistic when estimating max_tokens
+ max_length = generator.model.max_length
+ batch = Batch(
+ id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
+ )
+ generations, next_batch = generator.prefill(batch)
+ assert next_batch.size == batch_size
+ # Whatever was passed as max_tokens, the server will correct it
+ # because of static batching
+ assert next_batch.max_tokens == batch_size * max_length
+ assert len(generations) == batch_size
+ if do_sample:
+ expectations = {
+ "gpt2": [383, " The"],
+ "llama": [10058, " George"],
+ "mistral": [450, " The"],
+ "qwen2": [362, " A"],
+ "granite": [308, " ("],
+ }[config_name]
+ else:
+ expectations = {
+ "gpt2": [198, "\n"],
+ "llama": [10058, " George"],
+ "mistral": [13, "\n"],
+ "qwen2": [358, " I"],
+ "granite": [203, "\n"],
+ }[config_name]
+ for g in generations:
+ tokens = g.tokens
+ assert tokens.ids[0] == expectations[0]
+ assert tokens.texts[0] == expectations[1]
+
+
+def test_prefill_truncate(neuron_model_config):
+ config_name = neuron_model_config["name"]
+ neuron_model_path = neuron_model_config["neuron_model_path"]
+ generator = NeuronGenerator.from_pretrained(neuron_model_path)
+ batch_size = generator.model.batch_size
+ # We apply truncation to all requests but the first one
+ truncate = [
+ None,
+ ] + [i * 3 for i in range(1, batch_size)]
+ input_text = (
+ "Two gin-scented tears trickled down the sides of his nose."
+ " But it was all right, everything was all right, the struggle was finished."
+ " He had won the victory over himself. He loved Big Brother."
+ )
+ requests = []
+ for i in range(batch_size):
+ requests.append(create_request(id=i, inputs=input_text, truncate=truncate[i]))
+ max_length = generator.model.max_length
+ batch = Batch(
+ id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
+ )
+ generations, _ = generator.prefill(batch)
+ # Even if the input text is identical for all requests, the first generated token might
+ # be different because of the truncation
+ expectations = {
+ "gpt2": [" He", " He", "\n", " He"],
+ "llama": [" —", " The", " He", " He"],
+ "mistral": [" He", "\n", " He", " He"],
+ "qwen2": [" He", " The", " He", " He"],
+ "granite": ["\n", "\n", " I", " He"],
+ }[config_name]
+ for i, g in enumerate(generations):
+ tokens = g.tokens
+ assert tokens.texts[0] == expectations[i]
diff --git a/backends/neuron/tgi-entrypoint.sh b/backends/neuron/tgi-entrypoint.sh
new file mode 100755
index 000000000..b959a7958
--- /dev/null
+++ b/backends/neuron/tgi-entrypoint.sh
@@ -0,0 +1,16 @@
+#!/bin/bash
+set -e -o pipefail -u
+
+export ENV_FILEPATH=$(mktemp)
+
+trap "rm -f ${ENV_FILEPATH}" EXIT
+
+touch $ENV_FILEPATH
+
+SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
+
+${SCRIPT_DIR}/tgi_env.py $@
+
+source $ENV_FILEPATH
+
+exec text-generation-launcher $@
diff --git a/backends/neuron/tgi_env.py b/backends/neuron/tgi_env.py
new file mode 100755
index 000000000..a7042130b
--- /dev/null
+++ b/backends/neuron/tgi_env.py
@@ -0,0 +1,268 @@
+#!/usr/bin/env python
+
+import argparse
+import logging
+import os
+import sys
+from typing import Any, Dict, List, Optional
+
+from huggingface_hub import constants
+from transformers import AutoConfig
+
+from optimum.neuron.modeling_decoder import get_available_cores
+from optimum.neuron.utils import get_hub_cached_entries
+from optimum.neuron.utils.version_utils import get_neuronxcc_version
+
+
+logger = logging.getLogger(__name__)
+
+tgi_router_env_vars = [
+ "MAX_BATCH_SIZE",
+ "MAX_TOTAL_TOKENS",
+ "MAX_INPUT_TOKENS",
+ "MAX_BATCH_PREFILL_TOKENS",
+]
+tgi_server_env_vars = ["HF_NUM_CORES", "HF_AUTO_CAST_TYPE"]
+
+env_config_peering = [
+ ("MAX_BATCH_SIZE", "batch_size"),
+ ("MAX_TOTAL_TOKENS", "sequence_length"),
+ ("HF_AUTO_CAST_TYPE", "auto_cast_type"),
+ ("HF_NUM_CORES", "num_cores"),
+]
+
+# By the end of this script all env var should be specified properly
+env_vars = tgi_server_env_vars + tgi_router_env_vars
+
+available_cores = get_available_cores()
+neuronxcc_version = get_neuronxcc_version()
+
+
+def parse_cmdline_and_set_env(argv: List[str] = None) -> argparse.Namespace:
+ parser = argparse.ArgumentParser()
+ if not argv:
+ argv = sys.argv
+ # All these are params passed to tgi and intercepted here
+ parser.add_argument(
+ "--max-input-tokens",
+ type=int,
+ default=os.getenv("MAX_INPUT_TOKENS", os.getenv("MAX_INPUT_LENGTH", 0)),
+ )
+ parser.add_argument(
+ "--max-total-tokens", type=int, default=os.getenv("MAX_TOTAL_TOKENS", 0)
+ )
+ parser.add_argument(
+ "--max-batch-size", type=int, default=os.getenv("MAX_BATCH_SIZE", 0)
+ )
+ parser.add_argument(
+ "--max-batch-prefill-tokens",
+ type=int,
+ default=os.getenv("MAX_BATCH_PREFILL_TOKENS", 0),
+ )
+ parser.add_argument("--model-id", type=str, default=os.getenv("MODEL_ID"))
+ parser.add_argument("--revision", type=str, default=os.getenv("REVISION"))
+
+ args = parser.parse_known_args(argv)[0]
+
+ if not args.model_id:
+ raise Exception(
+ "No model id provided ! Either specify it using --model-id cmdline or MODEL_ID env var"
+ )
+
+ # Override env with cmdline params
+ os.environ["MODEL_ID"] = args.model_id
+
+ # Set all tgi router and tgi server values to consistent values as early as possible
+ # from the order of the parser defaults, the tgi router value can override the tgi server ones
+ if args.max_total_tokens > 0:
+ os.environ["MAX_TOTAL_TOKENS"] = str(args.max_total_tokens)
+
+ if args.max_input_tokens > 0:
+ os.environ["MAX_INPUT_TOKENS"] = str(args.max_input_tokens)
+
+ if args.max_batch_size > 0:
+ os.environ["MAX_BATCH_SIZE"] = str(args.max_batch_size)
+
+ if args.max_batch_prefill_tokens > 0:
+ os.environ["MAX_BATCH_PREFILL_TOKENS"] = str(args.max_batch_prefill_tokens)
+
+ if args.revision:
+ os.environ["REVISION"] = str(args.revision)
+
+ return args
+
+
+def neuron_config_to_env(neuron_config):
+ with open(os.environ["ENV_FILEPATH"], "w") as f:
+ for env_var, config_key in env_config_peering:
+ f.write("export {}={}\n".format(env_var, neuron_config[config_key]))
+ max_input_tokens = os.getenv("MAX_INPUT_TOKENS")
+ if not max_input_tokens:
+ max_input_tokens = int(neuron_config["sequence_length"]) // 2
+ if max_input_tokens == 0:
+ raise Exception("Model sequence length should be greater than 1")
+ f.write("export MAX_INPUT_TOKENS={}\n".format(max_input_tokens))
+ max_batch_prefill_tokens = os.getenv("MAX_BATCH_PREFILL_TOKENS")
+ if not max_batch_prefill_tokens:
+ max_batch_prefill_tokens = int(neuron_config["batch_size"]) * int(
+ max_input_tokens
+ )
+ f.write("export MAX_BATCH_PREFILL_TOKENS={}\n".format(max_batch_prefill_tokens))
+
+
+def sort_neuron_configs(dictionary):
+ return -dictionary["num_cores"], -dictionary["batch_size"]
+
+
+def lookup_compatible_cached_model(
+ model_id: str, revision: Optional[str]
+) -> Optional[Dict[str, Any]]:
+ # Reuse the same mechanic as the one in use to configure the tgi server part
+ # The only difference here is that we stay as flexible as possible on the compatibility part
+ entries = get_hub_cached_entries(model_id, "inference")
+
+ logger.debug(
+ "Found %d cached entries for model %s, revision %s",
+ len(entries),
+ model_id,
+ revision,
+ )
+
+ all_compatible = []
+ for entry in entries:
+ if check_env_and_neuron_config_compatibility(
+ entry, check_compiler_version=True
+ ):
+ all_compatible.append(entry)
+
+ if not all_compatible:
+ logger.debug(
+ "No compatible cached entry found for model %s, env %s, available cores %s, neuronxcc version %s",
+ model_id,
+ get_env_dict(),
+ available_cores,
+ neuronxcc_version,
+ )
+ return None
+
+ logger.info("%d compatible neuron cached models found", len(all_compatible))
+
+ all_compatible = sorted(all_compatible, key=sort_neuron_configs)
+
+ entry = all_compatible[0]
+
+ return entry
+
+
+def check_env_and_neuron_config_compatibility(
+ neuron_config: Dict[str, Any], check_compiler_version: bool
+) -> bool:
+ logger.debug(
+ "Checking the provided neuron config %s is compatible with the local setup and provided environment",
+ neuron_config,
+ )
+
+ # Local setup compat checks
+ if neuron_config["num_cores"] > available_cores:
+ logger.debug(
+ "Not enough neuron cores available to run the provided neuron config"
+ )
+ return False
+
+ if (
+ check_compiler_version
+ and neuron_config["compiler_version"] != neuronxcc_version
+ ):
+ logger.debug(
+ "Compiler version conflict, the local one (%s) differs from the one used to compile the model (%s)",
+ neuronxcc_version,
+ neuron_config["compiler_version"],
+ )
+ return False
+
+ for env_var, config_key in env_config_peering:
+ neuron_config_value = str(neuron_config[config_key])
+ env_value = os.getenv(env_var, str(neuron_config_value))
+ if env_value != neuron_config_value:
+ logger.debug(
+ "The provided env var '%s' and the neuron config '%s' param differ (%s != %s)",
+ env_var,
+ config_key,
+ env_value,
+ neuron_config_value,
+ )
+ return False
+
+ max_input_tokens = int(
+ os.getenv("MAX_INPUT_TOKENS", os.getenv("MAX_INPUT_LENGTH", 0))
+ )
+ if max_input_tokens > 0:
+ sequence_length = neuron_config["sequence_length"]
+ if max_input_tokens >= sequence_length:
+ logger.debug(
+ "Specified max input tokens is not compatible with config sequence length ( %s >= %s)",
+ max_input_tokens,
+ sequence_length,
+ )
+ return False
+
+ return True
+
+
+def get_env_dict() -> Dict[str, str]:
+ d = {}
+ for k in env_vars:
+ d[k] = os.getenv(k)
+ return d
+
+
+def main():
+ """
+ This script determines proper default TGI env variables for the neuron precompiled models to
+ work properly
+ :return:
+ """
+ args = parse_cmdline_and_set_env()
+
+ for env_var in env_vars:
+ if not os.getenv(env_var):
+ break
+ else:
+ logger.info(
+ "All env vars %s already set, skipping, user know what they are doing",
+ env_vars,
+ )
+ sys.exit(0)
+
+ cache_dir = constants.HF_HUB_CACHE
+
+ logger.info("Cache dir %s, model %s", cache_dir, args.model_id)
+
+ config = AutoConfig.from_pretrained(args.model_id, revision=args.revision)
+ neuron_config = getattr(config, "neuron", None)
+ if neuron_config is not None:
+ compatible = check_env_and_neuron_config_compatibility(
+ neuron_config, check_compiler_version=False
+ )
+ if not compatible:
+ env_dict = get_env_dict()
+ msg = (
+ "Invalid neuron config and env. Config {}, env {}, available cores {}, neuronxcc version {}"
+ ).format(neuron_config, env_dict, available_cores, neuronxcc_version)
+ logger.error(msg)
+ raise Exception(msg)
+ else:
+ neuron_config = lookup_compatible_cached_model(args.model_id, args.revision)
+
+ if not neuron_config:
+ msg = (
+ "No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}"
+ ).format(get_env_dict(), available_cores, neuronxcc_version)
+ logger.error(msg)
+ raise Exception(msg)
+
+ neuron_config_to_env(neuron_config)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/backends/trtllm/CMakeLists.txt b/backends/trtllm/CMakeLists.txt
index 831372cdf..e54fd1169 100644
--- a/backends/trtllm/CMakeLists.txt
+++ b/backends/trtllm/CMakeLists.txt
@@ -1,25 +1,19 @@
cmake_minimum_required(VERSION 3.20)
-if (NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER AND CMAKE_BUILD_TYPE STREQUAL "Debug")
- find_program(CCACHE_EXECUTABLE "ccache")
- if (CCACHE_EXECUTABLE)
- message(STATUS "Using ccache")
- set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_EXECUTABLE}" CACHE PATH "Path to ccache" FORCE)
- endif ()
-endif ()
-
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
cmake_policy(SET CMP0135 NEW)
endif ()
project(tgi-trtllm-backend VERSION 1.0.0)
-set(CMAKE_CXX_STANDARD 20)
+set(CMAKE_CXX_STANDARD 23)
include(FetchContent)
include(ExternalProject)
+include(CheckCXXCompilerFlag)
option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
+option(TGI_TRTLLM_BACKEND_BUILD_USE_LLD "Enable lld linker instead of ld" OFF)
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located")
set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located")
@@ -27,13 +21,24 @@ set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE ST
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
find_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
+find_package(MPI REQUIRED)
#### External dependencies ####
-include(cmake/fmt.cmake)
include(cmake/json.cmake)
include(cmake/spdlog.cmake)
include(cmake/trtllm.cmake)
+if (CMAKE_BUILD_TYPE STREQUAL "Debug")
+ set(TGI_TRTLLM_BACKEND_DEBUG ON)
+ add_compile_definitions(TGI_TRTLLM_BACKEND_DEBUG=1)
+ add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_TRACE)
+endif ()
+
+if (${TGI_TRTLLM_BACKEND_BUILD_USE_LLD})
+ message(STATUS "Using lld linker")
+ add_link_options("-fuse-ld=lld")
+endif ()
+
# Let's build TRTLLM as part of CMake
add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..")
@@ -41,35 +46,75 @@ add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..")
set_target_properties(executorWorker PROPERTIES SKIP_BUILD_RPATH TRUE)
# TGI TRTLLM Backend definition
-add_library(tgi_trtllm_backend_impl STATIC include/backend.h lib/backend.cpp include/hardware.h)
+add_library(tgi_trtllm_backend_impl STATIC csrc/hardware.hpp csrc/backend.hpp csrc/backend.cpp)
include_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
target_include_directories(tgi_trtllm_backend_impl PRIVATE
- $
- $
+ $
+ # $
)
target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
-target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper CUDA::cudart CUDA::nvml)
-target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt)
+target_link_libraries(tgi_trtllm_backend_impl PRIVATE CUDA::cudart CUDA::nvml)
+target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog)
+target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper)
# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back
-install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker)
-install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB)
+install(TARGETS tgi_trtllm_backend_impl)
+#install(TARGETS cutlass_src fb_gemm_src fpA_intB_gemm_src gemm_swiglu_sm90_src kernels_src)
+install(TARGETS decoder_attention_0 decoder_attention_1)
+install(TARGETS tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention_src executorWorker)
+install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} TYPE LIB)
+if (NOT ${TGI_TRTLLM_BACKEND_DEBUG})
+ install(FILES ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB)
+endif ()
+
#### Unit Tests ####
-if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})
+if (${TGI_TRTLLM_BACKEND_BUILD_TESTS} AND CMAKE_BUILD_TYPE MATCHES "Debug")
message(STATUS "Building tests")
+ option(TGI_TRTLLM_BACKEND_ENABLE_ASAN "Enable AddressSanitizer")
+ option(TGI_TRTLLM_BACKEND_ENABLE_UBSAN "Enable UndefinedSanitizer")
+
FetchContent_Declare(
Catch2
- GIT_REPOSITORY https://github.com/catchorg/Catch2
- GIT_TAG v3.6.0
+ URL https://github.com/catchorg/Catch2/archive/refs/tags/v3.7.1.tar.gz
)
FetchContent_MakeAvailable(Catch2)
- # add_executable(tgi_trtllm_backend_tests tests/infer_test.cpp)
- # target_link_libraries(tgi_trtllm_backend_tests PRIVATE tgi_trtllm_backend_impl Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt CUDA::cudart CUDA::nvml)
+ # This attempt to detect if the compiler can emit warning if it can't apply return value optimization from a function
+ check_cxx_compiler_flag("-Wnrvo" COMPILER_SUPPORT_WARNING_ON_NVRO)
+ if (${COMPILER_SUPPORT_WARNING_ON_NVRO})
+ message(STATUS "Enabling non-NVRO detection")
+ target_compile_options(tgi_trtllm_backend_impl PRIVATE -Wnrvo)
+ endif ()
+ target_compile_options(tgi_trtllm_backend_impl PRIVATE -Wall)
- list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
- include(CTest)
- include(Catch)
+ cmake_path(GET TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH PARENT_PATH TRTLLM_NVRTC_WRAPPER_PARENT_LIBRARY_PATH)
+ message(STATUS "Adding linking path: ${TRTLLM_NVRTC_WRAPPER_PARENT_LIBRARY_PATH}")
+
+ add_executable(tgi_trtllm_backend_tests tests/test_hardware.cpp tests/test_backend.cpp)
+
+ # target_compile_options(tgi_trtllm_backend_tests PRIVATE -Werror)
+ target_link_directories(tgi_trtllm_backend_tests PRIVATE "${TRTLLM_NVRTC_WRAPPER_PARENT_LIBRARY_PATH}")
+ target_include_directories(tgi_trtllm_backend_tests PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
+ target_include_directories(tgi_trtllm_backend_tests PUBLIC "csrc/")
+ target_link_libraries(tgi_trtllm_backend_tests PRIVATE ${TRTLLM_LIBS} CUDA::cudart CUDA::nvml)
+ target_link_libraries(tgi_trtllm_backend_tests PUBLIC Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog tgi_trtllm_backend_impl)
+ target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper)
+
+ if (${TGI_TRTLLM_BACKEND_ENABLE_ASAN})
+ message(STATUS "Enabled AddressSanitizer")
+ target_link_options(tgi_trtllm_backend_tests BEFORE PUBLIC -fsanitize=address)
+ endif ()
+
+ if (${TGI_TRTLLM_BACKEND_ENABLE_UBSAN})
+ message(STATUS "Enabled UndefinedSanitizer")
+ target_link_options(tgi_trtllm_backend_tests BEFORE PUBLIC -fsanitize=undefined)
+ endif ()
+
+ install(TARGETS tgi_trtllm_backend_tests)
+
+ # list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
+ # include(CTest)
+ # include(Catch)
# catch_discover_tests(tgi_trtllm_backend_tests)
endif ()
diff --git a/backends/trtllm/Cargo.toml b/backends/trtllm/Cargo.toml
index 97ef1a768..b6c39346a 100644
--- a/backends/trtllm/Cargo.toml
+++ b/backends/trtllm/Cargo.toml
@@ -7,20 +7,17 @@ homepage.workspace = true
[dependencies]
async-trait = "0.1"
-async-stream = "0.3"
clap = { version = "4.5", features = ["derive"] }
cxx = "1.0"
-hashbrown = "0.14"
+hashbrown = "0.15"
hf-hub = { workspace = true }
-log = { version = "0.4", features = [] }
text-generation-router = { path = "../../router" }
tokenizers = { workspace = true }
-tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
-tokio-stream = "0.1.15"
+tokio = { version = "1.43.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
+tokio-stream = "0.1.17"
thiserror = "1.0.63"
tracing = "0.1"
-tracing-opentelemetry = "0.25"
-tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
+pyo3 = { workspace = true }
[build-dependencies]
cmake = "0.1"
diff --git a/backends/trtllm/build.rs b/backends/trtllm/build.rs
index 985019260..c9918e2c5 100644
--- a/backends/trtllm/build.rs
+++ b/backends/trtllm/build.rs
@@ -3,24 +3,34 @@ use pkg_config;
use std::env;
use std::env::consts::ARCH;
use std::path::{absolute, PathBuf};
+use std::sync::LazyLock;
-const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
+const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 1] = ["spdlog"];
const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST");
-const CUDA_REQUIRED_VERSION: &str = "12.6";
+const CUDA_REQUIRED_VERSION: &str = "12.8";
const MPI_REQUIRED_VERSION: &str = "4.1";
const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX");
const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR");
const NCCL_ROOT_DIR: Option<&str> = option_env!("NCCL_ROOT_DIR");
+const IS_GHA_BUILD: LazyLock = LazyLock::new(|| {
+ option_env!("SCCACHE_GHA_ENABLED").map_or(false, |value| match value.to_lowercase().as_str() {
+ "on" => true,
+ "true" => true,
+ "1" => true,
+ _ => false,
+ })
+});
+
// Dependencies
-const BACKEND_DEPS: [&str; 2] = ["tgi_trtllm_backend_impl", "tgi_trtllm_backend"];
+const BACKEND_DEPS: &str = "tgi_trtllm_backend_impl";
const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"];
const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 5] = [
("dylib", "tensorrt_llm"),
- ("static", "tensorrt_llm_executor_static"),
("dylib", "tensorrt_llm_nvrtc_wrapper"),
("dylib", "nvinfer_plugin_tensorrt_llm"),
- ("dylib", "decoder_attention"),
+ ("dylib", "decoder_attention_0"),
+ ("dylib", "decoder_attention_1"),
];
macro_rules! probe {
@@ -32,6 +42,48 @@ macro_rules! probe {
};
}
+fn get_compiler_flag(
+ switch: bool,
+ true_case: &'static str,
+ false_case: &'static str,
+) -> &'static str {
+ match switch {
+ true => true_case,
+ false => false_case,
+ }
+}
+
+fn get_library_architecture() -> &'static str {
+ let os = env::var("CARGO_CFG_TARGET_OS").unwrap();
+ let arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap();
+ let env = env::var("CARGO_CFG_TARGET_ENV").unwrap();
+
+ match os.as_str() {
+ "linux" => {
+ if env != "gnu" {
+ panic!("unsupported linux ABI {env}, only 'gnu' is supported")
+ }
+
+ match arch.as_str() {
+ "x86_64" => "x86_64-linux-gnu",
+ "aarch64" => "aarch64-linux-gnu",
+ _ => panic!("unsupported linux architecture {arch}"),
+ }
+ }
+ "windows" => {
+ if env != "msvc" {
+ panic!("unsupported windows ABI {env}, only 'msvc' is supported")
+ }
+
+ match arch.as_str() {
+ "x86_64" => "x86_64-windows-msvc",
+ _ => panic!("unsupported windows architecture {arch}"),
+ }
+ }
+ _ => panic!("unsupported OS {os}"),
+ }
+}
+
fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) {
// Build the backend implementation through CMake
let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi");
@@ -43,7 +95,8 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
install_path = absolute(out_dir).expect("cannot happen").join(install_path);
}
- let _ = cmake::Config::new(".")
+ let mut config = cmake::Config::new(".");
+ config
.uses_cxx11()
.generator("Ninja")
.profile(match is_debug {
@@ -53,9 +106,50 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
.env("OPT_LEVEL", opt_level)
.define("CMAKE_INSTALL_PREFIX", &install_path)
.define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc")
+ .define("CMAKE_LIBRARY_ARCHITECTURE", get_library_architecture())
.define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list)
- .define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path)
- .build();
+ .define(
+ "TGI_TRTLLM_BACKEND_DEBUG",
+ get_compiler_flag(is_debug, "ON", "OFF"),
+ )
+ .define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path);
+
+ if is_debug || *IS_GHA_BUILD {
+ config.define("TGI_TRTLLM_BACKEND_BUILD_TESTS", "ON");
+ }
+
+ if option_env!("USE_LLD_LINKER").is_some() {
+ println!("cargo:warning=Using lld linker");
+ config.define("TGI_TRTLLM_BACKEND_BUILD_USE_LLD", "ON");
+ }
+
+ if (is_debug && option_env!("ENABLE_ASAN").is_some()) || *IS_GHA_BUILD {
+ println!("cargo:warning=Enabling Address Sanitizer");
+ config.define("TGI_TRTLLM_BACKEND_ENABLE_ASAN", "ON");
+ }
+
+ if (is_debug && option_env!("ENABLE_UBSAN").is_some()) || *IS_GHA_BUILD {
+ println!("cargo:warning=Enabling Undefined Sanitizer");
+ config.define("TGI_TRTLLM_BACKEND_ENABLE_UBSAN", "ON");
+ }
+
+ if let Some(nvcc_host_compiler) = option_env!("CMAKE_CUDA_HOST_COMPILER") {
+ config.define("CMAKE_CUDA_HOST_COMPILER", nvcc_host_compiler);
+ }
+
+ if let Some(wrapper) = option_env!("RUSTC_WRAPPER") {
+ println!("cargo:warning=Using caching tool: {wrapper}");
+ config.define("CMAKE_C_COMPILER_LAUNCHER", wrapper);
+ config.define("CMAKE_CXX_COMPILER_LAUNCHER", wrapper);
+ config.define("CMAKE_CUDA_COMPILER_LAUNCHER", wrapper);
+ }
+
+ // Allow to override which Python to use ...
+ if let Some(python3) = option_env!("Python3_EXECUTABLE") {
+ config.define("Python3_EXECUTABLE", python3);
+ }
+
+ config.build();
// Additional transitive CMake dependencies
let deps_folder = out_dir.join("build").join("_deps");
@@ -70,46 +164,43 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
}
// Emit linkage information from the artifacts we just built
- let install_lib_path = install_path.join("lib");
-
- println!(
- r"cargo:warning=Adding link search path: {}",
- install_lib_path.display()
- );
- println!(r"cargo:rustc-link-search={}", install_lib_path.display());
-
+ for path in ["lib", "lib64"] {
+ let install_lib_path = install_path.join(path);
+ println!(
+ r"cargo:warning=Adding link search path: {}",
+ install_lib_path.display()
+ );
+ println!(r"cargo:rustc-link-search={}", install_lib_path.display());
+ }
(PathBuf::from(install_path), deps_folder)
}
fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) {
- let ndebug = match is_debug {
- true => "1",
- false => "0",
- };
-
CFG.include_prefix = "backends/trtllm";
cxx_build::bridge("src/lib.rs")
.static_flag(true)
- .include(deps_folder.join("fmt-src").join("include"))
+ .std("c++23")
.include(deps_folder.join("spdlog-src").join("include"))
.include(deps_folder.join("json-src").join("include"))
.include(deps_folder.join("trtllm-src").join("cpp").join("include"))
.include("/usr/local/cuda/include")
.include("/usr/local/tensorrt/include")
- .file("src/ffi.cpp")
- .std("c++20")
- .define("NDEBUG", ndebug)
+ .include("csrc/")
+ .file("csrc/ffi.hpp")
+ .define(
+ "TGI_TRTLLM_BACKEND_DEBUG",
+ get_compiler_flag(is_debug, "ON", "OFF"),
+ )
.compile("tgi_trtllm_backend");
println!("cargo:rerun-if-changed=CMakeLists.txt");
println!("cargo:rerun-if-changed=cmake/trtllm.cmake");
println!("cargo:rerun-if-changed=cmake/json.cmake");
- println!("cargo:rerun-if-changed=cmake/fmt.cmake");
println!("cargo:rerun-if-changed=cmake/spdlog.cmake");
- println!("cargo:rerun-if-changed=include/backend.h");
- println!("cargo:rerun-if-changed=lib/backend.cpp");
- println!("cargo:rerun-if-changed=include/ffi.h");
- println!("cargo:rerun-if-changed=src/ffi.cpp");
+ println!("cargo:rerun-if-changed=csrc/backend.hpp");
+ println!("cargo:rerun-if-changed=csrc/backend.cpp");
+ println!("cargo:rerun-if-changed=csrc/hardware.hpp");
+ println!("cargo:rerun-if-changed=csrc/ffi.hpp");
}
fn main() {
@@ -118,6 +209,7 @@ fn main() {
let build_profile = env::var("PROFILE").unwrap();
let (is_debug, opt_level) = match build_profile.as_ref() {
"debug" => (true, "0"),
+ "dev" => (true, "0"),
_ => (false, "3"),
};
@@ -154,7 +246,5 @@ fn main() {
});
// Backend
- BACKEND_DEPS.iter().for_each(|name| {
- println!("cargo:rustc-link-lib=static={}", name);
- });
+ println!("cargo:rustc-link-lib=static={}", &BACKEND_DEPS);
}
diff --git a/backends/trtllm/cmake/fmt.cmake b/backends/trtllm/cmake/fmt.cmake
deleted file mode 100644
index afd6ea5f0..000000000
--- a/backends/trtllm/cmake/fmt.cmake
+++ /dev/null
@@ -1,6 +0,0 @@
-FetchContent_Declare(
- fmt
- DOWNLOAD_EXTRACT_TIMESTAMP
- URL https://github.com/fmtlib/fmt/archive/refs/tags/11.0.2.tar.gz
-)
-FetchContent_MakeAvailable(fmt)
diff --git a/backends/trtllm/cmake/json.cmake b/backends/trtllm/cmake/json.cmake
index 67eff2fe6..d6cdbe3aa 100644
--- a/backends/trtllm/cmake/json.cmake
+++ b/backends/trtllm/cmake/json.cmake
@@ -1,6 +1,6 @@
fetchcontent_declare(
json
- DOWNLOAD_EXTRACT_TIMESTAMP
- URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz
+# DOWNLOAD_EXTRACT_TIMESTAMP
+ URL https://github.com/nlohmann/json/archive/refs/tags/v3.11.3.tar.gz
)
fetchcontent_makeavailable(json)
diff --git a/backends/trtllm/cmake/spdlog.cmake b/backends/trtllm/cmake/spdlog.cmake
index 7f529a7d2..e7566cd73 100644
--- a/backends/trtllm/cmake/spdlog.cmake
+++ b/backends/trtllm/cmake/spdlog.cmake
@@ -1,17 +1,17 @@
set(SPDLOG_USE_FMT ON)
set(SPDLOG_BUILD_SHARED OFF)
-set(SPDLOG_FMT_EXTERNAL ON)
+set(SPDLOG_FMT_EXTERNAL OFF)
# Define the level at which SPDLOG_ compilation level is defined
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
- add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG)
+ add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_TRACE)
else ()
- add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO)
+ add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG)
endif ()
fetchcontent_declare(
spdlog
- DOWNLOAD_EXTRACT_TIMESTAMP
- URL https://github.com/gabime/spdlog/archive/refs/tags/v1.14.1.tar.gz
+ # DOWNLOAD_EXTRACT_TIMESTAMP
+ URL https://github.com/gabime/spdlog/archive/refs/tags/v1.15.0.tar.gz
)
fetchcontent_makeavailable(spdlog)
diff --git a/backends/trtllm/cmake/trtllm.cmake b/backends/trtllm/cmake/trtllm.cmake
index 5f1b6c19c..95a99e9bc 100644
--- a/backends/trtllm/cmake/trtllm.cmake
+++ b/backends/trtllm/cmake/trtllm.cmake
@@ -11,20 +11,25 @@ set(CMAKE_CUDA_ARCHITECTURES ${TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST})
message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}")
+set(ENABLE_UCX OFF)
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
set(FAST_BUILD ON)
- set(NVTX_DISABLE OFF)
+ set(NVTX_DISABLE ON)
+ set(INDEX_RANGE_CHECK ON)
else ()
set(FAST_BUILD OFF)
set(FAST_MATH ON)
- set(NVTX_DISABLE ON)
+ set(NVTX_DISABLE OFF)
+ set(INDEX_RANGE_CHECK OFF)
endif ()
+find_package(Python3 REQUIRED Interpreter)
+
fetchcontent_declare(
trtllm
- GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git
- GIT_TAG 201135e58aa525af7e523d091d4c9584229524bc
- GIT_SHALLOW FALSE
+ GIT_REPOSITORY https://github.com/nvidia/TensorRT-LLM.git
+ GIT_TAG v0.17.0
+ GIT_SHALLOW ON
DOWNLOAD_EXTRACT_TIMESTAMP
)
fetchcontent_makeavailable(trtllm)
diff --git a/backends/trtllm/csrc/backend.cpp b/backends/trtllm/csrc/backend.cpp
new file mode 100644
index 000000000..2151466be
--- /dev/null
+++ b/backends/trtllm/csrc/backend.cpp
@@ -0,0 +1,80 @@
+#include
+
+#include
+
+#include "backend.hpp"
+#include "hardware.hpp"
+
+namespace huggingface::tgi::backends::trtllm {
+ tle::ParallelConfig backend_workspace_t::parallel_config() const {
+ // Single engine (TP = PP = 1) -> using leader mode (no MPI involved)
+ const auto world_size = config_["/pretrained_config/mapping/world_size"_json_pointer].get();
+
+ auto mode = tle::CommunicationMode::kLEADER;
+ std::optional orchestratorConfig = std::nullopt;
+
+ if (world_size > 1) {
+ SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
+ mode = tle::CommunicationMode::kORCHESTRATOR;
+ orchestratorConfig = std::make_optional(true, executor_worker_path_, nullptr,
+ true);
+ } else {
+ SPDLOG_INFO("Detected single engine deployment, using leader mode");
+ }
+
+ return tle::ParallelConfig(tle::CommunicationType::kMPI, mode, std::nullopt, std::nullopt, orchestratorConfig);
+ }
+
+
+ tle::ExecutorConfig backend_workspace_t::executor_config() const {
+ // Retrieve the compute capabilities to enable some options at runtime
+ const auto compute_capabilities = hardware::cuda::compute_capabilities_t();
+
+ // Allocate the config
+ tle::ExecutorConfig executor_config(/* maxBeamWidth = */ 1);
+
+ // Set the parallel config as inferred
+ executor_config.setParallelConfig(parallel_config());
+
+ // Define some configuration variables
+ executor_config.setKvCacheConfig(tle::KvCacheConfig(true));
+ executor_config.setEnableChunkedContext(compute_capabilities.is_at_least_ampere());
+ executor_config.setSchedulerConfig(tle::SchedulerConfig(tle::CapacitySchedulerPolicy::kMAX_UTILIZATION));
+ return executor_config;
+ }
+
+ backend_t::backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path)
+ : workspace(engines_folder, executor_worker_path), executor_(executor_factory_initializer(workspace)) {}
+
+ size_t backend_t::num_tokens_ready() const noexcept {
+ return executor_.getNumResponsesReady();
+ }
+
+ std::expected
+ backend_t::submit(std::span token_ids, const generation_params_t g_params,
+ const sampling_params_t s_params) noexcept {
+ SPDLOG_DEBUG("Submit {:d} tokens for scheduling ({}, {})", token_ids.size(), g_params, s_params);
+ return executor_.enqueueRequest(tle::Request{
+ {token_ids.begin(), token_ids.end()}, // Making actual copy of the tokens
+ static_cast(g_params.max_new_tokens),
+ true,
+ (tle::SamplingConfig) s_params,
+ tle::OutputConfig{ /* returnLogProbs= */ true},
+ std::nullopt,
+ std::nullopt,
+ std::nullopt,
+ std::nullopt,
+ workspace.generation_config().stop_words
+ });
+ }
+
+ std::vector backend_t::pull_tokens() noexcept {
+ SPDLOG_TRACE(FMT_STRING("Pulling out tokens ({:d} available)"), num_tokens_ready());
+ return executor_.awaitResponses();
+ }
+
+ void backend_t::cancel(request_id_t request_id) noexcept {
+ SPDLOG_TRACE(FMT_STRING("Cancelling request: {:d}"), request_id);
+ executor_.cancelRequest(request_id);
+ }
+}
diff --git a/backends/trtllm/csrc/backend.hpp b/backends/trtllm/csrc/backend.hpp
new file mode 100644
index 000000000..40b44a842
--- /dev/null
+++ b/backends/trtllm/csrc/backend.hpp
@@ -0,0 +1,231 @@
+#ifndef TGI_BACKEND_TRTLLM
+#define TGI_BACKEND_TRTLLM
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+
+#include
+
+namespace huggingface::tgi::backends::trtllm {
+ namespace tle = tensorrt_llm::executor;
+ using json = nlohmann::json;
+ using request_id_t = uint64_t;
+ using token_id_t = tle::TokenIdType;
+
+ /**
+ * Represent the parameters used for generation
+ */
+ struct generation_params_t {
+ uint32_t max_new_tokens;
+ };
+
+ /**
+ * Represent the parameters used to sample tokens from the logit distribution
+ */
+ struct sampling_params_t {
+ uint32_t top_k;
+ float_t top_p;
+ float_t repetition_penalty;
+ float_t frequency_penalty;
+ float_t temperature;
+ uint64_t seed;
+
+ constexpr explicit operator tle::SamplingConfig() const {
+ return tle::SamplingConfig{
+ 1,
+ top_k,
+ top_p,
+ std::nullopt,
+ std::nullopt,
+ std::nullopt,
+ seed,
+ temperature,
+ std::nullopt,
+ std::nullopt,
+ repetition_penalty,
+ std::nullopt,
+ frequency_penalty,
+ std::nullopt
+ };
+ }
+ };
+
+ /**
+ * Represent possible values from transformers generation `generation_config.json`.
+ * It usually stores default sampling parameters to use, such as top_p, temperature, etc.
+ */
+ struct generation_config_t {
+ float_t top_p;
+ float_t temperature;
+ std::list> stop_words;
+
+ constexpr explicit generation_config_t(const json &config) :
+ top_p(config.value("top_p", 1.0f)), temperature(config.value("temperature", 1.0f)), stop_words(0) {
+ if (config.contains("/eos_token_id"_json_pointer) && config["/eos_token_id"_json_pointer].is_array()) {
+ const auto &eos_token_id = config["/eos_token_id"_json_pointer];
+ std::for_each(eos_token_id.begin(), eos_token_id.end(), [this](const auto token_id) {
+ stop_words.emplace_back(1, token_id.template get());
+ });
+
+ SPDLOG_DEBUG("Detected {:d} predefined stop_words from generation_config.json", stop_words.size());
+ }
+ }
+ };
+
+ /**
+ * Helper class representing various items which are stored within the TensorRT-LLM engines folder and
+ * can be retrieved at runtime
+ */
+ class backend_workspace_t {
+ private:
+ constexpr static auto as_json = [](const std::filesystem::path &path) -> json {
+ std::ifstream config_f(path);
+ return json::parse(config_f);
+ };
+
+ std::filesystem::path engines_folder_;
+ std::filesystem::path executor_worker_path_;
+ json config_;
+ generation_config_t generation_config_;
+
+ public:
+ backend_workspace_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path) :
+ engines_folder_(engines_folder),
+ executor_worker_path_(executor_worker_path),
+ config_(as_json(engines_folder / "config.json")),
+ generation_config_(as_json(engines_folder / "generation_config.json")) {};
+
+ backend_workspace_t(std::filesystem::path &&engines_folder, std::filesystem::path &&executor_worker_path) :
+ engines_folder_(engines_folder),
+ executor_worker_path_(executor_worker_path),
+ config_(as_json(engines_folder / "config.json")),
+ generation_config_(as_json(engines_folder / "generation_config.json")) {};
+
+ /**
+ * Path to the folder containing the TensorRT-LLM engines
+ * @return local filesystem path to the folder
+ */
+ [[nodiscard]] constexpr std::filesystem::path engines_folder() const { return engines_folder_; }
+
+ /**
+ * Hugging Face transformers' generated `generation_config_t` mapping information stored in the
+ * `generation_config.json` holding default generation parameters.
+ * @return `generation_config_t`
+ */
+ [[nodiscard]] constexpr const generation_config_t &generation_config() const { return generation_config_; }
+
+ /**
+ * Factory method returning new `tensorrt_llm::executor::ParallelConfig` instance used
+ * to initialize `tensorrt_llm::executor::Executor` with multi-instance communication information
+ * @return `tensorrt_llm::executor::ParallelConfig` instance
+ */
+ [[nodiscard]] tle::ParallelConfig parallel_config() const;
+
+ /**
+ * Factory method returning new `tensorrt_llm::executor::ExecutorConfig` instance used
+ * to initialize `tensorrt_llm::executor::Executor`
+ * @return `tensorrt_llm::executor::ExecutorConfig` instance
+ */
+ [[nodiscard]] tle::ExecutorConfig executor_config() const;
+ };
+
+ /**
+ * Error raised by the underlying backend implementation
+ */
+ enum backend_error_t {
+ EXECUTOR_NOT_READY = 3,
+ EXECUTOR_SCHEDULING_FAILED = 4,
+ };
+
+
+ /**
+ * Actual TensorRT-LLM backend implementation interacting with TensorRT-LLM Executor service to
+ * - schedule new request
+ * - pull status of submitted request(s)
+ * - cancel submitted request(s)
+ */
+ class backend_t {
+ private:
+ backend_workspace_t workspace;
+ tle::Executor executor_;
+
+ public:
+ backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path);
+
+ backend_t(std::filesystem::path &&engines_folder, std::filesystem::path &&executor_worker_path)
+ : backend_t(engines_folder, executor_worker_path) {};
+
+ /**
+ * Submit a new request to the executor
+ * @param token_ids
+ * @param generation_params
+ * @param sampling_params
+ * @return Either newly submitted request's id or the error why it failed to submit
+ */
+ [[nodiscard("Discarded executor request_id needs to be assigned")]]
+ std::expected
+ submit(std::span token_ids, generation_params_t generation_params,
+ sampling_params_t sampling_params) noexcept;
+
+ /**
+ * Query the number of tokens available across all in-flight generations
+ * @return
+ */
+ [[nodiscard("Pulling out the number of tokens")]]
+ size_t num_tokens_ready() const noexcept;
+
+ /**
+ * Pull out newly generated tokens from the executor
+ * @return
+ */
+ [[nodiscard("")]]
+ std::vector pull_tokens() noexcept;
+
+ /**
+ * Cancel the specified request on the executor' set
+ * @param request_id Request's Identifier to remove from the in-flight executor
+ */
+ void cancel(request_id_t) noexcept;
+ };
+
+ /**
+ * Create a TensorRT-LLM executor from a workspace
+ */
+ const auto executor_factory_initializer = [](const backend_workspace_t &workspace) -> tle::Executor {
+ return {workspace.engines_folder(), tensorrt_llm::executor::ModelType::kDECODER_ONLY,
+ workspace.executor_config()};
+ };
+}
+
+/**
+ * Helper structures to define formatting strategies for various types in the backend
+ */
+template<>
+struct fmt::formatter : formatter {
+ auto format(huggingface::tgi::backends::trtllm::generation_params_t const &c,
+ format_context &ctx) const -> format_context::iterator {
+ return fmt::format_to(ctx.out(), "generation_params_t{{ max_new_tokens={:d} }}", c.max_new_tokens);
+ }
+};
+
+template<>
+struct fmt::formatter : formatter {
+ auto format(huggingface::tgi::backends::trtllm::sampling_params_t const &c,
+ format_context &ctx) const -> format_context::iterator {
+ return fmt::format_to(
+ ctx.out(),
+ "sampling_params_t{{ top_k={:d}, top_p={:.3f}, repetition_penalty={:.3f}, frequency_penalty={:.3f}, temperature={:.3f}, seed={:d} }}",
+ c.top_k, c.top_p, c.repetition_penalty, c.frequency_penalty, c.temperature, c.seed
+ );
+ }
+};
+
+#endif
diff --git a/backends/trtllm/csrc/ffi.hpp b/backends/trtllm/csrc/ffi.hpp
new file mode 100644
index 000000000..840614bbc
--- /dev/null
+++ b/backends/trtllm/csrc/ffi.hpp
@@ -0,0 +1,191 @@
+#ifndef TGI_BACKEND_TRTLLM_FFI
+#define TGI_BACKEND_TRTLLM_FFI
+
+#include
+#include
+
+#include
+#include
+#include
+
+#include
+
+#include
+#include
+
+namespace rust::behavior {
+ template
+ static void trycatch(Try &&func, Fail &&fail) noexcept try {
+ func();
+ } catch (tensorrt_llm::common::TllmException &e) {
+ fail(e.what());
+ }
+}
+
+namespace huggingface::tgi::backends::trtllm {
+ class tensorrt_llm_backend_t;
+}
+
+#include "backends/trtllm/src/lib.rs.h"
+
+
+namespace huggingface::tgi::backends::trtllm {
+ std::once_flag backend_initialized_flag;
+
+ constexpr finish_reason_t as_finish_reason_t(const tle::FinishReason reason) noexcept {
+ switch (reason) {
+ case tle::FinishReason::kNOT_FINISHED:
+ return finish_reason_t::kNOT_FINISHED;
+ case tle::FinishReason::kSTOP_WORDS:
+ return finish_reason_t::kSTOP_WORDS;
+ case tle::FinishReason::kEND_ID:
+ return finish_reason_t::kEND_ID;
+ case tle::FinishReason::kLENGTH:
+ return finish_reason_t::kLENGTH;
+ default:
+ std::unreachable();
+ }
+ }
+
+ static auto as_generation_step = [](const tle::Response &r) {
+ const auto reqId = r.getRequestId();
+ if (!r.hasError()) [[likely]] {
+ const auto result = r.getResult();
+ const auto logits = result.logProbs.value()[0];
+ return generation_step_t{
+ reqId,
+ static_cast(result.outputTokenIds[0][0]),
+ logits.back(),
+ result.isFinal,
+ as_finish_reason_t(result.finishReasons[0]),
+ false,
+ std::string()
+ };
+ } else {
+ return generation_step_t{
+ reqId,
+ 0,
+ 0.0,
+ true,
+ finish_reason_t::kNOT_FINISHED,
+ true,
+ std::move(r.getErrorMsg())
+ };
+ }
+ };
+
+
+ class tensorrt_llm_backend_t {
+ private:
+ backend_t inner_;
+
+ public:
+ tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path)
+ : inner_(engine_folder, executor_worker_path) {}
+
+ size_t num_tokens_ready() const noexcept { return inner_.num_tokens_ready(); }
+
+ request_id_t submit(
+ rust::Slice tokens,
+ uint32_t max_new_tokens,
+ uint32_t top_k,
+ float_t top_p,
+ float_t temperature,
+ float_t repetition_penalty,
+ float_t frequency_penalty,
+ uint64_t seed
+ ) {
+ // This is enabled only if using add_compile_definitions(SPDLOG_ACTIVE_LEVEL=SPDLOG_LEVEL_TRACE)
+ SPDLOG_TRACE(FMT_STRING("[FFI] Submitting {:d} prompt tokens to the executor"));
+
+ // Submit the request to the executor and get back a potential request_id used to track request status
+ const auto signed_tokens = std::vector(tokens.begin(), tokens.end());
+ const auto maybe_request_id = inner_.submit(
+ signed_tokens,
+ {max_new_tokens},
+ {top_k, top_p, repetition_penalty, frequency_penalty, temperature, seed}
+ );
+
+ // If we do have a value, let's return the request_id
+ if (maybe_request_id.has_value()) [[likely]] {
+ return *maybe_request_id;
+ } else {
+ SPDLOG_WARN("[FFI] Failed to submit request to the executor");
+ return maybe_request_id.error();
+ }
+ }
+
+ std::unique_ptr> pull_tokens() noexcept {
+ if (num_tokens_ready() > 0) [[likely]] {
+ const auto responses = inner_.pull_tokens();
+
+ SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size());
+
+ // Transform tle::Response to generation_step_t
+#ifdef __cpp_lib_ranges_to_container
+ auto steps = responses | std::views::transform(as_generation_step) | std::ranges::to();
+#else
+ auto steps = std::vector();
+ steps.reserve(responses.size());
+ std::transform(responses.begin(), responses.end(), std::back_inserter(steps), as_generation_step);
+#endif
+ return std::make_unique>(steps);
+
+ } else {
+ return std::make_unique>();
+ }
+ }
+
+ void cancel(request_id_t request_id) noexcept {
+ SPDLOG_DEBUG("[FFI] cancelling request {:d}", request_id);
+ inner_.cancel(request_id);
+ }
+ };
+
+ void initialize_logging() {
+#ifndef TGI_TRTLLM_BACKEND_DEBUG
+ if (const auto TRTLLM_LOG_LEVEL_CSTR = std::getenv("TRTLLM_LOG_LEVEL")) {
+ std::string log_level(TRTLLM_LOG_LEVEL_CSTR);
+ std::transform(log_level.begin(), log_level.end(), log_level.begin(), [](unsigned char c) {
+ return std::tolower(c);
+ });
+
+ if (log_level == "debug")
+ spdlog::set_level(spdlog::level::debug);
+ else
+ spdlog::set_level(spdlog::level::info);
+ }
+#else
+ spdlog::set_level(spdlog::level::debug);
+#endif
+ }
+
+ void initialize_tensorrt_llm_backend() {
+ SPDLOG_INFO("Initializing TGI - TensoRT-LLM Backend (v{})", tle::version());
+
+ // Initialize everyone
+ initialize_logging();
+ nvmlInit_v2();
+ initTrtLlmPlugins();
+
+ const auto numGpus = huggingface::tgi::hardware::cuda::get_device_count();
+ if (numGpus.has_value()) {
+ SPDLOG_INFO("[FFI] Detected {:d} Nvidia GPU(s)", *numGpus);
+ } else {
+ SPDLOG_WARN("[FFI] Failed to detected Nvidia GPU(s) on the system");
+ // todo: throw
+ }
+ }
+
+ std::unique_ptr
+ create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) {
+ std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend);
+ return std::make_unique(
+ std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()),
+ std::filesystem::path::format::auto_format),
+ std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()),
+ std::filesystem::path::format::auto_format)
+ );
+ }
+}
+#endif
diff --git a/backends/trtllm/csrc/hardware.hpp b/backends/trtllm/csrc/hardware.hpp
new file mode 100644
index 000000000..abfb4afd5
--- /dev/null
+++ b/backends/trtllm/csrc/hardware.hpp
@@ -0,0 +1,81 @@
+#ifndef TGI_HARDWARE_CUDA
+#define TGI_HARDWARE_CUDA
+#include
+#include
+
+#include
+
+namespace huggingface::tgi::hardware::cuda {
+ static constexpr auto VOLTA = std::make_tuple(7u, 0u);
+ static constexpr auto TURING = std::make_tuple(7u, 5u);
+ static constexpr auto AMPERE = std::make_tuple(8u, 0u);
+ static constexpr auto HOPPER = std::make_tuple(9u, 0u);
+ static constexpr auto ADA_LOVELACE = std::make_tuple(8u, 9u);
+
+ /**
+ * Get the number of GPUs on the local machine
+ * @return std::nullopt if no device is available, otherwise >= 1
+ */
+ inline std::optional get_device_count() {
+ uint32_t numGpus = 0;
+ if (nvmlDeviceGetCount_v2(&numGpus) == NVML_SUCCESS) {
+ return numGpus;
+ }
+ return std::nullopt;
+ }
+
+ /**
+ * Store information about the version of the CUDA Compute Capabilities detected on the device
+ */
+ struct compute_capabilities_t {
+ int32_t major;
+ int32_t minor;
+
+ compute_capabilities_t(): compute_capabilities_t(0) {}
+ explicit compute_capabilities_t(size_t device_idx): major(-1), minor(-1) {
+ nvmlDevice_t device;
+ if (nvmlDeviceGetHandleByIndex_v2(device_idx, &device) == NVML_SUCCESS) {
+ nvmlDeviceGetCudaComputeCapability(device, &major, &minor);
+ }
+ };
+ compute_capabilities_t(int32_t major, int32_t minor): major(major), minor(minor) {}
+
+ /**
+ * Evaluate if the underlying capabilities is at least greater or equals to the provided 2-tuple (major, minor)
+ * @param sm Architecture version (major, minor)
+ * @return True if greater or equals to the underlying compute capabilities
+ */
+ [[nodiscard]] constexpr auto is_at_least(std::tuple sm) const -> decltype(auto) { return std::tie(major, minor) >= sm; }
+
+ /**
+ * Check if the capabilities match at least Volta architecture (sm_70)
+ * @return true if at least Volta (>= sm_70), false otherwise
+ */
+ [[nodiscard]] constexpr bool is_at_least_volta() const { return is_at_least(VOLTA); }
+
+ /**
+ * Check if the capabilities match at least Turing architecture (sm_75)
+ * @return true if at least Turing (>= sm_75), false otherwise
+ */
+ [[nodiscard]] constexpr bool is_at_least_turing() const { return is_at_least(TURING); }
+
+ /**
+ * Check if the capabilities match at least Ampere architecture (sm_80)
+ * @return true if at least Ampere (>= sm_80), false otherwise
+ */
+ [[nodiscard]] constexpr bool is_at_least_ampere() const { return is_at_least(AMPERE); }
+
+ /**
+ * Check if the capabilities match at least Ada Lovelace architecture (sm_89)
+ * @return true if at least Ada Lovelace (>= sm_89), false otherwise
+ */
+ [[nodiscard]] constexpr bool is_at_least_ada_lovelace() const { return is_at_least(ADA_LOVELACE); }
+
+ /**
+ * Check if the capabilities match at least Hopper architecture (sm_90)
+ * @return true if at least Hopper (>= sm_90), false otherwise
+ */
+ [[nodiscard]] constexpr bool is_at_least_hopper() const { return is_at_least(HOPPER); }
+ };
+}
+#endif
diff --git a/backends/trtllm/include/backend.h b/backends/trtllm/include/backend.h
deleted file mode 100644
index d23f62889..000000000
--- a/backends/trtllm/include/backend.h
+++ /dev/null
@@ -1,144 +0,0 @@
-//
-// Created by Morgan Funtowicz on 6/30/24.
-//
-
-#ifndef TGI_TRTLLM_BACKEND_H
-#define TGI_TRTLLM_BACKEND_H
-
-#include
-#include
-#include
-#include
-#include
-
-#include
-
-#include
-#include
-#include
-
-using json = nlohmann::json;
-namespace tle = tensorrt_llm::executor;
-
-
-#define CAST_SIZETYPE(x) static_cast(x)
-
-namespace huggingface::tgi::backends {
- using RequestId = tle::IdType;
- using TokenId = tle::TokenIdType;
-
- const static auto OUTPUT_CONFIG = tle::OutputConfig(true, false, false, true, false);
- constexpr auto FMT_NOT_ENOUGH_GPUS = FMT_STRING(
- "Not enough GPUs to allocate requested model (detected: {:d}, required: {:d})");
- constexpr auto FMT_EXECUTOR_STATS = FMT_STRING(
- "Submitting inference [{}] to the executor ({:d} already in-flight)");
- constexpr auto FMT_SAMPLING_CONFIG = FMT_STRING(
- "Sampling: topK={:d}, topP={:.1f}, temperature={:.1f}, repetition_penalty={:.1f}, frequency_penalty={:.1f}, seed={:d}");
-
- /**
- * Initialize all the components required by TRTLLM.
- * It is required to call this function before attempting to load any engine
- */
- void InitializeBackend();
-
- /**
- * Initialize logging mechanism
- */
- void InitializeLogging();
-
-
- /**
- *
- * @param config TensorRT-LLM configuration object
- * @param workerPath Path to the "executorWorker" provided by TensorRT-LLM when using orchestrator mode
- * @return
- */
- tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath);
-
- /**
- *
- * @param worldSize
- * @param workerPath
- * @return
- */
- tle::ParallelConfig GetParallelConfig(size_t worldSize, std::string workerPath) noexcept;
-
- /**
- * Get the sampling configuration from the parameters provided by TGI
- * @param topK
- * @param topP
- * @param temperature
- * @param repetition_penalty
- * @param frequency_penalty
- * @param seed
- * @return
- */
- tle::SamplingConfig GetSamplingConfig(
- uint32_t topK,
- float_t topP,
- float_t temperature,
- float_t repetition_penalty,
- float_t frequency_penalty,
- uint64_t seed
- ) noexcept;
-
- /**
- * Attempt to retrieve the
- * @param generationConfigPath
- * @return
- */
- std::optional>>
- GetStopWordsFromConfig(const std::filesystem::path &generationConfigPath) noexcept;
-
- /**
- *
- */
- class TensorRtLlmBackend {
- private:
- const json config;
- tle::Executor executor;
-
- /** Frequently accessed variables cached here **/
- uint32_t maxNumTokens;
- std::list> stopWords;
-
- public:
- explicit TensorRtLlmBackend(
- const std::filesystem::path &engineFolder,
- const std::filesystem::path &executorWorker
- );
-
- /**
- * Query the executor for the number of token available for pulling
- * @return
- */
- [[nodiscard]] size_t NumResponsesReady() const;
-
- /**
- * Submit a new generation task to the executor
- * @param tokens
- * @param topK
- * @param topP
- * @param temperature
- * @param repetitionPenalty
- * @param frequencyPenalty
- * @param seed
- * @return Request id related to this generation for reference
- */
- [[nodiscard]] RequestId Submit(
- const std::vector &tokens,
- uint32_t maxNewTokens,
- int32_t topK,
- float_t topP,
- float_t temperature,
- float_t repetitionPenalty,
- float_t frequencyPenalty,
- uint64_t seed
- );
-
- [[nodiscard]] std::vector PullNewTokens();
- };
-}
-
-
-#endif //TGI_TRTLLM_BACKEND_H
diff --git a/backends/trtllm/include/ffi.h b/backends/trtllm/include/ffi.h
deleted file mode 100644
index 449bcd4d7..000000000
--- a/backends/trtllm/include/ffi.h
+++ /dev/null
@@ -1,75 +0,0 @@
-//
-// Created by mfuntowicz on 7/11/24.
-//
-
-#ifndef TGI_TRTLLM_BACKEND_FFI_H
-#define TGI_TRTLLM_BACKEND_FFI_H
-
-#include
-#include
-#include
-#include "backend.h"
-
-namespace huggingface::tgi::backends {
- class TensorRtLlmBackendImpl;
-}
-
-// Template to support returning error from TllmException back to Rust in a Result<>
-#include
-
-namespace rust::behavior {
- template
- static void trycatch(Try &&func, Fail &&fail) noexcept try {
- func();
- } catch (tensorrt_llm::common::TllmException &e) {
- fail(e.what());
- }
-}
-
-#include "backends/trtllm/src/lib.rs.h"
-
-namespace huggingface::tgi::backends {
-
- class TensorRtLlmBackendImpl : public TensorRtLlmBackend {
- public:
- /***
- *
- * @param engineFolder
- * @param executorWorker
- */
- TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker);
-
- /***
- *
- * @param tokens
- * @param maxNewTokens
- * @param topK
- * @param topP
- * @param temperature
- * @param repetition_penalty
- * @param frequency_penalty
- * @param seed
- * @return
- */
- [[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
- uint64_t
- Submit(rust::Slice tokens, uint32_t maxNewTokens,
- int32_t topK, float_t topP, float_t temperature,
- float_t repetition_penalty, float_t frequency_penalty, uint64_t seed);
-
- /***
- *
- * @return
- */
- std::unique_ptr> PullTokens();
- };
-
- /***
- *
- * @param engineFolder
- * @return
- */
- std::unique_ptr CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker);
-}
-
-#endif //TGI_TRTLLM_BACKEND_FFI_H
diff --git a/backends/trtllm/include/hardware.h b/backends/trtllm/include/hardware.h
deleted file mode 100644
index 9633495f4..000000000
--- a/backends/trtllm/include/hardware.h
+++ /dev/null
@@ -1,59 +0,0 @@
-//
-// Created by mfuntowicz on 7/23/24.
-//
-
-#ifndef TGI_TRTLLM_BACKEND_HARDWARE_H
-#define TGI_TRTLLM_BACKEND_HARDWARE_H
-
-#include
-#include
-#include
-#include
-#include
-
-namespace huggingface::hardware::cuda {
-
-#define AMPERE_SM_MAJOR 8
-#define HOPPER_SM_MAJOR 9
-
- /**
- * Store information about the version of the CUDA Compute Capabilities detected on the device
- */
- struct CudaComputeCapabilities {
- int32_t major;
- int32_t minor;
-
- [[nodiscard]] constexpr bool IsPostAmpere() const { return major >= AMPERE_SM_MAJOR; }
-
- [[nodiscard]] constexpr bool IsPostHopper() const { return major >= HOPPER_SM_MAJOR; }
- };
-
- CudaComputeCapabilities GetCudaComputeCapabilities() {
- // Get the compute capabilities of the current hardware
- nvmlDevice_t device;
- CudaComputeCapabilities capabilities{0, 0};
- if (nvmlDeviceGetHandleByIndex_v2(0, &device) == NVML_SUCCESS) {
- SPDLOG_DEBUG("Successfully acquired nvmlDevice_t = 0");
- if (nvmlDeviceGetCudaComputeCapability(device, &capabilities.major, &capabilities.minor) == NVML_SUCCESS) {
- SPDLOG_INFO("Detected sm_{:d}{:d} compute capabilities", capabilities.major, capabilities.minor);
- }
- }
-
- return capabilities;
- }
-
- /**
- * Return the number of GPU detected. If no GPU is detected, return size_t::max()
- * @return
- */
- std::optional GetNumDevices() {
- uint32_t numGpus = 0;
- if (nvmlDeviceGetCount_v2(&numGpus) == NVML_SUCCESS) {
- return std::optional(numGpus);
- } else {
- return std::nullopt;
- }
- }
-}
-
-#endif //TGI_TRTLLM_BACKEND_HARDWARE_H
diff --git a/backends/trtllm/lib/backend.cpp b/backends/trtllm/lib/backend.cpp
deleted file mode 100644
index 4dd41de00..000000000
--- a/backends/trtllm/lib/backend.cpp
+++ /dev/null
@@ -1,203 +0,0 @@
-#include
-#include
-
-#include
-#include
-#include
-
-#include "backend.h"
-#include "hardware.h"
-
-
-void huggingface::tgi::backends::InitializeLogging() {
-#ifdef NDEBUG
- if (const auto TRTLLM_LOG_LEVEL_CSTR = std::getenv("TRTLLM_LOG_LEVEL")) {
- std::string log_level(TRTLLM_LOG_LEVEL_CSTR);
- std::transform(log_level.begin(), log_level.end(), log_level.begin(), [](unsigned char c) {
- return std::tolower(c);
- });
-
- if (log_level == "debug")
- spdlog::set_level(spdlog::level::debug);
- else
- spdlog::set_level(spdlog::level::info);
- }
-#else
- spdlog::set_level(spdlog::level::debug);
-#endif
-}
-
-void huggingface::tgi::backends::InitializeBackend() {
- SPDLOG_INFO("Initializing Backend...");
- nvmlInit_v2();
- initTrtLlmPlugins();
-
- InitializeLogging();
-
- SPDLOG_INFO("Backend Executor Version: {}", tle::version());
- const auto numGpus = huggingface::hardware::cuda::GetNumDevices();
- if (numGpus.has_value()) {
- SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value());
- } else {
- SPDLOG_WARN("Failed to detected Nvidia GPU(s) on the system");
- }
-}
-
-[[nodiscard]]
-tle::ParallelConfig
-huggingface::tgi::backends::GetParallelConfig(const size_t worldSize, const std::string workerPath) noexcept {
- auto mode = tle::CommunicationMode::kLEADER;
- std::optional orchestratorConfig = std::nullopt;
-
- if (worldSize > 1) {
- SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
- mode = tle::CommunicationMode::kORCHESTRATOR;
- orchestratorConfig = std::make_optional(true, workerPath, nullptr, true);
- } else {
- SPDLOG_INFO("Detected single engine deployment, using leader mode");
- }
-
- return tle::ParallelConfig(tle::CommunicationType::kMPI, mode, std::nullopt, std::nullopt, orchestratorConfig);
-}
-
-[[nodiscard]]
-tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
- tle::ExecutorConfig execConfig(/* maxBeamWidth = */ 1);
-
- // Retrieve the compute capabilities to enable some options at runtime
- const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities();
-
- // Single engine (TP = PP = 1) -> using leader mode (no MPI involved)
- const auto worldSize = config["/pretrained_config/mapping/world_size"_json_pointer].get();
- execConfig.setParallelConfig(GetParallelConfig(worldSize, workerPath));
-
- // Define some configuration variables
- execConfig.setKvCacheConfig(tle::KvCacheConfig(true));
- execConfig.setEnableChunkedContext(computeCapabilities.IsPostAmpere());
- execConfig.setSchedulerConfig(tle::SchedulerConfig(tle::CapacitySchedulerPolicy::kMAX_UTILIZATION));
- return execConfig;
-}
-
-tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
- const uint32_t topK,
- const float_t topP,
- const float_t temperature,
- const float_t repetition_penalty,
- const float_t frequency_penalty,
- const uint64_t seed) noexcept {
-
- return tle::SamplingConfig(
- 1, // TGI only use a single beam
- topK,
- topP,
- std::nullopt,
- std::nullopt,
- std::nullopt,
- seed,
- temperature,
- temperature,
- std::nullopt,
- repetition_penalty,
- std::nullopt,
- frequency_penalty
- );
-}
-
-std::optional>>
-huggingface::tgi::backends::GetStopWordsFromConfig(
- const std::filesystem::path &generationConfigPath) noexcept {
- if (exists(generationConfigPath)) {
- const auto generationConfig = json::parse(std::ifstream(generationConfigPath));
- if (const auto eosTokenIds = generationConfig["/eos_token_id"_json_pointer]; eosTokenIds.is_array()) {
- SPDLOG_INFO(FMT_STRING("Found {:d} EOS tokens"), eosTokenIds.size());
- std::list> stopWords(eosTokenIds.size());
-
- const auto to_single_token = [](const auto tokenIdObj) -> decltype(stopWords)::value_type {
- return {tokenIdObj.template get()};
- };
-
- std::transform(eosTokenIds.cbegin(), eosTokenIds.cend(), stopWords.begin(), to_single_token);
- return stopWords;
- } else {
- SPDLOG_INFO("Invalid EOS tokens entry found (not an array)");
- }
- } else {
- SPDLOG_INFO("No EOS tokens found, generation_config.json doesn't exist");
- }
-
- return std::nullopt;
-}
-
-huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
- const std::filesystem::path &enginesFolder,
- const std::filesystem::path &executorWorker
-) :
- config(json::parse(std::ifstream(enginesFolder / "config.json"))),
- executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY,
- GetExecutorConfig(config, executorWorker.string())) {
-
- SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get());
-
- // Ensure we have enough GPUs on the system
- const auto worldSize = config["/pretrained_config/mapping/world_size"_json_pointer].get();
- const auto numGpus = huggingface::hardware::cuda::GetNumDevices().value_or(0);
- if (numGpus < worldSize) {
- SPDLOG_CRITICAL(FMT_NOT_ENOUGH_GPUS, numGpus, worldSize);
- // todo : raise exception to catch on rust side
- }
-
- // Cache variables
- maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get();
-
- // Attempt to discover stopWords from the generation_config.json
- const auto generationConfigPath = enginesFolder / "generation_config.json";
- stopWords = GetStopWordsFromConfig(generationConfigPath).value_or(std::list>());
-}
-
-[[nodiscard("Returned number of requests needs to be consumed")]]
-size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const {
-#ifdef NDEBUG
- return executor.getNumResponsesReady();
-#else
- const auto numResponses = executor.getNumResponsesReady();
- if (numResponses > 0) SPDLOG_INFO(FMT_STRING("Num responses ready: {:d}"), numResponses);
- return numResponses;
-#endif
-}
-
-[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
-tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
- const std::vector &tokens,
- const uint32_t maxNewTokens,
- const int32_t topK,
- const float_t topP,
- const float_t temperature,
- const float_t repetitionPenalty,
- const float_t frequencyPenalty,
- const uint64_t seed
-) {
- const auto maxNewTokensChecked = std::min(maxNewTokens, static_cast(maxNumTokens - tokens.size()));
-#ifndef NDEBUG
- {
- const auto &iterations = executor.getLatestIterationStats();
- const auto &lastIteration = iterations.front();
-
- SPDLOG_DEBUG(FMT_EXECUTOR_STATS, fmt::join(tokens, ", "), lastIteration.numActiveRequests);
- SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed);
- SPDLOG_DEBUG(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked);
- }
-#endif
-
- const auto sampling = GetSamplingConfig(topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed);
-
- // Build the request
- auto request = tle::Request{tokens, CAST_SIZETYPE(maxNewTokensChecked), true, sampling, OUTPUT_CONFIG};
- request.setStopWords(stopWords);
-
- // Submit to the executor for batching
- return executor.enqueueRequest(request);
-}
-
-std::vector huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() {
- return executor.awaitResponses();
-}
diff --git a/backends/trtllm/scripts/install_tensorrt.sh b/backends/trtllm/scripts/install_tensorrt.sh
index 4c2dc26b6..e09db6b12 100755
--- a/backends/trtllm/scripts/install_tensorrt.sh
+++ b/backends/trtllm/scripts/install_tensorrt.sh
@@ -2,13 +2,13 @@
set -ex
-TRT_VER_BASE="10.4.0"
-TRT_VER_FULL="${TRT_VER_BASE}.26"
-CUDA_VER="12.6"
-CUDNN_VER="9.5.0.50-1"
-NCCL_VER="2.22.3-1+cuda12.6"
-CUBLAS_VER="12.6.3.3-1"
-NVRTC_VER="12.6.77-1"
+TRT_VER_BASE="10.8.0"
+TRT_VER_FULL="${TRT_VER_BASE}.43"
+CUDA_VER="12.8"
+CUDNN_VER="9.7.0.66-1"
+NCCL_VER="2.25.1-1+cuda${CUDA_VER}"
+CUBLAS_VER="${CUDA_VER}.3.14-1"
+NVRTC_VER="${CUDA_VER}.61-1"
for i in "$@"; do
case $i in
@@ -73,7 +73,7 @@ install_centos_requirements() {
install_tensorrt() {
#PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')
#PARSED_PY_VERSION=$(echo "${PY_VERSION//./}")
- TRT_CUDA_VERSION="12.6"
+ TRT_CUDA_VERSION="12.8"
if [ -z "$RELEASE_URL_TRT" ];then
ARCH=${TRT_TARGETARCH}
diff --git a/backends/trtllm/scripts/setup_sccache.py b/backends/trtllm/scripts/setup_sccache.py
new file mode 100644
index 000000000..65fdee235
--- /dev/null
+++ b/backends/trtllm/scripts/setup_sccache.py
@@ -0,0 +1,51 @@
+from argparse import ArgumentParser
+
+AWS_S3_CACHING_VARIABLES = {
+ "AWS_ACCESS_KEY_ID": "aws_access_key_id",
+ "AWS_SECRET_ACCESS_KEY": "aws_secret_access_key",
+ "AWS_SESSION_TOKEN": "aws_session_token",
+ "SCCACHE_REGION": "s3_region",
+ "SCCACHE_BUCKET": "s3_bucket_name",
+}
+
+ALL_CACHING_STORAGE_VARIABLES = {"AWS_S3_CACHING_VARIABLES"}
+
+
+def setup_sccache_locally():
+ from os import environ
+
+ print("Setting up Local Caching Layer")
+ for target in ALL_CACHING_STORAGE_VARIABLES:
+ for envvar in globals()[target].keys():
+ if envvar in environ:
+ print(f"Deleted {envvar} from environment variables")
+ del environ[envvar]
+
+
+def setup_sccache_for_s3():
+ from os import environ
+
+ print("Setting up AWS S3 Caching Layer")
+ for envvar in AWS_S3_CACHING_VARIABLES.keys():
+ if envvar not in environ or not environ[envvar] or len(environ[envvar]) == 0:
+ print(f"Missing definition for environment variable {envvar}")
+
+
+if __name__ == "__main__":
+ parser = ArgumentParser("TensorRT-LLM Build Caching Setup")
+
+ parser.add_argument(
+ "--is-gha-build",
+ type=str,
+ default="FALSE",
+ help="Indicate if the build is from Github Actions",
+ )
+
+ # Parse args
+ args = parser.parse_args()
+ args.is_gha_build = args.is_gha_build.lower() in {"on", "true", "1"}
+
+ if args.is_gha_build:
+ setup_sccache_for_s3()
+ else:
+ setup_sccache_locally()
diff --git a/backends/trtllm/src/ffi.cpp b/backends/trtllm/src/ffi.cpp
deleted file mode 100644
index 0a92c050f..000000000
--- a/backends/trtllm/src/ffi.cpp
+++ /dev/null
@@ -1,89 +0,0 @@
-//
-// Created by mfuntowicz on 6/30/24.
-//
-#pragma once
-
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-
-#include
-#include "backends/trtllm/include/ffi.h"
-
-
-huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(
- const std::string_view &engineFolder,
- const std::string_view &executorWorker
-) : TensorRtLlmBackend(engineFolder, executorWorker) {}
-
-
-uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
- rust::Slice tokens,
- uint32_t maxNewTokens,
- int32_t topK,
- float_t topP,
- float_t temperature,
- float_t repetition_penalty,
- float_t frequency_penalty,
- uint64_t seed) {
-
- // This will copy all the items from the initial slice
- std::vector tokens_(tokens.begin(), tokens.end());
- return TensorRtLlmBackend::Submit(
- std::move(tokens_), maxNewTokens, topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
-}
-
-std::unique_ptr>
-huggingface::tgi::backends::TensorRtLlmBackendImpl::PullTokens() {
- const auto responses = TensorRtLlmBackend::PullNewTokens();
-
- auto steps = std::make_unique>();
- steps->reserve(responses.size());
-
-#ifndef NDEBUG
- SPDLOG_DEBUG(FMT_STRING("Pulled out {:d} new tokens"), responses->size());
-#endif
-
- // Transform tle::Response to GenerationStep
- std::ranges::transform(responses.begin(), responses.end(), std::back_inserter(*steps), [](const tle::Response &r) {
- const auto reqId = r.getRequestId();
- if (!r.hasError()) {
- const auto result = r.getResult();
- return GenerationStep{
- reqId,
- static_cast(result.outputTokenIds[0][0]),
- result.logProbs.value()[0][0],
- result.isFinal,
- false,
- std::string()
- };
- } else {
- return GenerationStep{
- reqId,
- 0,
- 0.0,
- true,
- true,
- std::move(r.getErrorMsg())
- };
- }
- });
-
- return steps;
-}
-
-std::unique_ptr
-huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
- SPDLOG_INFO("Creating TensorRT-LLM Backend");
- // Unconditionally call this to initialize and discover TRTLLM plugins
- InitializeBackend();
-
- const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end());
- const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end());
- return std::make_unique(std::move(enginePath), std::move(executorPath));
-}
diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs
index edd8caff1..085072561 100644
--- a/backends/trtllm/src/lib.rs
+++ b/backends/trtllm/src/lib.rs
@@ -4,24 +4,47 @@ pub mod errors;
mod looper;
mod utils;
-#[cxx::bridge(namespace = "huggingface::tgi::backends")]
+#[cxx::bridge(namespace = "huggingface::tgi::backends::trtllm")]
mod ffi {
+ #[cxx_name = "finish_reason_t"]
+ #[derive(Debug, Clone, Copy)]
+ pub enum FinishReason {
+ /// The request is not finished.
+ #[cxx_name = "kNOT_FINISHED"]
+ NotFinished = 0u8,
+
+ /// The request finished because the end id was generated.
+ #[cxx_name = "kEND_ID"]
+ EndTokenId = 1u8,
+
+ /// The request finished because a stop word was generated.
+ #[cxx_name = "kSTOP_WORDS"]
+ StopWords = 2u8,
+
+ /// The request finished because the maximum number of tokens was reached.
+ #[cxx_name = "kLENGTH"]
+ MaxLength = 3u8,
+ }
+
/// Struct used as shared type between rust and C++ to represent the result
/// of a single decoding iteration
+ #[cxx_name = "generation_step_t"]
#[derive(Debug, Clone)]
pub struct GenerationStep {
request_id: u64,
token_id: u32,
log_prob: f32,
is_final: bool,
+ finish_reason: FinishReason,
has_error: bool,
error_msg: String,
}
unsafe extern "C++" {
- include!("backends/trtllm/src/ffi.cpp");
+ include!("backends/trtllm/csrc/ffi.hpp");
/// Represent an instance of the underlying TensorRT-LLM backend
+ #[cxx_name = "tensorrt_llm_backend_t"]
type TensorRtLlmBackendImpl;
/// Create an instance backed behind a std::unique_ptr to manage the lifespan of the backend
@@ -38,21 +61,18 @@ mod ffi {
/// ```
///
/// ```
- #[rust_name = "create_tensorrt_llm_backend"]
- fn CreateTensorRtLlmBackend(
+ fn create_backend_from_engine_folder(
engine_folder: &str,
executor_worker: &str,
) -> Result>;
- #[rust_name = "num_responses_ready"]
- fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
+ fn num_tokens_ready(self: &TensorRtLlmBackendImpl) -> usize;
- #[rust_name = "submit"]
- fn Submit(
+ fn submit(
self: Pin<&mut TensorRtLlmBackendImpl>,
tokens: &[u32],
max_new_tokens: u32,
- top_k: i32,
+ top_k: u32,
top_p: f32,
temperature: f32,
repetition_penalty: f32,
@@ -60,9 +80,24 @@ mod ffi {
seed: u64,
) -> Result;
- #[rust_name = "pull_tokens"]
- fn PullTokens(
+ fn pull_tokens(
self: Pin<&mut TensorRtLlmBackendImpl>,
) -> Result>>;
+
+ fn cancel(self: Pin<&mut TensorRtLlmBackendImpl>, request_id: u64);
+ }
+}
+
+use ffi::FinishReason;
+use text_generation_router::FinishReason as InferFinishReason;
+
+impl From for InferFinishReason {
+ fn from(reason: FinishReason) -> Self {
+ match reason {
+ FinishReason::StopWords => InferFinishReason::StopSequence,
+ FinishReason::MaxLength => InferFinishReason::Length,
+ FinishReason::EndTokenId => InferFinishReason::EndOfSequenceToken,
+ _ => panic!("Cannot convert {reason:?} to text_generation_router::FinishReason"),
+ }
}
}
diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs
index e26155c16..5fed954ff 100644
--- a/backends/trtllm/src/looper.rs
+++ b/backends/trtllm/src/looper.rs
@@ -1,14 +1,13 @@
-use std::hint;
-use std::ops::Deref;
-use std::path::Path;
-
use async_trait::async_trait;
use cxx::UniquePtr;
use hashbrown::HashMap;
+use std::hint;
+use std::ops::Deref;
+use std::path::Path;
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::TryAcquireError;
-use tokio::task::{spawn_blocking, JoinHandle};
+use tokio::task::spawn_blocking;
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, warn};
@@ -19,10 +18,12 @@ use text_generation_router::validation::ValidationError::{
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
};
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
-use text_generation_router::{FinishReason, Token};
+use text_generation_router::Token;
use crate::errors::TensorRtLlmBackendError;
-use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
+use crate::ffi::{
+ create_backend_from_engine_folder, FinishReason, GenerationStep, TensorRtLlmBackendImpl,
+};
use crate::utils::first_line;
type InferResult = Result;
@@ -30,9 +31,10 @@ type InferResult = Result;
/// Wrap the requests along with the channel used to stream back to the client the decoded tokens
struct GenerationContext {
request: ValidGenerateRequest,
+ streamer: UnboundedSender>,
+ tokens: Vec,
start: Option,
queued: Instant,
- streamer: UnboundedSender>,
}
#[derive(Debug, Copy, Clone)]
@@ -40,6 +42,7 @@ struct DecodedToken {
id: u32,
log_prob: f32,
is_final: bool,
+ finish_reason: FinishReason,
}
impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
@@ -51,6 +54,7 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
id: step.token_id,
log_prob: step.log_prob,
is_final: step.is_final,
+ finish_reason: step.finish_reason,
})
} else {
Err(GenerationError(step.error_msg.clone()))
@@ -58,31 +62,22 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
}
}
-/// Wraps the decoded token with the channel used to stream back to the client the decoded tokens
-struct DecodedTokenContext {
- token: DecodedToken,
- start: Option,
- queued: Instant,
- channel: UnboundedSender>,
-}
-
fn executor_status_looper(
- mut backend: UniquePtr,
max_inflight_requests: usize,
- mut waiting_requests: UnboundedReceiver,
- post_processor_sender: UnboundedSender<(u64, InferResult)>,
+ tokenizer: Tokenizer,
+ mut backend: UniquePtr,
+ mut backlog: UnboundedReceiver,
) {
// Track the tuple (request_id, stream) for each request
let mut in_flights =
HashMap::::with_capacity(max_inflight_requests * 2);
- // TODO: Does it need a spin-loop?
'scheduler: loop {
// Is there any request pending to be scheduled?
- let awaiting_requests = waiting_requests.len();
+ let awaiting_requests = backlog.len();
for _ in 0..awaiting_requests {
// Retrieve all the requests
- if let Some(mut ctx) = waiting_requests.blocking_recv() {
+ if let Some(ctx) = backlog.blocking_recv() {
// Submit all the request to the executor and move the context to the in-flight tracker
let request = &ctx.request;
let generation_params = &request.parameters;
@@ -93,7 +88,7 @@ fn executor_status_looper(
match backend.pin_mut().submit(
&input_ids.unwrap(), // This is checked beforehand in validate()
stopping_params.max_new_tokens,
- generation_params.top_k as i32,
+ generation_params.top_k,
generation_params.top_p,
generation_params.temperature,
generation_params.repetition_penalty,
@@ -103,7 +98,6 @@ fn executor_status_looper(
Ok(request_id) => {
// Insert the context linked to the generated request id in the tracker
debug!("[in-flight] Added {}", request_id);
- ctx.start = Some(Instant::now());
in_flights.insert(request_id, ctx);
}
Err(e) => {
@@ -117,29 +111,43 @@ fn executor_status_looper(
}
}
};
+ } else {
+ break 'scheduler;
}
}
- if backend.num_responses_ready() > 0 {
- match backend.pin_mut().pull_tokens() {
+ if backend.num_tokens_ready() > 0 {
+ let mut backend = backend.pin_mut();
+ match backend.as_mut().pull_tokens() {
Ok(responses) => {
// Iterate through all the decoded token
for step in responses.deref() {
- if let Some(ctx) = in_flights.get(&step.request_id) {
- // Remove from tracked requests
- let parcel =
- DecodedToken::try_from(step).map(|dt| DecodedTokenContext {
- token: dt,
- start: ctx.start,
- queued: ctx.queued,
- channel: ctx.streamer.clone(),
- });
+ if let Some(ctx) = in_flights.get_mut(&step.request_id) {
+ // Update the starting timestamp if not set
+ // This value might not be the actual real starting time of the request
+ // on the executor side - Need to expose more info from the executor to
+ // retrieve this value
+ // TODO : Expose actual real starting time for a request on FFI layer
+ if ctx.start.is_none() {
+ ctx.start = Some(Instant::now());
+ }
- // Submit the work to p:the post_processor
- let posted = post_processor_sender.send((step.request_id, parcel));
+ // Try to map the generation step to a DecodedToken
+ let response = match DecodedToken::try_from(step) {
+ Ok(decoded_token) => {
+ post_process_decoded_token(&tokenizer, ctx, decoded_token)
+ }
+ Err(err) => Err(err),
+ };
- if posted.is_err() || step.is_final {
- debug!("Removing {}", step.request_id);
+ // Attempt to send back the response to the client
+ if let Err(_) = ctx.streamer.send(response) {
+ // Client has dropped, remove from tracked requests
+ debug!(
+ "Client dropped - removing request {} from tracked requests",
+ step.request_id
+ );
+ backend.as_mut().cancel(step.request_id);
let _ = in_flights.remove(&step.request_id);
}
} else {
@@ -159,80 +167,51 @@ fn executor_status_looper(
}
}
-fn post_processor_looper(
- tokenizer: Tokenizer,
- max_inflight_requests: usize,
- mut decoded_tokens: UnboundedReceiver<(u64, InferResult)>,
-) {
- let mut states: HashMap> = HashMap::with_capacity(max_inflight_requests * 2);
+fn post_process_decoded_token(
+ tokenizer: &Tokenizer,
+ ctx: &mut GenerationContext,
+ decoded_token: DecodedToken,
+) -> InferResult {
+ match tokenizer.decode(&[decoded_token.id], false) {
+ Ok(text) => {
+ let is_special = tokenizer.get_added_vocabulary().is_special_token(&text);
+ let token = Token {
+ id: decoded_token.id,
+ text,
+ logprob: decoded_token.log_prob,
+ special: is_special,
+ };
- 'post_processor: loop {
- if decoded_tokens.is_closed() {
- warn!("Post processor IPC is closed, loop will exit now.");
- break 'post_processor;
- }
+ // Append the token to the tracked generated tokens
+ ctx.tokens.push(token.id);
- if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() {
- match decoded {
- Ok(ctx) => {
- states
- .entry(request_id)
- .and_modify(|s| s.push(*&ctx.token.id))
- .or_insert_with(|| {
- let mut state = Vec::with_capacity(MAX_NUM_TOKENS);
- state.push(*&ctx.token.id);
- state
- });
-
- let out = match tokenizer.decode(&[ctx.token.id], false) {
- Ok(text) => {
- let is_special =
- tokenizer.get_added_vocabulary().is_special_token(&text);
- let token = Token {
- id: ctx.token.id,
- text,
- logprob: ctx.token.log_prob,
- special: is_special,
- };
-
- let out = if !ctx.token.is_final {
- InferStreamResponse::Intermediate {
- token,
- top_tokens: vec![],
- }
- } else {
- let tokens = states.remove(&request_id).unwrap();
- let text = tokenizer.decode(&tokens, true);
- let generated_text = GeneratedText {
- text: text.unwrap(),
- generated_tokens: tokens.len() as u32,
- finish_reason: FinishReason::EndOfSequenceToken,
- seed: None,
- };
-
- InferStreamResponse::End {
- token,
- top_tokens: vec![],
- generated_text,
- start: ctx.start.unwrap(),
- queued: ctx.queued,
- }
- };
-
- Ok(out)
- }
- Err(err) => Err(GenerationError(err.to_string())),
- };
-
- if let Err(_) = ctx.channel.send(out) {
- warn!("Failed to send decoded token back to the user")
- }
+ // Map the correct response depending on the step is final or not
+ let out = if !decoded_token.is_final {
+ InferStreamResponse::Intermediate {
+ token,
+ top_tokens: vec![],
}
- Err(_err) => {
- todo!("what do we do?")
+ } else {
+ let text = tokenizer.decode(&ctx.tokens, true);
+ let generated_text = GeneratedText {
+ text: text.unwrap(),
+ generated_tokens: ctx.tokens.len() as u32,
+ finish_reason: decoded_token.finish_reason.into(),
+ seed: None,
+ };
+
+ InferStreamResponse::End {
+ token,
+ top_tokens: vec![],
+ generated_text,
+ start: ctx.start.unwrap(),
+ queued: ctx.queued,
}
- }
+ };
+
+ Ok(out)
}
+ Err(err) => Err(GenerationError(err.to_string())),
}
}
@@ -277,11 +256,7 @@ fn ensure_paths_exist, PP: AsRef>(
unsafe impl Send for TensorRtLlmBackendImpl {}
-pub struct TensorRtLlmBackendV2 {
- executor_looper: JoinHandle<()>,
- post_processor_looper: JoinHandle<()>,
- executor: UnboundedSender,
-}
+pub struct TensorRtLlmBackendV2(UnboundedSender);
impl TensorRtLlmBackendV2 {
pub fn new + Send, PP: AsRef + Send>(
@@ -295,32 +270,17 @@ impl TensorRtLlmBackendV2 {
// Allocate the IPC layer to communicate with the backend
let (executor_sender, executor_receiver) = unbounded_channel();
- let (post_processor_sender, post_processor_receiver) = unbounded_channel();
// Create the FFI backend
- let backend = create_tensorrt_llm_backend(&engine_folder, &executor_worker_path)
+ let backend = create_backend_from_engine_folder(&engine_folder, &executor_worker_path)
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
// Executor looper is responsible for scheduling and pulling requests state at regular interval
- let executor_looper = spawn_blocking(move || {
- executor_status_looper(
- backend,
- max_inflight_requests,
- executor_receiver,
- post_processor_sender,
- )
+ spawn_blocking(move || {
+ executor_status_looper(max_inflight_requests, tokenizer, backend, executor_receiver)
});
- // Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user
- let post_processor_looper = spawn_blocking(move || {
- post_processor_looper::<256>(tokenizer, max_inflight_requests, post_processor_receiver)
- });
-
- Ok(TensorRtLlmBackendV2 {
- executor_looper,
- post_processor_looper,
- executor: executor_sender,
- })
+ Ok(TensorRtLlmBackendV2(executor_sender))
}
fn validate(request: &ValidGenerateRequest) -> InferResult<()> {
@@ -354,20 +314,21 @@ impl TensorRtLlmBackendV2 {
impl Backend for TensorRtLlmBackendV2 {
fn schedule(
&self,
- inner: ValidGenerateRequest,
+ request: ValidGenerateRequest,
) -> Result>, InferError> {
- Self::validate(&inner)?;
+ Self::validate(&request)?;
// Open-up the stream to send tokens
let (streamer, receiver) = unbounded_channel::>();
// Send the context to the executor for scheduling
let queued = Instant::now();
- match self.executor.send(GenerationContext {
- request: inner,
+ match self.0.send(GenerationContext {
+ request,
+ streamer,
+ tokens: Vec::with_capacity(256),
start: None,
queued,
- streamer,
}) {
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
Err(_) => Err(GenerationError(
@@ -377,6 +338,10 @@ impl Backend for TensorRtLlmBackendV2 {
}
async fn health(&self, _: bool) -> bool {
- !self.executor_looper.is_finished() & !self.post_processor_looper.is_finished()
+ true
+ }
+
+ fn name(&self) -> &'static str {
+ "TensorRT-LLM"
}
}
diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs
index 8ab8c533c..9d4bf8f21 100644
--- a/backends/trtllm/src/main.rs
+++ b/backends/trtllm/src/main.rs
@@ -3,14 +3,15 @@ use std::path::{Path, PathBuf};
use clap::Parser;
use hf_hub::api::tokio::{Api, ApiBuilder};
use hf_hub::{Cache, Repo, RepoType};
-use tokenizers::Tokenizer;
use tracing::info;
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
use text_generation_backends_trtllm::TensorRtLlmBackendV2;
-use text_generation_router::server::get_base_tokenizer;
+use text_generation_router::server::{
+ get_hub_model_info, legacy_tokenizer_handle, py_resolve_tokenizer,
+};
use text_generation_router::usage_stats::UsageStatsLevel;
-use text_generation_router::{server, HubTokenizerConfig};
+use text_generation_router::{server, Tokenizer};
/// App Configuration
#[derive(Parser, Debug)]
@@ -61,16 +62,12 @@ struct Args {
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
executor_worker: PathBuf,
#[clap(default_value = "on", long, env)]
- usage_stats: usage_stats::UsageStatsLevel,
+ usage_stats: UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
}
-async fn get_tokenizer(
- tokenizer_name: &str,
- tokenizer_config_path: Option<&str>,
- revision: Option<&str>,
-) -> Option {
+async fn get_tokenizer(tokenizer_name: &str, revision: Option<&str>) -> Option {
// Parse Huggingface hub token
let authorization_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
@@ -89,6 +86,10 @@ async fn get_tokenizer(
builder = builder.with_cache_dir(cache_dir.into());
}
+ if let Ok(origin) = std::env::var("HF_HUB_USER_AGENT_ORIGIN") {
+ builder = builder.with_user_agent("origin", origin.as_str());
+ }
+
builder
};
@@ -126,18 +127,18 @@ async fn get_tokenizer(
// Load tokenizer and model info
let (
- tokenizer_filename,
- _config_filename,
- tokenizer_config_filename,
+ config_filename,
+ _tokenizer_config_filename,
_preprocessor_config_filename,
_processor_config_filename,
+ _model_info,
) = match api {
Type::None => (
- Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")),
+ None,
),
Type::Api(api) => {
let api_repo = api.repo(Repo::with_revision(
@@ -146,21 +147,23 @@ async fn get_tokenizer(
revision.unwrap_or_else(|| "main").to_string(),
));
- let tokenizer_filename = match api_repo.get("tokenizer.json").await {
- Ok(tokenizer_filename) => Some(tokenizer_filename),
- Err(_) => get_base_tokenizer(&api, &api_repo).await,
- };
let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
+ let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {
+ Some(model_info)
+ } else {
+ tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
+ None
+ };
(
- tokenizer_filename,
config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
+ model_info,
)
}
Type::Cache(cache) => {
@@ -170,24 +173,42 @@ async fn get_tokenizer(
revision.clone().unwrap_or_else(|| "main").to_string(),
));
(
- repo.get("tokenizer.json"),
repo.get("config.json"),
repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"),
repo.get("processor_config.json"),
+ None,
)
}
};
- // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
- let tokenizer_config: Option = if let Some(filename) = tokenizer_config_path
- {
- HubTokenizerConfig::from_file(filename)
- } else {
- tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
+ let tokenizer: Tokenizer = {
+ use pyo3::prelude::*;
+ pyo3::Python::with_gil(|py| -> PyResult<()> {
+ py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), false)?;
+ Ok(())
+ })
+ .inspect_err(|err| {
+ tracing::error!("Failed to import python tokenizer {err}");
+ })
+ .or_else(|err| {
+ let out = legacy_tokenizer_handle(config_filename.as_ref());
+ out.ok_or(err)
+ })
+ .expect("We cannot load a tokenizer");
+ let filename = "out/tokenizer.json";
+ if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) {
+ Tokenizer::Rust(tok)
+ } else {
+ Tokenizer::Python {
+ tokenizer_name: tokenizer_name.to_string(),
+ revision: revision.map(|revision| revision.to_string()),
+ trust_remote_code: false,
+ }
+ }
};
- tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok())
+ Some(tokenizer)
}
#[tokio::main]
@@ -258,50 +279,52 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
}
// Create the backend
- let tokenizer = get_tokenizer(
- &tokenizer_name,
- tokenizer_config_path.as_deref(),
- revision.as_deref(),
- )
- .await
- .expect("Failed to retrieve tokenizer implementation");
+ match get_tokenizer(&tokenizer_name, revision.as_deref())
+ .await
+ .expect("Failed to retrieve tokenizer implementation")
+ {
+ Tokenizer::Python { .. } => Err(TensorRtLlmBackendError::Tokenizer(
+ "Failed to retrieve Rust based tokenizer".to_string(),
+ )),
+ Tokenizer::Rust(tokenizer) => {
+ info!("Successfully retrieved tokenizer {}", &tokenizer_name);
+ let backend = TensorRtLlmBackendV2::new(
+ tokenizer,
+ model_id,
+ executor_worker,
+ max_concurrent_requests,
+ )?;
- info!("Successfully retrieved tokenizer {}", &tokenizer_name);
- let backend = TensorRtLlmBackendV2::new(
- tokenizer,
- model_id,
- executor_worker,
- max_concurrent_requests,
- )?;
+ info!("Successfully created backend");
- info!("Successfully created backend");
-
- // Run server
- server::run(
- backend,
- max_concurrent_requests,
- max_best_of,
- max_stop_sequences,
- max_top_n_tokens,
- max_input_tokens,
- max_total_tokens,
- validation_workers,
- auth_token,
- tokenizer_name,
- tokenizer_config_path,
- revision,
- false,
- hostname,
- port,
- cors_allow_origin,
- false,
- None,
- None,
- true,
- max_client_batch_size,
- usage_stats,
- payload_limit,
- )
- .await?;
- Ok(())
+ // Run server
+ server::run(
+ backend,
+ max_concurrent_requests,
+ max_best_of,
+ max_stop_sequences,
+ max_top_n_tokens,
+ max_input_tokens,
+ max_total_tokens,
+ validation_workers,
+ auth_token,
+ tokenizer_name,
+ tokenizer_config_path,
+ revision,
+ false,
+ hostname,
+ port,
+ cors_allow_origin,
+ false,
+ None,
+ None,
+ true,
+ max_client_batch_size,
+ usage_stats,
+ payload_limit,
+ )
+ .await?;
+ Ok(())
+ }
+ }
}
diff --git a/backends/trtllm/tests/infer_test.cpp b/backends/trtllm/tests/infer_test.cpp
deleted file mode 100644
index 8520065a7..000000000
--- a/backends/trtllm/tests/infer_test.cpp
+++ /dev/null
@@ -1,14 +0,0 @@
-//
-// Created by mfuntowicz on 7/2/24.
-//
-#include
-#include
-#include "../include/backend.h"
-
-TEST_CASE("Load TRTLLM Engine on the TGI Backend", "[trtllm][engine][load]") {
- const auto engines = std::filesystem::path("/home/mfuntowicz/.cache/huggingface/assets/trtllm/0.11.0.dev2024062500/meta-llama--Meta-Llama-3-8B-Instruct/4090/engines/");
- const auto executor = std::filesystem::path("/home/mfuntowicz/Workspace/text-generation-inference/backends/trtllm/cmake-build-debug/cmake-build-debug/_deps/trtllm-src/cpp/tensorrt_llm/executor_worker/executorWorker");
-
- spdlog::info("Loading config from: {}", absolute(engines).string());
- huggingface::tgi::backends::TensorRtLlmBackend backend(engines, executor);
-}
diff --git a/backends/trtllm/tests/test_backend.cpp b/backends/trtllm/tests/test_backend.cpp
new file mode 100644
index 000000000..f44cc03f9
--- /dev/null
+++ b/backends/trtllm/tests/test_backend.cpp
@@ -0,0 +1,154 @@
+//
+// Created by mfuntowicz on 12/3/24.
+//
+
+#include
+#include
+#include
+
+#include "backend.hpp"
+
+using namespace huggingface::tgi::backends::trtllm;
+
+TEST_CASE("parse generation_config.json all set", "[generation_config_t]")
+{
+ const json config_j = {{"temperature", 0.6},
+ {"top_p", 0.95},
+ {"eos_token_id", {1, 2, 3}}};
+ const auto generation_config = generation_config_t(config_j);
+
+ REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(0.6, 1e-6));
+ REQUIRE_THAT(generation_config.top_p, Catch::Matchers::WithinAbs(0.95, 1e-6));
+
+ // Stop words
+ REQUIRE_FALSE(generation_config.stop_words.empty());
+ REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size());
+
+ for (auto [lhs, rhs]: std::views::zip(generation_config.stop_words, std::list>{{1},
+ {2},
+ {3}})) {
+ // Currently we do not support multi-tokens stop words
+ REQUIRE(lhs.size() == 1);
+ REQUIRE(rhs.size() == 1);
+ REQUIRE_THAT(lhs, Catch::Matchers::UnorderedEquals(rhs));
+ }
+}
+
+TEST_CASE("parse generation_config.json default", "[generation_config_t]")
+{
+ const json config_j = {{"eos_token_id", {1, 2, 3}}};
+ const auto generation_config = generation_config_t(config_j);
+
+ REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6));
+ REQUIRE_THAT(generation_config.top_p, Catch::Matchers::WithinAbs(1.0, 1e-6));
+
+ REQUIRE_FALSE(generation_config.stop_words.empty());
+ REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size());
+
+ for (auto [lhs, rhs]: std::views::zip(generation_config.stop_words, std::list>{{1},
+ {2},
+ {3}})) {
+ // Currently we do not support multi-tokens stop words
+ REQUIRE(lhs.size() == 1);
+ REQUIRE(rhs.size() == 1);
+ REQUIRE_THAT(lhs, Catch::Matchers::UnorderedEquals(rhs));
+ }
+}
+
+TEST_CASE("parse generation_config.json empty", "[generation_config_t]")
+{
+ const json config_j = {{"eos_token_id", {}}};
+ const auto generation_config = generation_config_t(config_j);
+
+ REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6));
+ REQUIRE_THAT(generation_config.top_p, Catch::Matchers::WithinAbs(1.0, 1e-6));
+
+ REQUIRE(generation_config.stop_words.empty());
+
+ const json config_j2 = {};
+ const auto generation_config2 = generation_config_t(config_j);
+
+ REQUIRE_THAT(generation_config2.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6));
+ REQUIRE_THAT(generation_config2.top_p, Catch::Matchers::WithinAbs(1.0, 1e-6));
+
+ REQUIRE(generation_config2.stop_words.empty());
+}
+
+TEST_CASE("parallel_config single", "[backend_workspace_t]")
+{
+ // Generate temporary folder
+ const auto tmp_p = std::filesystem::temp_directory_path();
+ const auto config_p = tmp_p / "config.json";
+ const auto generation_config_p = tmp_p / "generation_config.json";
+
+ // Generate content
+ std::ofstream o_config(config_p);
+ o_config << R"({"pretrained_config": {"mapping": {"world_size": 2}}})"_json;
+ o_config.close();
+
+ std::ofstream o_generation_config(generation_config_p);
+ o_generation_config << R"({"eos_token_id": []})"_json;
+ o_generation_config.close();
+
+ const auto workspace = backend_workspace_t(tmp_p.generic_string(), tmp_p.generic_string());
+ const auto parallel = workspace.parallel_config();
+ REQUIRE(parallel.getCommunicationMode() == tle::CommunicationMode::kORCHESTRATOR);
+ REQUIRE(parallel.getCommunicationType() == tle::CommunicationType::kMPI);
+
+ std::filesystem::remove(config_p);
+ std::filesystem::remove(generation_config_p);
+}
+
+TEST_CASE("parallel_config multi", "[backend_workspace_t]")
+{
+ // Generate temporary folder
+ const auto tmp_p = std::filesystem::temp_directory_path();
+ const auto config_p = tmp_p / "config.json";
+ const auto generation_config_p = tmp_p / "generation_config.json";
+
+ // Generate content
+ std::ofstream o_config(config_p);
+ o_config << R"({"pretrained_config": {"mapping": {"world_size": 1}}})"_json;
+ o_config.close();
+
+ std::ofstream o_generation_config(generation_config_p);
+ o_generation_config << R"({"eos_token_id": []})"_json;
+ o_generation_config.close();
+
+ const auto workspace = backend_workspace_t(tmp_p.generic_string(), tmp_p.generic_string());
+ const auto parallel = workspace.parallel_config();
+ REQUIRE(parallel.getCommunicationMode() == tle::CommunicationMode::kLEADER);
+ REQUIRE(parallel.getCommunicationType() == tle::CommunicationType::kMPI);
+
+ std::filesystem::remove(config_p);
+ std::filesystem::remove(generation_config_p);
+}
+
+TEST_CASE("executor_config", "[backend_workspace_t]")
+{
+
+}
+
+TEST_CASE("sampling_params_t to tle::SamplingConfig", "[backend_t]")
+{
+ const sampling_params_t params = {40, 0.95, 0.9, 1.0, 0.6, 2014};
+ const auto config = static_cast(params);
+
+ REQUIRE(config.getTopK().has_value());
+ REQUIRE(config.getTopK().value() == params.top_k);
+
+ REQUIRE(config.getSeed().has_value());
+ REQUIRE(config.getSeed().value() == params.seed);
+
+ REQUIRE(config.getTopP().has_value());
+ REQUIRE_THAT(*config.getTopP(), Catch::Matchers::WithinAbs(params.top_p, 1e-6f));
+
+ REQUIRE(config.getRepetitionPenalty().has_value());
+ REQUIRE_THAT(*config.getRepetitionPenalty(), Catch::Matchers::WithinAbs(params.repetition_penalty, 1e-6f));
+
+ REQUIRE(config.getFrequencyPenalty().has_value());
+ REQUIRE_THAT(*config.getFrequencyPenalty(), Catch::Matchers::WithinAbs(params.frequency_penalty, 1e-6f));
+
+ REQUIRE(config.getTemperature().has_value());
+ REQUIRE_THAT(*config.getTemperature(), Catch::Matchers::WithinAbs(params.temperature, 1e-6f));
+}
diff --git a/backends/trtllm/tests/test_hardware.cpp b/backends/trtllm/tests/test_hardware.cpp
new file mode 100644
index 000000000..e14f1f357
--- /dev/null
+++ b/backends/trtllm/tests/test_hardware.cpp
@@ -0,0 +1,82 @@
+//
+// Created by mfuntowicz on 11/16/24.
+//
+
+#include
+#include "../csrc/hardware.hpp"
+
+using namespace huggingface::tgi::hardware::cuda;
+
+TEST_CASE("is_at_least_") {
+ const static auto VOLTA_CAPABILITIES = compute_capabilities_t(7, 0);
+ REQUIRE(VOLTA_CAPABILITIES.is_at_least_volta());
+ REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least_turing());
+ REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least_ampere());
+ REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least_ada_lovelace());
+ REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least_hopper());
+
+ const static auto TURING_CAPABILITIES = compute_capabilities_t(7, 5);
+ REQUIRE(TURING_CAPABILITIES.is_at_least_volta());
+ REQUIRE(TURING_CAPABILITIES.is_at_least_turing());
+ REQUIRE_FALSE(TURING_CAPABILITIES.is_at_least_ampere());
+ REQUIRE_FALSE(TURING_CAPABILITIES.is_at_least_ada_lovelace());
+ REQUIRE_FALSE(TURING_CAPABILITIES.is_at_least_hopper());
+
+ const static auto AMPERE_CAPABILITIES = compute_capabilities_t(8, 0);
+ REQUIRE(AMPERE_CAPABILITIES.is_at_least_volta());
+ REQUIRE(AMPERE_CAPABILITIES.is_at_least_turing());
+ REQUIRE(AMPERE_CAPABILITIES.is_at_least_ampere());
+ REQUIRE_FALSE(AMPERE_CAPABILITIES.is_at_least_ada_lovelace());
+ REQUIRE_FALSE(AMPERE_CAPABILITIES.is_at_least_hopper());
+
+ const static auto ADA_LOVELACE_CAPABILITIES = compute_capabilities_t(8, 9);
+ REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least_volta());
+ REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least_turing());
+ REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least_ampere());
+ REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least_ada_lovelace());
+ REQUIRE_FALSE(ADA_LOVELACE_CAPABILITIES.is_at_least_hopper());
+
+ const static auto HOPPER_CAPABILITIES = compute_capabilities_t(9, 0);
+ REQUIRE(HOPPER_CAPABILITIES.is_at_least_volta());
+ REQUIRE(HOPPER_CAPABILITIES.is_at_least_turing());
+ REQUIRE(HOPPER_CAPABILITIES.is_at_least_ampere());
+ REQUIRE(HOPPER_CAPABILITIES.is_at_least_ada_lovelace());
+ REQUIRE(HOPPER_CAPABILITIES.is_at_least_hopper());
+}
+
+TEST_CASE("is_at_least") {
+ const static auto VOLTA_CAPABILITIES = compute_capabilities_t(7, 0);
+ REQUIRE(VOLTA_CAPABILITIES.is_at_least(VOLTA));
+ REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least(TURING));
+ REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least(AMPERE));
+ REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least(ADA_LOVELACE));
+ REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least(HOPPER));
+
+ const static auto TURING_CAPABILITIES = compute_capabilities_t(7, 5);
+ REQUIRE(TURING_CAPABILITIES.is_at_least(VOLTA));
+ REQUIRE(TURING_CAPABILITIES.is_at_least(TURING));
+ REQUIRE_FALSE(TURING_CAPABILITIES.is_at_least(AMPERE));
+ REQUIRE_FALSE(TURING_CAPABILITIES.is_at_least(ADA_LOVELACE));
+ REQUIRE_FALSE(TURING_CAPABILITIES.is_at_least(HOPPER));
+
+ const static auto AMPERE_CAPABILITIES = compute_capabilities_t(8, 0);
+ REQUIRE(AMPERE_CAPABILITIES.is_at_least(VOLTA));
+ REQUIRE(AMPERE_CAPABILITIES.is_at_least(TURING));
+ REQUIRE(AMPERE_CAPABILITIES.is_at_least(AMPERE));
+ REQUIRE_FALSE(AMPERE_CAPABILITIES.is_at_least(ADA_LOVELACE));
+ REQUIRE_FALSE(AMPERE_CAPABILITIES.is_at_least(HOPPER));
+
+ const static auto ADA_LOVELACE_CAPABILITIES = compute_capabilities_t(8, 9);
+ REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least(VOLTA));
+ REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least(TURING));
+ REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least(AMPERE));
+ REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least(ADA_LOVELACE));
+ REQUIRE_FALSE(ADA_LOVELACE_CAPABILITIES.is_at_least(HOPPER));
+
+ const static auto HOPPER_CAPABILITIES = compute_capabilities_t (9, 0);
+ REQUIRE(HOPPER_CAPABILITIES.is_at_least(VOLTA));
+ REQUIRE(HOPPER_CAPABILITIES.is_at_least(TURING));
+ REQUIRE(HOPPER_CAPABILITIES.is_at_least(AMPERE));
+ REQUIRE(HOPPER_CAPABILITIES.is_at_least(ADA_LOVELACE));
+ REQUIRE(HOPPER_CAPABILITIES.is_at_least(HOPPER));
+}
diff --git a/backends/v2/Cargo.toml b/backends/v2/Cargo.toml
index 4d32474e7..0decf41ad 100644
--- a/backends/v2/Cargo.toml
+++ b/backends/v2/Cargo.toml
@@ -23,7 +23,7 @@ clap = { version = "4.4.5", features = ["derive", "env"] }
grpc-metadata = { path = "../grpc-metadata" }
futures = "0.3.28"
hf-hub = { workspace = true }
-jsonschema = { version = "0.17.1", features = ["draft202012"] }
+jsonschema = { version = "0.28.0" }
metrics = { workspace = true }
metrics-exporter-prometheus = { workspace = true }
nohash-hasher = "0.2.0"
diff --git a/backends/v2/src/backend.rs b/backends/v2/src/backend.rs
index cfe87f98f..adca3d5d2 100644
--- a/backends/v2/src/backend.rs
+++ b/backends/v2/src/backend.rs
@@ -108,6 +108,10 @@ impl Backend for BackendV2 {
fn start_health(&self) -> bool {
true
}
+
+ fn name(&self) -> &'static str {
+ "tgi-v2"
+ }
}
/// Batching logic
diff --git a/backends/v2/src/queue.rs b/backends/v2/src/queue.rs
index 61a3eebc9..c9a9335dd 100644
--- a/backends/v2/src/queue.rs
+++ b/backends/v2/src/queue.rs
@@ -213,8 +213,7 @@ impl State {
}
// Pad prefill_token_budget to be a multiple of block size
- let prefill_token_budget =
- ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
+ let prefill_token_budget = prefill_token_budget.div_ceil(self.block_size) * self.block_size;
// Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
@@ -245,9 +244,8 @@ impl State {
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
} else {
// pad to block size
- prefill_tokens += ((entry.request.input_length + self.block_size - 1)
- / self.block_size)
- * self.block_size;
+ prefill_tokens +=
+ entry.request.input_length.div_ceil(self.block_size) * self.block_size;
}
if self.requires_padding {
@@ -262,8 +260,7 @@ impl State {
};
// pad to block size
- decode_tokens +=
- ((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size;
+ decode_tokens += max_new_tokens.div_ceil(self.block_size) * self.block_size;
}
if prefill_tokens > prefill_token_budget
diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml
index 69dad072f..588a2716f 100644
--- a/backends/v3/Cargo.toml
+++ b/backends/v3/Cargo.toml
@@ -23,7 +23,7 @@ clap = { version = "4.4.5", features = ["derive", "env"] }
grpc-metadata = { path = "../grpc-metadata" }
futures = "0.3.28"
hf-hub = { workspace = true }
-jsonschema = { version = "0.17.1", features = ["draft202012"] }
+jsonschema = { version = "0.28.0" }
metrics = { workspace = true }
metrics-exporter-prometheus = { workspace = true }
nohash-hasher = "0.2.0"
@@ -71,6 +71,7 @@ prost-build = "0.12.1"
[dev-dependencies]
criterion = "0.3"
itertools = "0.13"
+rustc-hash = "2"
[features]
default = ["ngrok"]
diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs
index 736301b33..98e8d76f0 100644
--- a/backends/v3/src/backend.rs
+++ b/backends/v3/src/backend.rs
@@ -115,6 +115,10 @@ impl Backend for BackendV3 {
fn start_health(&self) -> bool {
true
}
+
+ fn name(&self) -> &'static str {
+ "tgi-v3"
+ }
}
/// Batching logic
diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs
index 4fea172b6..6da2b51da 100644
--- a/backends/v3/src/block_allocator.rs
+++ b/backends/v3/src/block_allocator.rs
@@ -2,7 +2,7 @@ use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};
use crate::radix::RadixAllocator;
-
+use text_generation_router::usage_stats::Env;
#[derive(Debug, Clone)]
pub struct BlockAllocation {
pub allocation_id: u64,
@@ -141,6 +141,7 @@ pub struct SimpleAllocator {
free_blocks: Vec,
block_size: u32,
window_size: Option,
+ is_hpu_device: bool,
}
impl SimpleAllocator {
@@ -150,6 +151,7 @@ impl SimpleAllocator {
// Block 0 is reserved for health checks
free_blocks: (1..blocks).collect(),
window_size,
+ is_hpu_device: Env::new().is_hpu_device(),
}
}
}
@@ -165,13 +167,13 @@ impl Allocator for SimpleAllocator {
let (tokens, repeats) = match self.window_size {
None => (tokens, 1),
Some(window_size) => {
- let repeats = (tokens + window_size - 1) / window_size;
+ let repeats = tokens.div_ceil(window_size);
let tokens = core::cmp::min(tokens, window_size);
(tokens, repeats as usize)
}
};
// Pad to a multiple of block size
- let required_blocks = (tokens + self.block_size - 1) / self.block_size;
+ let required_blocks = tokens.div_ceil(self.block_size);
(required_blocks, repeats)
};
@@ -179,9 +181,15 @@ impl Allocator for SimpleAllocator {
if required_blocks > self.free_blocks.len() as u32 {
None
} else {
- let blocks = self
+ if self.is_hpu_device {
+ self.free_blocks.sort_by(|a, b| b.cmp(a));
+ }
+ let mut blocks = self
.free_blocks
.split_off(self.free_blocks.len() - required_blocks as usize);
+ if self.is_hpu_device {
+ blocks.sort();
+ }
let mut slots =
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs
index dd27806f9..d3bf4b9c0 100644
--- a/backends/v3/src/queue.rs
+++ b/backends/v3/src/queue.rs
@@ -257,8 +257,7 @@ impl State {
}
// Pad prefill_token_budget to be a multiple of block size
- let prefill_token_budget =
- ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
+ let prefill_token_budget = prefill_token_budget.div_ceil(self.block_size) * self.block_size;
// Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
@@ -312,7 +311,7 @@ impl State {
+ entry.request.stopping_parameters.max_new_tokens
+ self.speculate
- 1;
- tracing::debug!("Allocating {tokens} with {input_ids:?}");
+ // tracing::debug!("Allocating {tokens} with {input_ids:?}");
let block_allocation = match block_allocator.allocate(tokens, input_ids).await {
None => {
@@ -323,7 +322,7 @@ impl State {
break 'entry_loop;
}
Some(mut block_allocation) => {
- tracing::debug!("Allocation: {block_allocation:?}");
+ // tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
if block_allocation.prefix_len == entry.request.input_length {
diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs
index 8a5448911..aea69693f 100644
--- a/backends/v3/src/radix.rs
+++ b/backends/v3/src/radix.rs
@@ -103,7 +103,7 @@ impl Allocator for RadixAllocator {
let prefix_len = blocks.len() * self.block_size as usize;
let suffix_len = tokens - prefix_len as u32;
- let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size;
+ let suffix_blocks = suffix_len.div_ceil(self.block_size);
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");
@@ -283,7 +283,7 @@ impl RadixTrie {
}
/// Find worker.
- fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec) -> NodeId {
+ fn find_(&mut self, node_id: NodeId, key: &[u32], blocks: &mut Vec) -> NodeId {
let node = &self.nodes[node_id];
if key.len() >= self.block_size {
@@ -295,9 +295,13 @@ impl RadixTrie {
assert_eq!(shared_prefix_len % self.block_size, 0);
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
+ // A node represents the prefix of its children. So, only
+ // recurse when there is a full prefix match.
let key = &key[shared_prefix_len..];
- if !key.is_empty() {
- node_id = self.find_(child_id, key, blocks);
+ if !key.is_empty() && shared_prefix_len == child.key.len() {
+ return self.find_(child_id, key, blocks);
+ } else {
+ return child_id;
}
}
}
@@ -631,6 +635,12 @@ fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize {
mod tests {
use std::sync::Arc;
+ use rand::{
+ distributions::Uniform, prelude::Distribution, rngs::SmallRng, seq::SliceRandom,
+ SeedableRng,
+ };
+ use rustc_hash::FxHashSet;
+
use super::*;
#[test]
@@ -873,4 +883,159 @@ mod tests {
// Clear out the whole trie.
assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]);
}
+
+ #[test]
+ fn full_match_returns_correct_node() {
+ let mut trie = RadixTrie::new(1);
+ trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
+ let node_id = trie.find(&[0, 1, 2], &mut vec![]);
+ // At this point, there are only two nodes: the root and the node
+ // with tokens 0, 1, 2. Looking up the exact prefix must return
+ // the non-root node.
+ assert_ne!(node_id, trie.root);
+ }
+
+ #[test]
+ fn partial_match_does_not_recurse() {
+ let mut trie = RadixTrie::new(1);
+ trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
+ trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2, 3, 4, 5])
+ .unwrap();
+ let mut blocks = Vec::new();
+ let node_id = trie.find(&[0, 1, 3, 4, 5], &mut blocks);
+ assert_eq!(blocks, vec![0, 1]);
+ assert_eq!(node_id, trie.find(&[0, 1], &mut blocks))
+ }
+
+ struct AllocationWithInfo {
+ allocation: BlockAllocation,
+ // We are doing a lot of set operations and `FxBuildHasher` is
+ // muc faster for a set of integers.
+ blockset: FxHashSet,
+ non_prefix_blocks: FxHashSet,
+ }
+
+ #[test]
+ fn invariants_hold_on_many_operations_remove_all() {
+ invariants_hold_on_many_insertions(true);
+ }
+
+ #[test]
+ fn invariants_hold_on_many_operations_remove_subset() {
+ invariants_hold_on_many_insertions(false);
+ }
+
+ fn invariants_hold_on_many_insertions(remove_all: bool) {
+ // Small vocabulary sizes lead to violations more quickly due to
+ // prefix sharing, etc.
+ const VOCAB_SIZE: u32 = 2;
+ const DATA_LEN: usize = 1_000;
+
+ const MAX_PREFILL_LEN: usize = 8;
+ const MAX_DECODE_LEN: usize = 8;
+
+ let vocab_range = Uniform::new(0, VOCAB_SIZE);
+ let data_range = Uniform::new(0, DATA_LEN);
+ let prefill_len_range = Uniform::new(0, MAX_PREFILL_LEN);
+ let decode_len_range = Uniform::new(0, MAX_DECODE_LEN);
+
+ let mut rng = SmallRng::seed_from_u64(64);
+ let data = (0..DATA_LEN)
+ .map(|_| vocab_range.sample(&mut rng))
+ .collect::>();
+ let mut allocator = RadixAllocator::new(1, 100, None);
+
+ let mut allocations = Vec::new();
+
+ for i in 0..100_000 {
+ // Allocate until all blocks are used.
+ 'allocation: loop {
+ // Use offset 0 half of the times for prefix sharing.
+ let prefill_offset = data_range.sample(&mut rng);
+ let prefill_len = prefill_len_range.sample(&mut rng);
+ let decode_len = decode_len_range.sample(&mut rng);
+
+ let prefill =
+ data[prefill_offset..data.len().min(prefill_offset + prefill_len)].to_vec();
+
+ let allocation = match allocator
+ .allocate((prefill.len() + decode_len) as u32, Some(Arc::new(prefill)))
+ {
+ Some(allocation) => allocation,
+ None => break 'allocation,
+ };
+ let non_prefix_blocks = allocation.blocks[allocation.prefix_len as usize..]
+ .iter()
+ .copied()
+ .collect::>();
+ let blockset = allocation.blocks.iter().copied().collect::>();
+
+ // No duplicate blocks in an allocation.
+ assert_eq!(
+ allocation.blocks.len(),
+ blockset.len(),
+ "Duplicate blocks in allocation"
+ );
+
+ allocations.push(AllocationWithInfo {
+ allocation,
+ blockset,
+ non_prefix_blocks,
+ });
+ }
+
+ // Check invariants. Skip first iteration, since there is no prefix sharing yet.
+ if i > 1 {
+ check_allocation_invariants(&allocations);
+ }
+
+ // Remove 20% of the allocations, randomly.
+ if remove_all {
+ allocations.into_iter().for_each(|allocation| {
+ allocator.free(
+ allocation.allocation.blocks.clone(),
+ allocation.allocation.allocation_id,
+ )
+ });
+ allocations = Vec::new();
+ } else {
+ allocations.shuffle(&mut rng);
+ let remove_index = (allocations.len() as f64 * 0.8) as usize;
+ for allocation in allocations.drain(remove_index..) {
+ allocator.free(
+ allocation.allocation.blocks.clone(),
+ allocation.allocation.allocation_id,
+ );
+ }
+ }
+ }
+ }
+
+ fn check_allocation_invariants(allocations: &[AllocationWithInfo]) {
+ for i in 0..allocations.len() {
+ let allocation = &allocations[i];
+
+ // 0 is used for health checks, must not be used.
+ assert!(
+ !allocation.blockset.contains(&0),
+ "Block 0 must not be allocated"
+ );
+
+ // No duplicate blocks in an allocation.
+ assert_eq!(
+ allocation.allocation.blocks.len(),
+ allocation.blockset.len(),
+ "Duplicate blocks in allocation"
+ );
+
+ for other_allocation in &allocations[i + 1..] {
+ assert!(
+ other_allocation
+ .non_prefix_blocks
+ .is_disjoint(&allocation.non_prefix_blocks),
+ "Allocations share non-prefix blocks"
+ )
+ }
+ }
+ }
}
diff --git a/clients/python/poetry.lock b/clients/python/poetry.lock
index 148d99065..36e82f2a0 100644
--- a/clients/python/poetry.lock
+++ b/clients/python/poetry.lock
@@ -1,124 +1,131 @@
-# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
+# This file is automatically @generated by Poetry 2.0.0 and should not be changed by hand.
+
+[[package]]
+name = "aiohappyeyeballs"
+version = "2.6.1"
+description = "Happy Eyeballs for asyncio"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "aiohappyeyeballs-2.6.1-py3-none-any.whl", hash = "sha256:f349ba8f4b75cb25c99c5c2d84e997e485204d2902a9597802b0371f09331fb8"},
+ {file = "aiohappyeyeballs-2.6.1.tar.gz", hash = "sha256:c3f9d0113123803ccadfdf3f0faa505bc78e6a72d1cc4806cbd719826e943558"},
+]
[[package]]
name = "aiohttp"
-version = "3.8.5"
+version = "3.11.16"
description = "Async http client/server framework (asyncio)"
optional = false
-python-versions = ">=3.6"
+python-versions = ">=3.9"
+groups = ["main"]
files = [
- {file = "aiohttp-3.8.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a94159871304770da4dd371f4291b20cac04e8c94f11bdea1c3478e557fbe0d8"},
- {file = "aiohttp-3.8.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:13bf85afc99ce6f9ee3567b04501f18f9f8dbbb2ea11ed1a2e079670403a7c84"},
- {file = "aiohttp-3.8.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2ce2ac5708501afc4847221a521f7e4b245abf5178cf5ddae9d5b3856ddb2f3a"},
- {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:96943e5dcc37a6529d18766597c491798b7eb7a61d48878611298afc1fca946c"},
- {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2ad5c3c4590bb3cc28b4382f031f3783f25ec223557124c68754a2231d989e2b"},
- {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0c413c633d0512df4dc7fd2373ec06cc6a815b7b6d6c2f208ada7e9e93a5061d"},
- {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df72ac063b97837a80d80dec8d54c241af059cc9bb42c4de68bd5b61ceb37caa"},
- {file = "aiohttp-3.8.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c48c5c0271149cfe467c0ff8eb941279fd6e3f65c9a388c984e0e6cf57538e14"},
- {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:368a42363c4d70ab52c2c6420a57f190ed3dfaca6a1b19afda8165ee16416a82"},
- {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:7607ec3ce4993464368505888af5beb446845a014bc676d349efec0e05085905"},
- {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:0d21c684808288a98914e5aaf2a7c6a3179d4df11d249799c32d1808e79503b5"},
- {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:312fcfbacc7880a8da0ae8b6abc6cc7d752e9caa0051a53d217a650b25e9a691"},
- {file = "aiohttp-3.8.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ad093e823df03bb3fd37e7dec9d4670c34f9e24aeace76808fc20a507cace825"},
- {file = "aiohttp-3.8.5-cp310-cp310-win32.whl", hash = "sha256:33279701c04351a2914e1100b62b2a7fdb9a25995c4a104259f9a5ead7ed4802"},
- {file = "aiohttp-3.8.5-cp310-cp310-win_amd64.whl", hash = "sha256:6e4a280e4b975a2e7745573e3fc9c9ba0d1194a3738ce1cbaa80626cc9b4f4df"},
- {file = "aiohttp-3.8.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ae871a964e1987a943d83d6709d20ec6103ca1eaf52f7e0d36ee1b5bebb8b9b9"},
- {file = "aiohttp-3.8.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:461908b2578955045efde733719d62f2b649c404189a09a632d245b445c9c975"},
- {file = "aiohttp-3.8.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:72a860c215e26192379f57cae5ab12b168b75db8271f111019509a1196dfc780"},
- {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc14be025665dba6202b6a71cfcdb53210cc498e50068bc088076624471f8bb9"},
- {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8af740fc2711ad85f1a5c034a435782fbd5b5f8314c9a3ef071424a8158d7f6b"},
- {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:841cd8233cbd2111a0ef0a522ce016357c5e3aff8a8ce92bcfa14cef890d698f"},
- {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ed1c46fb119f1b59304b5ec89f834f07124cd23ae5b74288e364477641060ff"},
- {file = "aiohttp-3.8.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:84f8ae3e09a34f35c18fa57f015cc394bd1389bce02503fb30c394d04ee6b938"},
- {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:62360cb771707cb70a6fd114b9871d20d7dd2163a0feafe43fd115cfe4fe845e"},
- {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:23fb25a9f0a1ca1f24c0a371523546366bb642397c94ab45ad3aedf2941cec6a"},
- {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:b0ba0d15164eae3d878260d4c4df859bbdc6466e9e6689c344a13334f988bb53"},
- {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:5d20003b635fc6ae3f96d7260281dfaf1894fc3aa24d1888a9b2628e97c241e5"},
- {file = "aiohttp-3.8.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0175d745d9e85c40dcc51c8f88c74bfbaef9e7afeeeb9d03c37977270303064c"},
- {file = "aiohttp-3.8.5-cp311-cp311-win32.whl", hash = "sha256:2e1b1e51b0774408f091d268648e3d57f7260c1682e7d3a63cb00d22d71bb945"},
- {file = "aiohttp-3.8.5-cp311-cp311-win_amd64.whl", hash = "sha256:043d2299f6dfdc92f0ac5e995dfc56668e1587cea7f9aa9d8a78a1b6554e5755"},
- {file = "aiohttp-3.8.5-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:cae533195e8122584ec87531d6df000ad07737eaa3c81209e85c928854d2195c"},
- {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f21e83f355643c345177a5d1d8079f9f28b5133bcd154193b799d380331d5d3"},
- {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a7a75ef35f2df54ad55dbf4b73fe1da96f370e51b10c91f08b19603c64004acc"},
- {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2e2e9839e14dd5308ee773c97115f1e0a1cb1d75cbeeee9f33824fa5144c7634"},
- {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44e65da1de4403d0576473e2344828ef9c4c6244d65cf4b75549bb46d40b8dd"},
- {file = "aiohttp-3.8.5-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:78d847e4cde6ecc19125ccbc9bfac4a7ab37c234dd88fbb3c5c524e8e14da543"},
- {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:c7a815258e5895d8900aec4454f38dca9aed71085f227537208057853f9d13f2"},
- {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:8b929b9bd7cd7c3939f8bcfffa92fae7480bd1aa425279d51a89327d600c704d"},
- {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:5db3a5b833764280ed7618393832e0853e40f3d3e9aa128ac0ba0f8278d08649"},
- {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:a0215ce6041d501f3155dc219712bc41252d0ab76474615b9700d63d4d9292af"},
- {file = "aiohttp-3.8.5-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:fd1ed388ea7fbed22c4968dd64bab0198de60750a25fe8c0c9d4bef5abe13824"},
- {file = "aiohttp-3.8.5-cp36-cp36m-win32.whl", hash = "sha256:6e6783bcc45f397fdebc118d772103d751b54cddf5b60fbcc958382d7dd64f3e"},
- {file = "aiohttp-3.8.5-cp36-cp36m-win_amd64.whl", hash = "sha256:b5411d82cddd212644cf9360879eb5080f0d5f7d809d03262c50dad02f01421a"},
- {file = "aiohttp-3.8.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:01d4c0c874aa4ddfb8098e85d10b5e875a70adc63db91f1ae65a4b04d3344cda"},
- {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5980a746d547a6ba173fd5ee85ce9077e72d118758db05d229044b469d9029a"},
- {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2a482e6da906d5e6e653be079b29bc173a48e381600161c9932d89dfae5942ef"},
- {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80bd372b8d0715c66c974cf57fe363621a02f359f1ec81cba97366948c7fc873"},
- {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1161b345c0a444ebcf46bf0a740ba5dcf50612fd3d0528883fdc0eff578006a"},
- {file = "aiohttp-3.8.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cd56db019015b6acfaaf92e1ac40eb8434847d9bf88b4be4efe5bfd260aee692"},
- {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:153c2549f6c004d2754cc60603d4668899c9895b8a89397444a9c4efa282aaf4"},
- {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4a01951fabc4ce26ab791da5f3f24dca6d9a6f24121746eb19756416ff2d881b"},
- {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bfb9162dcf01f615462b995a516ba03e769de0789de1cadc0f916265c257e5d8"},
- {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:7dde0009408969a43b04c16cbbe252c4f5ef4574ac226bc8815cd7342d2028b6"},
- {file = "aiohttp-3.8.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:4149d34c32f9638f38f544b3977a4c24052042affa895352d3636fa8bffd030a"},
- {file = "aiohttp-3.8.5-cp37-cp37m-win32.whl", hash = "sha256:68c5a82c8779bdfc6367c967a4a1b2aa52cd3595388bf5961a62158ee8a59e22"},
- {file = "aiohttp-3.8.5-cp37-cp37m-win_amd64.whl", hash = "sha256:2cf57fb50be5f52bda004b8893e63b48530ed9f0d6c96c84620dc92fe3cd9b9d"},
- {file = "aiohttp-3.8.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:eca4bf3734c541dc4f374ad6010a68ff6c6748f00451707f39857f429ca36ced"},
- {file = "aiohttp-3.8.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1274477e4c71ce8cfe6c1ec2f806d57c015ebf84d83373676036e256bc55d690"},
- {file = "aiohttp-3.8.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:28c543e54710d6158fc6f439296c7865b29e0b616629767e685a7185fab4a6b9"},
- {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:910bec0c49637d213f5d9877105d26e0c4a4de2f8b1b29405ff37e9fc0ad52b8"},
- {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5443910d662db951b2e58eb70b0fbe6b6e2ae613477129a5805d0b66c54b6cb7"},
- {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2e460be6978fc24e3df83193dc0cc4de46c9909ed92dd47d349a452ef49325b7"},
- {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb1558def481d84f03b45888473fc5a1f35747b5f334ef4e7a571bc0dfcb11f8"},
- {file = "aiohttp-3.8.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:34dd0c107799dcbbf7d48b53be761a013c0adf5571bf50c4ecad5643fe9cfcd0"},
- {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:aa1990247f02a54185dc0dff92a6904521172a22664c863a03ff64c42f9b5410"},
- {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:0e584a10f204a617d71d359fe383406305a4b595b333721fa50b867b4a0a1548"},
- {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:a3cf433f127efa43fee6b90ea4c6edf6c4a17109d1d037d1a52abec84d8f2e42"},
- {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:c11f5b099adafb18e65c2c997d57108b5bbeaa9eeee64a84302c0978b1ec948b"},
- {file = "aiohttp-3.8.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:84de26ddf621d7ac4c975dbea4c945860e08cccde492269db4e1538a6a6f3c35"},
- {file = "aiohttp-3.8.5-cp38-cp38-win32.whl", hash = "sha256:ab88bafedc57dd0aab55fa728ea10c1911f7e4d8b43e1d838a1739f33712921c"},
- {file = "aiohttp-3.8.5-cp38-cp38-win_amd64.whl", hash = "sha256:5798a9aad1879f626589f3df0f8b79b3608a92e9beab10e5fda02c8a2c60db2e"},
- {file = "aiohttp-3.8.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:a6ce61195c6a19c785df04e71a4537e29eaa2c50fe745b732aa937c0c77169f3"},
- {file = "aiohttp-3.8.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:773dd01706d4db536335fcfae6ea2440a70ceb03dd3e7378f3e815b03c97ab51"},
- {file = "aiohttp-3.8.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f83a552443a526ea38d064588613aca983d0ee0038801bc93c0c916428310c28"},
- {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f7372f7341fcc16f57b2caded43e81ddd18df53320b6f9f042acad41f8e049a"},
- {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea353162f249c8097ea63c2169dd1aa55de1e8fecbe63412a9bc50816e87b761"},
- {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e5d47ae48db0b2dcf70bc8a3bc72b3de86e2a590fc299fdbbb15af320d2659de"},
- {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d827176898a2b0b09694fbd1088c7a31836d1a505c243811c87ae53a3f6273c1"},
- {file = "aiohttp-3.8.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3562b06567c06439d8b447037bb655ef69786c590b1de86c7ab81efe1c9c15d8"},
- {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4e874cbf8caf8959d2adf572a78bba17cb0e9d7e51bb83d86a3697b686a0ab4d"},
- {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6809a00deaf3810e38c628e9a33271892f815b853605a936e2e9e5129762356c"},
- {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:33776e945d89b29251b33a7e7d006ce86447b2cfd66db5e5ded4e5cd0340585c"},
- {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:eaeed7abfb5d64c539e2db173f63631455f1196c37d9d8d873fc316470dfbacd"},
- {file = "aiohttp-3.8.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e91d635961bec2d8f19dfeb41a539eb94bd073f075ca6dae6c8dc0ee89ad6f91"},
- {file = "aiohttp-3.8.5-cp39-cp39-win32.whl", hash = "sha256:00ad4b6f185ec67f3e6562e8a1d2b69660be43070bd0ef6fcec5211154c7df67"},
- {file = "aiohttp-3.8.5-cp39-cp39-win_amd64.whl", hash = "sha256:c0a9034379a37ae42dea7ac1e048352d96286626251862e448933c0f59cbd79c"},
- {file = "aiohttp-3.8.5.tar.gz", hash = "sha256:b9552ec52cc147dbf1944ac7ac98af7602e51ea2dcd076ed194ca3c0d1c7d0bc"},
+ {file = "aiohttp-3.11.16-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fb46bb0f24813e6cede6cc07b1961d4b04f331f7112a23b5e21f567da4ee50aa"},
+ {file = "aiohttp-3.11.16-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:54eb3aead72a5c19fad07219acd882c1643a1027fbcdefac9b502c267242f955"},
+ {file = "aiohttp-3.11.16-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:38bea84ee4fe24ebcc8edeb7b54bf20f06fd53ce4d2cc8b74344c5b9620597fd"},
+ {file = "aiohttp-3.11.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0666afbe984f6933fe72cd1f1c3560d8c55880a0bdd728ad774006eb4241ecd"},
+ {file = "aiohttp-3.11.16-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7ba92a2d9ace559a0a14b03d87f47e021e4fa7681dc6970ebbc7b447c7d4b7cd"},
+ {file = "aiohttp-3.11.16-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3ad1d59fd7114e6a08c4814983bb498f391c699f3c78712770077518cae63ff7"},
+ {file = "aiohttp-3.11.16-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98b88a2bf26965f2015a771381624dd4b0839034b70d406dc74fd8be4cc053e3"},
+ {file = "aiohttp-3.11.16-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:576f5ca28d1b3276026f7df3ec841ae460e0fc3aac2a47cbf72eabcfc0f102e1"},
+ {file = "aiohttp-3.11.16-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a2a450bcce4931b295fc0848f384834c3f9b00edfc2150baafb4488c27953de6"},
+ {file = "aiohttp-3.11.16-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:37dcee4906454ae377be5937ab2a66a9a88377b11dd7c072df7a7c142b63c37c"},
+ {file = "aiohttp-3.11.16-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:4d0c970c0d602b1017e2067ff3b7dac41c98fef4f7472ec2ea26fd8a4e8c2149"},
+ {file = "aiohttp-3.11.16-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:004511d3413737700835e949433536a2fe95a7d0297edd911a1e9705c5b5ea43"},
+ {file = "aiohttp-3.11.16-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:c15b2271c44da77ee9d822552201180779e5e942f3a71fb74e026bf6172ff287"},
+ {file = "aiohttp-3.11.16-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ad9509ffb2396483ceacb1eee9134724443ee45b92141105a4645857244aecc8"},
+ {file = "aiohttp-3.11.16-cp310-cp310-win32.whl", hash = "sha256:634d96869be6c4dc232fc503e03e40c42d32cfaa51712aee181e922e61d74814"},
+ {file = "aiohttp-3.11.16-cp310-cp310-win_amd64.whl", hash = "sha256:938f756c2b9374bbcc262a37eea521d8a0e6458162f2a9c26329cc87fdf06534"},
+ {file = "aiohttp-3.11.16-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8cb0688a8d81c63d716e867d59a9ccc389e97ac7037ebef904c2b89334407180"},
+ {file = "aiohttp-3.11.16-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0ad1fb47da60ae1ddfb316f0ff16d1f3b8e844d1a1e154641928ea0583d486ed"},
+ {file = "aiohttp-3.11.16-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:df7db76400bf46ec6a0a73192b14c8295bdb9812053f4fe53f4e789f3ea66bbb"},
+ {file = "aiohttp-3.11.16-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc3a145479a76ad0ed646434d09216d33d08eef0d8c9a11f5ae5cdc37caa3540"},
+ {file = "aiohttp-3.11.16-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d007aa39a52d62373bd23428ba4a2546eed0e7643d7bf2e41ddcefd54519842c"},
+ {file = "aiohttp-3.11.16-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f6ddd90d9fb4b501c97a4458f1c1720e42432c26cb76d28177c5b5ad4e332601"},
+ {file = "aiohttp-3.11.16-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a2f451849e6b39e5c226803dcacfa9c7133e9825dcefd2f4e837a2ec5a3bb98"},
+ {file = "aiohttp-3.11.16-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8df6612df74409080575dca38a5237282865408016e65636a76a2eb9348c2567"},
+ {file = "aiohttp-3.11.16-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:78e6e23b954644737e385befa0deb20233e2dfddf95dd11e9db752bdd2a294d3"},
+ {file = "aiohttp-3.11.16-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:696ef00e8a1f0cec5e30640e64eca75d8e777933d1438f4facc9c0cdf288a810"},
+ {file = "aiohttp-3.11.16-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e3538bc9fe1b902bef51372462e3d7c96fce2b566642512138a480b7adc9d508"},
+ {file = "aiohttp-3.11.16-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:3ab3367bb7f61ad18793fea2ef71f2d181c528c87948638366bf1de26e239183"},
+ {file = "aiohttp-3.11.16-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:56a3443aca82abda0e07be2e1ecb76a050714faf2be84256dae291182ba59049"},
+ {file = "aiohttp-3.11.16-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:61c721764e41af907c9d16b6daa05a458f066015abd35923051be8705108ed17"},
+ {file = "aiohttp-3.11.16-cp311-cp311-win32.whl", hash = "sha256:3e061b09f6fa42997cf627307f220315e313ece74907d35776ec4373ed718b86"},
+ {file = "aiohttp-3.11.16-cp311-cp311-win_amd64.whl", hash = "sha256:745f1ed5e2c687baefc3c5e7b4304e91bf3e2f32834d07baaee243e349624b24"},
+ {file = "aiohttp-3.11.16-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:911a6e91d08bb2c72938bc17f0a2d97864c531536b7832abee6429d5296e5b27"},
+ {file = "aiohttp-3.11.16-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6ac13b71761e49d5f9e4d05d33683bbafef753e876e8e5a7ef26e937dd766713"},
+ {file = "aiohttp-3.11.16-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fd36c119c5d6551bce374fcb5c19269638f8d09862445f85a5a48596fd59f4bb"},
+ {file = "aiohttp-3.11.16-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d489d9778522fbd0f8d6a5c6e48e3514f11be81cb0a5954bdda06f7e1594b321"},
+ {file = "aiohttp-3.11.16-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:69a2cbd61788d26f8f1e626e188044834f37f6ae3f937bd9f08b65fc9d7e514e"},
+ {file = "aiohttp-3.11.16-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd464ba806e27ee24a91362ba3621bfc39dbbb8b79f2e1340201615197370f7c"},
+ {file = "aiohttp-3.11.16-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ce63ae04719513dd2651202352a2beb9f67f55cb8490c40f056cea3c5c355ce"},
+ {file = "aiohttp-3.11.16-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09b00dd520d88eac9d1768439a59ab3d145065c91a8fab97f900d1b5f802895e"},
+ {file = "aiohttp-3.11.16-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:7f6428fee52d2bcf96a8aa7b62095b190ee341ab0e6b1bcf50c615d7966fd45b"},
+ {file = "aiohttp-3.11.16-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:13ceac2c5cdcc3f64b9015710221ddf81c900c5febc505dbd8f810e770011540"},
+ {file = "aiohttp-3.11.16-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:fadbb8f1d4140825069db3fedbbb843290fd5f5bc0a5dbd7eaf81d91bf1b003b"},
+ {file = "aiohttp-3.11.16-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:6a792ce34b999fbe04a7a71a90c74f10c57ae4c51f65461a411faa70e154154e"},
+ {file = "aiohttp-3.11.16-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:f4065145bf69de124accdd17ea5f4dc770da0a6a6e440c53f6e0a8c27b3e635c"},
+ {file = "aiohttp-3.11.16-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fa73e8c2656a3653ae6c307b3f4e878a21f87859a9afab228280ddccd7369d71"},
+ {file = "aiohttp-3.11.16-cp312-cp312-win32.whl", hash = "sha256:f244b8e541f414664889e2c87cac11a07b918cb4b540c36f7ada7bfa76571ea2"},
+ {file = "aiohttp-3.11.16-cp312-cp312-win_amd64.whl", hash = "sha256:23a15727fbfccab973343b6d1b7181bfb0b4aa7ae280f36fd2f90f5476805682"},
+ {file = "aiohttp-3.11.16-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a3814760a1a700f3cfd2f977249f1032301d0a12c92aba74605cfa6ce9f78489"},
+ {file = "aiohttp-3.11.16-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:9b751a6306f330801665ae69270a8a3993654a85569b3469662efaad6cf5cc50"},
+ {file = "aiohttp-3.11.16-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ad497f38a0d6c329cb621774788583ee12321863cd4bd9feee1effd60f2ad133"},
+ {file = "aiohttp-3.11.16-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca37057625693d097543bd88076ceebeb248291df9d6ca8481349efc0b05dcd0"},
+ {file = "aiohttp-3.11.16-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a5abcbba9f4b463a45c8ca8b7720891200658f6f46894f79517e6cd11f3405ca"},
+ {file = "aiohttp-3.11.16-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f420bfe862fb357a6d76f2065447ef6f484bc489292ac91e29bc65d2d7a2c84d"},
+ {file = "aiohttp-3.11.16-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58ede86453a6cf2d6ce40ef0ca15481677a66950e73b0a788917916f7e35a0bb"},
+ {file = "aiohttp-3.11.16-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6fdec0213244c39973674ca2a7f5435bf74369e7d4e104d6c7473c81c9bcc8c4"},
+ {file = "aiohttp-3.11.16-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:72b1b03fb4655c1960403c131740755ec19c5898c82abd3961c364c2afd59fe7"},
+ {file = "aiohttp-3.11.16-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:780df0d837276276226a1ff803f8d0fa5f8996c479aeef52eb040179f3156cbd"},
+ {file = "aiohttp-3.11.16-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ecdb8173e6c7aa09eee342ac62e193e6904923bd232e76b4157ac0bfa670609f"},
+ {file = "aiohttp-3.11.16-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:a6db7458ab89c7d80bc1f4e930cc9df6edee2200127cfa6f6e080cf619eddfbd"},
+ {file = "aiohttp-3.11.16-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:2540ddc83cc724b13d1838026f6a5ad178510953302a49e6d647f6e1de82bc34"},
+ {file = "aiohttp-3.11.16-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:3b4e6db8dc4879015b9955778cfb9881897339c8fab7b3676f8433f849425913"},
+ {file = "aiohttp-3.11.16-cp313-cp313-win32.whl", hash = "sha256:493910ceb2764f792db4dc6e8e4b375dae1b08f72e18e8f10f18b34ca17d0979"},
+ {file = "aiohttp-3.11.16-cp313-cp313-win_amd64.whl", hash = "sha256:42864e70a248f5f6a49fdaf417d9bc62d6e4d8ee9695b24c5916cb4bb666c802"},
+ {file = "aiohttp-3.11.16-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bbcba75fe879ad6fd2e0d6a8d937f34a571f116a0e4db37df8079e738ea95c71"},
+ {file = "aiohttp-3.11.16-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:87a6e922b2b2401e0b0cf6b976b97f11ec7f136bfed445e16384fbf6fd5e8602"},
+ {file = "aiohttp-3.11.16-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ccf10f16ab498d20e28bc2b5c1306e9c1512f2840f7b6a67000a517a4b37d5ee"},
+ {file = "aiohttp-3.11.16-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb3d0cc5cdb926090748ea60172fa8a213cec728bd6c54eae18b96040fcd6227"},
+ {file = "aiohttp-3.11.16-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d07502cc14ecd64f52b2a74ebbc106893d9a9717120057ea9ea1fd6568a747e7"},
+ {file = "aiohttp-3.11.16-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:776c8e959a01e5e8321f1dec77964cb6101020a69d5a94cd3d34db6d555e01f7"},
+ {file = "aiohttp-3.11.16-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0902e887b0e1d50424112f200eb9ae3dfed6c0d0a19fc60f633ae5a57c809656"},
+ {file = "aiohttp-3.11.16-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e87fd812899aa78252866ae03a048e77bd11b80fb4878ce27c23cade239b42b2"},
+ {file = "aiohttp-3.11.16-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0a950c2eb8ff17361abd8c85987fd6076d9f47d040ebffce67dce4993285e973"},
+ {file = "aiohttp-3.11.16-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:c10d85e81d0b9ef87970ecbdbfaeec14a361a7fa947118817fcea8e45335fa46"},
+ {file = "aiohttp-3.11.16-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:7951decace76a9271a1ef181b04aa77d3cc309a02a51d73826039003210bdc86"},
+ {file = "aiohttp-3.11.16-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:14461157d8426bcb40bd94deb0450a6fa16f05129f7da546090cebf8f3123b0f"},
+ {file = "aiohttp-3.11.16-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:9756d9b9d4547e091f99d554fbba0d2a920aab98caa82a8fb3d3d9bee3c9ae85"},
+ {file = "aiohttp-3.11.16-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:87944bd16b7fe6160607f6a17808abd25f17f61ae1e26c47a491b970fb66d8cb"},
+ {file = "aiohttp-3.11.16-cp39-cp39-win32.whl", hash = "sha256:92b7ee222e2b903e0a4b329a9943d432b3767f2d5029dbe4ca59fb75223bbe2e"},
+ {file = "aiohttp-3.11.16-cp39-cp39-win_amd64.whl", hash = "sha256:17ae4664031aadfbcb34fd40ffd90976671fa0c0286e6c4113989f78bebab37a"},
+ {file = "aiohttp-3.11.16.tar.gz", hash = "sha256:16f8a2c9538c14a557b4d309ed4d0a7c60f0253e8ed7b6c9a2859a7582f8b1b8"},
]
[package.dependencies]
+aiohappyeyeballs = ">=2.3.0"
aiosignal = ">=1.1.2"
-async-timeout = ">=4.0.0a3,<5.0"
-asynctest = {version = "0.13.0", markers = "python_version < \"3.8\""}
+async-timeout = {version = ">=4.0,<6.0", markers = "python_version < \"3.11\""}
attrs = ">=17.3.0"
-charset-normalizer = ">=2.0,<4.0"
frozenlist = ">=1.1.1"
multidict = ">=4.5,<7.0"
-typing-extensions = {version = ">=3.7.4", markers = "python_version < \"3.8\""}
-yarl = ">=1.0,<2.0"
+propcache = ">=0.2.0"
+yarl = ">=1.17.0,<2.0"
[package.extras]
-speedups = ["Brotli", "aiodns", "cchardet"]
+speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"]
[[package]]
name = "aiosignal"
-version = "1.3.1"
+version = "1.3.2"
description = "aiosignal: a list of registered asynchronous callbacks"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.9"
+groups = ["main"]
files = [
- {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"},
- {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"},
+ {file = "aiosignal-1.3.2-py2.py3-none-any.whl", hash = "sha256:45cde58e409a301715980c2b01d0c28bdde3770d8290b5eb2173759d9acb31a5"},
+ {file = "aiosignal-1.3.2.tar.gz", hash = "sha256:a8c255c66fafb1e499c9351d0bf32ff2d8a0321595ebac3b93713656d2436f54"},
]
[package.dependencies]
@@ -126,167 +133,161 @@ frozenlist = ">=1.1.0"
[[package]]
name = "annotated-types"
-version = "0.5.0"
+version = "0.7.0"
description = "Reusable constraint types to use with typing.Annotated"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
+groups = ["main"]
files = [
- {file = "annotated_types-0.5.0-py3-none-any.whl", hash = "sha256:58da39888f92c276ad970249761ebea80ba544b77acddaa1a4d6cf78287d45fd"},
- {file = "annotated_types-0.5.0.tar.gz", hash = "sha256:47cdc3490d9ac1506ce92c7aaa76c579dc3509ff11e098fc867e5130ab7be802"},
+ {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"},
+ {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
]
-[package.dependencies]
-typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""}
-
[[package]]
name = "async-timeout"
-version = "4.0.3"
+version = "5.0.1"
description = "Timeout context manager for asyncio programs"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
+groups = ["main"]
+markers = "python_version < \"3.11\""
files = [
- {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"},
- {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"},
-]
-
-[package.dependencies]
-typing-extensions = {version = ">=3.6.5", markers = "python_version < \"3.8\""}
-
-[[package]]
-name = "asynctest"
-version = "0.13.0"
-description = "Enhance the standard unittest package with features for testing asyncio libraries"
-optional = false
-python-versions = ">=3.5"
-files = [
- {file = "asynctest-0.13.0-py3-none-any.whl", hash = "sha256:5da6118a7e6d6b54d83a8f7197769d046922a44d2a99c21382f0a6e4fadae676"},
- {file = "asynctest-0.13.0.tar.gz", hash = "sha256:c27862842d15d83e6a34eb0b2866c323880eb3a75e4485b079ea11748fd77fac"},
-]
-
-[[package]]
-name = "atomicwrites"
-version = "1.4.1"
-description = "Atomic file writes."
-optional = false
-python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
-files = [
- {file = "atomicwrites-1.4.1.tar.gz", hash = "sha256:81b2c9071a49367a7f770170e5eec8cb66567cfbbc8c73d20ce5ca4a8d71cf11"},
+ {file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"},
+ {file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"},
]
[[package]]
name = "attrs"
-version = "23.1.0"
+version = "25.3.0"
description = "Classes Without Boilerplate"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
+groups = ["main"]
files = [
- {file = "attrs-23.1.0-py3-none-any.whl", hash = "sha256:1f28b4522cdc2fb4256ac1a020c78acf9cba2c6b461ccd2c126f3aa8e8335d04"},
- {file = "attrs-23.1.0.tar.gz", hash = "sha256:6279836d581513a26f1bf235f9acd333bc9115683f14f7e8fae46c98fc50e015"},
+ {file = "attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3"},
+ {file = "attrs-25.3.0.tar.gz", hash = "sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b"},
]
-[package.dependencies]
-importlib-metadata = {version = "*", markers = "python_version < \"3.8\""}
-
[package.extras]
-cov = ["attrs[tests]", "coverage[toml] (>=5.3)"]
-dev = ["attrs[docs,tests]", "pre-commit"]
-docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"]
-tests = ["attrs[tests-no-zope]", "zope-interface"]
-tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
+benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
+cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
+dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit-uv", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
+docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier"]
+tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
+tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"]
[[package]]
name = "certifi"
-version = "2023.7.22"
+version = "2025.1.31"
description = "Python package for providing Mozilla's CA Bundle."
optional = false
python-versions = ">=3.6"
+groups = ["main"]
files = [
- {file = "certifi-2023.7.22-py3-none-any.whl", hash = "sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9"},
- {file = "certifi-2023.7.22.tar.gz", hash = "sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082"},
+ {file = "certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe"},
+ {file = "certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651"},
]
[[package]]
name = "charset-normalizer"
-version = "3.2.0"
+version = "3.4.1"
description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
optional = false
-python-versions = ">=3.7.0"
+python-versions = ">=3.7"
+groups = ["main"]
files = [
- {file = "charset-normalizer-3.2.0.tar.gz", hash = "sha256:3bb3d25a8e6c0aedd251753a79ae98a093c7e7b471faa3aa9a93a81431987ace"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b87549028f680ca955556e3bd57013ab47474c3124dc069faa0b6545b6c9710"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7c70087bfee18a42b4040bb9ec1ca15a08242cf5867c58726530bdf3945672ed"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a103b3a7069b62f5d4890ae1b8f0597618f628b286b03d4bc9195230b154bfa9"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94aea8eff76ee6d1cdacb07dd2123a68283cb5569e0250feab1240058f53b623"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:db901e2ac34c931d73054d9797383d0f8009991e723dab15109740a63e7f902a"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b0dac0ff919ba34d4df1b6131f59ce95b08b9065233446be7e459f95554c0dc8"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:193cbc708ea3aca45e7221ae58f0fd63f933753a9bfb498a3b474878f12caaad"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09393e1b2a9461950b1c9a45d5fd251dc7c6f228acab64da1c9c0165d9c7765c"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:baacc6aee0b2ef6f3d308e197b5d7a81c0e70b06beae1f1fcacffdbd124fe0e3"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:bf420121d4c8dce6b889f0e8e4ec0ca34b7f40186203f06a946fa0276ba54029"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:c04a46716adde8d927adb9457bbe39cf473e1e2c2f5d0a16ceb837e5d841ad4f"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:aaf63899c94de41fe3cf934601b0f7ccb6b428c6e4eeb80da72c58eab077b19a"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d62e51710986674142526ab9f78663ca2b0726066ae26b78b22e0f5e571238dd"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-win32.whl", hash = "sha256:04e57ab9fbf9607b77f7d057974694b4f6b142da9ed4a199859d9d4d5c63fe96"},
- {file = "charset_normalizer-3.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:48021783bdf96e3d6de03a6e39a1171ed5bd7e8bb93fc84cc649d11490f87cea"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4957669ef390f0e6719db3613ab3a7631e68424604a7b448f079bee145da6e09"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:46fb8c61d794b78ec7134a715a3e564aafc8f6b5e338417cb19fe9f57a5a9bf2"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f779d3ad205f108d14e99bb3859aa7dd8e9c68874617c72354d7ecaec2a054ac"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f25c229a6ba38a35ae6e25ca1264621cc25d4d38dca2942a7fce0b67a4efe918"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2efb1bd13885392adfda4614c33d3b68dee4921fd0ac1d3988f8cbb7d589e72a"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f30b48dd7fa1474554b0b0f3fdfdd4c13b5c737a3c6284d3cdc424ec0ffff3a"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:246de67b99b6851627d945db38147d1b209a899311b1305dd84916f2b88526c6"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bd9b3b31adcb054116447ea22caa61a285d92e94d710aa5ec97992ff5eb7cf3"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:8c2f5e83493748286002f9369f3e6607c565a6a90425a3a1fef5ae32a36d749d"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:3170c9399da12c9dc66366e9d14da8bf7147e1e9d9ea566067bbce7bb74bd9c2"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7a4826ad2bd6b07ca615c74ab91f32f6c96d08f6fcc3902ceeedaec8cdc3bcd6"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:3b1613dd5aee995ec6d4c69f00378bbd07614702a315a2cf6c1d21461fe17c23"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9e608aafdb55eb9f255034709e20d5a83b6d60c054df0802fa9c9883d0a937aa"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-win32.whl", hash = "sha256:f2a1d0fd4242bd8643ce6f98927cf9c04540af6efa92323e9d3124f57727bfc1"},
- {file = "charset_normalizer-3.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:681eb3d7e02e3c3655d1b16059fbfb605ac464c834a0c629048a30fad2b27489"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c57921cda3a80d0f2b8aec7e25c8aa14479ea92b5b51b6876d975d925a2ea346"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41b25eaa7d15909cf3ac4c96088c1f266a9a93ec44f87f1d13d4a0e86c81b982"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f058f6963fd82eb143c692cecdc89e075fa0828db2e5b291070485390b2f1c9c"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7647ebdfb9682b7bb97e2a5e7cb6ae735b1c25008a70b906aecca294ee96cf4"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eef9df1eefada2c09a5e7a40991b9fc6ac6ef20b1372abd48d2794a316dc0449"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e03b8895a6990c9ab2cdcd0f2fe44088ca1c65ae592b8f795c3294af00a461c3"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:ee4006268ed33370957f55bf2e6f4d263eaf4dc3cfc473d1d90baff6ed36ce4a"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c4983bf937209c57240cff65906b18bb35e64ae872da6a0db937d7b4af845dd7"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:3bb7fda7260735efe66d5107fb7e6af6a7c04c7fce9b2514e04b7a74b06bf5dd"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:72814c01533f51d68702802d74f77ea026b5ec52793c791e2da806a3844a46c3"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:70c610f6cbe4b9fce272c407dd9d07e33e6bf7b4aa1b7ffb6f6ded8e634e3592"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-win32.whl", hash = "sha256:a401b4598e5d3f4a9a811f3daf42ee2291790c7f9d74b18d75d6e21dda98a1a1"},
- {file = "charset_normalizer-3.2.0-cp37-cp37m-win_amd64.whl", hash = "sha256:c0b21078a4b56965e2b12f247467b234734491897e99c1d51cee628da9786959"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:95eb302ff792e12aba9a8b8f8474ab229a83c103d74a750ec0bd1c1eea32e669"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1a100c6d595a7f316f1b6f01d20815d916e75ff98c27a01ae817439ea7726329"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6339d047dab2780cc6220f46306628e04d9750f02f983ddb37439ca47ced7149"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4b749b9cc6ee664a3300bb3a273c1ca8068c46be705b6c31cf5d276f8628a94"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a38856a971c602f98472050165cea2cdc97709240373041b69030be15047691f"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f87f746ee241d30d6ed93969de31e5ffd09a2961a051e60ae6bddde9ec3583aa"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89f1b185a01fe560bc8ae5f619e924407efca2191b56ce749ec84982fc59a32a"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e1c8a2f4c69e08e89632defbfabec2feb8a8d99edc9f89ce33c4b9e36ab63037"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2f4ac36d8e2b4cc1aa71df3dd84ff8efbe3bfb97ac41242fbcfc053c67434f46"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a386ebe437176aab38c041de1260cd3ea459c6ce5263594399880bbc398225b2"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:ccd16eb18a849fd8dcb23e23380e2f0a354e8daa0c984b8a732d9cfaba3a776d"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:e6a5bf2cba5ae1bb80b154ed68a3cfa2fa00fde979a7f50d6598d3e17d9ac20c"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:45de3f87179c1823e6d9e32156fb14c1927fcc9aba21433f088fdfb555b77c10"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-win32.whl", hash = "sha256:1000fba1057b92a65daec275aec30586c3de2401ccdcd41f8a5c1e2c87078706"},
- {file = "charset_normalizer-3.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:8b2c760cfc7042b27ebdb4a43a4453bd829a5742503599144d54a032c5dc7e9e"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:855eafa5d5a2034b4621c74925d89c5efef61418570e5ef9b37717d9c796419c"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:203f0c8871d5a7987be20c72442488a0b8cfd0f43b7973771640fc593f56321f"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e857a2232ba53ae940d3456f7533ce6ca98b81917d47adc3c7fd55dad8fab858"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e86d77b090dbddbe78867a0275cb4df08ea195e660f1f7f13435a4649e954e5"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4fb39a81950ec280984b3a44f5bd12819953dc5fa3a7e6fa7a80db5ee853952"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2dee8e57f052ef5353cf608e0b4c871aee320dd1b87d351c28764fc0ca55f9f4"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8700f06d0ce6f128de3ccdbc1acaea1ee264d2caa9ca05daaf492fde7c2a7200"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1920d4ff15ce893210c1f0c0e9d19bfbecb7983c76b33f046c13a8ffbd570252"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c1c76a1743432b4b60ab3358c937a3fe1341c828ae6194108a94c69028247f22"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f7560358a6811e52e9c4d142d497f1a6e10103d3a6881f18d04dbce3729c0e2c"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:c8063cf17b19661471ecbdb3df1c84f24ad2e389e326ccaf89e3fb2484d8dd7e"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:cd6dbe0238f7743d0efe563ab46294f54f9bc8f4b9bcf57c3c666cc5bc9d1299"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:1249cbbf3d3b04902ff081ffbb33ce3377fa6e4c7356f759f3cd076cc138d020"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-win32.whl", hash = "sha256:6c409c0deba34f147f77efaa67b8e4bb83d2f11c8806405f76397ae5b8c0d1c9"},
- {file = "charset_normalizer-3.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:7095f6fbfaa55defb6b733cfeb14efaae7a29f0b59d8cf213be4e7ca0b857b80"},
- {file = "charset_normalizer-3.2.0-py3-none-any.whl", hash = "sha256:8e098148dd37b4ce3baca71fb394c81dc5d9c7728c95df695d2dca218edf40e6"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:91b36a978b5ae0ee86c394f5a54d6ef44db1de0815eb43de826d41d21e4af3de"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7461baadb4dc00fd9e0acbe254e3d7d2112e7f92ced2adc96e54ef6501c5f176"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e218488cd232553829be0664c2292d3af2eeeb94b32bea483cf79ac6a694e037"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80ed5e856eb7f30115aaf94e4a08114ccc8813e6ed1b5efa74f9f82e8509858f"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b010a7a4fd316c3c484d482922d13044979e78d1861f0e0650423144c616a46a"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4532bff1b8421fd0a320463030c7520f56a79c9024a4e88f01c537316019005a"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d973f03c0cb71c5ed99037b870f2be986c3c05e63622c017ea9816881d2dd247"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3a3bd0dcd373514dcec91c411ddb9632c0d7d92aed7093b8c3bbb6d69ca74408"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:d9c3cdf5390dcd29aa8056d13e8e99526cda0305acc038b96b30352aff5ff2bb"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:2bdfe3ac2e1bbe5b59a1a63721eb3b95fc9b6817ae4a46debbb4e11f6232428d"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:eab677309cdb30d047996b36d34caeda1dc91149e4fdca0b1a039b3f79d9a807"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-win32.whl", hash = "sha256:c0429126cf75e16c4f0ad00ee0eae4242dc652290f940152ca8c75c3a4b6ee8f"},
+ {file = "charset_normalizer-3.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:9f0b8b1c6d84c8034a44893aba5e767bf9c7a211e313a9605d9c617d7083829f"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8bfa33f4f2672964266e940dd22a195989ba31669bd84629f05fab3ef4e2d125"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28bf57629c75e810b6ae989f03c0828d64d6b26a5e205535585f96093e405ed1"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f08ff5e948271dc7e18a35641d2f11a4cd8dfd5634f55228b691e62b37125eb3"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:234ac59ea147c59ee4da87a0c0f098e9c8d169f4dc2a159ef720f1a61bbe27cd"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd4ec41f914fa74ad1b8304bbc634b3de73d2a0889bd32076342a573e0779e00"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eea6ee1db730b3483adf394ea72f808b6e18cf3cb6454b4d86e04fa8c4327a12"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c96836c97b1238e9c9e3fe90844c947d5afbf4f4c92762679acfe19927d81d77"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4d86f7aff21ee58f26dcf5ae81a9addbd914115cdebcbb2217e4f0ed8982e146"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:09b5e6733cbd160dcc09589227187e242a30a49ca5cefa5a7edd3f9d19ed53fd"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:5777ee0881f9499ed0f71cc82cf873d9a0ca8af166dfa0af8ec4e675b7df48e6"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:237bdbe6159cff53b4f24f397d43c6336c6b0b42affbe857970cefbb620911c8"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-win32.whl", hash = "sha256:8417cb1f36cc0bc7eaba8ccb0e04d55f0ee52df06df3ad55259b9a323555fc8b"},
+ {file = "charset_normalizer-3.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:d7f50a1f8c450f3925cb367d011448c39239bb3eb4117c36a6d354794de4ce76"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:73d94b58ec7fecbc7366247d3b0b10a21681004153238750bb67bd9012414545"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dad3e487649f498dd991eeb901125411559b22e8d7ab25d3aeb1af367df5efd7"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c30197aa96e8eed02200a83fba2657b4c3acd0f0aa4bdc9f6c1af8e8962e0757"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2369eea1ee4a7610a860d88f268eb39b95cb588acd7235e02fd5a5601773d4fa"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc2722592d8998c870fa4e290c2eec2c1569b87fe58618e67d38b4665dfa680d"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffc9202a29ab3920fa812879e95a9e78b2465fd10be7fcbd042899695d75e616"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:804a4d582ba6e5b747c625bf1255e6b1507465494a40a2130978bda7b932c90b"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:0f55e69f030f7163dffe9fd0752b32f070566451afe180f99dbeeb81f511ad8d"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:c4c3e6da02df6fa1410a7680bd3f63d4f710232d3139089536310d027950696a"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:5df196eb874dae23dcfb968c83d4f8fdccb333330fe1fc278ac5ceeb101003a9"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e358e64305fe12299a08e08978f51fc21fac060dcfcddd95453eabe5b93ed0e1"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-win32.whl", hash = "sha256:9b23ca7ef998bc739bf6ffc077c2116917eabcc901f88da1b9856b210ef63f35"},
+ {file = "charset_normalizer-3.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:6ff8a4a60c227ad87030d76e99cd1698345d4491638dfa6673027c48b3cd395f"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:aabfa34badd18f1da5ec1bc2715cadc8dca465868a4e73a0173466b688f29dda"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22e14b5d70560b8dd51ec22863f370d1e595ac3d024cb8ad7d308b4cd95f8313"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8436c508b408b82d87dc5f62496973a1805cd46727c34440b0d29d8a2f50a6c9"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2d074908e1aecee37a7635990b2c6d504cd4766c7bc9fc86d63f9c09af3fa11b"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:955f8851919303c92343d2f66165294848d57e9bba6cf6e3625485a70a038d11"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:44ecbf16649486d4aebafeaa7ec4c9fed8b88101f4dd612dcaf65d5e815f837f"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0924e81d3d5e70f8126529951dac65c1010cdf117bb75eb02dd12339b57749dd"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2967f74ad52c3b98de4c3b32e1a44e32975e008a9cd2a8cc8966d6a5218c5cb2"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:c75cb2a3e389853835e84a2d8fb2b81a10645b503eca9bcb98df6b5a43eb8886"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:09b26ae6b1abf0d27570633b2b078a2a20419c99d66fb2823173d73f188ce601"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:fa88b843d6e211393a37219e6a1c1df99d35e8fd90446f1118f4216e307e48cd"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-win32.whl", hash = "sha256:eb8178fe3dba6450a3e024e95ac49ed3400e506fd4e9e5c32d30adda88cbd407"},
+ {file = "charset_normalizer-3.4.1-cp313-cp313-win_amd64.whl", hash = "sha256:b1ac5992a838106edb89654e0aebfc24f5848ae2547d22c2c3f66454daa11971"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f30bf9fd9be89ecb2360c7d94a711f00c09b976258846efe40db3d05828e8089"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:97f68b8d6831127e4787ad15e6757232e14e12060bec17091b85eb1486b91d8d"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7974a0b5ecd505609e3b19742b60cee7aa2aa2fb3151bc917e6e2646d7667dcf"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc54db6c8593ef7d4b2a331b58653356cf04f67c960f584edb7c3d8c97e8f39e"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:311f30128d7d333eebd7896965bfcfbd0065f1716ec92bd5638d7748eb6f936a"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:7d053096f67cd1241601111b698f5cad775f97ab25d81567d3f59219b5f1adbd"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:807f52c1f798eef6cf26beb819eeb8819b1622ddfeef9d0977a8502d4db6d534"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:dccbe65bd2f7f7ec22c4ff99ed56faa1e9f785482b9bbd7c717e26fd723a1d1e"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_s390x.whl", hash = "sha256:2fb9bd477fdea8684f78791a6de97a953c51831ee2981f8e4f583ff3b9d9687e"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:01732659ba9b5b873fc117534143e4feefecf3b2078b0a6a2e925271bb6f4cfa"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-win32.whl", hash = "sha256:7a4f97a081603d2050bfaffdefa5b02a9ec823f8348a572e39032caa8404a487"},
+ {file = "charset_normalizer-3.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:7b1bef6280950ee6c177b326508f86cad7ad4dff12454483b51d8b7d673a2c5d"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ecddf25bee22fe4fe3737a399d0d177d72bc22be6913acfab364b40bce1ba83c"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c60ca7339acd497a55b0ea5d506b2a2612afb2826560416f6894e8b5770d4a9"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b7b2d86dd06bfc2ade3312a83a5c364c7ec2e3498f8734282c6c3d4b07b346b8"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dd78cfcda14a1ef52584dbb008f7ac81c1328c0f58184bf9a84c49c605002da6"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e27f48bcd0957c6d4cb9d6fa6b61d192d0b13d5ef563e5f2ae35feafc0d179c"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:01ad647cdd609225c5350561d084b42ddf732f4eeefe6e678765636791e78b9a"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:619a609aa74ae43d90ed2e89bdd784765de0a25ca761b93e196d938b8fd1dbbd"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:89149166622f4db9b4b6a449256291dc87a99ee53151c74cbd82a53c8c2f6ccd"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:7709f51f5f7c853f0fb938bcd3bc59cdfdc5203635ffd18bf354f6967ea0f824"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:345b0426edd4e18138d6528aed636de7a9ed169b4aaf9d61a8c19e39d26838ca"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0907f11d019260cdc3f94fbdb23ff9125f6b5d1039b76003b5b0ac9d6a6c9d5b"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-win32.whl", hash = "sha256:ea0d8d539afa5eb2728aa1932a988a9a7af94f18582ffae4bc10b3fbdad0626e"},
+ {file = "charset_normalizer-3.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:329ce159e82018d646c7ac45b01a430369d526569ec08516081727a20e9e4af4"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b97e690a2118911e39b4042088092771b4ae3fc3aa86518f84b8cf6888dbdb41"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78baa6d91634dfb69ec52a463534bc0df05dbd546209b79a3880a34487f4b84f"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1a2bc9f351a75ef49d664206d51f8e5ede9da246602dc2d2726837620ea034b2"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:75832c08354f595c760a804588b9357d34ec00ba1c940c15e31e96d902093770"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0af291f4fe114be0280cdd29d533696a77b5b49cfde5467176ecab32353395c4"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0167ddc8ab6508fe81860a57dd472b2ef4060e8d378f0cc555707126830f2537"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2a75d49014d118e4198bcee5ee0a6f25856b29b12dbf7cd012791f8a6cc5c496"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:363e2f92b0f0174b2f8238240a1a30142e3db7b957a5dd5689b0e75fb717cc78"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ab36c8eb7e454e34e60eb55ca5d241a5d18b2c6244f6827a30e451c42410b5f7"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:4c0907b1928a36d5a998d72d64d8eaa7244989f7aaaf947500d3a800c83a3fd6"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:04432ad9479fa40ec0f387795ddad4437a2b50417c69fa275e212933519ff294"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-win32.whl", hash = "sha256:3bed14e9c89dcb10e8f3a29f9ccac4955aebe93c71ae803af79265c9ca5644c5"},
+ {file = "charset_normalizer-3.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:49402233c892a461407c512a19435d1ce275543138294f7ef013f0b63d5d3765"},
+ {file = "charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85"},
+ {file = "charset_normalizer-3.4.1.tar.gz", hash = "sha256:44251f18cd68a75b56585dd00dae26183e102cd5e0f9f1466e6df5da2ed64ea3"},
]
[[package]]
@@ -295,78 +296,84 @@ version = "0.4.6"
description = "Cross-platform colored terminal text."
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
+groups = ["main", "dev"]
files = [
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
]
+markers = {main = "platform_system == \"Windows\"", dev = "sys_platform == \"win32\""}
[[package]]
name = "coverage"
-version = "7.2.7"
+version = "7.8.0"
description = "Code coverage measurement for Python"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.9"
+groups = ["dev"]
files = [
- {file = "coverage-7.2.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d39b5b4f2a66ccae8b7263ac3c8170994b65266797fb96cbbfd3fb5b23921db8"},
- {file = "coverage-7.2.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6d040ef7c9859bb11dfeb056ff5b3872436e3b5e401817d87a31e1750b9ae2fb"},
- {file = "coverage-7.2.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba90a9563ba44a72fda2e85302c3abc71c5589cea608ca16c22b9804262aaeb6"},
- {file = "coverage-7.2.7-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e7d9405291c6928619403db1d10bd07888888ec1abcbd9748fdaa971d7d661b2"},
- {file = "coverage-7.2.7-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31563e97dae5598556600466ad9beea39fb04e0229e61c12eaa206e0aa202063"},
- {file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ebba1cd308ef115925421d3e6a586e655ca5a77b5bf41e02eb0e4562a111f2d1"},
- {file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:cb017fd1b2603ef59e374ba2063f593abe0fc45f2ad9abdde5b4d83bd922a353"},
- {file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d62a5c7dad11015c66fbb9d881bc4caa5b12f16292f857842d9d1871595f4495"},
- {file = "coverage-7.2.7-cp310-cp310-win32.whl", hash = "sha256:ee57190f24fba796e36bb6d3aa8a8783c643d8fa9760c89f7a98ab5455fbf818"},
- {file = "coverage-7.2.7-cp310-cp310-win_amd64.whl", hash = "sha256:f75f7168ab25dd93110c8a8117a22450c19976afbc44234cbf71481094c1b850"},
- {file = "coverage-7.2.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:06a9a2be0b5b576c3f18f1a241f0473575c4a26021b52b2a85263a00f034d51f"},
- {file = "coverage-7.2.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5baa06420f837184130752b7c5ea0808762083bf3487b5038d68b012e5937dbe"},
- {file = "coverage-7.2.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdec9e8cbf13a5bf63290fc6013d216a4c7232efb51548594ca3631a7f13c3a3"},
- {file = "coverage-7.2.7-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:52edc1a60c0d34afa421c9c37078817b2e67a392cab17d97283b64c5833f427f"},
- {file = "coverage-7.2.7-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63426706118b7f5cf6bb6c895dc215d8a418d5952544042c8a2d9fe87fcf09cb"},
- {file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:afb17f84d56068a7c29f5fa37bfd38d5aba69e3304af08ee94da8ed5b0865833"},
- {file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:48c19d2159d433ccc99e729ceae7d5293fbffa0bdb94952d3579983d1c8c9d97"},
- {file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0e1f928eaf5469c11e886fe0885ad2bf1ec606434e79842a879277895a50942a"},
- {file = "coverage-7.2.7-cp311-cp311-win32.whl", hash = "sha256:33d6d3ea29d5b3a1a632b3c4e4f4ecae24ef170b0b9ee493883f2df10039959a"},
- {file = "coverage-7.2.7-cp311-cp311-win_amd64.whl", hash = "sha256:5b7540161790b2f28143191f5f8ec02fb132660ff175b7747b95dcb77ac26562"},
- {file = "coverage-7.2.7-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f2f67fe12b22cd130d34d0ef79206061bfb5eda52feb6ce0dba0644e20a03cf4"},
- {file = "coverage-7.2.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a342242fe22407f3c17f4b499276a02b01e80f861f1682ad1d95b04018e0c0d4"},
- {file = "coverage-7.2.7-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:171717c7cb6b453aebac9a2ef603699da237f341b38eebfee9be75d27dc38e01"},
- {file = "coverage-7.2.7-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49969a9f7ffa086d973d91cec8d2e31080436ef0fb4a359cae927e742abfaaa6"},
- {file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b46517c02ccd08092f4fa99f24c3b83d8f92f739b4657b0f146246a0ca6a831d"},
- {file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:a3d33a6b3eae87ceaefa91ffdc130b5e8536182cd6dfdbfc1aa56b46ff8c86de"},
- {file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:976b9c42fb2a43ebf304fa7d4a310e5f16cc99992f33eced91ef6f908bd8f33d"},
- {file = "coverage-7.2.7-cp312-cp312-win32.whl", hash = "sha256:8de8bb0e5ad103888d65abef8bca41ab93721647590a3f740100cd65c3b00511"},
- {file = "coverage-7.2.7-cp312-cp312-win_amd64.whl", hash = "sha256:9e31cb64d7de6b6f09702bb27c02d1904b3aebfca610c12772452c4e6c21a0d3"},
- {file = "coverage-7.2.7-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:58c2ccc2f00ecb51253cbe5d8d7122a34590fac9646a960d1430d5b15321d95f"},
- {file = "coverage-7.2.7-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d22656368f0e6189e24722214ed8d66b8022db19d182927b9a248a2a8a2f67eb"},
- {file = "coverage-7.2.7-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a895fcc7b15c3fc72beb43cdcbdf0ddb7d2ebc959edac9cef390b0d14f39f8a9"},
- {file = "coverage-7.2.7-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e84606b74eb7de6ff581a7915e2dab7a28a0517fbe1c9239eb227e1354064dcd"},
- {file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:0a5f9e1dbd7fbe30196578ca36f3fba75376fb99888c395c5880b355e2875f8a"},
- {file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:419bfd2caae268623dd469eff96d510a920c90928b60f2073d79f8fe2bbc5959"},
- {file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:2aee274c46590717f38ae5e4650988d1af340fe06167546cc32fe2f58ed05b02"},
- {file = "coverage-7.2.7-cp37-cp37m-win32.whl", hash = "sha256:61b9a528fb348373c433e8966535074b802c7a5d7f23c4f421e6c6e2f1697a6f"},
- {file = "coverage-7.2.7-cp37-cp37m-win_amd64.whl", hash = "sha256:b1c546aca0ca4d028901d825015dc8e4d56aac4b541877690eb76490f1dc8ed0"},
- {file = "coverage-7.2.7-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:54b896376ab563bd38453cecb813c295cf347cf5906e8b41d340b0321a5433e5"},
- {file = "coverage-7.2.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3d376df58cc111dc8e21e3b6e24606b5bb5dee6024f46a5abca99124b2229ef5"},
- {file = "coverage-7.2.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e330fc79bd7207e46c7d7fd2bb4af2963f5f635703925543a70b99574b0fea9"},
- {file = "coverage-7.2.7-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e9d683426464e4a252bf70c3498756055016f99ddaec3774bf368e76bbe02b6"},
- {file = "coverage-7.2.7-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d13c64ee2d33eccf7437961b6ea7ad8673e2be040b4f7fd4fd4d4d28d9ccb1e"},
- {file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b7aa5f8a41217360e600da646004f878250a0d6738bcdc11a0a39928d7dc2050"},
- {file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8fa03bce9bfbeeef9f3b160a8bed39a221d82308b4152b27d82d8daa7041fee5"},
- {file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:245167dd26180ab4c91d5e1496a30be4cd721a5cf2abf52974f965f10f11419f"},
- {file = "coverage-7.2.7-cp38-cp38-win32.whl", hash = "sha256:d2c2db7fd82e9b72937969bceac4d6ca89660db0a0967614ce2481e81a0b771e"},
- {file = "coverage-7.2.7-cp38-cp38-win_amd64.whl", hash = "sha256:2e07b54284e381531c87f785f613b833569c14ecacdcb85d56b25c4622c16c3c"},
- {file = "coverage-7.2.7-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:537891ae8ce59ef63d0123f7ac9e2ae0fc8b72c7ccbe5296fec45fd68967b6c9"},
- {file = "coverage-7.2.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:06fb182e69f33f6cd1d39a6c597294cff3143554b64b9825d1dc69d18cc2fff2"},
- {file = "coverage-7.2.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:201e7389591af40950a6480bd9edfa8ed04346ff80002cec1a66cac4549c1ad7"},
- {file = "coverage-7.2.7-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f6951407391b639504e3b3be51b7ba5f3528adbf1a8ac3302b687ecababf929e"},
- {file = "coverage-7.2.7-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f48351d66575f535669306aa7d6d6f71bc43372473b54a832222803eb956fd1"},
- {file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b29019c76039dc3c0fd815c41392a044ce555d9bcdd38b0fb60fb4cd8e475ba9"},
- {file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:81c13a1fc7468c40f13420732805a4c38a105d89848b7c10af65a90beff25250"},
- {file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:975d70ab7e3c80a3fe86001d8751f6778905ec723f5b110aed1e450da9d4b7f2"},
- {file = "coverage-7.2.7-cp39-cp39-win32.whl", hash = "sha256:7ee7d9d4822c8acc74a5e26c50604dff824710bc8de424904c0982e25c39c6cb"},
- {file = "coverage-7.2.7-cp39-cp39-win_amd64.whl", hash = "sha256:eb393e5ebc85245347950143969b241d08b52b88a3dc39479822e073a1a8eb27"},
- {file = "coverage-7.2.7-pp37.pp38.pp39-none-any.whl", hash = "sha256:b7b4c971f05e6ae490fef852c218b0e79d4e52f79ef0c8475566584a8fb3e01d"},
- {file = "coverage-7.2.7.tar.gz", hash = "sha256:924d94291ca674905fe9481f12294eb11f2d3d3fd1adb20314ba89e94f44ed59"},
+ {file = "coverage-7.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2931f66991175369859b5fd58529cd4b73582461877ecfd859b6549869287ffe"},
+ {file = "coverage-7.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:52a523153c568d2c0ef8826f6cc23031dc86cffb8c6aeab92c4ff776e7951b28"},
+ {file = "coverage-7.8.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c8a5c139aae4c35cbd7cadca1df02ea8cf28a911534fc1b0456acb0b14234f3"},
+ {file = "coverage-7.8.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5a26c0c795c3e0b63ec7da6efded5f0bc856d7c0b24b2ac84b4d1d7bc578d676"},
+ {file = "coverage-7.8.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:821f7bcbaa84318287115d54becb1915eece6918136c6f91045bb84e2f88739d"},
+ {file = "coverage-7.8.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a321c61477ff8ee705b8a5fed370b5710c56b3a52d17b983d9215861e37b642a"},
+ {file = "coverage-7.8.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:ed2144b8a78f9d94d9515963ed273d620e07846acd5d4b0a642d4849e8d91a0c"},
+ {file = "coverage-7.8.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:042e7841a26498fff7a37d6fda770d17519982f5b7d8bf5278d140b67b61095f"},
+ {file = "coverage-7.8.0-cp310-cp310-win32.whl", hash = "sha256:f9983d01d7705b2d1f7a95e10bbe4091fabc03a46881a256c2787637b087003f"},
+ {file = "coverage-7.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:5a570cd9bd20b85d1a0d7b009aaf6c110b52b5755c17be6962f8ccd65d1dbd23"},
+ {file = "coverage-7.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e7ac22a0bb2c7c49f441f7a6d46c9c80d96e56f5a8bc6972529ed43c8b694e27"},
+ {file = "coverage-7.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bf13d564d310c156d1c8e53877baf2993fb3073b2fc9f69790ca6a732eb4bfea"},
+ {file = "coverage-7.8.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5761c70c017c1b0d21b0815a920ffb94a670c8d5d409d9b38857874c21f70d7"},
+ {file = "coverage-7.8.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5ff52d790c7e1628241ffbcaeb33e07d14b007b6eb00a19320c7b8a7024c040"},
+ {file = "coverage-7.8.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d39fc4817fd67b3915256af5dda75fd4ee10621a3d484524487e33416c6f3543"},
+ {file = "coverage-7.8.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b44674870709017e4b4036e3d0d6c17f06a0e6d4436422e0ad29b882c40697d2"},
+ {file = "coverage-7.8.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:8f99eb72bf27cbb167b636eb1726f590c00e1ad375002230607a844d9e9a2318"},
+ {file = "coverage-7.8.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b571bf5341ba8c6bc02e0baeaf3b061ab993bf372d982ae509807e7f112554e9"},
+ {file = "coverage-7.8.0-cp311-cp311-win32.whl", hash = "sha256:e75a2ad7b647fd8046d58c3132d7eaf31b12d8a53c0e4b21fa9c4d23d6ee6d3c"},
+ {file = "coverage-7.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:3043ba1c88b2139126fc72cb48574b90e2e0546d4c78b5299317f61b7f718b78"},
+ {file = "coverage-7.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:bbb5cc845a0292e0c520656d19d7ce40e18d0e19b22cb3e0409135a575bf79fc"},
+ {file = "coverage-7.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4dfd9a93db9e78666d178d4f08a5408aa3f2474ad4d0e0378ed5f2ef71640cb6"},
+ {file = "coverage-7.8.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f017a61399f13aa6d1039f75cd467be388d157cd81f1a119b9d9a68ba6f2830d"},
+ {file = "coverage-7.8.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0915742f4c82208ebf47a2b154a5334155ed9ef9fe6190674b8a46c2fb89cb05"},
+ {file = "coverage-7.8.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a40fcf208e021eb14b0fac6bdb045c0e0cab53105f93ba0d03fd934c956143a"},
+ {file = "coverage-7.8.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a1f406a8e0995d654b2ad87c62caf6befa767885301f3b8f6f73e6f3c31ec3a6"},
+ {file = "coverage-7.8.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:77af0f6447a582fdc7de5e06fa3757a3ef87769fbb0fdbdeba78c23049140a47"},
+ {file = "coverage-7.8.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:f2d32f95922927186c6dbc8bc60df0d186b6edb828d299ab10898ef3f40052fe"},
+ {file = "coverage-7.8.0-cp312-cp312-win32.whl", hash = "sha256:769773614e676f9d8e8a0980dd7740f09a6ea386d0f383db6821df07d0f08545"},
+ {file = "coverage-7.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:e5d2b9be5b0693cf21eb4ce0ec8d211efb43966f6657807f6859aab3814f946b"},
+ {file = "coverage-7.8.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5ac46d0c2dd5820ce93943a501ac5f6548ea81594777ca585bf002aa8854cacd"},
+ {file = "coverage-7.8.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:771eb7587a0563ca5bb6f622b9ed7f9d07bd08900f7589b4febff05f469bea00"},
+ {file = "coverage-7.8.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42421e04069fb2cbcbca5a696c4050b84a43b05392679d4068acbe65449b5c64"},
+ {file = "coverage-7.8.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:554fec1199d93ab30adaa751db68acec2b41c5602ac944bb19187cb9a41a8067"},
+ {file = "coverage-7.8.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5aaeb00761f985007b38cf463b1d160a14a22c34eb3f6a39d9ad6fc27cb73008"},
+ {file = "coverage-7.8.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:581a40c7b94921fffd6457ffe532259813fc68eb2bdda60fa8cc343414ce3733"},
+ {file = "coverage-7.8.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:f319bae0321bc838e205bf9e5bc28f0a3165f30c203b610f17ab5552cff90323"},
+ {file = "coverage-7.8.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04bfec25a8ef1c5f41f5e7e5c842f6b615599ca8ba8391ec33a9290d9d2db3a3"},
+ {file = "coverage-7.8.0-cp313-cp313-win32.whl", hash = "sha256:dd19608788b50eed889e13a5d71d832edc34fc9dfce606f66e8f9f917eef910d"},
+ {file = "coverage-7.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:a9abbccd778d98e9c7e85038e35e91e67f5b520776781d9a1e2ee9d400869487"},
+ {file = "coverage-7.8.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:18c5ae6d061ad5b3e7eef4363fb27a0576012a7447af48be6c75b88494c6cf25"},
+ {file = "coverage-7.8.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:95aa6ae391a22bbbce1b77ddac846c98c5473de0372ba5c463480043a07bff42"},
+ {file = "coverage-7.8.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e013b07ba1c748dacc2a80e69a46286ff145935f260eb8c72df7185bf048f502"},
+ {file = "coverage-7.8.0-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d766a4f0e5aa1ba056ec3496243150698dc0481902e2b8559314368717be82b1"},
+ {file = "coverage-7.8.0-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad80e6b4a0c3cb6f10f29ae4c60e991f424e6b14219d46f1e7d442b938ee68a4"},
+ {file = "coverage-7.8.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b87eb6fc9e1bb8f98892a2458781348fa37e6925f35bb6ceb9d4afd54ba36c73"},
+ {file = "coverage-7.8.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:d1ba00ae33be84066cfbe7361d4e04dec78445b2b88bdb734d0d1cbab916025a"},
+ {file = "coverage-7.8.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f3c38e4e5ccbdc9198aecc766cedbb134b2d89bf64533973678dfcf07effd883"},
+ {file = "coverage-7.8.0-cp313-cp313t-win32.whl", hash = "sha256:379fe315e206b14e21db5240f89dc0774bdd3e25c3c58c2c733c99eca96f1ada"},
+ {file = "coverage-7.8.0-cp313-cp313t-win_amd64.whl", hash = "sha256:2e4b6b87bb0c846a9315e3ab4be2d52fac905100565f4b92f02c445c8799e257"},
+ {file = "coverage-7.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fa260de59dfb143af06dcf30c2be0b200bed2a73737a8a59248fcb9fa601ef0f"},
+ {file = "coverage-7.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:96121edfa4c2dfdda409877ea8608dd01de816a4dc4a0523356067b305e4e17a"},
+ {file = "coverage-7.8.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b8af63b9afa1031c0ef05b217faa598f3069148eeee6bb24b79da9012423b82"},
+ {file = "coverage-7.8.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:89b1f4af0d4afe495cd4787a68e00f30f1d15939f550e869de90a86efa7e0814"},
+ {file = "coverage-7.8.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94ec0be97723ae72d63d3aa41961a0b9a6f5a53ff599813c324548d18e3b9e8c"},
+ {file = "coverage-7.8.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:8a1d96e780bdb2d0cbb297325711701f7c0b6f89199a57f2049e90064c29f6bd"},
+ {file = "coverage-7.8.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:f1d8a2a57b47142b10374902777e798784abf400a004b14f1b0b9eaf1e528ba4"},
+ {file = "coverage-7.8.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:cf60dd2696b457b710dd40bf17ad269d5f5457b96442f7f85722bdb16fa6c899"},
+ {file = "coverage-7.8.0-cp39-cp39-win32.whl", hash = "sha256:be945402e03de47ba1872cd5236395e0f4ad635526185a930735f66710e1bd3f"},
+ {file = "coverage-7.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:90e7fbc6216ecaffa5a880cdc9c77b7418c1dcb166166b78dbc630d07f278cc3"},
+ {file = "coverage-7.8.0-pp39.pp310.pp311-none-any.whl", hash = "sha256:b8194fb8e50d556d5849753de991d390c5a1edeeba50f68e3a9253fbd8bf8ccd"},
+ {file = "coverage-7.8.0-py3-none-any.whl", hash = "sha256:dbf364b4c5e7bae9250528167dfe40219b62e2d573c854d74be213e1e52069f7"},
+ {file = "coverage-7.8.0.tar.gz", hash = "sha256:7a3d62b3b03b4b6fd41a085f3574874cf946cb4604d2b4d3e8dca8cd570ca501"},
]
[package.dependencies]
@@ -376,112 +383,150 @@ tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.1
toml = ["tomli"]
[[package]]
-name = "filelock"
-version = "3.12.2"
-description = "A platform independent file lock."
+name = "exceptiongroup"
+version = "1.2.2"
+description = "Backport of PEP 654 (exception groups)"
optional = false
python-versions = ">=3.7"
+groups = ["dev"]
+markers = "python_version < \"3.11\""
files = [
- {file = "filelock-3.12.2-py3-none-any.whl", hash = "sha256:cbb791cdea2a72f23da6ac5b5269ab0a0d161e9ef0100e653b69049a7706d1ec"},
- {file = "filelock-3.12.2.tar.gz", hash = "sha256:002740518d8aa59a26b0c76e10fb8c6e15eae825d34b6fdf670333fd7b938d81"},
+ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"},
+ {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"},
]
[package.extras]
-docs = ["furo (>=2023.5.20)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"]
-testing = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "diff-cover (>=7.5)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"]
+test = ["pytest (>=6)"]
+
+[[package]]
+name = "filelock"
+version = "3.18.0"
+description = "A platform independent file lock."
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "filelock-3.18.0-py3-none-any.whl", hash = "sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de"},
+ {file = "filelock-3.18.0.tar.gz", hash = "sha256:adbc88eabb99d2fec8c9c1b229b171f18afa655400173ddc653d5d01501fb9f2"},
+]
+
+[package.extras]
+docs = ["furo (>=2024.8.6)", "sphinx (>=8.1.3)", "sphinx-autodoc-typehints (>=3)"]
+testing = ["covdefaults (>=2.3)", "coverage (>=7.6.10)", "diff-cover (>=9.2.1)", "pytest (>=8.3.4)", "pytest-asyncio (>=0.25.2)", "pytest-cov (>=6)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.28.1)"]
+typing = ["typing-extensions (>=4.12.2)"]
[[package]]
name = "frozenlist"
-version = "1.3.3"
+version = "1.5.0"
description = "A list-like structure which implements collections.abc.MutableSequence"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
+groups = ["main"]
files = [
- {file = "frozenlist-1.3.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ff8bf625fe85e119553b5383ba0fb6aa3d0ec2ae980295aaefa552374926b3f4"},
- {file = "frozenlist-1.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dfbac4c2dfcc082fcf8d942d1e49b6aa0766c19d3358bd86e2000bf0fa4a9cf0"},
- {file = "frozenlist-1.3.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b1c63e8d377d039ac769cd0926558bb7068a1f7abb0f003e3717ee003ad85530"},
- {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7fdfc24dcfce5b48109867c13b4cb15e4660e7bd7661741a391f821f23dfdca7"},
- {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2c926450857408e42f0bbc295e84395722ce74bae69a3b2aa2a65fe22cb14b99"},
- {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1841e200fdafc3d51f974d9d377c079a0694a8f06de2e67b48150328d66d5483"},
- {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f470c92737afa7d4c3aacc001e335062d582053d4dbe73cda126f2d7031068dd"},
- {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:783263a4eaad7c49983fe4b2e7b53fa9770c136c270d2d4bbb6d2192bf4d9caf"},
- {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:924620eef691990dfb56dc4709f280f40baee568c794b5c1885800c3ecc69816"},
- {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ae4dc05c465a08a866b7a1baf360747078b362e6a6dbeb0c57f234db0ef88ae0"},
- {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:bed331fe18f58d844d39ceb398b77d6ac0b010d571cba8267c2e7165806b00ce"},
- {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:02c9ac843e3390826a265e331105efeab489ffaf4dd86384595ee8ce6d35ae7f"},
- {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9545a33965d0d377b0bc823dcabf26980e77f1b6a7caa368a365a9497fb09420"},
- {file = "frozenlist-1.3.3-cp310-cp310-win32.whl", hash = "sha256:d5cd3ab21acbdb414bb6c31958d7b06b85eeb40f66463c264a9b343a4e238642"},
- {file = "frozenlist-1.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:b756072364347cb6aa5b60f9bc18e94b2f79632de3b0190253ad770c5df17db1"},
- {file = "frozenlist-1.3.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:b4395e2f8d83fbe0c627b2b696acce67868793d7d9750e90e39592b3626691b7"},
- {file = "frozenlist-1.3.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:14143ae966a6229350021384870458e4777d1eae4c28d1a7aa47f24d030e6678"},
- {file = "frozenlist-1.3.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5d8860749e813a6f65bad8285a0520607c9500caa23fea6ee407e63debcdbef6"},
- {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23d16d9f477bb55b6154654e0e74557040575d9d19fe78a161bd33d7d76808e8"},
- {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eb82dbba47a8318e75f679690190c10a5e1f447fbf9df41cbc4c3afd726d88cb"},
- {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9309869032abb23d196cb4e4db574232abe8b8be1339026f489eeb34a4acfd91"},
- {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a97b4fe50b5890d36300820abd305694cb865ddb7885049587a5678215782a6b"},
- {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c188512b43542b1e91cadc3c6c915a82a5eb95929134faf7fd109f14f9892ce4"},
- {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:303e04d422e9b911a09ad499b0368dc551e8c3cd15293c99160c7f1f07b59a48"},
- {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:0771aed7f596c7d73444c847a1c16288937ef988dc04fb9f7be4b2aa91db609d"},
- {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:66080ec69883597e4d026f2f71a231a1ee9887835902dbe6b6467d5a89216cf6"},
- {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:41fe21dc74ad3a779c3d73a2786bdf622ea81234bdd4faf90b8b03cad0c2c0b4"},
- {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f20380df709d91525e4bee04746ba612a4df0972c1b8f8e1e8af997e678c7b81"},
- {file = "frozenlist-1.3.3-cp311-cp311-win32.whl", hash = "sha256:f30f1928162e189091cf4d9da2eac617bfe78ef907a761614ff577ef4edfb3c8"},
- {file = "frozenlist-1.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:a6394d7dadd3cfe3f4b3b186e54d5d8504d44f2d58dcc89d693698e8b7132b32"},
- {file = "frozenlist-1.3.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8df3de3a9ab8325f94f646609a66cbeeede263910c5c0de0101079ad541af332"},
- {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0693c609e9742c66ba4870bcee1ad5ff35462d5ffec18710b4ac89337ff16e27"},
- {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cd4210baef299717db0a600d7a3cac81d46ef0e007f88c9335db79f8979c0d3d"},
- {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:394c9c242113bfb4b9aa36e2b80a05ffa163a30691c7b5a29eba82e937895d5e"},
- {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6327eb8e419f7d9c38f333cde41b9ae348bec26d840927332f17e887a8dcb70d"},
- {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e24900aa13212e75e5b366cb9065e78bbf3893d4baab6052d1aca10d46d944c"},
- {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:3843f84a6c465a36559161e6c59dce2f2ac10943040c2fd021cfb70d58c4ad56"},
- {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:84610c1502b2461255b4c9b7d5e9c48052601a8957cd0aea6ec7a7a1e1fb9420"},
- {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:c21b9aa40e08e4f63a2f92ff3748e6b6c84d717d033c7b3438dd3123ee18f70e"},
- {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:efce6ae830831ab6a22b9b4091d411698145cb9b8fc869e1397ccf4b4b6455cb"},
- {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:40de71985e9042ca00b7953c4f41eabc3dc514a2d1ff534027f091bc74416401"},
- {file = "frozenlist-1.3.3-cp37-cp37m-win32.whl", hash = "sha256:180c00c66bde6146a860cbb81b54ee0df350d2daf13ca85b275123bbf85de18a"},
- {file = "frozenlist-1.3.3-cp37-cp37m-win_amd64.whl", hash = "sha256:9bbbcedd75acdfecf2159663b87f1bb5cfc80e7cd99f7ddd9d66eb98b14a8411"},
- {file = "frozenlist-1.3.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:034a5c08d36649591be1cbb10e09da9f531034acfe29275fc5454a3b101ce41a"},
- {file = "frozenlist-1.3.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ba64dc2b3b7b158c6660d49cdb1d872d1d0bf4e42043ad8d5006099479a194e5"},
- {file = "frozenlist-1.3.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:47df36a9fe24054b950bbc2db630d508cca3aa27ed0566c0baf661225e52c18e"},
- {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:008a054b75d77c995ea26629ab3a0c0d7281341f2fa7e1e85fa6153ae29ae99c"},
- {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:841ea19b43d438a80b4de62ac6ab21cfe6827bb8a9dc62b896acc88eaf9cecba"},
- {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e235688f42b36be2b6b06fc37ac2126a73b75fb8d6bc66dd632aa35286238703"},
- {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca713d4af15bae6e5d79b15c10c8522859a9a89d3b361a50b817c98c2fb402a2"},
- {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ac5995f2b408017b0be26d4a1d7c61bce106ff3d9e3324374d66b5964325448"},
- {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:a4ae8135b11652b08a8baf07631d3ebfe65a4c87909dbef5fa0cdde440444ee4"},
- {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:4ea42116ceb6bb16dbb7d526e242cb6747b08b7710d9782aa3d6732bd8d27649"},
- {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:810860bb4bdce7557bc0febb84bbd88198b9dbc2022d8eebe5b3590b2ad6c842"},
- {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:ee78feb9d293c323b59a6f2dd441b63339a30edf35abcb51187d2fc26e696d13"},
- {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:0af2e7c87d35b38732e810befb9d797a99279cbb85374d42ea61c1e9d23094b3"},
- {file = "frozenlist-1.3.3-cp38-cp38-win32.whl", hash = "sha256:899c5e1928eec13fd6f6d8dc51be23f0d09c5281e40d9cf4273d188d9feeaf9b"},
- {file = "frozenlist-1.3.3-cp38-cp38-win_amd64.whl", hash = "sha256:7f44e24fa70f6fbc74aeec3e971f60a14dde85da364aa87f15d1be94ae75aeef"},
- {file = "frozenlist-1.3.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:2b07ae0c1edaa0a36339ec6cce700f51b14a3fc6545fdd32930d2c83917332cf"},
- {file = "frozenlist-1.3.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ebb86518203e12e96af765ee89034a1dbb0c3c65052d1b0c19bbbd6af8a145e1"},
- {file = "frozenlist-1.3.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5cf820485f1b4c91e0417ea0afd41ce5cf5965011b3c22c400f6d144296ccbc0"},
- {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c11e43016b9024240212d2a65043b70ed8dfd3b52678a1271972702d990ac6d"},
- {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8fa3c6e3305aa1146b59a09b32b2e04074945ffcfb2f0931836d103a2c38f936"},
- {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:352bd4c8c72d508778cf05ab491f6ef36149f4d0cb3c56b1b4302852255d05d5"},
- {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:65a5e4d3aa679610ac6e3569e865425b23b372277f89b5ef06cf2cdaf1ebf22b"},
- {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1e2c1185858d7e10ff045c496bbf90ae752c28b365fef2c09cf0fa309291669"},
- {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f163d2fd041c630fed01bc48d28c3ed4a3b003c00acd396900e11ee5316b56bb"},
- {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:05cdb16d09a0832eedf770cb7bd1fe57d8cf4eaf5aced29c4e41e3f20b30a784"},
- {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:8bae29d60768bfa8fb92244b74502b18fae55a80eac13c88eb0b496d4268fd2d"},
- {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:eedab4c310c0299961ac285591acd53dc6723a1ebd90a57207c71f6e0c2153ab"},
- {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3bbdf44855ed8f0fbcd102ef05ec3012d6a4fd7c7562403f76ce6a52aeffb2b1"},
- {file = "frozenlist-1.3.3-cp39-cp39-win32.whl", hash = "sha256:efa568b885bca461f7c7b9e032655c0c143d305bf01c30caf6db2854a4532b38"},
- {file = "frozenlist-1.3.3-cp39-cp39-win_amd64.whl", hash = "sha256:cfe33efc9cb900a4c46f91a5ceba26d6df370ffddd9ca386eb1d4f0ad97b9ea9"},
- {file = "frozenlist-1.3.3.tar.gz", hash = "sha256:58bcc55721e8a90b88332d6cd441261ebb22342e238296bb330968952fbb3a6a"},
+ {file = "frozenlist-1.5.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5b6a66c18b5b9dd261ca98dffcb826a525334b2f29e7caa54e182255c5f6a65a"},
+ {file = "frozenlist-1.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d1b3eb7b05ea246510b43a7e53ed1653e55c2121019a97e60cad7efb881a97bb"},
+ {file = "frozenlist-1.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:15538c0cbf0e4fa11d1e3a71f823524b0c46299aed6e10ebb4c2089abd8c3bec"},
+ {file = "frozenlist-1.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e79225373c317ff1e35f210dd5f1344ff31066ba8067c307ab60254cd3a78ad5"},
+ {file = "frozenlist-1.5.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9272fa73ca71266702c4c3e2d4a28553ea03418e591e377a03b8e3659d94fa76"},
+ {file = "frozenlist-1.5.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:498524025a5b8ba81695761d78c8dd7382ac0b052f34e66939c42df860b8ff17"},
+ {file = "frozenlist-1.5.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:92b5278ed9d50fe610185ecd23c55d8b307d75ca18e94c0e7de328089ac5dcba"},
+ {file = "frozenlist-1.5.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f3c8c1dacd037df16e85227bac13cca58c30da836c6f936ba1df0c05d046d8d"},
+ {file = "frozenlist-1.5.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f2ac49a9bedb996086057b75bf93538240538c6d9b38e57c82d51f75a73409d2"},
+ {file = "frozenlist-1.5.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e66cc454f97053b79c2ab09c17fbe3c825ea6b4de20baf1be28919460dd7877f"},
+ {file = "frozenlist-1.5.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:5a3ba5f9a0dfed20337d3e966dc359784c9f96503674c2faf015f7fe8e96798c"},
+ {file = "frozenlist-1.5.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:6321899477db90bdeb9299ac3627a6a53c7399c8cd58d25da094007402b039ab"},
+ {file = "frozenlist-1.5.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:76e4753701248476e6286f2ef492af900ea67d9706a0155335a40ea21bf3b2f5"},
+ {file = "frozenlist-1.5.0-cp310-cp310-win32.whl", hash = "sha256:977701c081c0241d0955c9586ffdd9ce44f7a7795df39b9151cd9a6fd0ce4cfb"},
+ {file = "frozenlist-1.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:189f03b53e64144f90990d29a27ec4f7997d91ed3d01b51fa39d2dbe77540fd4"},
+ {file = "frozenlist-1.5.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:fd74520371c3c4175142d02a976aee0b4cb4a7cc912a60586ffd8d5929979b30"},
+ {file = "frozenlist-1.5.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2f3f7a0fbc219fb4455264cae4d9f01ad41ae6ee8524500f381de64ffaa077d5"},
+ {file = "frozenlist-1.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f47c9c9028f55a04ac254346e92977bf0f166c483c74b4232bee19a6697e4778"},
+ {file = "frozenlist-1.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0996c66760924da6e88922756d99b47512a71cfd45215f3570bf1e0b694c206a"},
+ {file = "frozenlist-1.5.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a2fe128eb4edeabe11896cb6af88fca5346059f6c8d807e3b910069f39157869"},
+ {file = "frozenlist-1.5.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1a8ea951bbb6cacd492e3948b8da8c502a3f814f5d20935aae74b5df2b19cf3d"},
+ {file = "frozenlist-1.5.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de537c11e4aa01d37db0d403b57bd6f0546e71a82347a97c6a9f0dcc532b3a45"},
+ {file = "frozenlist-1.5.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c2623347b933fcb9095841f1cc5d4ff0b278addd743e0e966cb3d460278840d"},
+ {file = "frozenlist-1.5.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:cee6798eaf8b1416ef6909b06f7dc04b60755206bddc599f52232606e18179d3"},
+ {file = "frozenlist-1.5.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:f5f9da7f5dbc00a604fe74aa02ae7c98bcede8a3b8b9666f9f86fc13993bc71a"},
+ {file = "frozenlist-1.5.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:90646abbc7a5d5c7c19461d2e3eeb76eb0b204919e6ece342feb6032c9325ae9"},
+ {file = "frozenlist-1.5.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:bdac3c7d9b705d253b2ce370fde941836a5f8b3c5c2b8fd70940a3ea3af7f4f2"},
+ {file = "frozenlist-1.5.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:03d33c2ddbc1816237a67f66336616416e2bbb6beb306e5f890f2eb22b959cdf"},
+ {file = "frozenlist-1.5.0-cp311-cp311-win32.whl", hash = "sha256:237f6b23ee0f44066219dae14c70ae38a63f0440ce6750f868ee08775073f942"},
+ {file = "frozenlist-1.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:0cc974cc93d32c42e7b0f6cf242a6bd941c57c61b618e78b6c0a96cb72788c1d"},
+ {file = "frozenlist-1.5.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:31115ba75889723431aa9a4e77d5f398f5cf976eea3bdf61749731f62d4a4a21"},
+ {file = "frozenlist-1.5.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7437601c4d89d070eac8323f121fcf25f88674627505334654fd027b091db09d"},
+ {file = "frozenlist-1.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7948140d9f8ece1745be806f2bfdf390127cf1a763b925c4a805c603df5e697e"},
+ {file = "frozenlist-1.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:feeb64bc9bcc6b45c6311c9e9b99406660a9c05ca8a5b30d14a78555088b0b3a"},
+ {file = "frozenlist-1.5.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:683173d371daad49cffb8309779e886e59c2f369430ad28fe715f66d08d4ab1a"},
+ {file = "frozenlist-1.5.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7d57d8f702221405a9d9b40f9da8ac2e4a1a8b5285aac6100f3393675f0a85ee"},
+ {file = "frozenlist-1.5.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:30c72000fbcc35b129cb09956836c7d7abf78ab5416595e4857d1cae8d6251a6"},
+ {file = "frozenlist-1.5.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:000a77d6034fbad9b6bb880f7ec073027908f1b40254b5d6f26210d2dab1240e"},
+ {file = "frozenlist-1.5.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5d7f5a50342475962eb18b740f3beecc685a15b52c91f7d975257e13e029eca9"},
+ {file = "frozenlist-1.5.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:87f724d055eb4785d9be84e9ebf0f24e392ddfad00b3fe036e43f489fafc9039"},
+ {file = "frozenlist-1.5.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:6e9080bb2fb195a046e5177f10d9d82b8a204c0736a97a153c2466127de87784"},
+ {file = "frozenlist-1.5.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:9b93d7aaa36c966fa42efcaf716e6b3900438632a626fb09c049f6a2f09fc631"},
+ {file = "frozenlist-1.5.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:52ef692a4bc60a6dd57f507429636c2af8b6046db8b31b18dac02cbc8f507f7f"},
+ {file = "frozenlist-1.5.0-cp312-cp312-win32.whl", hash = "sha256:29d94c256679247b33a3dc96cce0f93cbc69c23bf75ff715919332fdbb6a32b8"},
+ {file = "frozenlist-1.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:8969190d709e7c48ea386db202d708eb94bdb29207a1f269bab1196ce0dcca1f"},
+ {file = "frozenlist-1.5.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:7a1a048f9215c90973402e26c01d1cff8a209e1f1b53f72b95c13db61b00f953"},
+ {file = "frozenlist-1.5.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:dd47a5181ce5fcb463b5d9e17ecfdb02b678cca31280639255ce9d0e5aa67af0"},
+ {file = "frozenlist-1.5.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1431d60b36d15cda188ea222033eec8e0eab488f39a272461f2e6d9e1a8e63c2"},
+ {file = "frozenlist-1.5.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6482a5851f5d72767fbd0e507e80737f9c8646ae7fd303def99bfe813f76cf7f"},
+ {file = "frozenlist-1.5.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:44c49271a937625619e862baacbd037a7ef86dd1ee215afc298a417ff3270608"},
+ {file = "frozenlist-1.5.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:12f78f98c2f1c2429d42e6a485f433722b0061d5c0b0139efa64f396efb5886b"},
+ {file = "frozenlist-1.5.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce3aa154c452d2467487765e3adc730a8c153af77ad84096bc19ce19a2400840"},
+ {file = "frozenlist-1.5.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9b7dc0c4338e6b8b091e8faf0db3168a37101943e687f373dce00959583f7439"},
+ {file = "frozenlist-1.5.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:45e0896250900b5aa25180f9aec243e84e92ac84bd4a74d9ad4138ef3f5c97de"},
+ {file = "frozenlist-1.5.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:561eb1c9579d495fddb6da8959fd2a1fca2c6d060d4113f5844b433fc02f2641"},
+ {file = "frozenlist-1.5.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:df6e2f325bfee1f49f81aaac97d2aa757c7646534a06f8f577ce184afe2f0a9e"},
+ {file = "frozenlist-1.5.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:140228863501b44b809fb39ec56b5d4071f4d0aa6d216c19cbb08b8c5a7eadb9"},
+ {file = "frozenlist-1.5.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:7707a25d6a77f5d27ea7dc7d1fc608aa0a478193823f88511ef5e6b8a48f9d03"},
+ {file = "frozenlist-1.5.0-cp313-cp313-win32.whl", hash = "sha256:31a9ac2b38ab9b5a8933b693db4939764ad3f299fcaa931a3e605bc3460e693c"},
+ {file = "frozenlist-1.5.0-cp313-cp313-win_amd64.whl", hash = "sha256:11aabdd62b8b9c4b84081a3c246506d1cddd2dd93ff0ad53ede5defec7886b28"},
+ {file = "frozenlist-1.5.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:dd94994fc91a6177bfaafd7d9fd951bc8689b0a98168aa26b5f543868548d3ca"},
+ {file = "frozenlist-1.5.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2d0da8bbec082bf6bf18345b180958775363588678f64998c2b7609e34719b10"},
+ {file = "frozenlist-1.5.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:73f2e31ea8dd7df61a359b731716018c2be196e5bb3b74ddba107f694fbd7604"},
+ {file = "frozenlist-1.5.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:828afae9f17e6de596825cf4228ff28fbdf6065974e5ac1410cecc22f699d2b3"},
+ {file = "frozenlist-1.5.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f1577515d35ed5649d52ab4319db757bb881ce3b2b796d7283e6634d99ace307"},
+ {file = "frozenlist-1.5.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2150cc6305a2c2ab33299453e2968611dacb970d2283a14955923062c8d00b10"},
+ {file = "frozenlist-1.5.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a72b7a6e3cd2725eff67cd64c8f13335ee18fc3c7befc05aed043d24c7b9ccb9"},
+ {file = "frozenlist-1.5.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c16d2fa63e0800723139137d667e1056bee1a1cf7965153d2d104b62855e9b99"},
+ {file = "frozenlist-1.5.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:17dcc32fc7bda7ce5875435003220a457bcfa34ab7924a49a1c19f55b6ee185c"},
+ {file = "frozenlist-1.5.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:97160e245ea33d8609cd2b8fd997c850b56db147a304a262abc2b3be021a9171"},
+ {file = "frozenlist-1.5.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:f1e6540b7fa044eee0bb5111ada694cf3dc15f2b0347ca125ee9ca984d5e9e6e"},
+ {file = "frozenlist-1.5.0-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:91d6c171862df0a6c61479d9724f22efb6109111017c87567cfeb7b5d1449fdf"},
+ {file = "frozenlist-1.5.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c1fac3e2ace2eb1052e9f7c7db480818371134410e1f5c55d65e8f3ac6d1407e"},
+ {file = "frozenlist-1.5.0-cp38-cp38-win32.whl", hash = "sha256:b97f7b575ab4a8af9b7bc1d2ef7f29d3afee2226bd03ca3875c16451ad5a7723"},
+ {file = "frozenlist-1.5.0-cp38-cp38-win_amd64.whl", hash = "sha256:374ca2dabdccad8e2a76d40b1d037f5bd16824933bf7bcea3e59c891fd4a0923"},
+ {file = "frozenlist-1.5.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:9bbcdfaf4af7ce002694a4e10a0159d5a8d20056a12b05b45cea944a4953f972"},
+ {file = "frozenlist-1.5.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1893f948bf6681733aaccf36c5232c231e3b5166d607c5fa77773611df6dc336"},
+ {file = "frozenlist-1.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2b5e23253bb709ef57a8e95e6ae48daa9ac5f265637529e4ce6b003a37b2621f"},
+ {file = "frozenlist-1.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f253985bb515ecd89629db13cb58d702035ecd8cfbca7d7a7e29a0e6d39af5f"},
+ {file = "frozenlist-1.5.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:04a5c6babd5e8fb7d3c871dc8b321166b80e41b637c31a995ed844a6139942b6"},
+ {file = "frozenlist-1.5.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a9fe0f1c29ba24ba6ff6abf688cb0b7cf1efab6b6aa6adc55441773c252f7411"},
+ {file = "frozenlist-1.5.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:226d72559fa19babe2ccd920273e767c96a49b9d3d38badd7c91a0fdeda8ea08"},
+ {file = "frozenlist-1.5.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15b731db116ab3aedec558573c1a5eec78822b32292fe4f2f0345b7f697745c2"},
+ {file = "frozenlist-1.5.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:366d8f93e3edfe5a918c874702f78faac300209a4d5bf38352b2c1bdc07a766d"},
+ {file = "frozenlist-1.5.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:1b96af8c582b94d381a1c1f51ffaedeb77c821c690ea5f01da3d70a487dd0a9b"},
+ {file = "frozenlist-1.5.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:c03eff4a41bd4e38415cbed054bbaff4a075b093e2394b6915dca34a40d1e38b"},
+ {file = "frozenlist-1.5.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:50cf5e7ee9b98f22bdecbabf3800ae78ddcc26e4a435515fc72d97903e8488e0"},
+ {file = "frozenlist-1.5.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1e76bfbc72353269c44e0bc2cfe171900fbf7f722ad74c9a7b638052afe6a00c"},
+ {file = "frozenlist-1.5.0-cp39-cp39-win32.whl", hash = "sha256:666534d15ba8f0fda3f53969117383d5dc021266b3c1a42c9ec4855e4b58b9d3"},
+ {file = "frozenlist-1.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:5c28f4b5dbef8a0d8aad0d4de24d1e9e981728628afaf4ea0792f5d0939372f0"},
+ {file = "frozenlist-1.5.0-py3-none-any.whl", hash = "sha256:d994863bba198a4a518b467bb971c56e1db3f180a25c6cf7bb1949c267f748c3"},
+ {file = "frozenlist-1.5.0.tar.gz", hash = "sha256:81d5af29e61b9c8348e876d442253723928dce6433e0e76cd925cd83f1b4b817"},
]
[[package]]
name = "fsspec"
-version = "2023.1.0"
+version = "2025.3.2"
description = "File-system specification"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.9"
+groups = ["main"]
files = [
- {file = "fsspec-2023.1.0-py3-none-any.whl", hash = "sha256:b833e2e541e9e8cde0ab549414187871243177feb3d344f9d27b25a93f5d8139"},
- {file = "fsspec-2023.1.0.tar.gz", hash = "sha256:fbae7f20ff801eb5f7d0bedf81f25c787c0dfac5e982d98fa3884a9cde2b5411"},
+ {file = "fsspec-2025.3.2-py3-none-any.whl", hash = "sha256:2daf8dc3d1dfa65b6aa37748d112773a7a08416f6c70d96b264c96476ecaf711"},
+ {file = "fsspec-2025.3.2.tar.gz", hash = "sha256:e52c77ef398680bbd6a98c0e628fbc469491282981209907bbc8aea76a04fdc6"},
]
[package.extras]
@@ -489,8 +534,10 @@ abfs = ["adlfs"]
adl = ["adlfs"]
arrow = ["pyarrow (>=1)"]
dask = ["dask", "distributed"]
+dev = ["pre-commit", "ruff"]
+doc = ["numpydoc", "sphinx", "sphinx-design", "sphinx-rtd-theme", "yarl"]
dropbox = ["dropbox", "dropboxdrivefs", "requests"]
-entrypoints = ["importlib-metadata"]
+full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"]
fuse = ["fusepy"]
gcs = ["gcsfs"]
git = ["pygit2"]
@@ -498,30 +545,33 @@ github = ["requests"]
gs = ["gcsfs"]
gui = ["panel"]
hdfs = ["pyarrow (>=1)"]
-http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "requests"]
+http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)"]
libarchive = ["libarchive-c"]
oci = ["ocifs"]
s3 = ["s3fs"]
sftp = ["paramiko"]
smb = ["smbprotocol"]
ssh = ["paramiko"]
+test = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "numpy", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "requests"]
+test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask[dataframe,test]", "moto[server] (>4,<5)", "pytest-timeout", "xarray"]
+test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"]
tqdm = ["tqdm"]
[[package]]
name = "huggingface-hub"
-version = "0.16.4"
+version = "0.30.2"
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
optional = false
-python-versions = ">=3.7.0"
+python-versions = ">=3.8.0"
+groups = ["main"]
files = [
- {file = "huggingface_hub-0.16.4-py3-none-any.whl", hash = "sha256:0d3df29932f334fead024afc7cb4cc5149d955238b8b5e42dcf9740d6995a349"},
- {file = "huggingface_hub-0.16.4.tar.gz", hash = "sha256:608c7d4f3d368b326d1747f91523dbd1f692871e8e2e7a4750314a2dd8b63e14"},
+ {file = "huggingface_hub-0.30.2-py3-none-any.whl", hash = "sha256:68ff05969927058cfa41df4f2155d4bb48f5f54f719dd0390103eefa9b191e28"},
+ {file = "huggingface_hub-0.30.2.tar.gz", hash = "sha256:9a7897c5b6fd9dad3168a794a8998d6378210f5b9688d0dfc180b1a228dc2466"},
]
[package.dependencies]
filelock = "*"
-fsspec = "*"
-importlib-metadata = {version = "*", markers = "python_version < \"3.8\""}
+fsspec = ">=2023.5.0"
packaging = ">=20.9"
pyyaml = ">=5.1"
requests = "*"
@@ -529,314 +579,429 @@ tqdm = ">=4.42.1"
typing-extensions = ">=3.7.4.3"
[package.extras]
-all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"]
+all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
cli = ["InquirerPy (==0.3.4)"]
-dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"]
+dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
-inference = ["aiohttp", "pydantic"]
-quality = ["black (>=23.1,<24.0)", "mypy (==0.982)", "ruff (>=0.0.241)"]
+hf-transfer = ["hf-transfer (>=0.1.4)"]
+hf-xet = ["hf-xet (>=0.1.4)"]
+inference = ["aiohttp"]
+quality = ["libcst (==1.4.0)", "mypy (==1.5.1)", "ruff (>=0.9.0)"]
tensorflow = ["graphviz", "pydot", "tensorflow"]
-testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
-torch = ["torch"]
-typing = ["pydantic", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
+tensorflow-testing = ["keras (<3.0)", "tensorflow"]
+testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
+torch = ["safetensors[torch]", "torch"]
+typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"]
[[package]]
name = "idna"
-version = "3.4"
+version = "3.10"
description = "Internationalized Domain Names in Applications (IDNA)"
optional = false
-python-versions = ">=3.5"
+python-versions = ">=3.6"
+groups = ["main"]
files = [
- {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"},
- {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"},
+ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"},
+ {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"},
]
-[[package]]
-name = "importlib-metadata"
-version = "6.7.0"
-description = "Read metadata from Python packages"
-optional = false
-python-versions = ">=3.7"
-files = [
- {file = "importlib_metadata-6.7.0-py3-none-any.whl", hash = "sha256:cb52082e659e97afc5dac71e79de97d8681de3aa07ff18578330904a9d18e5b5"},
- {file = "importlib_metadata-6.7.0.tar.gz", hash = "sha256:1aaf550d4f73e5d6783e7acb77aec43d49da8017410afae93822cc9cca98c4d4"},
-]
-
-[package.dependencies]
-typing-extensions = {version = ">=3.6.4", markers = "python_version < \"3.8\""}
-zipp = ">=0.5"
-
[package.extras]
-docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
-perf = ["ipython"]
-testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"]
+all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"]
[[package]]
name = "iniconfig"
-version = "2.0.0"
+version = "2.1.0"
description = "brain-dead simple config-ini parsing"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
+groups = ["dev"]
files = [
- {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
- {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
+ {file = "iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760"},
+ {file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"},
]
[[package]]
name = "multidict"
-version = "6.0.4"
+version = "6.4.3"
description = "multidict implementation"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.9"
+groups = ["main"]
files = [
- {file = "multidict-6.0.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b1a97283e0c85772d613878028fec909f003993e1007eafa715b24b377cb9b8"},
- {file = "multidict-6.0.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eeb6dcc05e911516ae3d1f207d4b0520d07f54484c49dfc294d6e7d63b734171"},
- {file = "multidict-6.0.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d6d635d5209b82a3492508cf5b365f3446afb65ae7ebd755e70e18f287b0adf7"},
- {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c048099e4c9e9d615545e2001d3d8a4380bd403e1a0578734e0d31703d1b0c0b"},
- {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea20853c6dbbb53ed34cb4d080382169b6f4554d394015f1bef35e881bf83547"},
- {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:16d232d4e5396c2efbbf4f6d4df89bfa905eb0d4dc5b3549d872ab898451f569"},
- {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36c63aaa167f6c6b04ef2c85704e93af16c11d20de1d133e39de6a0e84582a93"},
- {file = "multidict-6.0.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:64bdf1086b6043bf519869678f5f2757f473dee970d7abf6da91ec00acb9cb98"},
- {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:43644e38f42e3af682690876cff722d301ac585c5b9e1eacc013b7a3f7b696a0"},
- {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:7582a1d1030e15422262de9f58711774e02fa80df0d1578995c76214f6954988"},
- {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ddff9c4e225a63a5afab9dd15590432c22e8057e1a9a13d28ed128ecf047bbdc"},
- {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ee2a1ece51b9b9e7752e742cfb661d2a29e7bcdba2d27e66e28a99f1890e4fa0"},
- {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a2e4369eb3d47d2034032a26c7a80fcb21a2cb22e1173d761a162f11e562caa5"},
- {file = "multidict-6.0.4-cp310-cp310-win32.whl", hash = "sha256:574b7eae1ab267e5f8285f0fe881f17efe4b98c39a40858247720935b893bba8"},
- {file = "multidict-6.0.4-cp310-cp310-win_amd64.whl", hash = "sha256:4dcbb0906e38440fa3e325df2359ac6cb043df8e58c965bb45f4e406ecb162cc"},
- {file = "multidict-6.0.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:0dfad7a5a1e39c53ed00d2dd0c2e36aed4650936dc18fd9a1826a5ae1cad6f03"},
- {file = "multidict-6.0.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:64da238a09d6039e3bd39bb3aee9c21a5e34f28bfa5aa22518581f910ff94af3"},
- {file = "multidict-6.0.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ff959bee35038c4624250473988b24f846cbeb2c6639de3602c073f10410ceba"},
- {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:01a3a55bd90018c9c080fbb0b9f4891db37d148a0a18722b42f94694f8b6d4c9"},
- {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c5cb09abb18c1ea940fb99360ea0396f34d46566f157122c92dfa069d3e0e982"},
- {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:666daae833559deb2d609afa4490b85830ab0dfca811a98b70a205621a6109fe"},
- {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11bdf3f5e1518b24530b8241529d2050014c884cf18b6fc69c0c2b30ca248710"},
- {file = "multidict-6.0.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d18748f2d30f94f498e852c67d61261c643b349b9d2a581131725595c45ec6c"},
- {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:458f37be2d9e4c95e2d8866a851663cbc76e865b78395090786f6cd9b3bbf4f4"},
- {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:b1a2eeedcead3a41694130495593a559a668f382eee0727352b9a41e1c45759a"},
- {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7d6ae9d593ef8641544d6263c7fa6408cc90370c8cb2bbb65f8d43e5b0351d9c"},
- {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:5979b5632c3e3534e42ca6ff856bb24b2e3071b37861c2c727ce220d80eee9ed"},
- {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:dcfe792765fab89c365123c81046ad4103fcabbc4f56d1c1997e6715e8015461"},
- {file = "multidict-6.0.4-cp311-cp311-win32.whl", hash = "sha256:3601a3cece3819534b11d4efc1eb76047488fddd0c85a3948099d5da4d504636"},
- {file = "multidict-6.0.4-cp311-cp311-win_amd64.whl", hash = "sha256:81a4f0b34bd92df3da93315c6a59034df95866014ac08535fc819f043bfd51f0"},
- {file = "multidict-6.0.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:67040058f37a2a51ed8ea8f6b0e6ee5bd78ca67f169ce6122f3e2ec80dfe9b78"},
- {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:853888594621e6604c978ce2a0444a1e6e70c8d253ab65ba11657659dcc9100f"},
- {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:39ff62e7d0f26c248b15e364517a72932a611a9b75f35b45be078d81bdb86603"},
- {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af048912e045a2dc732847d33821a9d84ba553f5c5f028adbd364dd4765092ac"},
- {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1e8b901e607795ec06c9e42530788c45ac21ef3aaa11dbd0c69de543bfb79a9"},
- {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62501642008a8b9871ddfccbf83e4222cf8ac0d5aeedf73da36153ef2ec222d2"},
- {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:99b76c052e9f1bc0721f7541e5e8c05db3941eb9ebe7b8553c625ef88d6eefde"},
- {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:509eac6cf09c794aa27bcacfd4d62c885cce62bef7b2c3e8b2e49d365b5003fe"},
- {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:21a12c4eb6ddc9952c415f24eef97e3e55ba3af61f67c7bc388dcdec1404a067"},
- {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:5cad9430ab3e2e4fa4a2ef4450f548768400a2ac635841bc2a56a2052cdbeb87"},
- {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ab55edc2e84460694295f401215f4a58597f8f7c9466faec545093045476327d"},
- {file = "multidict-6.0.4-cp37-cp37m-win32.whl", hash = "sha256:5a4dcf02b908c3b8b17a45fb0f15b695bf117a67b76b7ad18b73cf8e92608775"},
- {file = "multidict-6.0.4-cp37-cp37m-win_amd64.whl", hash = "sha256:6ed5f161328b7df384d71b07317f4d8656434e34591f20552c7bcef27b0ab88e"},
- {file = "multidict-6.0.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5fc1b16f586f049820c5c5b17bb4ee7583092fa0d1c4e28b5239181ff9532e0c"},
- {file = "multidict-6.0.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1502e24330eb681bdaa3eb70d6358e818e8e8f908a22a1851dfd4e15bc2f8161"},
- {file = "multidict-6.0.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b692f419760c0e65d060959df05f2a531945af31fda0c8a3b3195d4efd06de11"},
- {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45e1ecb0379bfaab5eef059f50115b54571acfbe422a14f668fc8c27ba410e7e"},
- {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ddd3915998d93fbcd2566ddf9cf62cdb35c9e093075f862935573d265cf8f65d"},
- {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:59d43b61c59d82f2effb39a93c48b845efe23a3852d201ed2d24ba830d0b4cf2"},
- {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc8e1d0c705233c5dd0c5e6460fbad7827d5d36f310a0fadfd45cc3029762258"},
- {file = "multidict-6.0.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6aa0418fcc838522256761b3415822626f866758ee0bc6632c9486b179d0b52"},
- {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6748717bb10339c4760c1e63da040f5f29f5ed6e59d76daee30305894069a660"},
- {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:4d1a3d7ef5e96b1c9e92f973e43aa5e5b96c659c9bc3124acbbd81b0b9c8a951"},
- {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4372381634485bec7e46718edc71528024fcdc6f835baefe517b34a33c731d60"},
- {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:fc35cb4676846ef752816d5be2193a1e8367b4c1397b74a565a9d0389c433a1d"},
- {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:4b9d9e4e2b37daddb5c23ea33a3417901fa7c7b3dee2d855f63ee67a0b21e5b1"},
- {file = "multidict-6.0.4-cp38-cp38-win32.whl", hash = "sha256:e41b7e2b59679edfa309e8db64fdf22399eec4b0b24694e1b2104fb789207779"},
- {file = "multidict-6.0.4-cp38-cp38-win_amd64.whl", hash = "sha256:d6c254ba6e45d8e72739281ebc46ea5eb5f101234f3ce171f0e9f5cc86991480"},
- {file = "multidict-6.0.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:16ab77bbeb596e14212e7bab8429f24c1579234a3a462105cda4a66904998664"},
- {file = "multidict-6.0.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc779e9e6f7fda81b3f9aa58e3a6091d49ad528b11ed19f6621408806204ad35"},
- {file = "multidict-6.0.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4ceef517eca3e03c1cceb22030a3e39cb399ac86bff4e426d4fc6ae49052cc60"},
- {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:281af09f488903fde97923c7744bb001a9b23b039a909460d0f14edc7bf59706"},
- {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:52f2dffc8acaba9a2f27174c41c9e57f60b907bb9f096b36b1a1f3be71c6284d"},
- {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b41156839806aecb3641f3208c0dafd3ac7775b9c4c422d82ee2a45c34ba81ca"},
- {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5e3fc56f88cc98ef8139255cf8cd63eb2c586531e43310ff859d6bb3a6b51f1"},
- {file = "multidict-6.0.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8316a77808c501004802f9beebde51c9f857054a0c871bd6da8280e718444449"},
- {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f70b98cd94886b49d91170ef23ec5c0e8ebb6f242d734ed7ed677b24d50c82cf"},
- {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:bf6774e60d67a9efe02b3616fee22441d86fab4c6d335f9d2051d19d90a40063"},
- {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:e69924bfcdda39b722ef4d9aa762b2dd38e4632b3641b1d9a57ca9cd18f2f83a"},
- {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:6b181d8c23da913d4ff585afd1155a0e1194c0b50c54fcfe286f70cdaf2b7176"},
- {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:52509b5be062d9eafc8170e53026fbc54cf3b32759a23d07fd935fb04fc22d95"},
- {file = "multidict-6.0.4-cp39-cp39-win32.whl", hash = "sha256:27c523fbfbdfd19c6867af7346332b62b586eed663887392cff78d614f9ec313"},
- {file = "multidict-6.0.4-cp39-cp39-win_amd64.whl", hash = "sha256:33029f5734336aa0d4c0384525da0387ef89148dc7191aae00ca5fb23d7aafc2"},
- {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"},
+ {file = "multidict-6.4.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:32a998bd8a64ca48616eac5a8c1cc4fa38fb244a3facf2eeb14abe186e0f6cc5"},
+ {file = "multidict-6.4.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a54ec568f1fc7f3c313c2f3b16e5db346bf3660e1309746e7fccbbfded856188"},
+ {file = "multidict-6.4.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a7be07e5df178430621c716a63151165684d3e9958f2bbfcb644246162007ab7"},
+ {file = "multidict-6.4.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b128dbf1c939674a50dd0b28f12c244d90e5015e751a4f339a96c54f7275e291"},
+ {file = "multidict-6.4.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:b9cb19dfd83d35b6ff24a4022376ea6e45a2beba8ef3f0836b8a4b288b6ad685"},
+ {file = "multidict-6.4.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3cf62f8e447ea2c1395afa289b332e49e13d07435369b6f4e41f887db65b40bf"},
+ {file = "multidict-6.4.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:909f7d43ff8f13d1adccb6a397094adc369d4da794407f8dd592c51cf0eae4b1"},
+ {file = "multidict-6.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0bb8f8302fbc7122033df959e25777b0b7659b1fd6bcb9cb6bed76b5de67afef"},
+ {file = "multidict-6.4.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:224b79471b4f21169ea25ebc37ed6f058040c578e50ade532e2066562597b8a9"},
+ {file = "multidict-6.4.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a7bd27f7ab3204f16967a6f899b3e8e9eb3362c0ab91f2ee659e0345445e0078"},
+ {file = "multidict-6.4.3-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:99592bd3162e9c664671fd14e578a33bfdba487ea64bcb41d281286d3c870ad7"},
+ {file = "multidict-6.4.3-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:a62d78a1c9072949018cdb05d3c533924ef8ac9bcb06cbf96f6d14772c5cd451"},
+ {file = "multidict-6.4.3-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:3ccdde001578347e877ca4f629450973c510e88e8865d5aefbcb89b852ccc666"},
+ {file = "multidict-6.4.3-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:eccb67b0e78aa2e38a04c5ecc13bab325a43e5159a181a9d1a6723db913cbb3c"},
+ {file = "multidict-6.4.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8b6fcf6054fc4114a27aa865f8840ef3d675f9316e81868e0ad5866184a6cba5"},
+ {file = "multidict-6.4.3-cp310-cp310-win32.whl", hash = "sha256:f92c7f62d59373cd93bc9969d2da9b4b21f78283b1379ba012f7ee8127b3152e"},
+ {file = "multidict-6.4.3-cp310-cp310-win_amd64.whl", hash = "sha256:b57e28dbc031d13916b946719f213c494a517b442d7b48b29443e79610acd887"},
+ {file = "multidict-6.4.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f6f19170197cc29baccd33ccc5b5d6a331058796485857cf34f7635aa25fb0cd"},
+ {file = "multidict-6.4.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f2882bf27037eb687e49591690e5d491e677272964f9ec7bc2abbe09108bdfb8"},
+ {file = "multidict-6.4.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fbf226ac85f7d6b6b9ba77db4ec0704fde88463dc17717aec78ec3c8546c70ad"},
+ {file = "multidict-6.4.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e329114f82ad4b9dd291bef614ea8971ec119ecd0f54795109976de75c9a852"},
+ {file = "multidict-6.4.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:1f4e0334d7a555c63f5c8952c57ab6f1c7b4f8c7f3442df689fc9f03df315c08"},
+ {file = "multidict-6.4.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:740915eb776617b57142ce0bb13b7596933496e2f798d3d15a20614adf30d229"},
+ {file = "multidict-6.4.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:255dac25134d2b141c944b59a0d2f7211ca12a6d4779f7586a98b4b03ea80508"},
+ {file = "multidict-6.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4e8535bd4d741039b5aad4285ecd9b902ef9e224711f0b6afda6e38d7ac02c7"},
+ {file = "multidict-6.4.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:30c433a33be000dd968f5750722eaa0991037be0be4a9d453eba121774985bc8"},
+ {file = "multidict-6.4.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:4eb33b0bdc50acd538f45041f5f19945a1f32b909b76d7b117c0c25d8063df56"},
+ {file = "multidict-6.4.3-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:75482f43465edefd8a5d72724887ccdcd0c83778ded8f0cb1e0594bf71736cc0"},
+ {file = "multidict-6.4.3-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:ce5b3082e86aee80b3925ab4928198450d8e5b6466e11501fe03ad2191c6d777"},
+ {file = "multidict-6.4.3-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:e413152e3212c4d39f82cf83c6f91be44bec9ddea950ce17af87fbf4e32ca6b2"},
+ {file = "multidict-6.4.3-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:8aac2eeff69b71f229a405c0a4b61b54bade8e10163bc7b44fcd257949620618"},
+ {file = "multidict-6.4.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ab583ac203af1d09034be41458feeab7863c0635c650a16f15771e1386abf2d7"},
+ {file = "multidict-6.4.3-cp311-cp311-win32.whl", hash = "sha256:1b2019317726f41e81154df636a897de1bfe9228c3724a433894e44cd2512378"},
+ {file = "multidict-6.4.3-cp311-cp311-win_amd64.whl", hash = "sha256:43173924fa93c7486402217fab99b60baf78d33806af299c56133a3755f69589"},
+ {file = "multidict-6.4.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:1f1c2f58f08b36f8475f3ec6f5aeb95270921d418bf18f90dffd6be5c7b0e676"},
+ {file = "multidict-6.4.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:26ae9ad364fc61b936fb7bf4c9d8bd53f3a5b4417142cd0be5c509d6f767e2f1"},
+ {file = "multidict-6.4.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:659318c6c8a85f6ecfc06b4e57529e5a78dfdd697260cc81f683492ad7e9435a"},
+ {file = "multidict-6.4.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e1eb72c741fd24d5a28242ce72bb61bc91f8451877131fa3fe930edb195f7054"},
+ {file = "multidict-6.4.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3cd06d88cb7398252284ee75c8db8e680aa0d321451132d0dba12bc995f0adcc"},
+ {file = "multidict-6.4.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4543d8dc6470a82fde92b035a92529317191ce993533c3c0c68f56811164ed07"},
+ {file = "multidict-6.4.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:30a3ebdc068c27e9d6081fca0e2c33fdf132ecea703a72ea216b81a66860adde"},
+ {file = "multidict-6.4.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b038f10e23f277153f86f95c777ba1958bcd5993194fda26a1d06fae98b2f00c"},
+ {file = "multidict-6.4.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c605a2b2dc14282b580454b9b5d14ebe0668381a3a26d0ac39daa0ca115eb2ae"},
+ {file = "multidict-6.4.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8bd2b875f4ca2bb527fe23e318ddd509b7df163407b0fb717df229041c6df5d3"},
+ {file = "multidict-6.4.3-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:c2e98c840c9c8e65c0e04b40c6c5066c8632678cd50c8721fdbcd2e09f21a507"},
+ {file = "multidict-6.4.3-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:66eb80dd0ab36dbd559635e62fba3083a48a252633164857a1d1684f14326427"},
+ {file = "multidict-6.4.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:c23831bdee0a2a3cf21be057b5e5326292f60472fb6c6f86392bbf0de70ba731"},
+ {file = "multidict-6.4.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:1535cec6443bfd80d028052e9d17ba6ff8a5a3534c51d285ba56c18af97e9713"},
+ {file = "multidict-6.4.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3b73e7227681f85d19dec46e5b881827cd354aabe46049e1a61d2f9aaa4e285a"},
+ {file = "multidict-6.4.3-cp312-cp312-win32.whl", hash = "sha256:8eac0c49df91b88bf91f818e0a24c1c46f3622978e2c27035bfdca98e0e18124"},
+ {file = "multidict-6.4.3-cp312-cp312-win_amd64.whl", hash = "sha256:11990b5c757d956cd1db7cb140be50a63216af32cd6506329c2c59d732d802db"},
+ {file = "multidict-6.4.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:7a76534263d03ae0cfa721fea40fd2b5b9d17a6f85e98025931d41dc49504474"},
+ {file = "multidict-6.4.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:805031c2f599eee62ac579843555ed1ce389ae00c7e9f74c2a1b45e0564a88dd"},
+ {file = "multidict-6.4.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c56c179839d5dcf51d565132185409d1d5dd8e614ba501eb79023a6cab25576b"},
+ {file = "multidict-6.4.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c64f4ddb3886dd8ab71b68a7431ad4aa01a8fa5be5b11543b29674f29ca0ba3"},
+ {file = "multidict-6.4.3-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3002a856367c0b41cad6784f5b8d3ab008eda194ed7864aaa58f65312e2abcac"},
+ {file = "multidict-6.4.3-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3d75e621e7d887d539d6e1d789f0c64271c250276c333480a9e1de089611f790"},
+ {file = "multidict-6.4.3-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:995015cf4a3c0d72cbf453b10a999b92c5629eaf3a0c3e1efb4b5c1f602253bb"},
+ {file = "multidict-6.4.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2b0fabae7939d09d7d16a711468c385272fa1b9b7fb0d37e51143585d8e72e0"},
+ {file = "multidict-6.4.3-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:61ed4d82f8a1e67eb9eb04f8587970d78fe7cddb4e4d6230b77eda23d27938f9"},
+ {file = "multidict-6.4.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:062428944a8dc69df9fdc5d5fc6279421e5f9c75a9ee3f586f274ba7b05ab3c8"},
+ {file = "multidict-6.4.3-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:b90e27b4674e6c405ad6c64e515a505c6d113b832df52fdacb6b1ffd1fa9a1d1"},
+ {file = "multidict-6.4.3-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7d50d4abf6729921e9613d98344b74241572b751c6b37feed75fb0c37bd5a817"},
+ {file = "multidict-6.4.3-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:43fe10524fb0a0514be3954be53258e61d87341008ce4914f8e8b92bee6f875d"},
+ {file = "multidict-6.4.3-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:236966ca6c472ea4e2d3f02f6673ebfd36ba3f23159c323f5a496869bc8e47c9"},
+ {file = "multidict-6.4.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:422a5ec315018e606473ba1f5431e064cf8b2a7468019233dcf8082fabad64c8"},
+ {file = "multidict-6.4.3-cp313-cp313-win32.whl", hash = "sha256:f901a5aace8e8c25d78960dcc24c870c8d356660d3b49b93a78bf38eb682aac3"},
+ {file = "multidict-6.4.3-cp313-cp313-win_amd64.whl", hash = "sha256:1c152c49e42277bc9a2f7b78bd5fa10b13e88d1b0328221e7aef89d5c60a99a5"},
+ {file = "multidict-6.4.3-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:be8751869e28b9c0d368d94f5afcb4234db66fe8496144547b4b6d6a0645cfc6"},
+ {file = "multidict-6.4.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0d4b31f8a68dccbcd2c0ea04f0e014f1defc6b78f0eb8b35f2265e8716a6df0c"},
+ {file = "multidict-6.4.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:032efeab3049e37eef2ff91271884303becc9e54d740b492a93b7e7266e23756"},
+ {file = "multidict-6.4.3-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9e78006af1a7c8a8007e4f56629d7252668344442f66982368ac06522445e375"},
+ {file = "multidict-6.4.3-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:daeac9dd30cda8703c417e4fddccd7c4dc0c73421a0b54a7da2713be125846be"},
+ {file = "multidict-6.4.3-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1f6f90700881438953eae443a9c6f8a509808bc3b185246992c4233ccee37fea"},
+ {file = "multidict-6.4.3-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f84627997008390dd15762128dcf73c3365f4ec0106739cde6c20a07ed198ec8"},
+ {file = "multidict-6.4.3-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3307b48cd156153b117c0ea54890a3bdbf858a5b296ddd40dc3852e5f16e9b02"},
+ {file = "multidict-6.4.3-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ead46b0fa1dcf5af503a46e9f1c2e80b5d95c6011526352fa5f42ea201526124"},
+ {file = "multidict-6.4.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:1748cb2743bedc339d63eb1bca314061568793acd603a6e37b09a326334c9f44"},
+ {file = "multidict-6.4.3-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:acc9fa606f76fc111b4569348cc23a771cb52c61516dcc6bcef46d612edb483b"},
+ {file = "multidict-6.4.3-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:31469d5832b5885adeb70982e531ce86f8c992334edd2f2254a10fa3182ac504"},
+ {file = "multidict-6.4.3-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:ba46b51b6e51b4ef7bfb84b82f5db0dc5e300fb222a8a13b8cd4111898a869cf"},
+ {file = "multidict-6.4.3-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:389cfefb599edf3fcfd5f64c0410da686f90f5f5e2c4d84e14f6797a5a337af4"},
+ {file = "multidict-6.4.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:64bc2bbc5fba7b9db5c2c8d750824f41c6994e3882e6d73c903c2afa78d091e4"},
+ {file = "multidict-6.4.3-cp313-cp313t-win32.whl", hash = "sha256:0ecdc12ea44bab2807d6b4a7e5eef25109ab1c82a8240d86d3c1fc9f3b72efd5"},
+ {file = "multidict-6.4.3-cp313-cp313t-win_amd64.whl", hash = "sha256:7146a8742ea71b5d7d955bffcef58a9e6e04efba704b52a460134fefd10a8208"},
+ {file = "multidict-6.4.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5427a2679e95a642b7f8b0f761e660c845c8e6fe3141cddd6b62005bd133fc21"},
+ {file = "multidict-6.4.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:24a8caa26521b9ad09732972927d7b45b66453e6ebd91a3c6a46d811eeb7349b"},
+ {file = "multidict-6.4.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6b5a272bc7c36a2cd1b56ddc6bff02e9ce499f9f14ee4a45c45434ef083f2459"},
+ {file = "multidict-6.4.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edf74dc5e212b8c75165b435c43eb0d5e81b6b300a938a4eb82827119115e840"},
+ {file = "multidict-6.4.3-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:9f35de41aec4b323c71f54b0ca461ebf694fb48bec62f65221f52e0017955b39"},
+ {file = "multidict-6.4.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ae93e0ff43b6f6892999af64097b18561691ffd835e21a8348a441e256592e1f"},
+ {file = "multidict-6.4.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5e3929269e9d7eff905d6971d8b8c85e7dbc72c18fb99c8eae6fe0a152f2e343"},
+ {file = "multidict-6.4.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb6214fe1750adc2a1b801a199d64b5a67671bf76ebf24c730b157846d0e90d2"},
+ {file = "multidict-6.4.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6d79cf5c0c6284e90f72123f4a3e4add52d6c6ebb4a9054e88df15b8d08444c6"},
+ {file = "multidict-6.4.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2427370f4a255262928cd14533a70d9738dfacadb7563bc3b7f704cc2360fc4e"},
+ {file = "multidict-6.4.3-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:fbd8d737867912b6c5f99f56782b8cb81f978a97b4437a1c476de90a3e41c9a1"},
+ {file = "multidict-6.4.3-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:0ee1bf613c448997f73fc4efb4ecebebb1c02268028dd4f11f011f02300cf1e8"},
+ {file = "multidict-6.4.3-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:578568c4ba5f2b8abd956baf8b23790dbfdc953e87d5b110bce343b4a54fc9e7"},
+ {file = "multidict-6.4.3-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:a059ad6b80de5b84b9fa02a39400319e62edd39d210b4e4f8c4f1243bdac4752"},
+ {file = "multidict-6.4.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:dd53893675b729a965088aaadd6a1f326a72b83742b056c1065bdd2e2a42b4df"},
+ {file = "multidict-6.4.3-cp39-cp39-win32.whl", hash = "sha256:abcfed2c4c139f25c2355e180bcc077a7cae91eefbb8b3927bb3f836c9586f1f"},
+ {file = "multidict-6.4.3-cp39-cp39-win_amd64.whl", hash = "sha256:b1b389ae17296dd739015d5ddb222ee99fd66adeae910de21ac950e00979d897"},
+ {file = "multidict-6.4.3-py3-none-any.whl", hash = "sha256:59fe01ee8e2a1e8ceb3f6dbb216b09c8d9f4ef1c22c4fc825d045a147fa2ebc9"},
+ {file = "multidict-6.4.3.tar.gz", hash = "sha256:3ada0b058c9f213c5f95ba301f922d402ac234f1111a7d8fd70f1b99f3c281ec"},
]
+[package.dependencies]
+typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.11\""}
+
[[package]]
name = "packaging"
-version = "23.1"
+version = "24.2"
description = "Core utilities for Python packages"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
+groups = ["main", "dev"]
files = [
- {file = "packaging-23.1-py3-none-any.whl", hash = "sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61"},
- {file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"},
+ {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"},
+ {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"},
]
[[package]]
name = "pluggy"
-version = "1.2.0"
+version = "1.5.0"
description = "plugin and hook calling mechanisms for python"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
+groups = ["dev"]
files = [
- {file = "pluggy-1.2.0-py3-none-any.whl", hash = "sha256:c2fd55a7d7a3863cba1a013e4e2414658b1d07b6bc57b3919e0c63c9abb99849"},
- {file = "pluggy-1.2.0.tar.gz", hash = "sha256:d12f0c4b579b15f5e054301bb226ee85eeeba08ffec228092f8defbaa3a4c4b3"},
+ {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"},
+ {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"},
]
-[package.dependencies]
-importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""}
-
[package.extras]
dev = ["pre-commit", "tox"]
testing = ["pytest", "pytest-benchmark"]
[[package]]
-name = "py"
-version = "1.11.0"
-description = "library with cross-python path, ini-parsing, io, code, log facilities"
+name = "propcache"
+version = "0.3.1"
+description = "Accelerated property cache"
optional = false
-python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
+python-versions = ">=3.9"
+groups = ["main"]
files = [
- {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"},
- {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"},
+ {file = "propcache-0.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f27785888d2fdd918bc36de8b8739f2d6c791399552333721b58193f68ea3e98"},
+ {file = "propcache-0.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4e89cde74154c7b5957f87a355bb9c8ec929c167b59c83d90654ea36aeb6180"},
+ {file = "propcache-0.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:730178f476ef03d3d4d255f0c9fa186cb1d13fd33ffe89d39f2cda4da90ceb71"},
+ {file = "propcache-0.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:967a8eec513dbe08330f10137eacb427b2ca52118769e82ebcfcab0fba92a649"},
+ {file = "propcache-0.3.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5b9145c35cc87313b5fd480144f8078716007656093d23059e8993d3a8fa730f"},
+ {file = "propcache-0.3.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9e64e948ab41411958670f1093c0a57acfdc3bee5cf5b935671bbd5313bcf229"},
+ {file = "propcache-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:319fa8765bfd6a265e5fa661547556da381e53274bc05094fc9ea50da51bfd46"},
+ {file = "propcache-0.3.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c66d8ccbc902ad548312b96ed8d5d266d0d2c6d006fd0f66323e9d8f2dd49be7"},
+ {file = "propcache-0.3.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:2d219b0dbabe75e15e581fc1ae796109b07c8ba7d25b9ae8d650da582bed01b0"},
+ {file = "propcache-0.3.1-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:cd6a55f65241c551eb53f8cf4d2f4af33512c39da5d9777694e9d9c60872f519"},
+ {file = "propcache-0.3.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:9979643ffc69b799d50d3a7b72b5164a2e97e117009d7af6dfdd2ab906cb72cd"},
+ {file = "propcache-0.3.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:4cf9e93a81979f1424f1a3d155213dc928f1069d697e4353edb8a5eba67c6259"},
+ {file = "propcache-0.3.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:2fce1df66915909ff6c824bbb5eb403d2d15f98f1518e583074671a30fe0c21e"},
+ {file = "propcache-0.3.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:4d0dfdd9a2ebc77b869a0b04423591ea8823f791293b527dc1bb896c1d6f1136"},
+ {file = "propcache-0.3.1-cp310-cp310-win32.whl", hash = "sha256:1f6cc0ad7b4560e5637eb2c994e97b4fa41ba8226069c9277eb5ea7101845b42"},
+ {file = "propcache-0.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:47ef24aa6511e388e9894ec16f0fbf3313a53ee68402bc428744a367ec55b833"},
+ {file = "propcache-0.3.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:7f30241577d2fef2602113b70ef7231bf4c69a97e04693bde08ddab913ba0ce5"},
+ {file = "propcache-0.3.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:43593c6772aa12abc3af7784bff4a41ffa921608dd38b77cf1dfd7f5c4e71371"},
+ {file = "propcache-0.3.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a75801768bbe65499495660b777e018cbe90c7980f07f8aa57d6be79ea6f71da"},
+ {file = "propcache-0.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6f1324db48f001c2ca26a25fa25af60711e09b9aaf4b28488602776f4f9a744"},
+ {file = "propcache-0.3.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5cdb0f3e1eb6dfc9965d19734d8f9c481b294b5274337a8cb5cb01b462dcb7e0"},
+ {file = "propcache-0.3.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1eb34d90aac9bfbced9a58b266f8946cb5935869ff01b164573a7634d39fbcb5"},
+ {file = "propcache-0.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f35c7070eeec2cdaac6fd3fe245226ed2a6292d3ee8c938e5bb645b434c5f256"},
+ {file = "propcache-0.3.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b23c11c2c9e6d4e7300c92e022046ad09b91fd00e36e83c44483df4afa990073"},
+ {file = "propcache-0.3.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:3e19ea4ea0bf46179f8a3652ac1426e6dcbaf577ce4b4f65be581e237340420d"},
+ {file = "propcache-0.3.1-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:bd39c92e4c8f6cbf5f08257d6360123af72af9f4da75a690bef50da77362d25f"},
+ {file = "propcache-0.3.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b0313e8b923b3814d1c4a524c93dfecea5f39fa95601f6a9b1ac96cd66f89ea0"},
+ {file = "propcache-0.3.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:e861ad82892408487be144906a368ddbe2dc6297074ade2d892341b35c59844a"},
+ {file = "propcache-0.3.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:61014615c1274df8da5991a1e5da85a3ccb00c2d4701ac6f3383afd3ca47ab0a"},
+ {file = "propcache-0.3.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:71ebe3fe42656a2328ab08933d420df5f3ab121772eef78f2dc63624157f0ed9"},
+ {file = "propcache-0.3.1-cp311-cp311-win32.whl", hash = "sha256:58aa11f4ca8b60113d4b8e32d37e7e78bd8af4d1a5b5cb4979ed856a45e62005"},
+ {file = "propcache-0.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:9532ea0b26a401264b1365146c440a6d78269ed41f83f23818d4b79497aeabe7"},
+ {file = "propcache-0.3.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:f78eb8422acc93d7b69964012ad7048764bb45a54ba7a39bb9e146c72ea29723"},
+ {file = "propcache-0.3.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:89498dd49c2f9a026ee057965cdf8192e5ae070ce7d7a7bd4b66a8e257d0c976"},
+ {file = "propcache-0.3.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:09400e98545c998d57d10035ff623266927cb784d13dd2b31fd33b8a5316b85b"},
+ {file = "propcache-0.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aa8efd8c5adc5a2c9d3b952815ff8f7710cefdcaf5f2c36d26aff51aeca2f12f"},
+ {file = "propcache-0.3.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c2fe5c910f6007e716a06d269608d307b4f36e7babee5f36533722660e8c4a70"},
+ {file = "propcache-0.3.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a0ab8cf8cdd2194f8ff979a43ab43049b1df0b37aa64ab7eca04ac14429baeb7"},
+ {file = "propcache-0.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:563f9d8c03ad645597b8d010ef4e9eab359faeb11a0a2ac9f7b4bc8c28ebef25"},
+ {file = "propcache-0.3.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fb6e0faf8cb6b4beea5d6ed7b5a578254c6d7df54c36ccd3d8b3eb00d6770277"},
+ {file = "propcache-0.3.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1c5c7ab7f2bb3f573d1cb921993006ba2d39e8621019dffb1c5bc94cdbae81e8"},
+ {file = "propcache-0.3.1-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:050b571b2e96ec942898f8eb46ea4bfbb19bd5502424747e83badc2d4a99a44e"},
+ {file = "propcache-0.3.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:e1c4d24b804b3a87e9350f79e2371a705a188d292fd310e663483af6ee6718ee"},
+ {file = "propcache-0.3.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:e4fe2a6d5ce975c117a6bb1e8ccda772d1e7029c1cca1acd209f91d30fa72815"},
+ {file = "propcache-0.3.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:feccd282de1f6322f56f6845bf1207a537227812f0a9bf5571df52bb418d79d5"},
+ {file = "propcache-0.3.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ec314cde7314d2dd0510c6787326bbffcbdc317ecee6b7401ce218b3099075a7"},
+ {file = "propcache-0.3.1-cp312-cp312-win32.whl", hash = "sha256:7d2d5a0028d920738372630870e7d9644ce437142197f8c827194fca404bf03b"},
+ {file = "propcache-0.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:88c423efef9d7a59dae0614eaed718449c09a5ac79a5f224a8b9664d603f04a3"},
+ {file = "propcache-0.3.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:f1528ec4374617a7a753f90f20e2f551121bb558fcb35926f99e3c42367164b8"},
+ {file = "propcache-0.3.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:dc1915ec523b3b494933b5424980831b636fe483d7d543f7afb7b3bf00f0c10f"},
+ {file = "propcache-0.3.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a110205022d077da24e60b3df8bcee73971be9575dec5573dd17ae5d81751111"},
+ {file = "propcache-0.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d249609e547c04d190e820d0d4c8ca03ed4582bcf8e4e160a6969ddfb57b62e5"},
+ {file = "propcache-0.3.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5ced33d827625d0a589e831126ccb4f5c29dfdf6766cac441d23995a65825dcb"},
+ {file = "propcache-0.3.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4114c4ada8f3181af20808bedb250da6bae56660e4b8dfd9cd95d4549c0962f7"},
+ {file = "propcache-0.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:975af16f406ce48f1333ec5e912fe11064605d5c5b3f6746969077cc3adeb120"},
+ {file = "propcache-0.3.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a34aa3a1abc50740be6ac0ab9d594e274f59960d3ad253cd318af76b996dd654"},
+ {file = "propcache-0.3.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9cec3239c85ed15bfaded997773fdad9fb5662b0a7cbc854a43f291eb183179e"},
+ {file = "propcache-0.3.1-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:05543250deac8e61084234d5fc54f8ebd254e8f2b39a16b1dce48904f45b744b"},
+ {file = "propcache-0.3.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:5cb5918253912e088edbf023788de539219718d3b10aef334476b62d2b53de53"},
+ {file = "propcache-0.3.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f3bbecd2f34d0e6d3c543fdb3b15d6b60dd69970c2b4c822379e5ec8f6f621d5"},
+ {file = "propcache-0.3.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:aca63103895c7d960a5b9b044a83f544b233c95e0dcff114389d64d762017af7"},
+ {file = "propcache-0.3.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5a0a9898fdb99bf11786265468571e628ba60af80dc3f6eb89a3545540c6b0ef"},
+ {file = "propcache-0.3.1-cp313-cp313-win32.whl", hash = "sha256:3a02a28095b5e63128bcae98eb59025924f121f048a62393db682f049bf4ac24"},
+ {file = "propcache-0.3.1-cp313-cp313-win_amd64.whl", hash = "sha256:813fbb8b6aea2fc9659815e585e548fe706d6f663fa73dff59a1677d4595a037"},
+ {file = "propcache-0.3.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:a444192f20f5ce8a5e52761a031b90f5ea6288b1eef42ad4c7e64fef33540b8f"},
+ {file = "propcache-0.3.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0fbe94666e62ebe36cd652f5fc012abfbc2342de99b523f8267a678e4dfdee3c"},
+ {file = "propcache-0.3.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:f011f104db880f4e2166bcdcf7f58250f7a465bc6b068dc84c824a3d4a5c94dc"},
+ {file = "propcache-0.3.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e584b6d388aeb0001d6d5c2bd86b26304adde6d9bb9bfa9c4889805021b96de"},
+ {file = "propcache-0.3.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8a17583515a04358b034e241f952f1715243482fc2c2945fd99a1b03a0bd77d6"},
+ {file = "propcache-0.3.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5aed8d8308215089c0734a2af4f2e95eeb360660184ad3912686c181e500b2e7"},
+ {file = "propcache-0.3.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d8e309ff9a0503ef70dc9a0ebd3e69cf7b3894c9ae2ae81fc10943c37762458"},
+ {file = "propcache-0.3.1-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b655032b202028a582d27aeedc2e813299f82cb232f969f87a4fde491a233f11"},
+ {file = "propcache-0.3.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:9f64d91b751df77931336b5ff7bafbe8845c5770b06630e27acd5dbb71e1931c"},
+ {file = "propcache-0.3.1-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:19a06db789a4bd896ee91ebc50d059e23b3639c25d58eb35be3ca1cbe967c3bf"},
+ {file = "propcache-0.3.1-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:bef100c88d8692864651b5f98e871fb090bd65c8a41a1cb0ff2322db39c96c27"},
+ {file = "propcache-0.3.1-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:87380fb1f3089d2a0b8b00f006ed12bd41bd858fabfa7330c954c70f50ed8757"},
+ {file = "propcache-0.3.1-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:e474fc718e73ba5ec5180358aa07f6aded0ff5f2abe700e3115c37d75c947e18"},
+ {file = "propcache-0.3.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:17d1c688a443355234f3c031349da69444be052613483f3e4158eef751abcd8a"},
+ {file = "propcache-0.3.1-cp313-cp313t-win32.whl", hash = "sha256:359e81a949a7619802eb601d66d37072b79b79c2505e6d3fd8b945538411400d"},
+ {file = "propcache-0.3.1-cp313-cp313t-win_amd64.whl", hash = "sha256:e7fb9a84c9abbf2b2683fa3e7b0d7da4d8ecf139a1c635732a8bda29c5214b0e"},
+ {file = "propcache-0.3.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:ed5f6d2edbf349bd8d630e81f474d33d6ae5d07760c44d33cd808e2f5c8f4ae6"},
+ {file = "propcache-0.3.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:668ddddc9f3075af019f784456267eb504cb77c2c4bd46cc8402d723b4d200bf"},
+ {file = "propcache-0.3.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0c86e7ceea56376216eba345aa1fc6a8a6b27ac236181f840d1d7e6a1ea9ba5c"},
+ {file = "propcache-0.3.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:83be47aa4e35b87c106fc0c84c0fc069d3f9b9b06d3c494cd404ec6747544894"},
+ {file = "propcache-0.3.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:27c6ac6aa9fc7bc662f594ef380707494cb42c22786a558d95fcdedb9aa5d035"},
+ {file = "propcache-0.3.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64a956dff37080b352c1c40b2966b09defb014347043e740d420ca1eb7c9b908"},
+ {file = "propcache-0.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:82de5da8c8893056603ac2d6a89eb8b4df49abf1a7c19d536984c8dd63f481d5"},
+ {file = "propcache-0.3.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0c3c3a203c375b08fd06a20da3cf7aac293b834b6f4f4db71190e8422750cca5"},
+ {file = "propcache-0.3.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:b303b194c2e6f171cfddf8b8ba30baefccf03d36a4d9cab7fd0bb68ba476a3d7"},
+ {file = "propcache-0.3.1-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:916cd229b0150129d645ec51614d38129ee74c03293a9f3f17537be0029a9641"},
+ {file = "propcache-0.3.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:a461959ead5b38e2581998700b26346b78cd98540b5524796c175722f18b0294"},
+ {file = "propcache-0.3.1-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:069e7212890b0bcf9b2be0a03afb0c2d5161d91e1bf51569a64f629acc7defbf"},
+ {file = "propcache-0.3.1-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:ef2e4e91fb3945769e14ce82ed53007195e616a63aa43b40fb7ebaaf907c8d4c"},
+ {file = "propcache-0.3.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:8638f99dca15b9dff328fb6273e09f03d1c50d9b6512f3b65a4154588a7595fe"},
+ {file = "propcache-0.3.1-cp39-cp39-win32.whl", hash = "sha256:6f173bbfe976105aaa890b712d1759de339d8a7cef2fc0a1714cc1a1e1c47f64"},
+ {file = "propcache-0.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:603f1fe4144420374f1a69b907494c3acbc867a581c2d49d4175b0de7cc64566"},
+ {file = "propcache-0.3.1-py3-none-any.whl", hash = "sha256:9a8ecf38de50a7f518c21568c80f985e776397b902f1ce0b01f799aba1608b40"},
+ {file = "propcache-0.3.1.tar.gz", hash = "sha256:40d980c33765359098837527e18eddefc9a24cea5b45e078a7f3bb5b032c6ecf"},
]
[[package]]
name = "pydantic"
-version = "2.5.3"
+version = "2.11.3"
description = "Data validation using Python type hints"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.9"
+groups = ["main"]
files = [
- {file = "pydantic-2.5.3-py3-none-any.whl", hash = "sha256:d0caf5954bee831b6bfe7e338c32b9e30c85dfe080c843680783ac2b631673b4"},
- {file = "pydantic-2.5.3.tar.gz", hash = "sha256:b3ef57c62535b0941697cce638c08900d87fcb67e29cfa99e8a68f747f393f7a"},
+ {file = "pydantic-2.11.3-py3-none-any.whl", hash = "sha256:a082753436a07f9ba1289c6ffa01cd93db3548776088aa917cc43b63f68fa60f"},
+ {file = "pydantic-2.11.3.tar.gz", hash = "sha256:7471657138c16adad9322fe3070c0116dd6c3ad8d649300e3cbdfe91f4db4ec3"},
]
[package.dependencies]
-annotated-types = ">=0.4.0"
-importlib-metadata = {version = "*", markers = "python_version == \"3.7\""}
-pydantic-core = "2.14.6"
-typing-extensions = ">=4.6.1"
+annotated-types = ">=0.6.0"
+pydantic-core = "2.33.1"
+typing-extensions = ">=4.12.2"
+typing-inspection = ">=0.4.0"
[package.extras]
email = ["email-validator (>=2.0.0)"]
+timezone = ["tzdata"]
[[package]]
name = "pydantic-core"
-version = "2.14.6"
-description = ""
+version = "2.33.1"
+description = "Core functionality for Pydantic validation and serialization"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.9"
+groups = ["main"]
files = [
- {file = "pydantic_core-2.14.6-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:72f9a942d739f09cd42fffe5dc759928217649f070056f03c70df14f5770acf9"},
- {file = "pydantic_core-2.14.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6a31d98c0d69776c2576dda4b77b8e0c69ad08e8b539c25c7d0ca0dc19a50d6c"},
- {file = "pydantic_core-2.14.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5aa90562bc079c6c290f0512b21768967f9968e4cfea84ea4ff5af5d917016e4"},
- {file = "pydantic_core-2.14.6-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:370ffecb5316ed23b667d99ce4debe53ea664b99cc37bfa2af47bc769056d534"},
- {file = "pydantic_core-2.14.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f85f3843bdb1fe80e8c206fe6eed7a1caeae897e496542cee499c374a85c6e08"},
- {file = "pydantic_core-2.14.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9862bf828112e19685b76ca499b379338fd4c5c269d897e218b2ae8fcb80139d"},
- {file = "pydantic_core-2.14.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:036137b5ad0cb0004c75b579445a1efccd072387a36c7f217bb8efd1afbe5245"},
- {file = "pydantic_core-2.14.6-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:92879bce89f91f4b2416eba4429c7b5ca22c45ef4a499c39f0c5c69257522c7c"},
- {file = "pydantic_core-2.14.6-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0c08de15d50fa190d577e8591f0329a643eeaed696d7771760295998aca6bc66"},
- {file = "pydantic_core-2.14.6-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:36099c69f6b14fc2c49d7996cbf4f87ec4f0e66d1c74aa05228583225a07b590"},
- {file = "pydantic_core-2.14.6-cp310-none-win32.whl", hash = "sha256:7be719e4d2ae6c314f72844ba9d69e38dff342bc360379f7c8537c48e23034b7"},
- {file = "pydantic_core-2.14.6-cp310-none-win_amd64.whl", hash = "sha256:36fa402dcdc8ea7f1b0ddcf0df4254cc6b2e08f8cd80e7010d4c4ae6e86b2a87"},
- {file = "pydantic_core-2.14.6-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:dea7fcd62915fb150cdc373212141a30037e11b761fbced340e9db3379b892d4"},
- {file = "pydantic_core-2.14.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ffff855100bc066ff2cd3aa4a60bc9534661816b110f0243e59503ec2df38421"},
- {file = "pydantic_core-2.14.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b027c86c66b8627eb90e57aee1f526df77dc6d8b354ec498be9a757d513b92b"},
- {file = "pydantic_core-2.14.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:00b1087dabcee0b0ffd104f9f53d7d3eaddfaa314cdd6726143af6bc713aa27e"},
- {file = "pydantic_core-2.14.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:75ec284328b60a4e91010c1acade0c30584f28a1f345bc8f72fe8b9e46ec6a96"},
- {file = "pydantic_core-2.14.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7e1f4744eea1501404b20b0ac059ff7e3f96a97d3e3f48ce27a139e053bb370b"},
- {file = "pydantic_core-2.14.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2602177668f89b38b9f84b7b3435d0a72511ddef45dc14446811759b82235a1"},
- {file = "pydantic_core-2.14.6-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6c8edaea3089bf908dd27da8f5d9e395c5b4dc092dbcce9b65e7156099b4b937"},
- {file = "pydantic_core-2.14.6-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:478e9e7b360dfec451daafe286998d4a1eeaecf6d69c427b834ae771cad4b622"},
- {file = "pydantic_core-2.14.6-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b6ca36c12a5120bad343eef193cc0122928c5c7466121da7c20f41160ba00ba2"},
- {file = "pydantic_core-2.14.6-cp311-none-win32.whl", hash = "sha256:2b8719037e570639e6b665a4050add43134d80b687288ba3ade18b22bbb29dd2"},
- {file = "pydantic_core-2.14.6-cp311-none-win_amd64.whl", hash = "sha256:78ee52ecc088c61cce32b2d30a826f929e1708f7b9247dc3b921aec367dc1b23"},
- {file = "pydantic_core-2.14.6-cp311-none-win_arm64.whl", hash = "sha256:a19b794f8fe6569472ff77602437ec4430f9b2b9ec7a1105cfd2232f9ba355e6"},
- {file = "pydantic_core-2.14.6-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:667aa2eac9cd0700af1ddb38b7b1ef246d8cf94c85637cbb03d7757ca4c3fdec"},
- {file = "pydantic_core-2.14.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:cdee837710ef6b56ebd20245b83799fce40b265b3b406e51e8ccc5b85b9099b7"},
- {file = "pydantic_core-2.14.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c5bcf3414367e29f83fd66f7de64509a8fd2368b1edf4351e862910727d3e51"},
- {file = "pydantic_core-2.14.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:26a92ae76f75d1915806b77cf459811e772d8f71fd1e4339c99750f0e7f6324f"},
- {file = "pydantic_core-2.14.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a983cca5ed1dd9a35e9e42ebf9f278d344603bfcb174ff99a5815f953925140a"},
- {file = "pydantic_core-2.14.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cb92f9061657287eded380d7dc455bbf115430b3aa4741bdc662d02977e7d0af"},
- {file = "pydantic_core-2.14.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4ace1e220b078c8e48e82c081e35002038657e4b37d403ce940fa679e57113b"},
- {file = "pydantic_core-2.14.6-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ef633add81832f4b56d3b4c9408b43d530dfca29e68fb1b797dcb861a2c734cd"},
- {file = "pydantic_core-2.14.6-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7e90d6cc4aad2cc1f5e16ed56e46cebf4877c62403a311af20459c15da76fd91"},
- {file = "pydantic_core-2.14.6-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e8a5ac97ea521d7bde7621d86c30e86b798cdecd985723c4ed737a2aa9e77d0c"},
- {file = "pydantic_core-2.14.6-cp312-none-win32.whl", hash = "sha256:f27207e8ca3e5e021e2402ba942e5b4c629718e665c81b8b306f3c8b1ddbb786"},
- {file = "pydantic_core-2.14.6-cp312-none-win_amd64.whl", hash = "sha256:b3e5fe4538001bb82e2295b8d2a39356a84694c97cb73a566dc36328b9f83b40"},
- {file = "pydantic_core-2.14.6-cp312-none-win_arm64.whl", hash = "sha256:64634ccf9d671c6be242a664a33c4acf12882670b09b3f163cd00a24cffbd74e"},
- {file = "pydantic_core-2.14.6-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:24368e31be2c88bd69340fbfe741b405302993242ccb476c5c3ff48aeee1afe0"},
- {file = "pydantic_core-2.14.6-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:e33b0834f1cf779aa839975f9d8755a7c2420510c0fa1e9fa0497de77cd35d2c"},
- {file = "pydantic_core-2.14.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6af4b3f52cc65f8a0bc8b1cd9676f8c21ef3e9132f21fed250f6958bd7223bed"},
- {file = "pydantic_core-2.14.6-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d15687d7d7f40333bd8266f3814c591c2e2cd263fa2116e314f60d82086e353a"},
- {file = "pydantic_core-2.14.6-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:095b707bb287bfd534044166ab767bec70a9bba3175dcdc3371782175c14e43c"},
- {file = "pydantic_core-2.14.6-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:94fc0e6621e07d1e91c44e016cc0b189b48db053061cc22d6298a611de8071bb"},
- {file = "pydantic_core-2.14.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ce830e480f6774608dedfd4a90c42aac4a7af0a711f1b52f807130c2e434c06"},
- {file = "pydantic_core-2.14.6-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a306cdd2ad3a7d795d8e617a58c3a2ed0f76c8496fb7621b6cd514eb1532cae8"},
- {file = "pydantic_core-2.14.6-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:2f5fa187bde8524b1e37ba894db13aadd64faa884657473b03a019f625cee9a8"},
- {file = "pydantic_core-2.14.6-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:438027a975cc213a47c5d70672e0d29776082155cfae540c4e225716586be75e"},
- {file = "pydantic_core-2.14.6-cp37-none-win32.whl", hash = "sha256:f96ae96a060a8072ceff4cfde89d261837b4294a4f28b84a28765470d502ccc6"},
- {file = "pydantic_core-2.14.6-cp37-none-win_amd64.whl", hash = "sha256:e646c0e282e960345314f42f2cea5e0b5f56938c093541ea6dbf11aec2862391"},
- {file = "pydantic_core-2.14.6-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:db453f2da3f59a348f514cfbfeb042393b68720787bbef2b4c6068ea362c8149"},
- {file = "pydantic_core-2.14.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3860c62057acd95cc84044e758e47b18dcd8871a328ebc8ccdefd18b0d26a21b"},
- {file = "pydantic_core-2.14.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36026d8f99c58d7044413e1b819a67ca0e0b8ebe0f25e775e6c3d1fabb3c38fb"},
- {file = "pydantic_core-2.14.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8ed1af8692bd8d2a29d702f1a2e6065416d76897d726e45a1775b1444f5928a7"},
- {file = "pydantic_core-2.14.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:314ccc4264ce7d854941231cf71b592e30d8d368a71e50197c905874feacc8a8"},
- {file = "pydantic_core-2.14.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:982487f8931067a32e72d40ab6b47b1628a9c5d344be7f1a4e668fb462d2da42"},
- {file = "pydantic_core-2.14.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2dbe357bc4ddda078f79d2a36fc1dd0494a7f2fad83a0a684465b6f24b46fe80"},
- {file = "pydantic_core-2.14.6-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2f6ffc6701a0eb28648c845f4945a194dc7ab3c651f535b81793251e1185ac3d"},
- {file = "pydantic_core-2.14.6-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:7f5025db12fc6de7bc1104d826d5aee1d172f9ba6ca936bf6474c2148ac336c1"},
- {file = "pydantic_core-2.14.6-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:dab03ed811ed1c71d700ed08bde8431cf429bbe59e423394f0f4055f1ca0ea60"},
- {file = "pydantic_core-2.14.6-cp38-none-win32.whl", hash = "sha256:dfcbebdb3c4b6f739a91769aea5ed615023f3c88cb70df812849aef634c25fbe"},
- {file = "pydantic_core-2.14.6-cp38-none-win_amd64.whl", hash = "sha256:99b14dbea2fdb563d8b5a57c9badfcd72083f6006caf8e126b491519c7d64ca8"},
- {file = "pydantic_core-2.14.6-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:4ce8299b481bcb68e5c82002b96e411796b844d72b3e92a3fbedfe8e19813eab"},
- {file = "pydantic_core-2.14.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b9a9d92f10772d2a181b5ca339dee066ab7d1c9a34ae2421b2a52556e719756f"},
- {file = "pydantic_core-2.14.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd9e98b408384989ea4ab60206b8e100d8687da18b5c813c11e92fd8212a98e0"},
- {file = "pydantic_core-2.14.6-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4f86f1f318e56f5cbb282fe61eb84767aee743ebe32c7c0834690ebea50c0a6b"},
- {file = "pydantic_core-2.14.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:86ce5fcfc3accf3a07a729779d0b86c5d0309a4764c897d86c11089be61da160"},
- {file = "pydantic_core-2.14.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3dcf1978be02153c6a31692d4fbcc2a3f1db9da36039ead23173bc256ee3b91b"},
- {file = "pydantic_core-2.14.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eedf97be7bc3dbc8addcef4142f4b4164066df0c6f36397ae4aaed3eb187d8ab"},
- {file = "pydantic_core-2.14.6-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d5f916acf8afbcab6bacbb376ba7dc61f845367901ecd5e328fc4d4aef2fcab0"},
- {file = "pydantic_core-2.14.6-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:8a14c192c1d724c3acbfb3f10a958c55a2638391319ce8078cb36c02283959b9"},
- {file = "pydantic_core-2.14.6-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0348b1dc6b76041516e8a854ff95b21c55f5a411c3297d2ca52f5528e49d8411"},
- {file = "pydantic_core-2.14.6-cp39-none-win32.whl", hash = "sha256:de2a0645a923ba57c5527497daf8ec5df69c6eadf869e9cd46e86349146e5975"},
- {file = "pydantic_core-2.14.6-cp39-none-win_amd64.whl", hash = "sha256:aca48506a9c20f68ee61c87f2008f81f8ee99f8d7f0104bff3c47e2d148f89d9"},
- {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:d5c28525c19f5bb1e09511669bb57353d22b94cf8b65f3a8d141c389a55dec95"},
- {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:78d0768ee59baa3de0f4adac9e3748b4b1fffc52143caebddfd5ea2961595277"},
- {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b93785eadaef932e4fe9c6e12ba67beb1b3f1e5495631419c784ab87e975670"},
- {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a874f21f87c485310944b2b2734cd6d318765bcbb7515eead33af9641816506e"},
- {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b89f4477d915ea43b4ceea6756f63f0288941b6443a2b28c69004fe07fde0d0d"},
- {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:172de779e2a153d36ee690dbc49c6db568d7b33b18dc56b69a7514aecbcf380d"},
- {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:dfcebb950aa7e667ec226a442722134539e77c575f6cfaa423f24371bb8d2e94"},
- {file = "pydantic_core-2.14.6-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:55a23dcd98c858c0db44fc5c04fc7ed81c4b4d33c653a7c45ddaebf6563a2f66"},
- {file = "pydantic_core-2.14.6-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:4241204e4b36ab5ae466ecec5c4c16527a054c69f99bba20f6f75232a6a534e2"},
- {file = "pydantic_core-2.14.6-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e574de99d735b3fc8364cba9912c2bec2da78775eba95cbb225ef7dda6acea24"},
- {file = "pydantic_core-2.14.6-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1302a54f87b5cd8528e4d6d1bf2133b6aa7c6122ff8e9dc5220fbc1e07bffebd"},
- {file = "pydantic_core-2.14.6-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f8e81e4b55930e5ffab4a68db1af431629cf2e4066dbdbfef65348b8ab804ea8"},
- {file = "pydantic_core-2.14.6-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:c99462ffc538717b3e60151dfaf91125f637e801f5ab008f81c402f1dff0cd0f"},
- {file = "pydantic_core-2.14.6-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:e4cf2d5829f6963a5483ec01578ee76d329eb5caf330ecd05b3edd697e7d768a"},
- {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:cf10b7d58ae4a1f07fccbf4a0a956d705356fea05fb4c70608bb6fa81d103cda"},
- {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:399ac0891c284fa8eb998bcfa323f2234858f5d2efca3950ae58c8f88830f145"},
- {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c6a5c79b28003543db3ba67d1df336f253a87d3112dac3a51b94f7d48e4c0e1"},
- {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:599c87d79cab2a6a2a9df4aefe0455e61e7d2aeede2f8577c1b7c0aec643ee8e"},
- {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:43e166ad47ba900f2542a80d83f9fc65fe99eb63ceec4debec160ae729824052"},
- {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:3a0b5db001b98e1c649dd55afa928e75aa4087e587b9524a4992316fa23c9fba"},
- {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:747265448cb57a9f37572a488a57d873fd96bf51e5bb7edb52cfb37124516da4"},
- {file = "pydantic_core-2.14.6-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:7ebe3416785f65c28f4f9441e916bfc8a54179c8dea73c23023f7086fa601c5d"},
- {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:86c963186ca5e50d5c8287b1d1c9d3f8f024cbe343d048c5bd282aec2d8641f2"},
- {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:e0641b506486f0b4cd1500a2a65740243e8670a2549bb02bc4556a83af84ae03"},
- {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:71d72ca5eaaa8d38c8df16b7deb1a2da4f650c41b58bb142f3fb75d5ad4a611f"},
- {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:27e524624eace5c59af499cd97dc18bb201dc6a7a2da24bfc66ef151c69a5f2a"},
- {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a3dde6cac75e0b0902778978d3b1646ca9f438654395a362cb21d9ad34b24acf"},
- {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:00646784f6cd993b1e1c0e7b0fdcbccc375d539db95555477771c27555e3c556"},
- {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:23598acb8ccaa3d1d875ef3b35cb6376535095e9405d91a3d57a8c7db5d29341"},
- {file = "pydantic_core-2.14.6-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7f41533d7e3cf9520065f610b41ac1c76bc2161415955fbcead4981b22c7611e"},
- {file = "pydantic_core-2.14.6.tar.gz", hash = "sha256:1fd0c1d395372843fba13a51c28e3bb9d59bd7aebfeb17358ffaaa1e4dbbe948"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3077cfdb6125cc8dab61b155fdd714663e401f0e6883f9632118ec12cf42df26"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8ffab8b2908d152e74862d276cf5017c81a2f3719f14e8e3e8d6b83fda863927"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5183e4f6a2d468787243ebcd70cf4098c247e60d73fb7d68d5bc1e1beaa0c4db"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:398a38d323f37714023be1e0285765f0a27243a8b1506b7b7de87b647b517e48"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:87d3776f0001b43acebfa86f8c64019c043b55cc5a6a2e313d728b5c95b46969"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c566dd9c5f63d22226409553531f89de0cac55397f2ab8d97d6f06cfce6d947e"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0d5f3acc81452c56895e90643a625302bd6be351e7010664151cc55b7b97f89"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d3a07fadec2a13274a8d861d3d37c61e97a816beae717efccaa4b36dfcaadcde"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:f99aeda58dce827f76963ee87a0ebe75e648c72ff9ba1174a253f6744f518f65"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:902dbc832141aa0ec374f4310f1e4e7febeebc3256f00dc359a9ac3f264a45dc"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fe44d56aa0b00d66640aa84a3cbe80b7a3ccdc6f0b1ca71090696a6d4777c091"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-win32.whl", hash = "sha256:ed3eb16d51257c763539bde21e011092f127a2202692afaeaccb50db55a31383"},
+ {file = "pydantic_core-2.33.1-cp310-cp310-win_amd64.whl", hash = "sha256:694ad99a7f6718c1a498dc170ca430687a39894a60327f548e02a9c7ee4b6504"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:6e966fc3caaf9f1d96b349b0341c70c8d6573bf1bac7261f7b0ba88f96c56c24"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bfd0adeee563d59c598ceabddf2c92eec77abcb3f4a391b19aa7366170bd9e30"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91815221101ad3c6b507804178a7bb5cb7b2ead9ecd600041669c8d805ebd595"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9fea9c1869bb4742d174a57b4700c6dadea951df8b06de40c2fedb4f02931c2e"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d20eb4861329bb2484c021b9d9a977566ab16d84000a57e28061151c62b349a"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0fb935c5591573ae3201640579f30128ccc10739b45663f93c06796854405505"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c964fd24e6166420d18fb53996d8c9fd6eac9bf5ae3ec3d03015be4414ce497f"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:681d65e9011f7392db5aa002b7423cc442d6a673c635668c227c6c8d0e5a4f77"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e100c52f7355a48413e2999bfb4e139d2977a904495441b374f3d4fb4a170961"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:048831bd363490be79acdd3232f74a0e9951b11b2b4cc058aeb72b22fdc3abe1"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:bdc84017d28459c00db6f918a7272a5190bec3090058334e43a76afb279eac7c"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-win32.whl", hash = "sha256:32cd11c5914d1179df70406427097c7dcde19fddf1418c787540f4b730289896"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-win_amd64.whl", hash = "sha256:2ea62419ba8c397e7da28a9170a16219d310d2cf4970dbc65c32faf20d828c83"},
+ {file = "pydantic_core-2.33.1-cp311-cp311-win_arm64.whl", hash = "sha256:fc903512177361e868bc1f5b80ac8c8a6e05fcdd574a5fb5ffeac5a9982b9e89"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:1293d7febb995e9d3ec3ea09caf1a26214eec45b0f29f6074abb004723fc1de8"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:99b56acd433386c8f20be5c4000786d1e7ca0523c8eefc995d14d79c7a081498"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35a5ec3fa8c2fe6c53e1b2ccc2454398f95d5393ab398478f53e1afbbeb4d939"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b172f7b9d2f3abc0efd12e3386f7e48b576ef309544ac3a63e5e9cdd2e24585d"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9097b9f17f91eea659b9ec58148c0747ec354a42f7389b9d50701610d86f812e"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cc77ec5b7e2118b152b0d886c7514a4653bcb58c6b1d760134a9fab915f777b3"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5e3d15245b08fa4a84cefc6c9222e6f37c98111c8679fbd94aa145f9a0ae23d"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ef99779001d7ac2e2461d8ab55d3373fe7315caefdbecd8ced75304ae5a6fc6b"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:fc6bf8869e193855e8d91d91f6bf59699a5cdfaa47a404e278e776dd7f168b39"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:b1caa0bc2741b043db7823843e1bde8aaa58a55a58fda06083b0569f8b45693a"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ec259f62538e8bf364903a7d0d0239447059f9434b284f5536e8402b7dd198db"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-win32.whl", hash = "sha256:e14f369c98a7c15772b9da98987f58e2b509a93235582838bd0d1d8c08b68fda"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-win_amd64.whl", hash = "sha256:1c607801d85e2e123357b3893f82c97a42856192997b95b4d8325deb1cd0c5f4"},
+ {file = "pydantic_core-2.33.1-cp312-cp312-win_arm64.whl", hash = "sha256:8d13f0276806ee722e70a1c93da19748594f19ac4299c7e41237fc791d1861ea"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:70af6a21237b53d1fe7b9325b20e65cbf2f0a848cf77bed492b029139701e66a"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:282b3fe1bbbe5ae35224a0dbd05aed9ccabccd241e8e6b60370484234b456266"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b315e596282bbb5822d0c7ee9d255595bd7506d1cb20c2911a4da0b970187d3"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1dfae24cf9921875ca0ca6a8ecb4bb2f13c855794ed0d468d6abbec6e6dcd44a"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6dd8ecfde08d8bfadaea669e83c63939af76f4cf5538a72597016edfa3fad516"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2f593494876eae852dc98c43c6f260f45abdbfeec9e4324e31a481d948214764"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:948b73114f47fd7016088e5186d13faf5e1b2fe83f5e320e371f035557fd264d"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e11f3864eb516af21b01e25fac915a82e9ddad3bb0fb9e95a246067398b435a4"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:549150be302428b56fdad0c23c2741dcdb5572413776826c965619a25d9c6bde"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:495bc156026efafd9ef2d82372bd38afce78ddd82bf28ef5276c469e57c0c83e"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ec79de2a8680b1a67a07490bddf9636d5c2fab609ba8c57597e855fa5fa4dacd"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-win32.whl", hash = "sha256:ee12a7be1742f81b8a65b36c6921022301d466b82d80315d215c4c691724986f"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-win_amd64.whl", hash = "sha256:ede9b407e39949d2afc46385ce6bd6e11588660c26f80576c11c958e6647bc40"},
+ {file = "pydantic_core-2.33.1-cp313-cp313-win_arm64.whl", hash = "sha256:aa687a23d4b7871a00e03ca96a09cad0f28f443690d300500603bd0adba4b523"},
+ {file = "pydantic_core-2.33.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:401d7b76e1000d0dd5538e6381d28febdcacb097c8d340dde7d7fc6e13e9f95d"},
+ {file = "pydantic_core-2.33.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7aeb055a42d734c0255c9e489ac67e75397d59c6fbe60d155851e9782f276a9c"},
+ {file = "pydantic_core-2.33.1-cp313-cp313t-win_amd64.whl", hash = "sha256:338ea9b73e6e109f15ab439e62cb3b78aa752c7fd9536794112e14bee02c8d18"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:5ab77f45d33d264de66e1884fca158bc920cb5e27fd0764a72f72f5756ae8bdb"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e7aaba1b4b03aaea7bb59e1b5856d734be011d3e6d98f5bcaa98cb30f375f2ad"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7fb66263e9ba8fea2aa85e1e5578980d127fb37d7f2e292773e7bc3a38fb0c7b"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3f2648b9262607a7fb41d782cc263b48032ff7a03a835581abbf7a3bec62bcf5"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:723c5630c4259400818b4ad096735a829074601805d07f8cafc366d95786d331"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d100e3ae783d2167782391e0c1c7a20a31f55f8015f3293647544df3f9c67824"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:177d50460bc976a0369920b6c744d927b0ecb8606fb56858ff542560251b19e5"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a3edde68d1a1f9af1273b2fe798997b33f90308fb6d44d8550c89fc6a3647cf6"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a62c3c3ef6a7e2c45f7853b10b5bc4ddefd6ee3cd31024754a1a5842da7d598d"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-musllinux_1_1_armv7l.whl", hash = "sha256:c91dbb0ab683fa0cd64a6e81907c8ff41d6497c346890e26b23de7ee55353f96"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9f466e8bf0a62dc43e068c12166281c2eca72121dd2adc1040f3aa1e21ef8599"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-win32.whl", hash = "sha256:ab0277cedb698749caada82e5d099dc9fed3f906a30d4c382d1a21725777a1e5"},
+ {file = "pydantic_core-2.33.1-cp39-cp39-win_amd64.whl", hash = "sha256:5773da0ee2d17136b1f1c6fbde543398d452a6ad2a7b54ea1033e2daa739b8d2"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5c834f54f8f4640fd7e4b193f80eb25a0602bba9e19b3cd2fc7ffe8199f5ae02"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:049e0de24cf23766f12cc5cc71d8abc07d4a9deb9061b334b62093dedc7cb068"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a28239037b3d6f16916a4c831a5a0eadf856bdd6d2e92c10a0da3a59eadcf3e"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d3da303ab5f378a268fa7d45f37d7d85c3ec19769f28d2cc0c61826a8de21fe"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:25626fb37b3c543818c14821afe0fd3830bc327a43953bc88db924b68c5723f1"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:3ab2d36e20fbfcce8f02d73c33a8a7362980cff717926bbae030b93ae46b56c7"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:2f9284e11c751b003fd4215ad92d325d92c9cb19ee6729ebd87e3250072cdcde"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:048c01eee07d37cbd066fc512b9d8b5ea88ceeb4e629ab94b3e56965ad655add"},
+ {file = "pydantic_core-2.33.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:5ccd429694cf26af7997595d627dd2637e7932214486f55b8a357edaac9dae8c"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3a371dc00282c4b84246509a5ddc808e61b9864aa1eae9ecc92bb1268b82db4a"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:f59295ecc75a1788af8ba92f2e8c6eeaa5a94c22fc4d151e8d9638814f85c8fc"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08530b8ac922003033f399128505f513e30ca770527cc8bbacf75a84fcc2c74b"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bae370459da6a5466978c0eacf90690cb57ec9d533f8e63e564ef3822bfa04fe"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e3de2777e3b9f4d603112f78006f4ae0acb936e95f06da6cb1a45fbad6bdb4b5"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:3a64e81e8cba118e108d7126362ea30e021291b7805d47e4896e52c791be2761"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:52928d8c1b6bda03cc6d811e8923dffc87a2d3c8b3bfd2ce16471c7147a24850"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:1b30d92c9412beb5ac6b10a3eb7ef92ccb14e3f2a8d7732e2d739f58b3aa7544"},
+ {file = "pydantic_core-2.33.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:f995719707e0e29f0f41a8aa3bcea6e761a36c9136104d3189eafb83f5cec5e5"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:7edbc454a29fc6aeae1e1eecba4f07b63b8d76e76a748532233c4c167b4cb9ea"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:ad05b683963f69a1d5d2c2bdab1274a31221ca737dbbceaa32bcb67359453cdd"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df6a94bf9452c6da9b5d76ed229a5683d0306ccb91cca8e1eea883189780d568"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7965c13b3967909a09ecc91f21d09cfc4576bf78140b988904e94f130f188396"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3f1fdb790440a34f6ecf7679e1863b825cb5ffde858a9197f851168ed08371e5"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:5277aec8d879f8d05168fdd17ae811dd313b8ff894aeeaf7cd34ad28b4d77e33"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:8ab581d3530611897d863d1a649fb0644b860286b4718db919bfd51ece41f10b"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:0483847fa9ad5e3412265c1bd72aad35235512d9ce9d27d81a56d935ef489672"},
+ {file = "pydantic_core-2.33.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:de9e06abe3cc5ec6a2d5f75bc99b0bdca4f5c719a5b34026f8c57efbdecd2ee3"},
+ {file = "pydantic_core-2.33.1.tar.gz", hash = "sha256:bcc9c6fdb0ced789245b02b7d6603e17d1563064ddcfc36f046b61c0c05dd9df"},
]
[package.dependencies]
@@ -844,134 +1009,139 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
[[package]]
name = "pytest"
-version = "6.2.5"
+version = "8.3.5"
description = "pytest: simple powerful testing with Python"
optional = false
-python-versions = ">=3.6"
+python-versions = ">=3.8"
+groups = ["dev"]
files = [
- {file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"},
- {file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"},
+ {file = "pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820"},
+ {file = "pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845"},
]
[package.dependencies]
-atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""}
-attrs = ">=19.2.0"
colorama = {version = "*", markers = "sys_platform == \"win32\""}
-importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""}
+exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
iniconfig = "*"
packaging = "*"
-pluggy = ">=0.12,<2.0"
-py = ">=1.8.2"
-toml = "*"
+pluggy = ">=1.5,<2"
+tomli = {version = ">=1", markers = "python_version < \"3.11\""}
[package.extras]
-testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"]
+dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
[[package]]
name = "pytest-asyncio"
-version = "0.17.2"
+version = "0.26.0"
description = "Pytest support for asyncio"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.9"
+groups = ["dev"]
files = [
- {file = "pytest-asyncio-0.17.2.tar.gz", hash = "sha256:6d895b02432c028e6957d25fc936494e78c6305736e785d9fee408b1efbc7ff4"},
- {file = "pytest_asyncio-0.17.2-py3-none-any.whl", hash = "sha256:e0fe5dbea40516b661ef1bcfe0bd9461c2847c4ef4bb40012324f2454fb7d56d"},
+ {file = "pytest_asyncio-0.26.0-py3-none-any.whl", hash = "sha256:7b51ed894f4fbea1340262bdae5135797ebbe21d8638978e35d31c6d19f72fb0"},
+ {file = "pytest_asyncio-0.26.0.tar.gz", hash = "sha256:c4df2a697648241ff39e7f0e4a73050b03f123f760673956cf0d72a4990e312f"},
]
[package.dependencies]
-pytest = ">=6.1.0"
-typing-extensions = {version = ">=4.0", markers = "python_version < \"3.8\""}
+pytest = ">=8.2,<9"
+typing-extensions = {version = ">=4.12", markers = "python_version < \"3.10\""}
[package.extras]
-testing = ["coverage (==6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (==0.931)"]
+docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"]
+testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
[[package]]
name = "pytest-cov"
-version = "3.0.0"
+version = "6.1.1"
description = "Pytest plugin for measuring coverage."
optional = false
-python-versions = ">=3.6"
+python-versions = ">=3.9"
+groups = ["dev"]
files = [
- {file = "pytest-cov-3.0.0.tar.gz", hash = "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470"},
- {file = "pytest_cov-3.0.0-py3-none-any.whl", hash = "sha256:578d5d15ac4a25e5f961c938b85a05b09fdaae9deef3bb6de9a6e766622ca7a6"},
+ {file = "pytest_cov-6.1.1-py3-none-any.whl", hash = "sha256:bddf29ed2d0ab6f4df17b4c55b0a657287db8684af9c42ea546b21b1041b3dde"},
+ {file = "pytest_cov-6.1.1.tar.gz", hash = "sha256:46935f7aaefba760e716c2ebfbe1c216240b9592966e7da99ea8292d4d3e2a0a"},
]
[package.dependencies]
-coverage = {version = ">=5.2.1", extras = ["toml"]}
+coverage = {version = ">=7.5", extras = ["toml"]}
pytest = ">=4.6"
[package.extras]
-testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"]
+testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"]
[[package]]
name = "pyyaml"
-version = "6.0.1"
+version = "6.0.2"
description = "YAML parser and emitter for Python"
optional = false
-python-versions = ">=3.6"
+python-versions = ">=3.8"
+groups = ["main"]
files = [
- {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"},
- {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"},
- {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
- {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
- {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
- {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"},
- {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
- {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
- {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
- {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"},
- {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
- {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
- {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
- {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"},
- {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
- {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
- {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
- {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
- {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
- {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
- {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
- {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
- {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"},
- {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"},
- {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"},
- {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"},
- {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"},
- {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"},
- {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"},
- {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"},
- {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"},
- {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"},
- {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"},
- {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"},
- {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"},
- {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"},
- {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"},
- {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"},
- {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"},
- {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"},
- {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"},
- {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"},
- {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"},
- {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"},
- {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"},
- {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"},
- {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"},
- {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"},
- {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"},
- {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"},
- {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
+ {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"},
+ {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"},
+ {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237"},
+ {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b"},
+ {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed"},
+ {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180"},
+ {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68"},
+ {file = "PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99"},
+ {file = "PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e"},
+ {file = "PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774"},
+ {file = "PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee"},
+ {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c"},
+ {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317"},
+ {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85"},
+ {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4"},
+ {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e"},
+ {file = "PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5"},
+ {file = "PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44"},
+ {file = "PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab"},
+ {file = "PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725"},
+ {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5"},
+ {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425"},
+ {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476"},
+ {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48"},
+ {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b"},
+ {file = "PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4"},
+ {file = "PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8"},
+ {file = "PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba"},
+ {file = "PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1"},
+ {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133"},
+ {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484"},
+ {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5"},
+ {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc"},
+ {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652"},
+ {file = "PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183"},
+ {file = "PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563"},
+ {file = "PyYAML-6.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a"},
+ {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5"},
+ {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d"},
+ {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083"},
+ {file = "PyYAML-6.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706"},
+ {file = "PyYAML-6.0.2-cp38-cp38-win32.whl", hash = "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a"},
+ {file = "PyYAML-6.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff"},
+ {file = "PyYAML-6.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d"},
+ {file = "PyYAML-6.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f"},
+ {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290"},
+ {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12"},
+ {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19"},
+ {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e"},
+ {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725"},
+ {file = "PyYAML-6.0.2-cp39-cp39-win32.whl", hash = "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631"},
+ {file = "PyYAML-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8"},
+ {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"},
]
[[package]]
name = "requests"
-version = "2.31.0"
+version = "2.32.3"
description = "Python HTTP for Humans."
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
+groups = ["main"]
files = [
- {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"},
- {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"},
+ {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"},
+ {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"},
]
[package.dependencies]
@@ -984,180 +1154,220 @@ urllib3 = ">=1.21.1,<3"
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
-[[package]]
-name = "toml"
-version = "0.10.2"
-description = "Python Library for Tom's Obvious, Minimal Language"
-optional = false
-python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
-files = [
- {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
- {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
-]
-
[[package]]
name = "tomli"
-version = "2.0.1"
+version = "2.2.1"
description = "A lil' TOML parser"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
+groups = ["dev"]
+markers = "python_full_version <= \"3.11.0a6\""
files = [
- {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
- {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
+ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
+ {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"},
+ {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a"},
+ {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee"},
+ {file = "tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e"},
+ {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4"},
+ {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106"},
+ {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8"},
+ {file = "tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff"},
+ {file = "tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b"},
+ {file = "tomli-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea"},
+ {file = "tomli-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8"},
+ {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192"},
+ {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222"},
+ {file = "tomli-2.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77"},
+ {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6"},
+ {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd"},
+ {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e"},
+ {file = "tomli-2.2.1-cp312-cp312-win32.whl", hash = "sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98"},
+ {file = "tomli-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4"},
+ {file = "tomli-2.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f4039b9cbc3048b2416cc57ab3bda989a6fcf9b36cf8937f01a6e731b64f80d7"},
+ {file = "tomli-2.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:286f0ca2ffeeb5b9bd4fcc8d6c330534323ec51b2f52da063b11c502da16f30c"},
+ {file = "tomli-2.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a92ef1a44547e894e2a17d24e7557a5e85a9e1d0048b0b5e7541f76c5032cb13"},
+ {file = "tomli-2.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9316dc65bed1684c9a98ee68759ceaed29d229e985297003e494aa825ebb0281"},
+ {file = "tomli-2.2.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e85e99945e688e32d5a35c1ff38ed0b3f41f43fad8df0bdf79f72b2ba7bc5272"},
+ {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ac065718db92ca818f8d6141b5f66369833d4a80a9d74435a268c52bdfa73140"},
+ {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:d920f33822747519673ee656a4b6ac33e382eca9d331c87770faa3eef562aeb2"},
+ {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a198f10c4d1b1375d7687bc25294306e551bf1abfa4eace6650070a5c1ae2744"},
+ {file = "tomli-2.2.1-cp313-cp313-win32.whl", hash = "sha256:d3f5614314d758649ab2ab3a62d4f2004c825922f9e370b29416484086b264ec"},
+ {file = "tomli-2.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:a38aa0308e754b0e3c67e344754dff64999ff9b513e691d0e786265c93583c69"},
+ {file = "tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc"},
+ {file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"},
]
[[package]]
name = "tqdm"
-version = "4.66.1"
+version = "4.67.1"
description = "Fast, Extensible Progress Meter"
optional = false
python-versions = ">=3.7"
+groups = ["main"]
files = [
- {file = "tqdm-4.66.1-py3-none-any.whl", hash = "sha256:d302b3c5b53d47bce91fea46679d9c3c6508cf6332229aa1e7d8653723793386"},
- {file = "tqdm-4.66.1.tar.gz", hash = "sha256:d88e651f9db8d8551a62556d3cff9e3034274ca5d66e93197cf2490e2dcb69c7"},
+ {file = "tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2"},
+ {file = "tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2"},
]
[package.dependencies]
colorama = {version = "*", markers = "platform_system == \"Windows\""}
[package.extras]
-dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"]
+dev = ["nbval", "pytest (>=6)", "pytest-asyncio (>=0.24)", "pytest-cov", "pytest-timeout"]
+discord = ["requests"]
notebook = ["ipywidgets (>=6)"]
slack = ["slack-sdk"]
telegram = ["requests"]
[[package]]
name = "typing-extensions"
-version = "4.7.1"
-description = "Backported and Experimental Type Hints for Python 3.7+"
+version = "4.13.2"
+description = "Backported and Experimental Type Hints for Python 3.8+"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
+groups = ["main", "dev"]
files = [
- {file = "typing_extensions-4.7.1-py3-none-any.whl", hash = "sha256:440d5dd3af93b060174bf433bccd69b0babc3b15b1a8dca43789fd7f61514b36"},
- {file = "typing_extensions-4.7.1.tar.gz", hash = "sha256:b75ddc264f0ba5615db7ba217daeb99701ad295353c45f9e95963337ceeeffb2"},
+ {file = "typing_extensions-4.13.2-py3-none-any.whl", hash = "sha256:a439e7c04b49fec3e5d3e2beaa21755cadbbdc391694e28ccdd36ca4a1408f8c"},
+ {file = "typing_extensions-4.13.2.tar.gz", hash = "sha256:e6c81219bd689f51865d9e372991c540bda33a0379d5573cddb9a3a23f7caaef"},
]
+markers = {dev = "python_version < \"3.10\""}
+
+[[package]]
+name = "typing-inspection"
+version = "0.4.0"
+description = "Runtime typing introspection tools"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "typing_inspection-0.4.0-py3-none-any.whl", hash = "sha256:50e72559fcd2a6367a19f7a7e610e6afcb9fac940c650290eed893d61386832f"},
+ {file = "typing_inspection-0.4.0.tar.gz", hash = "sha256:9765c87de36671694a67904bf2c96e395be9c6439bb6c87b5142569dcdd65122"},
+]
+
+[package.dependencies]
+typing-extensions = ">=4.12.0"
[[package]]
name = "urllib3"
-version = "2.0.5"
+version = "2.4.0"
description = "HTTP library with thread-safe connection pooling, file post, and more."
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.9"
+groups = ["main"]
files = [
- {file = "urllib3-2.0.5-py3-none-any.whl", hash = "sha256:ef16afa8ba34a1f989db38e1dbbe0c302e4289a47856990d0682e374563ce35e"},
- {file = "urllib3-2.0.5.tar.gz", hash = "sha256:13abf37382ea2ce6fb744d4dad67838eec857c9f4f57009891805e0b5e123594"},
+ {file = "urllib3-2.4.0-py3-none-any.whl", hash = "sha256:4e16665048960a0900c702d4a66415956a584919c03361cac9f1df5c5dd7e813"},
+ {file = "urllib3-2.4.0.tar.gz", hash = "sha256:414bc6535b787febd7567804cc015fee39daab8ad86268f1310a9250697de466"},
]
[package.extras]
brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
-secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"]
+h2 = ["h2 (>=4,<5)"]
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "yarl"
-version = "1.9.2"
+version = "1.19.0"
description = "Yet another URL library"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.9"
+groups = ["main"]
files = [
- {file = "yarl-1.9.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8c2ad583743d16ddbdf6bb14b5cd76bf43b0d0006e918809d5d4ddf7bde8dd82"},
- {file = "yarl-1.9.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:82aa6264b36c50acfb2424ad5ca537a2060ab6de158a5bd2a72a032cc75b9eb8"},
- {file = "yarl-1.9.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c0c77533b5ed4bcc38e943178ccae29b9bcf48ffd1063f5821192f23a1bd27b9"},
- {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee4afac41415d52d53a9833ebae7e32b344be72835bbb589018c9e938045a560"},
- {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9bf345c3a4f5ba7f766430f97f9cc1320786f19584acc7086491f45524a551ac"},
- {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a96c19c52ff442a808c105901d0bdfd2e28575b3d5f82e2f5fd67e20dc5f4ea"},
- {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:891c0e3ec5ec881541f6c5113d8df0315ce5440e244a716b95f2525b7b9f3608"},
- {file = "yarl-1.9.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c3a53ba34a636a256d767c086ceb111358876e1fb6b50dfc4d3f4951d40133d5"},
- {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:566185e8ebc0898b11f8026447eacd02e46226716229cea8db37496c8cdd26e0"},
- {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:2b0738fb871812722a0ac2154be1f049c6223b9f6f22eec352996b69775b36d4"},
- {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:32f1d071b3f362c80f1a7d322bfd7b2d11e33d2adf395cc1dd4df36c9c243095"},
- {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:e9fdc7ac0d42bc3ea78818557fab03af6181e076a2944f43c38684b4b6bed8e3"},
- {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:56ff08ab5df8429901ebdc5d15941b59f6253393cb5da07b4170beefcf1b2528"},
- {file = "yarl-1.9.2-cp310-cp310-win32.whl", hash = "sha256:8ea48e0a2f931064469bdabca50c2f578b565fc446f302a79ba6cc0ee7f384d3"},
- {file = "yarl-1.9.2-cp310-cp310-win_amd64.whl", hash = "sha256:50f33040f3836e912ed16d212f6cc1efb3231a8a60526a407aeb66c1c1956dde"},
- {file = "yarl-1.9.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:646d663eb2232d7909e6601f1a9107e66f9791f290a1b3dc7057818fe44fc2b6"},
- {file = "yarl-1.9.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:aff634b15beff8902d1f918012fc2a42e0dbae6f469fce134c8a0dc51ca423bb"},
- {file = "yarl-1.9.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a83503934c6273806aed765035716216cc9ab4e0364f7f066227e1aaea90b8d0"},
- {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b25322201585c69abc7b0e89e72790469f7dad90d26754717f3310bfe30331c2"},
- {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:22a94666751778629f1ec4280b08eb11815783c63f52092a5953faf73be24191"},
- {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ec53a0ea2a80c5cd1ab397925f94bff59222aa3cf9c6da938ce05c9ec20428d"},
- {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:159d81f22d7a43e6eabc36d7194cb53f2f15f498dbbfa8edc8a3239350f59fe7"},
- {file = "yarl-1.9.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:832b7e711027c114d79dffb92576acd1bd2decc467dec60e1cac96912602d0e6"},
- {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:95d2ecefbcf4e744ea952d073c6922e72ee650ffc79028eb1e320e732898d7e8"},
- {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d4e2c6d555e77b37288eaf45b8f60f0737c9efa3452c6c44626a5455aeb250b9"},
- {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:783185c75c12a017cc345015ea359cc801c3b29a2966c2655cd12b233bf5a2be"},
- {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:b8cc1863402472f16c600e3e93d542b7e7542a540f95c30afd472e8e549fc3f7"},
- {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:822b30a0f22e588b32d3120f6d41e4ed021806418b4c9f0bc3048b8c8cb3f92a"},
- {file = "yarl-1.9.2-cp311-cp311-win32.whl", hash = "sha256:a60347f234c2212a9f0361955007fcf4033a75bf600a33c88a0a8e91af77c0e8"},
- {file = "yarl-1.9.2-cp311-cp311-win_amd64.whl", hash = "sha256:be6b3fdec5c62f2a67cb3f8c6dbf56bbf3f61c0f046f84645cd1ca73532ea051"},
- {file = "yarl-1.9.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:38a3928ae37558bc1b559f67410df446d1fbfa87318b124bf5032c31e3447b74"},
- {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac9bb4c5ce3975aeac288cfcb5061ce60e0d14d92209e780c93954076c7c4367"},
- {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3da8a678ca8b96c8606bbb8bfacd99a12ad5dd288bc6f7979baddd62f71c63ef"},
- {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:13414591ff516e04fcdee8dc051c13fd3db13b673c7a4cb1350e6b2ad9639ad3"},
- {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf74d08542c3a9ea97bb8f343d4fcbd4d8f91bba5ec9d5d7f792dbe727f88938"},
- {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e7221580dc1db478464cfeef9b03b95c5852cc22894e418562997df0d074ccc"},
- {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:494053246b119b041960ddcd20fd76224149cfea8ed8777b687358727911dd33"},
- {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:52a25809fcbecfc63ac9ba0c0fb586f90837f5425edfd1ec9f3372b119585e45"},
- {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:e65610c5792870d45d7b68c677681376fcf9cc1c289f23e8e8b39c1485384185"},
- {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:1b1bba902cba32cdec51fca038fd53f8beee88b77efc373968d1ed021024cc04"},
- {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:662e6016409828ee910f5d9602a2729a8a57d74b163c89a837de3fea050c7582"},
- {file = "yarl-1.9.2-cp37-cp37m-win32.whl", hash = "sha256:f364d3480bffd3aa566e886587eaca7c8c04d74f6e8933f3f2c996b7f09bee1b"},
- {file = "yarl-1.9.2-cp37-cp37m-win_amd64.whl", hash = "sha256:6a5883464143ab3ae9ba68daae8e7c5c95b969462bbe42e2464d60e7e2698368"},
- {file = "yarl-1.9.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5610f80cf43b6202e2c33ba3ec2ee0a2884f8f423c8f4f62906731d876ef4fac"},
- {file = "yarl-1.9.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b9a4e67ad7b646cd6f0938c7ebfd60e481b7410f574c560e455e938d2da8e0f4"},
- {file = "yarl-1.9.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:83fcc480d7549ccebe9415d96d9263e2d4226798c37ebd18c930fce43dfb9574"},
- {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5fcd436ea16fee7d4207c045b1e340020e58a2597301cfbcfdbe5abd2356c2fb"},
- {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84e0b1599334b1e1478db01b756e55937d4614f8654311eb26012091be109d59"},
- {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3458a24e4ea3fd8930e934c129b676c27452e4ebda80fbe47b56d8c6c7a63a9e"},
- {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:838162460b3a08987546e881a2bfa573960bb559dfa739e7800ceeec92e64417"},
- {file = "yarl-1.9.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f4e2d08f07a3d7d3e12549052eb5ad3eab1c349c53ac51c209a0e5991bbada78"},
- {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:de119f56f3c5f0e2fb4dee508531a32b069a5f2c6e827b272d1e0ff5ac040333"},
- {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:149ddea5abf329752ea5051b61bd6c1d979e13fbf122d3a1f9f0c8be6cb6f63c"},
- {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:674ca19cbee4a82c9f54e0d1eee28116e63bc6fd1e96c43031d11cbab8b2afd5"},
- {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:9b3152f2f5677b997ae6c804b73da05a39daa6a9e85a512e0e6823d81cdad7cc"},
- {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5415d5a4b080dc9612b1b63cba008db84e908b95848369aa1da3686ae27b6d2b"},
- {file = "yarl-1.9.2-cp38-cp38-win32.whl", hash = "sha256:f7a3d8146575e08c29ed1cd287068e6d02f1c7bdff8970db96683b9591b86ee7"},
- {file = "yarl-1.9.2-cp38-cp38-win_amd64.whl", hash = "sha256:63c48f6cef34e6319a74c727376e95626f84ea091f92c0250a98e53e62c77c72"},
- {file = "yarl-1.9.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:75df5ef94c3fdc393c6b19d80e6ef1ecc9ae2f4263c09cacb178d871c02a5ba9"},
- {file = "yarl-1.9.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c027a6e96ef77d401d8d5a5c8d6bc478e8042f1e448272e8d9752cb0aff8b5c8"},
- {file = "yarl-1.9.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f3b078dbe227f79be488ffcfc7a9edb3409d018e0952cf13f15fd6512847f3f7"},
- {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59723a029760079b7d991a401386390c4be5bfec1e7dd83e25a6a0881859e716"},
- {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b03917871bf859a81ccb180c9a2e6c1e04d2f6a51d953e6a5cdd70c93d4e5a2a"},
- {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c1012fa63eb6c032f3ce5d2171c267992ae0c00b9e164efe4d73db818465fac3"},
- {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a74dcbfe780e62f4b5a062714576f16c2f3493a0394e555ab141bf0d746bb955"},
- {file = "yarl-1.9.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8c56986609b057b4839968ba901944af91b8e92f1725d1a2d77cbac6972b9ed1"},
- {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2c315df3293cd521033533d242d15eab26583360b58f7ee5d9565f15fee1bef4"},
- {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:b7232f8dfbd225d57340e441d8caf8652a6acd06b389ea2d3222b8bc89cbfca6"},
- {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:53338749febd28935d55b41bf0bcc79d634881195a39f6b2f767870b72514caf"},
- {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:066c163aec9d3d073dc9ffe5dd3ad05069bcb03fcaab8d221290ba99f9f69ee3"},
- {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8288d7cd28f8119b07dd49b7230d6b4562f9b61ee9a4ab02221060d21136be80"},
- {file = "yarl-1.9.2-cp39-cp39-win32.whl", hash = "sha256:b124e2a6d223b65ba8768d5706d103280914d61f5cae3afbc50fc3dfcc016623"},
- {file = "yarl-1.9.2-cp39-cp39-win_amd64.whl", hash = "sha256:61016e7d582bc46a5378ffdd02cd0314fb8ba52f40f9cf4d9a5e7dbef88dee18"},
- {file = "yarl-1.9.2.tar.gz", hash = "sha256:04ab9d4b9f587c06d801c2abfe9317b77cdf996c65a90d5e84ecc45010823571"},
+ {file = "yarl-1.19.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0bae32f8ebd35c04d6528cedb4a26b8bf25339d3616b04613b97347f919b76d3"},
+ {file = "yarl-1.19.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8015a076daf77823e7ebdcba474156587391dab4e70c732822960368c01251e6"},
+ {file = "yarl-1.19.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9973ac95327f5d699eb620286c39365990b240031672b5c436a4cd00539596c5"},
+ {file = "yarl-1.19.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd4b5fbd7b9dde785cfeb486b8cca211a0b138d4f3a7da27db89a25b3c482e5c"},
+ {file = "yarl-1.19.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:75460740005de5a912b19f657848aef419387426a40f581b1dc9fac0eb9addb5"},
+ {file = "yarl-1.19.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:57abd66ca913f2cfbb51eb3dbbbac3648f1f6983f614a4446e0802e241441d2a"},
+ {file = "yarl-1.19.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:46ade37911b7c99ce28a959147cb28bffbd14cea9e7dd91021e06a8d2359a5aa"},
+ {file = "yarl-1.19.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8346ec72ada749a6b5d82bff7be72578eab056ad7ec38c04f668a685abde6af0"},
+ {file = "yarl-1.19.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7e4cb14a6ee5b6649ccf1c6d648b4da9220e8277d4d4380593c03cc08d8fe937"},
+ {file = "yarl-1.19.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:66fc1c2926a73a2fb46e4b92e3a6c03904d9bc3a0b65e01cb7d2b84146a8bd3b"},
+ {file = "yarl-1.19.0-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:5a70201dd1e0a4304849b6445a9891d7210604c27e67da59091d5412bc19e51c"},
+ {file = "yarl-1.19.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e4807aab1bdeab6ae6f296be46337a260ae4b1f3a8c2fcd373e236b4b2b46efd"},
+ {file = "yarl-1.19.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:ae584afe81a1de4c1bb06672481050f0d001cad13163e3c019477409f638f9b7"},
+ {file = "yarl-1.19.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:30eaf4459df6e91f21b2999d1ee18f891bcd51e3cbe1de301b4858c84385895b"},
+ {file = "yarl-1.19.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:0e617d45d03c8dec0dfce6f51f3e1b8a31aa81aaf4a4d1442fdb232bcf0c6d8c"},
+ {file = "yarl-1.19.0-cp310-cp310-win32.whl", hash = "sha256:32ba32d0fa23893fd8ea8d05bdb05de6eb19d7f2106787024fd969f4ba5466cb"},
+ {file = "yarl-1.19.0-cp310-cp310-win_amd64.whl", hash = "sha256:545575ecfcd465891b51546c2bcafdde0acd2c62c2097d8d71902050b20e4922"},
+ {file = "yarl-1.19.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:163ff326680de5f6d4966954cf9e3fe1bf980f5fee2255e46e89b8cf0f3418b5"},
+ {file = "yarl-1.19.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a626c4d9cca298d1be8625cff4b17004a9066330ac82d132bbda64a4c17c18d3"},
+ {file = "yarl-1.19.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:961c3e401ea7f13d02b8bb7cb0c709152a632a6e14cdc8119e9c6ee5596cd45d"},
+ {file = "yarl-1.19.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a39d7b807ab58e633ed760f80195cbd145b58ba265436af35f9080f1810dfe64"},
+ {file = "yarl-1.19.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:c4228978fb59c6b10f60124ba8e311c26151e176df364e996f3f8ff8b93971b5"},
+ {file = "yarl-1.19.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ba536b17ecf3c74a94239ec1137a3ad3caea8c0e4deb8c8d2ffe847d870a8c5"},
+ {file = "yarl-1.19.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a251e00e445d2e9df7b827c9843c0b87f58a3254aaa3f162fb610747491fe00f"},
+ {file = "yarl-1.19.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9b92431d8b4d4ca5ccbfdbac95b05a3a6cd70cd73aa62f32f9627acfde7549c"},
+ {file = "yarl-1.19.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ec2f56edaf476f70b5831bbd59700b53d9dd011b1f77cd4846b5ab5c5eafdb3f"},
+ {file = "yarl-1.19.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:acf9b92c4245ac8b59bc7ec66a38d3dcb8d1f97fac934672529562bb824ecadb"},
+ {file = "yarl-1.19.0-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:57711f1465c06fee8825b95c0b83e82991e6d9425f9a042c3c19070a70ac92bf"},
+ {file = "yarl-1.19.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:528e86f5b1de0ad8dd758ddef4e0ed24f5d946d4a1cef80ffb2d4fca4e10f122"},
+ {file = "yarl-1.19.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:3b77173663e075d9e5a57e09d711e9da2f3266be729ecca0b8ae78190990d260"},
+ {file = "yarl-1.19.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:d8717924cf0a825b62b1a96fc7d28aab7f55a81bf5338b8ef41d7a76ab9223e9"},
+ {file = "yarl-1.19.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0df9f0221a78d858793f40cbea3915c29f969c11366646a92ca47e080a14f881"},
+ {file = "yarl-1.19.0-cp311-cp311-win32.whl", hash = "sha256:8b3ade62678ee2c7c10dcd6be19045135e9badad53108f7d2ed14896ee396045"},
+ {file = "yarl-1.19.0-cp311-cp311-win_amd64.whl", hash = "sha256:0626ee31edb23ac36bdffe607231de2cca055ad3a5e2dc5da587ef8bc6a321bc"},
+ {file = "yarl-1.19.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:7b687c334da3ff8eab848c9620c47a253d005e78335e9ce0d6868ed7e8fd170b"},
+ {file = "yarl-1.19.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b0fe766febcf523a2930b819c87bb92407ae1368662c1bc267234e79b20ff894"},
+ {file = "yarl-1.19.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:742ceffd3c7beeb2b20d47cdb92c513eef83c9ef88c46829f88d5b06be6734ee"},
+ {file = "yarl-1.19.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2af682a1e97437382ee0791eacbf540318bd487a942e068e7e0a6c571fadbbd3"},
+ {file = "yarl-1.19.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:63702f1a098d0eaaea755e9c9d63172be1acb9e2d4aeb28b187092bcc9ca2d17"},
+ {file = "yarl-1.19.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3560dcba3c71ae7382975dc1e912ee76e50b4cd7c34b454ed620d55464f11876"},
+ {file = "yarl-1.19.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:68972df6a0cc47c8abaf77525a76ee5c5f6ea9bbdb79b9565b3234ded3c5e675"},
+ {file = "yarl-1.19.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5684e7ff93ea74e47542232bd132f608df4d449f8968fde6b05aaf9e08a140f9"},
+ {file = "yarl-1.19.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8182ad422bfacdebd4759ce3adc6055c0c79d4740aea1104e05652a81cd868c6"},
+ {file = "yarl-1.19.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:aee5b90a5a9b71ac57400a7bdd0feaa27c51e8f961decc8d412e720a004a1791"},
+ {file = "yarl-1.19.0-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:8c0b2371858d5a814b08542d5d548adb03ff2d7ab32f23160e54e92250961a72"},
+ {file = "yarl-1.19.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:cd430c2b7df4ae92498da09e9b12cad5bdbb140d22d138f9e507de1aa3edfea3"},
+ {file = "yarl-1.19.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:a93208282c0ccdf73065fd76c6c129bd428dba5ff65d338ae7d2ab27169861a0"},
+ {file = "yarl-1.19.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:b8179280cdeb4c36eb18d6534a328f9d40da60d2b96ac4a295c5f93e2799e9d9"},
+ {file = "yarl-1.19.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:eda3c2b42dc0c389b7cfda2c4df81c12eeb552019e0de28bde8f913fc3d1fcf3"},
+ {file = "yarl-1.19.0-cp312-cp312-win32.whl", hash = "sha256:57f3fed859af367b9ca316ecc05ce79ce327d6466342734305aa5cc380e4d8be"},
+ {file = "yarl-1.19.0-cp312-cp312-win_amd64.whl", hash = "sha256:5507c1f7dd3d41251b67eecba331c8b2157cfd324849879bebf74676ce76aff7"},
+ {file = "yarl-1.19.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:59281b9ed27bc410e0793833bcbe7fc149739d56ffa071d1e0fe70536a4f7b61"},
+ {file = "yarl-1.19.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:d27a6482ad5e05e8bafd47bf42866f8a1c0c3345abcb48d4511b3c29ecc197dc"},
+ {file = "yarl-1.19.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7a8e19fd5a6fdf19a91f2409665c7a089ffe7b9b5394ab33c0eec04cbecdd01f"},
+ {file = "yarl-1.19.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cda34ab19099c3a1685ad48fe45172536610c312b993310b5f1ca3eb83453b36"},
+ {file = "yarl-1.19.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:7908a25d33f94852b479910f9cae6cdb9e2a509894e8d5f416c8342c0253c397"},
+ {file = "yarl-1.19.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e66c14d162bac94973e767b24de5d7e6c5153f7305a64ff4fcba701210bcd638"},
+ {file = "yarl-1.19.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c03607bf932aa4cfae371e2dc9ca8b76faf031f106dac6a6ff1458418140c165"},
+ {file = "yarl-1.19.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9931343d1c1f4e77421687b6b94bbebd8a15a64ab8279adf6fbb047eff47e536"},
+ {file = "yarl-1.19.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:262087a8a0d73e1d169d45c2baf968126f93c97cf403e1af23a7d5455d52721f"},
+ {file = "yarl-1.19.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:70f384921c24e703d249a6ccdabeb57dd6312b568b504c69e428a8dd3e8e68ca"},
+ {file = "yarl-1.19.0-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:756b9ea5292a2c180d1fe782a377bc4159b3cfefaca7e41b5b0a00328ef62fa9"},
+ {file = "yarl-1.19.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:cbeb9c145d534c240a63b6ecc8a8dd451faeb67b3dc61d729ec197bb93e29497"},
+ {file = "yarl-1.19.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:087ae8f8319848c18e0d114d0f56131a9c017f29200ab1413b0137ad7c83e2ae"},
+ {file = "yarl-1.19.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:362f5480ba527b6c26ff58cff1f229afe8b7fdd54ee5ffac2ab827c1a75fc71c"},
+ {file = "yarl-1.19.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f408d4b4315e814e5c3668094e33d885f13c7809cbe831cbdc5b1bb8c7a448f4"},
+ {file = "yarl-1.19.0-cp313-cp313-win32.whl", hash = "sha256:24e4c367ad69988a2283dd45ea88172561ca24b2326b9781e164eb46eea68345"},
+ {file = "yarl-1.19.0-cp313-cp313-win_amd64.whl", hash = "sha256:0110f91c57ab43d1538dfa92d61c45e33b84df9257bd08fcfcda90cce931cbc9"},
+ {file = "yarl-1.19.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:85ac908cd5a97bbd3048cca9f1bf37b932ea26c3885099444f34b0bf5d5e9fa6"},
+ {file = "yarl-1.19.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6ba0931b559f1345df48a78521c31cfe356585670e8be22af84a33a39f7b9221"},
+ {file = "yarl-1.19.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5bc503e1c1fee1b86bcb58db67c032957a52cae39fe8ddd95441f414ffbab83e"},
+ {file = "yarl-1.19.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d995122dcaf180fd4830a9aa425abddab7c0246107c21ecca2fa085611fa7ce9"},
+ {file = "yarl-1.19.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:217f69e60a14da4eed454a030ea8283f8fbd01a7d6d81e57efb865856822489b"},
+ {file = "yarl-1.19.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aad67c8f13a4b79990082f72ef09c078a77de2b39899aabf3960a48069704973"},
+ {file = "yarl-1.19.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dff065a1a8ed051d7e641369ba1ad030d5a707afac54cf4ede7069b959898835"},
+ {file = "yarl-1.19.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ada882e26b16ee651ab6544ce956f2f4beaed38261238f67c2a96db748e17741"},
+ {file = "yarl-1.19.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:67a56b1acc7093451ea2de0687aa3bd4e58d6b4ef6cbeeaad137b45203deaade"},
+ {file = "yarl-1.19.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:e97d2f0a06b39e231e59ebab0e6eec45c7683b339e8262299ac952707bdf7688"},
+ {file = "yarl-1.19.0-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:a5288adb7c59d0f54e4ad58d86fb06d4b26e08a59ed06d00a1aac978c0e32884"},
+ {file = "yarl-1.19.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:1efbf4d03e6eddf5da27752e0b67a8e70599053436e9344d0969532baa99df53"},
+ {file = "yarl-1.19.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:f228f42f29cc87db67020f7d71624102b2c837686e55317b16e1d3ef2747a993"},
+ {file = "yarl-1.19.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:c515f7dd60ca724e4c62b34aeaa603188964abed2eb66bb8e220f7f104d5a187"},
+ {file = "yarl-1.19.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:4815ec6d3d68a96557fa71bd36661b45ac773fb50e5cfa31a7e843edb098f060"},
+ {file = "yarl-1.19.0-cp39-cp39-win32.whl", hash = "sha256:9fac2dd1c5ecb921359d9546bc23a6dcc18c6acd50c6d96f118188d68010f497"},
+ {file = "yarl-1.19.0-cp39-cp39-win_amd64.whl", hash = "sha256:5864f539ce86b935053bfa18205fa08ce38e9a40ea4d51b19ce923345f0ed5db"},
+ {file = "yarl-1.19.0-py3-none-any.whl", hash = "sha256:a727101eb27f66727576630d02985d8a065d09cd0b5fcbe38a5793f71b2a97ef"},
+ {file = "yarl-1.19.0.tar.gz", hash = "sha256:01e02bb80ae0dbed44273c304095295106e1d9470460e773268a27d11e594892"},
]
[package.dependencies]
idna = ">=2.0"
multidict = ">=4.0"
-typing-extensions = {version = ">=3.7.4", markers = "python_version < \"3.8\""}
-
-[[package]]
-name = "zipp"
-version = "3.15.0"
-description = "Backport of pathlib-compatible object wrapper for zip files"
-optional = false
-python-versions = ">=3.7"
-files = [
- {file = "zipp-3.15.0-py3-none-any.whl", hash = "sha256:48904fc76a60e542af151aded95726c1a5c34ed43ab4134b597665c86d7ad556"},
- {file = "zipp-3.15.0.tar.gz", hash = "sha256:112929ad649da941c23de50f356a2b5570c954b65150642bccdd66bf194d224b"},
-]
-
-[package.extras]
-docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
-testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
+propcache = ">=0.2.1"
[metadata]
-lock-version = "2.0"
-python-versions = "^3.7"
-content-hash = "b7fab8703967f2616ea59a98a437cd30f97f0c8d2a06e399d688814a2a2c64f8"
+lock-version = "2.1"
+python-versions = "^3.9"
+content-hash = "f136e898d37b7c7db1ccceb1822ade280d3542ca19cdd9dcf583cb9aefef11c6"
diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml
index 47ef9d717..1448d7618 100644
--- a/clients/python/pyproject.toml
+++ b/clients/python/pyproject.toml
@@ -11,15 +11,15 @@ repository = "https://github.com/huggingface/text-generation-inference"
[tool.poetry.dependencies]
-python = "^3.7"
+python = "^3.9"
pydantic = "> 2, < 3"
-aiohttp = "^3.8"
+aiohttp = "^3.11"
huggingface-hub = ">= 0.12, < 1.0"
-[tool.poetry.dev-dependencies]
-pytest = "^6.2.5"
-pytest-asyncio = "^0.17.2"
-pytest-cov = "^3.0.0"
+[tool.poetry.group.dev.dependencies]
+pytest = "^8"
+pytest-asyncio = "^0.26"
+pytest-cov = "^6.0.0"
[tool.pytest.ini_options]
asyncio_mode = "auto"
diff --git a/clients/python/tests/conftest.py b/clients/python/tests/conftest.py
index 17bb73b5d..f3db6e68a 100644
--- a/clients/python/tests/conftest.py
+++ b/clients/python/tests/conftest.py
@@ -21,7 +21,7 @@ def fake_model():
@pytest.fixture
def unsupported_model():
- return "gpt2"
+ return "google-bert/bert-base-uncased"
@pytest.fixture
diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py
index 8aed865b7..0c702c636 100644
--- a/clients/python/tests/test_client.py
+++ b/clients/python/tests/test_client.py
@@ -2,7 +2,7 @@ import pytest
from text_generation import Client, AsyncClient
from text_generation.errors import NotFoundError, ValidationError
-from text_generation.types import FinishReason, InputToken
+from text_generation.types import FinishReason
def test_generate(llama_7b_url, hf_headers):
@@ -13,8 +13,8 @@ def test_generate(llama_7b_url, hf_headers):
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
- assert len(response.details.prefill) == 2
- assert response.details.prefill[0] == InputToken(id=1, text="", logprob=None)
+ assert len(response.details.prefill) == 0
+ # assert response.details.prefill[0] == InputToken(id=1, text="", logprob=None)
assert len(response.details.tokens) == 1
assert response.details.tokens[0].id == 29918
assert response.details.tokens[0].text == "_"
@@ -83,11 +83,11 @@ async def test_generate_async(llama_7b_url, hf_headers):
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
- assert len(response.details.prefill) == 2
- assert response.details.prefill[0] == InputToken(id=1, text="", logprob=None)
- assert response.details.prefill[1] == InputToken(
- id=1243, text="test", logprob=-10.96875
- )
+ assert len(response.details.prefill) == 0
+ # assert response.details.prefill[0] == InputToken(id=1, text="", logprob=None)
+ # assert response.details.prefill[1] == InputToken(
+ # id=1243, text="test", logprob=-10.96875
+ # )
assert len(response.details.tokens) == 1
assert response.details.tokens[0].id == 29918
assert response.details.tokens[0].text == "_"
diff --git a/clients/python/tests/test_inference_api.py b/clients/python/tests/test_inference_api.py
index 59297c26e..5a2584059 100644
--- a/clients/python/tests/test_inference_api.py
+++ b/clients/python/tests/test_inference_api.py
@@ -1,42 +1,42 @@
-import pytest
-
-from text_generation import (
- InferenceAPIClient,
- InferenceAPIAsyncClient,
- Client,
- AsyncClient,
-)
-from text_generation.errors import NotSupportedError, NotFoundError
-from text_generation.inference_api import check_model_support, deployed_models
-
-
-def test_check_model_support(flan_t5_xxl, unsupported_model, fake_model):
- assert check_model_support(flan_t5_xxl)
- assert not check_model_support(unsupported_model)
-
- with pytest.raises(NotFoundError):
- check_model_support(fake_model)
-
-
-def test_deployed_models():
- deployed_models()
-
-
-def test_client(flan_t5_xxl):
- client = InferenceAPIClient(flan_t5_xxl)
- assert isinstance(client, Client)
-
-
-def test_client_unsupported_model(unsupported_model):
- with pytest.raises(NotSupportedError):
- InferenceAPIClient(unsupported_model)
-
-
-def test_async_client(flan_t5_xxl):
- client = InferenceAPIAsyncClient(flan_t5_xxl)
- assert isinstance(client, AsyncClient)
-
-
-def test_async_client_unsupported_model(unsupported_model):
- with pytest.raises(NotSupportedError):
- InferenceAPIAsyncClient(unsupported_model)
+# import pytest
+#
+# from text_generation import (
+# InferenceAPIClient,
+# InferenceAPIAsyncClient,
+# Client,
+# AsyncClient,
+# )
+# from text_generation.errors import NotSupportedError, NotFoundError
+# from text_generation.inference_api import check_model_support, deployed_models
+#
+#
+# def test_check_model_support(flan_t5_xxl, unsupported_model, fake_model):
+# assert check_model_support(flan_t5_xxl)
+# assert not check_model_support(unsupported_model)
+#
+# with pytest.raises(NotFoundError):
+# check_model_support(fake_model)
+#
+#
+# def test_deployed_models():
+# deployed_models()
+#
+#
+# def test_client(flan_t5_xxl):
+# client = InferenceAPIClient(flan_t5_xxl)
+# assert isinstance(client, Client)
+#
+#
+# def test_client_unsupported_model(unsupported_model):
+# with pytest.raises(NotSupportedError):
+# InferenceAPIClient(unsupported_model)
+#
+#
+# def test_async_client(flan_t5_xxl):
+# client = InferenceAPIAsyncClient(flan_t5_xxl)
+# assert isinstance(client, AsyncClient)
+#
+#
+# def test_async_client_unsupported_model(unsupported_model):
+# with pytest.raises(NotSupportedError):
+# InferenceAPIAsyncClient(unsupported_model)
diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py
index 45301b632..0b60d93aa 100644
--- a/clients/python/text_generation/client.py
+++ b/clients/python/text_generation/client.py
@@ -867,7 +867,7 @@ class AsyncClient:
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
- async with session.post(self.base_url, json=request.dict()) as resp:
+ async with session.post(self.base_url, json=request.model_dump()) as resp:
payload = await resp.json()
if resp.status != 200:
diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py
index 1085075e4..6f51c153e 100644
--- a/clients/python/text_generation/types.py
+++ b/clients/python/text_generation/types.py
@@ -67,7 +67,7 @@ class ChoiceDeltaToolCall(BaseModel):
class ChoiceDelta(BaseModel):
role: str
content: Optional[str] = None
- tool_calls: Optional[ChoiceDeltaToolCall] = None
+ tool_calls: Optional[List[ChoiceDeltaToolCall]] = None
class Choice(BaseModel):
diff --git a/docs/openapi.json b/docs/openapi.json
index 1caf67525..ad5124798 100644
--- a/docs/openapi.json
+++ b/docs/openapi.json
@@ -10,7 +10,7 @@
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
},
- "version": "3.0.1-dev0"
+ "version": "3.2.3-dev0"
},
"paths": {
"/": {
@@ -1865,25 +1865,57 @@
}
},
"Message": {
- "type": "object",
- "required": [
- "role",
- "content"
- ],
- "properties": {
- "content": {
- "$ref": "#/components/schemas/MessageContent"
+ "allOf": [
+ {
+ "$ref": "#/components/schemas/MessageBody"
},
- "name": {
- "type": "string",
- "example": "\"David\"",
- "nullable": true
- },
- "role": {
- "type": "string",
- "example": "user"
+ {
+ "type": "object",
+ "required": [
+ "role"
+ ],
+ "properties": {
+ "name": {
+ "type": "string",
+ "example": "\"David\"",
+ "nullable": true
+ },
+ "role": {
+ "type": "string",
+ "example": "user"
+ }
+ }
}
- }
+ ]
+ },
+ "MessageBody": {
+ "oneOf": [
+ {
+ "type": "object",
+ "required": [
+ "content"
+ ],
+ "properties": {
+ "content": {
+ "$ref": "#/components/schemas/MessageContent"
+ }
+ }
+ },
+ {
+ "type": "object",
+ "required": [
+ "tool_calls"
+ ],
+ "properties": {
+ "tool_calls": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/ToolCall"
+ }
+ }
+ }
+ }
+ ]
},
"MessageChunk": {
"oneOf": [
@@ -2116,9 +2148,6 @@
},
"StreamOptions": {
"type": "object",
- "required": [
- "include_usage"
- ],
"properties": {
"include_usage": {
"type": "boolean",
@@ -2179,6 +2208,10 @@
"role": {
"type": "string",
"example": "user"
+ },
+ "tool_call_id": {
+ "type": "string",
+ "nullable": true
}
}
},
@@ -2266,7 +2299,10 @@
"example": "assistant"
},
"tool_calls": {
- "$ref": "#/components/schemas/DeltaToolCall"
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/DeltaToolCall"
+ }
}
}
},
diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index e31a37884..4c6f0151d 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -12,11 +12,15 @@
- local: installation_gaudi
title: Using TGI with Intel Gaudi
- local: installation_inferentia
- title: Using TGI with AWS Inferentia
+ title: Using TGI with AWS Trainium and Inferentia
+ - local: installation_tpu
+ title: Using TGI with Google TPUs
- local: installation_intel
title: Using TGI with Intel GPUs
- local: installation
title: Installation from source
+ - local: multi_backend_support
+ title: Multi-backend support
- local: architecture
title: Internal Architecture
@@ -45,6 +49,16 @@
- local: basic_tutorials/train_medusa
title: Train Medusa
title: Tutorials
+- sections:
+ - local: backends/neuron
+ title: Neuron
+ - local: backends/gaudi
+ title: Gaudi
+ - local: backends/trtllm
+ title: TensorRT-LLM
+ - local: backends/llamacpp
+ title: Llamacpp
+ title: Backends
- sections:
- local: reference/launcher
title: All TGI CLI options
diff --git a/docs/source/architecture.md b/docs/source/architecture.md
index 6660630d0..b475bb6dc 100644
--- a/docs/source/architecture.md
+++ b/docs/source/architecture.md
@@ -9,8 +9,10 @@ A high-level architecture diagram can be seen here:
This diagram shows well there are these separate components:
- **The router**, also named `webserver`, that receives the client requests, buffers them, creates some batches, and prepares gRPC calls to a model server.
-- **The model server**, responsible of receiving the gRPC requests and to process the inference on the model. If the model is sharded across multiple accelerators (e.g.: multiple GPUs), the model server shards might be synchronized via NCCL or equivalent.
- **The launcher** is a helper that will be able to launch one or several model servers (if model is sharded), and it launches the router with the compatible arguments.
+- **The model server**, responsible for receiving the gRPC requests and to process the inference on the model. If the model is sharded across multiple accelerators (e.g.: multiple GPUs), the model server shards might be synchronized via NCCL or equivalent.
+
+Note that for other backends (eg. TRTLLM) the model server and launcher are specific to the backend.
The router and the model server can be two different machines, they do not need to be deployed together.
@@ -105,7 +107,7 @@ Several variants of the model server exist that are actively supported by Huggin
- A [version optimized for AMD with ROCm](https://huggingface.co/docs/text-generation-inference/installation_amd) is hosted in the main TGI repository. Some model features differ.
- A [version optimized for Intel GPUs](https://huggingface.co/docs/text-generation-inference/installation_intel) is hosted in the main TGI repository. Some model features differ.
- The [version for Intel Gaudi](https://huggingface.co/docs/text-generation-inference/installation_gaudi) is maintained on a forked repository, often resynchronized with the main [TGI repository](https://github.com/huggingface/tgi-gaudi).
-- A [version for Neuron (AWS Inferentia2)](https://huggingface.co/docs/text-generation-inference/installation_inferentia) is maintained as part of [Optimum Neuron](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference).
+- A [version for Neuron (AWS Inferentia2)](https://huggingface.co/docs/text-generation-inference/installation_inferentia) is maintained in the main TGI repository. Some model features differ.
- A version for Google TPUs is maintained as part of [Optimum TPU](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference).
Not all variants provide the same features, as hardware and middleware capabilities do not provide the same optimizations.
diff --git a/docs/source/backends/gaudi.mdx b/docs/source/backends/gaudi.mdx
new file mode 100644
index 000000000..5c54d19d1
--- /dev/null
+++ b/docs/source/backends/gaudi.mdx
@@ -0,0 +1,334 @@
+# Gaudi Backend for Text Generation Inference
+
+## Overview
+Text Generation Inference (TGI) has been optimized to run on Gaudi hardware via the Gaudi backend for TGI.
+
+## Supported Hardware
+- **Gaudi1**: Available on [AWS EC2 DL1 instances](https://aws.amazon.com/ec2/instance-types/dl1/)
+- **Gaudi2**: Available on [Intel Cloud](https://console.cloud.intel.com/docs/reference/ai_instances.html)
+- **Gaudi3**: Available on [Intel Cloud](https://console.cloud.intel.com/docs/reference/ai_instances.html)
+
+## Tutorial: Getting Started with TGI on Gaudi
+
+### Basic Usage
+The easiest way to run TGI on Gaudi is to use the official Docker image:
+
+```bash
+model=meta-llama/Meta-Llama-3.1-8B-Instruct
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+hf_token=YOUR_HF_ACCESS_TOKEN
+
+docker run --runtime=habana --cap-add=sys_nice --ipc=host \
+ -p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
+ ghcr.io/huggingface/text-generation-inference:3.2.3-gaudi \
+ --model-id $model
+```
+
+Once you see the `connected` log, the server is ready to accept requests:
+> 2024-05-22T19:31:48.302239Z INFO text_generation_router: router/src/main.rs:378: Connected
+
+You can find your `YOUR_HF_ACCESS_TOKEN` at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens). This is necessary to access gated models like llama3.1.
+
+### Making Your First Request
+You can send a request from a separate terminal:
+
+```bash
+curl 127.0.0.1:8080/generate \
+ -X POST \
+ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":32}}' \
+ -H 'Content-Type: application/json'
+```
+
+## How-to Guides
+
+You can view the full list of supported models in the [Supported Models](https://huggingface.co/docs/text-generation-inference/backends/gaudi#supported-models) section.
+
+For example, to run Llama3.1-8B, you can use the following command:
+
+```bash
+model=meta-llama/Meta-Llama-3.1-8B-Instruct
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+hf_token=YOUR_ACCESS_TOKEN
+
+docker run --runtime=habana --cap-add=sys_nice --ipc=host \
+ -p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
+ ghcr.io/huggingface/text-generation-inference:3.2.3-gaudi \
+ --model-id $model
+
+```
+
+For the full list of service parameters, refer to the [launcher-arguments page](https://huggingface.co/docs/text-generation-inference/reference/launcher).
+
+The validated docker commands can be found in the [examples/docker_commands folder](https://github.com/huggingface/text-generation-inference/tree/main/backends/gaudi/examples/docker_commands).
+
+> Note: `--runtime=habana --cap-add=sys_nice --ipc=host ` is required to enable docker to use the Gaudi hardware (more details [here](https://docs.habana.ai/en/latest/Installation_Guide/Additional_Installation/Docker_Installation.html)).
+
+### How to Enable Multi-Card Inference (Sharding)
+
+TGI-Gaudi supports sharding for multi-card inference, allowing you to distribute the load across multiple Gaudi cards. This is recommended to run large models and to speed up inference.
+
+For example, on a machine with 8 Gaudi cards, you can run:
+
+```bash
+docker run --runtime=habana --ipc=host --cap-add=sys_nice \
+ -p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
+ tgi-gaudi \
+ --model-id $model --sharded true --num-shard 8
+```
+
+
+We recommend always using sharding when running on a multi-card machine.
+
+
+### How to Use Different Precision Formats
+
+#### BF16 Precision (Default)
+By default, all models run with BF16 precision on Gaudi hardware.
+
+#### FP8 Precision
+TGI-Gaudi supports FP8 precision inference with [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html).
+
+To run FP8 Inference:
+
+1. Measure statistics using [Optimum Habana measurement script](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation#running-with-fp8)
+2. Run the model in TGI with QUANT_CONFIG setting - e.g. `-e QUANT_CONFIG=./quantization_config/maxabs_quant.json`.
+
+The following commmand example for FP8 inference is based on the assumption that measurement is done via the first step above.
+
+Example for Llama3.1-70B on 8 cards with FP8 precision:
+
+```bash
+model=meta-llama/Meta-Llama-3.1-70B-Instruct
+hf_token=YOUR_ACCESS_TOKEN
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+
+docker run -p 8080:80 \
+ --runtime=habana \
+ --cap-add=sys_nice \
+ --ipc=host \
+ -v $volume:/data \
+ -v $PWD/quantization_config:/usr/src/quantization_config \
+ -v $PWD/hqt_output:/usr/src/hqt_output \
+ -e QUANT_CONFIG=./quantization_config/maxabs_quant.json \
+ -e HF_TOKEN=$hf_token \
+ -e MAX_TOTAL_TOKENS=2048 \
+ -e BATCH_BUCKET_SIZE=256 \
+ -e PREFILL_BATCH_BUCKET_SIZE=4 \
+ -e PAD_SEQUENCE_TO_MULTIPLE_OF=64 \
+ ghcr.io/huggingface/text-generation-inference:3.2.3-gaudi \
+ --model-id $model \
+ --sharded true --num-shard 8 \
+ --max-input-tokens 1024 --max-total-tokens 2048 \
+ --max-batch-prefill-tokens 4096 --max-batch-size 256 \
+ --max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
+```
+
+### How to Run Vision-Language Models (VLMs)
+
+Gaudi supports VLM inference.
+
+Example for Llava-v1.6-Mistral-7B on 1 card:
+
+Start the TGI server via the following command:
+```bash
+model=llava-hf/llava-v1.6-mistral-7b-hf
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+
+docker run -p 8080:80 \
+ --runtime=habana \
+ --cap-add=sys_nice \
+ --ipc=host \
+ -v $volume:/data \
+ -e PREFILL_BATCH_BUCKET_SIZE=1 \
+ -e BATCH_BUCKET_SIZE=1 \
+ ghcr.io/huggingface/text-generation-inference:3.2.3-gaudi \
+ --model-id $model \
+ --max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
+ --max-total-tokens 8192 --max-batch-size 4
+```
+
+You can then send a request to the server via the following command:
+```bash
+curl -N 127.0.0.1:8080/generate \
+ -X POST \
+ -d '{"inputs":"What is this a picture of?\n\n","parameters":{"max_new_tokens":32}}' \
+ -H 'Content-Type: application/json'
+```
+
+> Note: In Llava-v1.6-Mistral-7B, an image usually accounts for 2000 input tokens. For example, an image of size 512x512 is represented by 2800 tokens. Thus, `max-input-tokens` must be larger than the number of tokens associated with the image. Otherwise the image may be truncated. We set `BASE_IMAGE_TOKENS=2048` as the default image token value. This is the minimum value of `max-input-tokens`. You can override the environment variable `BASE_IMAGE_TOKENS` to change this value. The warmup will generate graphs with input length from `BASE_IMAGE_TOKENS` to `max-input-tokens`. For Llava-v1.6-Mistral-7B, the value of `max-batch-prefill-tokens` is 16384, which is calcualted as follows: `prefill_batch_size` = `max-batch-prefill-tokens` / `max-input-tokens`.
+
+### How to Benchmark Performance
+
+We recommend using the [inference-benchmarker tool](https://github.com/huggingface/inference-benchmarker) to benchmark performance on Gaudi hardware.
+
+This benchmark tool simulates user requests and measures the performance of the model on realistic scenarios.
+
+To run it on the same machine, you can do the following:
+```bash
+MODEL=meta-llama/Llama-3.1-8B-Instruct
+HF_TOKEN=
+# run a benchmark to evaluate the performance of the model for chat use case
+# we mount results to the current directory
+docker run \
+ --rm \
+ -it \
+ --net host \
+ -v $(pwd):/opt/inference-benchmarker/results \
+ -e "HF_TOKEN=$HF_TOKEN" \
+ ghcr.io/huggingface/inference-benchmarker:latest \
+ inference-benchmarker \
+ --tokenizer-name "$MODEL" \
+ --url http://localhost:8080 \
+ --profile chat
+```
+
+Please refer to the [inference-benchmarker README](https://github.com/huggingface/inference-benchmarker) for more details.
+
+### How to Profile Performance
+
+To collect performance profiling, you need to set the following environment variables:
+
+| Name | Value(s) | Default | Description |
+|--------------------| :--------- | :--------------- | :------------------------------------------------------- |
+| PROF_WAITSTEP | integer | 0 | Control profile wait steps |
+| PROF_WARMUPSTEP | integer | 0 | Control profile warmup steps |
+| PROF_STEP | integer | 0 | Enable/disable profile, control profile active steps |
+| PROF_PATH | string | /tmp/hpu_profile | Define profile folder |
+| PROF_RANKS | string | 0 | Comma-separated list of ranks to profile |
+| PROF_RECORD_SHAPES | True/False | False | Control record_shapes option in the profiler |
+
+To use these environment variables, add them to your docker run command with the -e flag. For example:
+
+```bash
+docker run --runtime=habana --ipc=host --cap-add=sys_nice \
+ -p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
+ -e PROF_WAITSTEP=10 \
+ -e PROF_WARMUPSTEP=10 \
+ -e PROF_STEP=1 \
+ -e PROF_PATH=/tmp/hpu_profile \
+ -e PROF_RANKS=0 \
+ -e PROF_RECORD_SHAPES=True \
+ ghcr.io/huggingface/text-generation-inference:3.2.3-gaudi \
+ --model-id $model
+```
+
+## Explanation: Understanding TGI on Gaudi
+
+### The Warmup Process
+
+To ensure optimal performance, warmup is performed at the beginning of each server run. This process creates queries with various input shapes based on provided parameters and runs basic TGI operations (prefill, decode, concatenate).
+
+Note: Model warmup can take several minutes, especially for FP8 inference. For faster subsequent runs, refer to [Disk Caching Eviction Policy](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#disk-caching-eviction-policy).
+
+### Understanding Parameter Tuning
+
+#### Sequence Length Parameters
+- `--max-input-tokens` is the maximum possible input prompt length. Default value is `4095`.
+- `--max-total-tokens` is the maximum possible total length of the sequence (input and output). Default value is `4096`.
+
+#### Batch Size Parameters
+- For prefill operation, please set `--max-batch-prefill-tokens` as `bs * max-input-tokens`, where `bs` is your expected maximum prefill batch size.
+- For decode operation, please set `--max-batch-size` as `bs`, where `bs` is your expected maximum decode batch size.
+- Please note that batch size will be always padded to the nearest multiplication of `BATCH_BUCKET_SIZE` and `PREFILL_BATCH_BUCKET_SIZE`.
+
+#### Performance and Memory Parameters
+- `PAD_SEQUENCE_TO_MULTIPLE_OF` determines sizes of input length buckets. Since warmup creates several graphs for each bucket, it's important to adjust that value proportionally to input sequence length. Otherwise, some out of memory issues can be observed.
+- `ENABLE_HPU_GRAPH` enables HPU graphs usage, which is crucial for performance results. Recommended value to keep is `true`.
+
+#### Sequence Length Parameters
+- `--max-input-tokens`: Maximum possible input prompt length (default: 4095)
+- `--max-total-tokens`: Maximum possible total sequence length (input + output) (default: 4096)
+
+#### Batch Size Parameters
+- `--max-batch-prefill-tokens`: Set as `bs * max-input-tokens` where `bs` is your expected maximum prefill batch size
+- `--max-batch-size`: Set as `bs` where `bs` is your expected maximum decode batch size
+- Note: Batch sizes are padded to the nearest multiple of `BATCH_BUCKET_SIZE` and `PREFILL_BATCH_BUCKET_SIZE`
+
+## Reference
+
+This section contains reference information about the Gaudi backend.
+
+### Supported Models
+
+Text Generation Inference enables serving optimized models on Gaudi hardware. The following sections list which models (VLMs & LLMs) are supported on Gaudi.
+
+**Large Language Models (LLMs)**
+- [Llama2-7B](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
+- [Llama2-70B](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf)
+- [Llama3-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)
+- [Llama3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct)
+- [LLama3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)
+- [LLama3.1-70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct)
+- [CodeLlama-13B](https://huggingface.co/codellama/CodeLlama-13b-hf)
+- [Opt-125m](https://huggingface.co/facebook/opt-125m)
+- [OpenAI-gpt2](https://huggingface.co/openai-community/gpt2)
+- [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
+- [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3)
+- [Qwen2-72B](https://huggingface.co/Qwen/Qwen2-72B-Instruct)
+- [Qwen2-7B](https://huggingface.co/Qwen/Qwen2-7B-Instruct)
+- [Phi-1.5](https://huggingface.co/microsoft/phi-1_5)
+- [Gemma-7b](https://huggingface.co/google/gemma-7b-it)
+- [Starcoder2-3b](https://huggingface.co/bigcode/starcoder2-3b)
+- [Starcoder2-15b](https://huggingface.co/bigcode/starcoder2-15b)
+- [Starcoder](https://huggingface.co/bigcode/starcoder)
+- [falcon-7b-instruct](https://huggingface.co/tiiuae/falcon-7b-instruct)
+- [Falcon-180B](https://huggingface.co/tiiuae/falcon-180B-chat)
+- [GPT-2](https://huggingface.co/openai-community/gpt2)
+- [gpt-j-6b](https://huggingface.co/EleutherAI/gpt-j-6b)
+
+**Vision-Language Models (VLMs)**
+- [LLaVA-v1.6-Mistral-7B](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf)
+- [Mllama (Multimodal Llama from Meta)](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)
+- [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b)
+- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b)
+- [Idefics 2.5](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3)
+- [Qwen2-VL-2B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct)
+- [Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)
+
+We also support on a best effort basis models with different parameters count that use the same model architecture but those models were not tested. For example, the gaudi backend supports `meta-llama/Llama-3.2-1B` as the architecture is the standard llama3 architecture. If you have an issue with a model, please open an issue on the [Gaudi backend repository](https://github.com/huggingface/text-generation-inference/issues).
+
+### Environment Variables
+
+The following table contains the environment variables that can be used to configure the Gaudi backend:
+
+| Name | Value(s) | Default | Description | Usage |
+|-----------------------------| :--------- | :--------------- | :------------------------------------------------------------------------------------------------------------------------------- | :--------------------------- |
+| ENABLE_HPU_GRAPH | True/False | True | Enable hpu graph or not | add -e in docker run command |
+| LIMIT_HPU_GRAPH | True/False | True | Skip HPU graph usage for prefill to save memory, set to `True` for large sequence/decoding lengths(e.g. 300/212) | add -e in docker run command |
+| BATCH_BUCKET_SIZE | integer | 8 | Batch size for decode operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
+| PREFILL_BATCH_BUCKET_SIZE | integer | 4 | Batch size for prefill operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
+| PAD_SEQUENCE_TO_MULTIPLE_OF | integer | 128 | For prefill operation, sequences will be padded to a multiple of provided value. | add -e in docker run command |
+| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command |
+| WARMUP_ENABLED | True/False | True | Enable warmup during server initialization to recompile all graphs. This can increase TGI setup time. | add -e in docker run command |
+| QUEUE_THRESHOLD_MS | integer | 120 | Controls the threshold beyond which the request are considered overdue and handled with priority. Shorter requests are prioritized otherwise. | add -e in docker run command |
+| USE_FLASH_ATTENTION | True/False | True | Whether to enable Habana Flash Attention, provided that the model supports it. Please refer to https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html?highlight=fusedsdpa#using-fused-scaled-dot-product-attention-fusedsdpa | add -e in docker run command |
+| FLASH_ATTENTION_RECOMPUTE | True/False | True | Whether to enable Habana Flash Attention in recompute mode on first token generation. | add -e in docker run command |
+
+## Contributing
+
+Contributions to the TGI-Gaudi project are welcome. Please refer to the [contributing guide](https://github.com/huggingface/text-generation-inference/blob/main/CONTRIBUTING.md).
+
+**Guidelines for contributing to Gaudi on TGI:** All changes should be made within the `backends/gaudi` folder. In general, you should avoid modifying the router, launcher, or benchmark to accommodate Gaudi hardware, as all Gaudi-specific logic should be contained within the `backends/gaudi` folder.
+
+### Building the Docker Image from Source
+
+To build the Docker image from source:
+
+```bash
+make -C backends/gaudi image
+```
+
+This builds the image and saves it as `tgi-gaudi`. You can then run TGI-Gaudi with this image:
+
+```bash
+model=meta-llama/Meta-Llama-3.1-8B-Instruct
+volume=$PWD/data
+hf_token=YOUR_ACCESS_TOKEN
+
+docker run --runtime=habana --ipc=host --cap-add=sys_nice \
+ -p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \
+ tgi-gaudi \
+ --model-id $model
+```
+
+For more details, see the [README of the Gaudi backend](https://github.com/huggingface/text-generation-inference/blob/main/backends/gaudi/README.md) and the [Makefile of the Gaudi backend](https://github.com/huggingface/text-generation-inference/blob/main/backends/gaudi/Makefile).
diff --git a/docs/source/backends/llamacpp.md b/docs/source/backends/llamacpp.md
new file mode 100644
index 000000000..5cf0edf0c
--- /dev/null
+++ b/docs/source/backends/llamacpp.md
@@ -0,0 +1,144 @@
+# Llamacpp Backend
+
+The llamacpp backend facilitates the deployment of large language models
+(LLMs) by integrating [llama.cpp][llama.cpp], an advanced inference engine
+optimized for both CPU and GPU computation. This backend is a component
+of Hugging Face’s **Text Generation Inference (TGI)** suite,
+specifically designed to streamline the deployment of LLMs in production
+environments.
+
+## Key Capabilities
+
+- Full compatibility with GGUF format and all quantization formats
+ (GGUF-related constraints may be mitigated dynamically by on-the-fly
+ generation in future updates)
+- Optimized inference on CPU and GPU architectures
+- Containerized deployment, eliminating dependency complexity
+- Seamless interoperability with the Hugging Face ecosystem
+
+## Model Compatibility
+
+This backend leverages models formatted in **GGUF**, providing an
+optimized balance between computational efficiency and model accuracy.
+You will find the best models on [Hugging Face][GGUF].
+
+## Build Docker image
+
+For optimal performance, the Docker image is compiled with native CPU
+instructions by default. As a result, it is strongly recommended to run
+the container on the same host architecture used during the build
+process. Efforts are ongoing to improve portability across different
+systems while preserving high computational efficiency.
+
+To build the Docker image, use the following command:
+
+```bash
+docker build \
+ -t tgi-llamacpp \
+ https://github.com/huggingface/text-generation-inference.git \
+ -f Dockerfile_llamacpp
+```
+
+### Build parameters
+
+| Parameter (with --build-arg) | Description |
+| ----------------------------------------- | -------------------------------- |
+| `llamacpp_version=bXXXX` | Specific version of llama.cpp |
+| `llamacpp_cuda=ON` | Enables CUDA acceleration |
+| `llamacpp_native=OFF` | Disable automatic CPU detection |
+| `llamacpp_cpu_arm_arch=ARCH[+FEATURE]...` | Specific ARM CPU and features |
+| `cuda_arch=ARCH` | Defines target CUDA architecture |
+
+For example, to target Graviton4 when building on another ARM
+architecture:
+
+```bash
+docker build \
+ -t tgi-llamacpp \
+ --build-arg llamacpp_native=OFF \
+ --build-arg llamacpp_cpu_arm_arch=armv9-a+i8mm \
+ https://github.com/huggingface/text-generation-inference.git \
+ -f Dockerfile_llamacpp
+```
+
+## Run Docker image
+
+### CPU-based inference
+
+```bash
+docker run \
+ -p 3000:3000 \
+ -e "HF_TOKEN=$HF_TOKEN" \
+ -v "$HOME/models:/app/models" \
+ tgi-llamacpp \
+ --model-id "Qwen/Qwen2.5-3B-Instruct"
+```
+
+### GPU-Accelerated inference
+
+```bash
+docker run \
+ --gpus all \
+ -p 3000:3000 \
+ -e "HF_TOKEN=$HF_TOKEN" \
+ -v "$HOME/models:/app/models" \
+ tgi-llamacpp \
+ --n-gpu-layers 99
+ --model-id "Qwen/Qwen2.5-3B-Instruct"
+```
+
+## Using a custom GGUF
+
+GGUF files are optional as they will be automatically generated at
+startup if not already present in the `models` directory. However, if
+the default GGUF generation is not suitable for your use case, you can
+provide your own GGUF file with `--model-gguf`, for example:
+
+```bash
+docker run \
+ -p 3000:3000 \
+ -e "HF_TOKEN=$HF_TOKEN" \
+ -v "$HOME/models:/app/models" \
+ tgi-llamacpp \
+ --model-id "Qwen/Qwen2.5-3B-Instruct" \
+ --model-gguf "models/qwen2.5-3b-instruct-q4_0.gguf"
+```
+
+Note that `--model-id` is still required.
+
+## Advanced parameters
+
+A full listing of configurable parameters is available in the `--help`:
+
+```bash
+docker run tgi-llamacpp --help
+
+```
+
+The table below summarizes key options:
+
+| Parameter | Description |
+|-------------------------------------|------------------------------------------------------------------------|
+| `--n-threads` | Number of threads to use for generation |
+| `--n-threads-batch` | Number of threads to use for batch processing |
+| `--n-gpu-layers` | Number of layers to store in VRAM |
+| `--split-mode` | Split the model across multiple GPUs |
+| `--defrag-threshold` | Defragment the KV cache if holes/size > threshold |
+| `--numa` | Enable NUMA optimizations |
+| `--disable-mmap` | Disable memory mapping for the model |
+| `--use-mlock` | Use memory locking to prevent swapping |
+| `--disable-offload-kqv` | Disable offloading of KQV operations to the GPU |
+| `--disable-flash-attention` | Disable flash attention |
+| `--type-k` | Data type used for K cache |
+| `--type-v` | Data type used for V cache |
+| `--validation-workers` | Number of tokenizer workers used for payload validation and truncation |
+| `--max-concurrent-requests` | Maximum number of concurrent requests |
+| `--max-input-tokens` | Maximum number of input tokens per request |
+| `--max-total-tokens` | Maximum number of total tokens (input + output) per request |
+| `--max-batch-total-tokens` | Maximum number of tokens in a batch |
+| `--max-physical-batch-total-tokens` | Maximum number of tokens in a physical batch |
+| `--max-batch-size` | Maximum number of requests per batch |
+
+---
+[llama.cpp]: https://github.com/ggerganov/llama.cpp
+[GGUF]: https://huggingface.co/models?library=gguf&sort=trending
diff --git a/docs/source/backends/neuron.md b/docs/source/backends/neuron.md
new file mode 100644
index 000000000..c8e3876e1
--- /dev/null
+++ b/docs/source/backends/neuron.md
@@ -0,0 +1,179 @@
+# Neuron backend for AWS Trainium and Inferentia
+
+The Neuron backend allows the deployment of TGI on AWS Trainium and Inferentia family of chips.
+
+The following hardware targets are supported:
+- Trainium 1,
+- Inferentia 2.
+
+## Features
+
+The basic TGI features are supported:
+
+- continuous batching,
+- token streaming,
+- greedy search and multinomial sampling using [transformers](https://huggingface.co/docs/transformers/generation_strategies#customize-text-generation).
+
+
+## Deploy the service from the Hugging Face hub
+
+The simplest way to deploy the NeuronX TGI service for a specific model is to follow the
+deployment instructions in the model card:
+
+- click on the "Deploy" button on the right,
+- select your deployment service ("Inference Endpoints" and "SageMaker" are supported),
+- select "AWS Trainum & Inferentia",
+- follow the instructions.
+
+
+## Deploy the service on a dedicated host
+
+The service is launched simply by running the text-generation-inference container with two sets of parameters:
+
+```
+docker run