Merge branch 'huggingface:main' into main

This commit is contained in:
icyboy™ 2024-07-22 15:18:57 +08:00 committed by GitHub
commit 6111e9ecd5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
100 changed files with 8624 additions and 1447 deletions

View File

@ -30,6 +30,10 @@ jobs:
id: install-router id: install-router
run: cargo install --path router/ run: cargo install --path router/
- uses: actions/setup-node@v4
with:
node-version: 22
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v2 uses: actions/setup-python@v2
with: with:
@ -37,4 +41,5 @@ jobs:
- name: Check that documentation is up-to-date - name: Check that documentation is up-to-date
run: | run: |
npm install -g swagger-cli
python update_doc.py --check python update_doc.py --check

View File

@ -27,8 +27,8 @@ jobs:
concurrency: concurrency:
group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }} group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true cancel-in-progress: true
# TODO see with @Glegendre to get CPU runner here instead runs-on:
runs-on: [self-hosted, nvidia-gpu , multi-gpu, 4-a10, ci] group: aws-r7i-8xlarge-priv
permissions: permissions:
contents: write contents: write
packages: write packages: write
@ -49,7 +49,7 @@ jobs:
export dockerfile="Dockerfile" export dockerfile="Dockerfile"
export label_extension="" export label_extension=""
export docker_devices="" export docker_devices=""
export runs_on="nvidia-gpu" export runs_on="aws-g5-12xlarge"
;; ;;
rocm) rocm)
export dockerfile="Dockerfile_amd" export dockerfile="Dockerfile_amd"
@ -79,9 +79,15 @@ jobs:
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
with: with:
install: true install: true
config-inline: | buildkitd-config-inline: |
[registry."docker.io"] [registry."docker.io"]
mirrors = ["registry.github-runners.huggingface.tech"] mirrors = ["registry-us-east-1-mirror.prod.aws.ci.huggingface.tech"]
- name: Login to internal Container Registry
uses: docker/login-action@v3
with:
username: ${{ secrets.REGISTRY_USERNAME }}
password: ${{ secrets.REGISTRY_PASSWORD }}
registry: registry.internal.huggingface.tech
- name: Login to GitHub Container Registry - name: Login to GitHub Container Registry
if: github.event_name != 'pull_request' if: github.event_name != 'pull_request'
uses: docker/login-action@v3 uses: docker/login-action@v3
@ -103,7 +109,8 @@ jobs:
uses: docker/metadata-action@v5 uses: docker/metadata-action@v5
with: with:
images: | images: |
registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference registry-us-east-1.prod.aws.ci.huggingface.tech/api-inference/community/text-generation-inference
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
tags: | tags: |
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
# If main, release or tag # If main, release or tag
@ -115,7 +122,8 @@ jobs:
flavor: | flavor: |
latest=auto latest=auto
images: | images: |
registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference registry-us-east-1.prod.aws.ci.huggingface.tech/api-inference/community/text-generation-inference
registry.internal.huggingface.tech/api-inference/community/text-generation-inferenceca
ghcr.io/huggingface/text-generation-inference ghcr.io/huggingface/text-generation-inference
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
tags: | tags: |
@ -136,12 +144,12 @@ jobs:
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }} tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }} labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
cache-from: type=registry,ref=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
cache-to: type=registry,ref=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
- name: Final - name: Final
id: final id: final
run: | run: |
echo "docker_image=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT" echo "docker_image=registry-us-east-1.prod.aws.ci.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT" echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT"
echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT" echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT"
echo "label=${{ env.LABEL }}" >> "$GITHUB_OUTPUT" echo "label=${{ env.LABEL }}" >> "$GITHUB_OUTPUT"
@ -150,7 +158,8 @@ jobs:
group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label }}-${{ github.head_ref || github.run_id }} group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true cancel-in-progress: true
needs: build-and-push needs: build-and-push
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"] runs-on:
group: ${{ needs.build-and-push.outputs.runs_on }}
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest' if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
env: env:
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '' }} PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '' }}

View File

@ -15,7 +15,8 @@ jobs:
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true cancel-in-progress: true
runs-on: [self-hosted, nvidia-gpu , multi-gpu, 4-a10, ci] runs-on:
group: aws-g5-12xlarge
env: env:
DOCKER_VOLUME: /cache DOCKER_VOLUME: /cache
steps: steps:

89
Cargo.lock generated
View File

@ -801,6 +801,27 @@ dependencies = [
"typenum", "typenum",
] ]
[[package]]
name = "csv"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe"
dependencies = [
"csv-core",
"itoa",
"ryu",
"serde",
]
[[package]]
name = "csv-core"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70"
dependencies = [
"memchr",
]
[[package]] [[package]]
name = "ctrlc" name = "ctrlc"
version = "3.4.4" version = "3.4.4"
@ -1935,17 +1956,6 @@ version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
[[package]]
name = "metrics"
version = "0.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fde3af1a009ed76a778cb84fdef9e7dbbdf5775ae3e4cc1f434a6a307f6f76c5"
dependencies = [
"ahash",
"metrics-macros",
"portable-atomic",
]
[[package]] [[package]]
name = "metrics" name = "metrics"
version = "0.23.0" version = "0.23.0"
@ -1969,7 +1979,7 @@ dependencies = [
"hyper-util", "hyper-util",
"indexmap 2.2.6", "indexmap 2.2.6",
"ipnet", "ipnet",
"metrics 0.23.0", "metrics",
"metrics-util", "metrics-util",
"quanta", "quanta",
"thiserror", "thiserror",
@ -1977,17 +1987,6 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "metrics-macros"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38b4faf00617defe497754acde3024865bc143d44a86799b24e191ecff91354f"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.68",
]
[[package]] [[package]]
name = "metrics-util" name = "metrics-util"
version = "0.17.0" version = "0.17.0"
@ -1997,7 +1996,7 @@ dependencies = [
"crossbeam-epoch", "crossbeam-epoch",
"crossbeam-utils", "crossbeam-utils",
"hashbrown 0.14.5", "hashbrown 0.14.5",
"metrics 0.23.0", "metrics",
"num_cpus", "num_cpus",
"quanta", "quanta",
"sketches-ddsketch", "sketches-ddsketch",
@ -3424,9 +3423,9 @@ dependencies = [
[[package]] [[package]]
name = "serde_json" name = "serde_json"
version = "1.0.118" version = "1.0.120"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d947f6b3163d8857ea16c4fa0dd4840d52f3041039a85decd46867eb1abef2e4" checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5"
dependencies = [ dependencies = [
"itoa", "itoa",
"ryu", "ryu",
@ -3672,15 +3671,16 @@ checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394"
[[package]] [[package]]
name = "sysinfo" name = "sysinfo"
version = "0.30.12" version = "0.30.13"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "732ffa00f53e6b2af46208fba5718d9662a421049204e156328b66791ffa15ae" checksum = "0a5b4ddaee55fb2bea2bf0e5000747e5f5c0de765e5a5ff87f4cd106439f4bb3"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"core-foundation-sys", "core-foundation-sys",
"libc", "libc",
"ntapi", "ntapi",
"once_cell", "once_cell",
"rayon",
"windows", "windows",
] ]
@ -3762,7 +3762,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-benchmark" name = "text-generation-benchmark"
version = "2.1.1-dev0" version = "2.1.2-dev0"
dependencies = [ dependencies = [
"average", "average",
"clap", "clap",
@ -3783,7 +3783,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-client" name = "text-generation-client"
version = "2.1.1-dev0" version = "2.1.2-dev0"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"base64 0.22.1", "base64 0.22.1",
@ -3801,7 +3801,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "2.1.1-dev0" version = "2.1.2-dev0"
dependencies = [ dependencies = [
"clap", "clap",
"ctrlc", "ctrlc",
@ -3820,13 +3820,14 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router" name = "text-generation-router"
version = "2.1.1-dev0" version = "2.1.2-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"axum 0.7.5", "axum 0.7.5",
"axum-tracing-opentelemetry", "axum-tracing-opentelemetry",
"base64 0.22.1", "base64 0.22.1",
"clap", "clap",
"csv",
"futures", "futures",
"futures-util", "futures-util",
"hf-hub", "hf-hub",
@ -3834,7 +3835,7 @@ dependencies = [
"init-tracing-opentelemetry", "init-tracing-opentelemetry",
"itertools 0.10.5", "itertools 0.10.5",
"jsonschema", "jsonschema",
"metrics 0.21.1", "metrics",
"metrics-exporter-prometheus", "metrics-exporter-prometheus",
"minijinja", "minijinja",
"minijinja-contrib", "minijinja-contrib",
@ -3848,6 +3849,7 @@ dependencies = [
"reqwest", "reqwest",
"serde", "serde",
"serde_json", "serde_json",
"sysinfo",
"text-generation-client", "text-generation-client",
"thiserror", "thiserror",
"tokenizers", "tokenizers",
@ -3859,6 +3861,7 @@ dependencies = [
"tracing-subscriber", "tracing-subscriber",
"utoipa", "utoipa",
"utoipa-swagger-ui", "utoipa-swagger-ui",
"uuid",
"vergen", "vergen",
] ]
@ -4530,9 +4533,25 @@ dependencies = [
[[package]] [[package]]
name = "uuid" name = "uuid"
version = "1.9.1" version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5de17fd2f7da591098415cff336e12965a28061ddace43b59cb3c430179c9439" checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314"
dependencies = [
"getrandom",
"rand",
"uuid-macro-internal",
]
[[package]]
name = "uuid-macro-internal"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee1cd046f83ea2c4e920d6ee9f7c3537ef928d75dce5d84a87c2c5d6b3999a3a"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.68",
]
[[package]] [[package]]
name = "v_frame" name = "v_frame"

View File

@ -40,7 +40,9 @@ RUN cargo build --profile release-opt
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS pytorch-install FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS pytorch-install
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
ARG PYTORCH_VERSION=2.3.0 ARG PYTORCH_VERSION=2.3.0
ARG PYTHON_VERSION=3.10 ARG PYTHON_VERSION=3.10
# Keep in sync with `server/pyproject.toml # Keep in sync with `server/pyproject.toml
ARG CUDA_VERSION=12.1 ARG CUDA_VERSION=12.1
@ -159,6 +161,17 @@ COPY server/custom_kernels/ .
# Build specific version of transformers # Build specific version of transformers
RUN python setup.py build RUN python setup.py build
# Build FBGEMM CUDA kernels
FROM kernel-builder AS fbgemm-builder
WORKDIR /usr/src
COPY server/Makefile-fbgemm Makefile
COPY server/fbgemm_remove_unused.patch fbgemm_remove_unused.patch
COPY server/fix_torch90a.sh fix_torch90a.sh
RUN make build-fbgemm
# Build vllm CUDA kernels # Build vllm CUDA kernels
FROM kernel-builder AS vllm-builder FROM kernel-builder AS vllm-builder
@ -223,10 +236,10 @@ COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-31
# Copy build artifacts from marlin kernels builder # Copy build artifacts from marlin kernels builder
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from fbgemm builder
# Copy builds artifacts from vllm builder COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from mamba builder # Copy build artifacts from mamba builder
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
@ -241,7 +254,10 @@ COPY server/Makefile server/Makefile
RUN cd server && \ RUN cd server && \
make gen-server && \ make gen-server && \
pip install -r requirements_cuda.txt && \ pip install -r requirements_cuda.txt && \
pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir && \
pip install nvidia-nccl-cu12==2.22.3
ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2
# Deps before the binaries # Deps before the binaries
# The binaries change on every build given we burn the SHA into them # The binaries change on every build given we burn the SHA into them

View File

@ -21,14 +21,15 @@ to power Hugging Chat, the Inference API and Inference Endpoint.
## Table of contents ## Table of contents
- [Get Started](#get-started) - [Get Started](#get-started)
- [API Documentation](#api-documentation) - [Docker](#docker)
- [API documentation](#api-documentation)
- [Using a private or gated model](#using-a-private-or-gated-model) - [Using a private or gated model](#using-a-private-or-gated-model)
- [A note on Shared Memory](#a-note-on-shared-memory-shm) - [A note on Shared Memory (shm)](#a-note-on-shared-memory-shm)
- [Distributed Tracing](#distributed-tracing) - [Distributed Tracing](#distributed-tracing)
- [Local Install](#local-install) - [Architecture](#architecture)
- [CUDA Kernels](#cuda-kernels) - [Local install](#local-install)
- [Optimized architectures](#optimized-architectures) - [Optimized architectures](#optimized-architectures)
- [Run Mistral](#run-a-model) - [Run locally](#run-locally)
- [Run](#run) - [Run](#run)
- [Quantization](#quantization) - [Quantization](#quantization)
- [Develop](#develop) - [Develop](#develop)

View File

@ -61,7 +61,7 @@ class ChoiceDeltaToolCall(BaseModel):
class ChoiceDelta(BaseModel): class ChoiceDelta(BaseModel):
role: str role: str
content: Optional[str] = None content: Optional[str] = None
tool_calls: Optional[ChoiceDeltaToolCall] tool_calls: Optional[ChoiceDeltaToolCall] = None
class Choice(BaseModel): class Choice(BaseModel):

View File

@ -492,12 +492,12 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"$ref": "#/components/schemas/Completion" "$ref": "#/components/schemas/CompletionFinal"
} }
}, },
"text/event-stream": { "text/event-stream": {
"schema": { "schema": {
"$ref": "#/components/schemas/CompletionCompleteChunk" "$ref": "#/components/schemas/Chunk"
} }
} }
} }
@ -809,7 +809,6 @@
"ChatRequest": { "ChatRequest": {
"type": "object", "type": "object",
"required": [ "required": [
"model",
"messages" "messages"
], ],
"properties": { "properties": {
@ -854,7 +853,8 @@
"model": { "model": {
"type": "string", "type": "string",
"description": "[UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", "description": "[UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.",
"example": "mistralai/Mistral-7B-Instruct-v0.2" "example": "mistralai/Mistral-7B-Instruct-v0.2",
"nullable": true
}, },
"n": { "n": {
"type": "integer", "type": "integer",
@ -909,7 +909,7 @@
"tool_choice": { "tool_choice": {
"allOf": [ "allOf": [
{ {
"$ref": "#/components/schemas/ToolType" "$ref": "#/components/schemas/ToolChoice"
} }
], ],
"nullable": true "nullable": true
@ -1116,7 +1116,6 @@
"CompletionRequest": { "CompletionRequest": {
"type": "object", "type": "object",
"required": [ "required": [
"model",
"prompt" "prompt"
], ],
"properties": { "properties": {
@ -1138,7 +1137,8 @@
"model": { "model": {
"type": "string", "type": "string",
"description": "UNUSED\nID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", "description": "UNUSED\nID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.",
"example": "mistralai/Mistral-7B-Instruct-v0.2" "example": "mistralai/Mistral-7B-Instruct-v0.2",
"nullable": true
}, },
"prompt": { "prompt": {
"$ref": "#/components/schemas/Prompt" "$ref": "#/components/schemas/Prompt"
@ -1324,6 +1324,17 @@
} }
} }
}, },
"FunctionName": {
"type": "object",
"required": [
"name"
],
"properties": {
"name": {
"type": "string"
}
}
},
"GenerateParameters": { "GenerateParameters": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -1708,6 +1719,72 @@
} }
} }
}, },
"MessageChunk": {
"oneOf": [
{
"type": "object",
"required": [
"text",
"type"
],
"properties": {
"text": {
"type": "string"
},
"type": {
"type": "string",
"enum": [
"text"
]
}
}
},
{
"type": "object",
"required": [
"image_url",
"type"
],
"properties": {
"image_url": {
"$ref": "#/components/schemas/Url"
},
"type": {
"type": "string",
"enum": [
"image_url"
]
}
}
}
],
"discriminator": {
"propertyName": "type"
}
},
"MessageContent": {
"oneOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/MessageChunk"
}
}
]
},
"OutputMessage": {
"oneOf": [
{
"$ref": "#/components/schemas/TextMessage"
},
{
"$ref": "#/components/schemas/ToolCallMessage"
}
]
},
"PrefillToken": { "PrefillToken": {
"type": "object", "type": "object",
"required": [ "required": [
@ -1834,6 +1911,23 @@
} }
} }
}, },
"TextMessage": {
"type": "object",
"required": [
"role",
"content"
],
"properties": {
"content": {
"type": "string",
"example": "My name is David and I"
},
"role": {
"type": "string",
"example": "user"
}
}
},
"Token": { "Token": {
"type": "object", "type": "object",
"required": [ "required": [
@ -1906,6 +2000,49 @@
} }
} }
}, },
"ToolCallDelta": {
"type": "object",
"required": [
"role",
"tool_calls"
],
"properties": {
"role": {
"type": "string",
"example": "assistant"
},
"tool_calls": {
"$ref": "#/components/schemas/DeltaToolCall"
}
}
},
"ToolCallMessage": {
"type": "object",
"required": [
"role",
"tool_calls"
],
"properties": {
"role": {
"type": "string",
"example": "assistant"
},
"tool_calls": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolCall"
}
}
}
},
"ToolChoice": {
"allOf": [
{
"$ref": "#/components/schemas/ToolType"
}
],
"nullable": true
},
"ToolType": { "ToolType": {
"oneOf": [ "oneOf": [
{ {
@ -1926,9 +2063,25 @@
"$ref": "#/components/schemas/FunctionName" "$ref": "#/components/schemas/FunctionName"
} }
} }
},
{
"type": "object",
"default": null,
"nullable": true
} }
] ]
}, },
"Url": {
"type": "object",
"required": [
"url"
],
"properties": {
"url": {
"type": "string"
}
}
},
"Usage": { "Usage": {
"type": "object", "type": "object",
"required": [ "required": [

View File

@ -11,6 +11,8 @@
title: Using TGI with Intel Gaudi title: Using TGI with Intel Gaudi
- local: installation_inferentia - local: installation_inferentia
title: Using TGI with AWS Inferentia title: Using TGI with AWS Inferentia
- local: installation_intel
title: Using TGI with Intel GPUs
- local: installation - local: installation
title: Installation from source title: Installation from source
- local: supported_models - local: supported_models
@ -19,6 +21,8 @@
title: Messages API title: Messages API
- local: architecture - local: architecture
title: Internal Architecture title: Internal Architecture
- local: usage_statistics
title: Usage Statistics
title: Getting started title: Getting started
- sections: - sections:
- local: basic_tutorials/consuming_tgi - local: basic_tutorials/consuming_tgi

View File

@ -103,6 +103,7 @@ Several variants of the model server exist that are actively supported by Huggin
- By default, the model server will attempt building [a server optimized for Nvidia GPUs with CUDA](https://huggingface.co/docs/text-generation-inference/installation_nvidia). The code for this version is hosted in the [main TGI repository](https://github.com/huggingface/text-generation-inference). - By default, the model server will attempt building [a server optimized for Nvidia GPUs with CUDA](https://huggingface.co/docs/text-generation-inference/installation_nvidia). The code for this version is hosted in the [main TGI repository](https://github.com/huggingface/text-generation-inference).
- A [version optimized for AMD with ROCm](https://huggingface.co/docs/text-generation-inference/installation_amd) is hosted in the main TGI repository. Some model features differ. - A [version optimized for AMD with ROCm](https://huggingface.co/docs/text-generation-inference/installation_amd) is hosted in the main TGI repository. Some model features differ.
- A [version optimized for Intel GPUs](https://huggingface.co/docs/text-generation-inference/installation_intel) is hosted in the main TGI repository. Some model features differ.
- The [version for Intel Gaudi](https://huggingface.co/docs/text-generation-inference/installation_gaudi) is maintained on a forked repository, often resynchronized with the main [TGI repository](https://github.com/huggingface/tgi-gaudi). - The [version for Intel Gaudi](https://huggingface.co/docs/text-generation-inference/installation_gaudi) is maintained on a forked repository, often resynchronized with the main [TGI repository](https://github.com/huggingface/tgi-gaudi).
- A [version for Neuron (AWS Inferentia2)](https://huggingface.co/docs/text-generation-inference/installation_inferentia) is maintained as part of [Optimum Neuron](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference). - A [version for Neuron (AWS Inferentia2)](https://huggingface.co/docs/text-generation-inference/installation_inferentia) is maintained as part of [Optimum Neuron](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference).
- A version for Google TPUs is maintained as part of [Optimum TPU](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference). - A version for Google TPUs is maintained as part of [Optimum TPU](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference).

View File

@ -424,6 +424,22 @@ Options:
[env: LORA_ADAPTERS=] [env: LORA_ADAPTERS=]
```
## DISABLE_USAGE_STATS
```shell
--disable-usage-stats
Disable sending of all usage statistics
[env: DISABLE_USAGE_STATS=]
```
## DISABLE_CRASH_REPORTS
```shell
--disable-crash-reports
Disable sending of crash reports, but allow anonymous usage statistics
[env: DISABLE_CRASH_REPORTS=]
``` ```
## HELP ## HELP
```shell ```shell

View File

@ -0,0 +1,19 @@
# Using TGI with Intel GPUs
TGI optimized models are supported on Intel Data Center GPU [Max1100](https://www.intel.com/content/www/us/en/products/sku/232876/intel-data-center-gpu-max-1100/specifications.html), [Max1550](https://www.intel.com/content/www/us/en/products/sku/232873/intel-data-center-gpu-max-1550/specifications.html), the recommended usage is through Docker.
On a server powered by Intel GPUs, TGI can be launched with the following command:
```bash
model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:latest-intel \
--model-id $model --cuda-graphs 0
```
The launched TGI server can then be queried from clients, make sure to check out the [Consuming TGI](./basic_tutorials/consuming_tgi) guide.

View File

@ -17,7 +17,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
### Supported hardware ### Supported hardware
TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on. TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Intel GPUs](./installation_intel), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on.
## Consuming TGI ## Consuming TGI

View File

@ -5,6 +5,7 @@ Text Generation Inference enables serving optimized models on specific hardware
## Supported Models ## Supported Models
- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal) - [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal) - [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) - [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)

View File

@ -0,0 +1,73 @@
# Collection of Usage Statistics
Text Generation Inference collects anonymous usage statistics to help us improve the service. The collected data is used to improve TGI and to understand what causes failures. The data is collected transparently and any sensitive information is omitted.
Data is sent twice, once on server startup and once when server stops. Also, usage statistics are only enabled when TGI is running in docker to avoid collecting data then TGI runs directly on the host machine.
## What data is collected
The code that collects the data is available [here](https://github.com/huggingface/text-generation-inference/blob/main/router/src/usage_stats.rs).
As of release 2.1.2 this is an example of the data collected:
- From the TGI configuration:
```json
{
"event_type": "start",
"disable_grammar_support": false,
"max_batch_prefill_tokens": 4096,
"max_batch_size": null,
"max_batch_total_tokens": null,
"max_best_of": 2,
"max_client_batch_size": 4,
"max_concurrent_requests": 128,
"max_input_tokens": 1024,
"max_stop_sequences": 4,
"max_top_n_tokens": 5,
"max_total_tokens": 2048,
"max_waiting_tokens": 20,
"messages_api_enabled": false,
"model_config": {
"model_type": "Bloom"
},
"revision": null,
"tokenizer_class": "BloomTokenizerFast",
"validation_workers": 2,
"waiting_served_ratio": 1.2,
"docker_label": "latest",
"git_sha": "cfc118704880453d29bcbe4fbbd91dda501cf5fe",
"nvidia_env": {
"name": "NVIDIA A10G",
"pci_bus_id": "00000000:00:1E.0",
"driver_version": "535.183.01",
"pstate": "P8",
"pcie_link_gen_max": "4",
"pcie_link_gen_current": "1",
"temperature_gpu": "31",
"utilization_gpu": "0 %",
"utilization_memory": "0 %",
"memory_total": "23028 MiB",
"memory_free": "22515 MiB",
"memory_used": "0 MiB",
"reset_status_reset_required": "No",
"reset_status_drain_and_reset_recommended": "No",
"compute_cap": "8.6",
"ecc_errors_corrected_volatile_total": "0",
"mig_mode_current": "[N/A]",
"power_draw_instant": "10.86 W",
"power_limit": "300.00 W"
},
"system_env": {
"cpu_count": 16,
"cpu_type": "AMD EPYC 7R32",
"total_memory": 66681196544,
"architecture": "x86_64",
"platform": "linux-unix-x86_64"
}
}
```
## How to opt-out
You can easily opt out by passing the `--disable-usage-stats` to the text-generation-launcher command. This will disable all usage statistics. You can also pass `--disable-crash-reports` which disables sending specific crash reports, but allows anonymous usage statistics.

View File

@ -333,6 +333,8 @@ def launcher(event_loop):
max_input_length: Optional[int] = None, max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None, max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None,
lora_adapters: Optional[List[str]] = None,
cuda_graphs: Optional[List[int]] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
master_port = random.randint(10_000, 20_000) master_port = random.randint(10_000, 20_000)
@ -379,6 +381,14 @@ def launcher(event_loop):
if max_total_tokens: if max_total_tokens:
args.append("--max-total-tokens") args.append("--max-total-tokens")
args.append(str(max_total_tokens)) args.append(str(max_total_tokens))
if lora_adapters:
args.append("--lora-adapters")
args.append(",".join(lora_adapters))
if cuda_graphs:
args.append("--cuda-graphs")
args.append(",".join(map(str, cuda_graphs)))
print(" ".join(args), file=sys.stderr)
env["LOG_LEVEL"] = "info,text_generation_router=debug" env["LOG_LEVEL"] = "info,text_generation_router=debug"
@ -418,6 +428,8 @@ def launcher(event_loop):
max_input_length: Optional[int] = None, max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None, max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None,
lora_adapters: Optional[List[str]] = None,
cuda_graphs: Optional[List[int]] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
@ -447,6 +459,12 @@ def launcher(event_loop):
if max_total_tokens: if max_total_tokens:
args.append("--max-total-tokens") args.append("--max-total-tokens")
args.append(str(max_total_tokens)) args.append(str(max_total_tokens))
if lora_adapters:
args.append("--lora-adapters")
args.append(",".join(lora_adapters))
if cuda_graphs:
args.append("--cuda-graphs")
args.append(",".join(map(str, cuda_graphs)))
client = docker.from_env() client = docker.from_env()

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 100000,
"logprob": null,
"text": "<begin▁of▁sentence>"
},
{
"id": 3533,
"logprob": -9.625,
"text": "Test"
},
{
"id": 3102,
"logprob": -11.1875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 185,
"logprob": -1.5546875,
"special": false,
"text": "\n"
},
{
"id": 549,
"logprob": -2.84375,
"special": false,
"text": "The"
},
{
"id": 1727,
"logprob": -2.34375,
"special": false,
"text": " test"
},
{
"id": 3102,
"logprob": -0.8359375,
"special": false,
"text": " request"
},
{
"id": 317,
"logprob": -1.0859375,
"special": false,
"text": " is"
},
{
"id": 254,
"logprob": -1.5390625,
"special": false,
"text": " the"
},
{
"id": 1022,
"logprob": -1.1875,
"special": false,
"text": " first"
},
{
"id": 3458,
"logprob": -0.35546875,
"special": false,
"text": " step"
},
{
"id": 279,
"logprob": -0.8828125,
"special": false,
"text": " in"
},
{
"id": 254,
"logprob": -0.71484375,
"special": false,
"text": " the"
}
],
"top_tokens": null
},
"generated_text": "\nThe test request is the first step in the"
}

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 100000,
"logprob": null,
"text": "<begin▁of▁sentence>"
},
{
"id": 3533,
"logprob": -9.625,
"text": "Test"
},
{
"id": 3102,
"logprob": -11.1875,
"text": " request"
}
],
"seed": 0,
"tokens": [
{
"id": 2143,
"logprob": -1.828125,
"special": false,
"text": " sent"
},
{
"id": 10081,
"logprob": -0.36914062,
"special": false,
"text": " successfully"
},
{
"id": 13,
"logprob": 0.0,
"special": false,
"text": "."
},
{
"id": 185,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 1380,
"logprob": -0.38671875,
"special": false,
"text": "We"
},
{
"id": 543,
"logprob": -0.12695312,
"special": false,
"text": " will"
},
{
"id": 752,
"logprob": -0.20117188,
"special": false,
"text": " get"
},
{
"id": 279,
"logprob": 0.0,
"special": false,
"text": " in"
},
{
"id": 5402,
"logprob": 0.0,
"special": false,
"text": " touch"
},
{
"id": 366,
"logprob": 0.0,
"special": false,
"text": " with"
}
],
"top_tokens": null
},
"generated_text": "Test request sent successfully.\nWe will get in touch with"
}

View File

@ -0,0 +1,358 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 100000,
"logprob": null,
"text": "<begin▁of▁sentence>"
},
{
"id": 3533,
"logprob": -9.625,
"text": "Test"
},
{
"id": 3102,
"logprob": -11.1875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 185,
"logprob": -1.5546875,
"special": false,
"text": "\n"
},
{
"id": 549,
"logprob": -2.8125,
"special": false,
"text": "The"
},
{
"id": 1727,
"logprob": -2.375,
"special": false,
"text": " test"
},
{
"id": 3102,
"logprob": -0.890625,
"special": false,
"text": " request"
},
{
"id": 317,
"logprob": -1.1484375,
"special": false,
"text": " is"
},
{
"id": 245,
"logprob": -1.5390625,
"special": false,
"text": " a"
},
{
"id": 3102,
"logprob": -2.609375,
"special": false,
"text": " request"
},
{
"id": 327,
"logprob": -0.75,
"special": false,
"text": " for"
},
{
"id": 245,
"logprob": -1.1171875,
"special": false,
"text": " a"
},
{
"id": 1727,
"logprob": -0.90625,
"special": false,
"text": " test"
}
],
"top_tokens": null
},
"generated_text": "\nThe test request is a request for a test"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 100000,
"logprob": null,
"text": "<begin▁of▁sentence>"
},
{
"id": 3533,
"logprob": -9.625,
"text": "Test"
},
{
"id": 3102,
"logprob": -11.25,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 185,
"logprob": -1.5546875,
"special": false,
"text": "\n"
},
{
"id": 549,
"logprob": -2.8125,
"special": false,
"text": "The"
},
{
"id": 1727,
"logprob": -2.375,
"special": false,
"text": " test"
},
{
"id": 3102,
"logprob": -0.890625,
"special": false,
"text": " request"
},
{
"id": 317,
"logprob": -1.1484375,
"special": false,
"text": " is"
},
{
"id": 245,
"logprob": -1.5390625,
"special": false,
"text": " a"
},
{
"id": 3102,
"logprob": -2.609375,
"special": false,
"text": " request"
},
{
"id": 327,
"logprob": -0.75,
"special": false,
"text": " for"
},
{
"id": 245,
"logprob": -1.1171875,
"special": false,
"text": " a"
},
{
"id": 1727,
"logprob": -0.90625,
"special": false,
"text": " test"
}
],
"top_tokens": null
},
"generated_text": "\nThe test request is a request for a test"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 100000,
"logprob": null,
"text": "<begin▁of▁sentence>"
},
{
"id": 3533,
"logprob": -9.625,
"text": "Test"
},
{
"id": 3102,
"logprob": -11.25,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 185,
"logprob": -1.5546875,
"special": false,
"text": "\n"
},
{
"id": 549,
"logprob": -2.8125,
"special": false,
"text": "The"
},
{
"id": 1727,
"logprob": -2.375,
"special": false,
"text": " test"
},
{
"id": 3102,
"logprob": -0.890625,
"special": false,
"text": " request"
},
{
"id": 317,
"logprob": -1.1484375,
"special": false,
"text": " is"
},
{
"id": 245,
"logprob": -1.5390625,
"special": false,
"text": " a"
},
{
"id": 3102,
"logprob": -2.609375,
"special": false,
"text": " request"
},
{
"id": 327,
"logprob": -0.75,
"special": false,
"text": " for"
},
{
"id": 245,
"logprob": -1.1171875,
"special": false,
"text": " a"
},
{
"id": 1727,
"logprob": -0.90625,
"special": false,
"text": " test"
}
],
"top_tokens": null
},
"generated_text": "\nThe test request is a request for a test"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 100000,
"logprob": null,
"text": "<begin▁of▁sentence>"
},
{
"id": 3533,
"logprob": -9.625,
"text": "Test"
},
{
"id": 3102,
"logprob": -11.25,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 185,
"logprob": -1.5546875,
"special": false,
"text": "\n"
},
{
"id": 549,
"logprob": -2.8125,
"special": false,
"text": "The"
},
{
"id": 1727,
"logprob": -2.375,
"special": false,
"text": " test"
},
{
"id": 3102,
"logprob": -0.890625,
"special": false,
"text": " request"
},
{
"id": 317,
"logprob": -1.1484375,
"special": false,
"text": " is"
},
{
"id": 245,
"logprob": -1.5390625,
"special": false,
"text": " a"
},
{
"id": 3102,
"logprob": -2.609375,
"special": false,
"text": " request"
},
{
"id": 327,
"logprob": -0.75,
"special": false,
"text": " for"
},
{
"id": 245,
"logprob": -1.1171875,
"special": false,
"text": " a"
},
{
"id": 1727,
"logprob": -0.90625,
"special": false,
"text": " test"
}
],
"top_tokens": null
},
"generated_text": "\nThe test request is a request for a test"
}
]

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 2323,
"logprob": -9.421875,
"text": "Test"
},
{
"id": 1715,
"logprob": -10.546875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 369,
"logprob": -2.1816406,
"special": false,
"text": " for"
},
{
"id": 279,
"logprob": -2.6992188,
"special": false,
"text": " the"
},
{
"id": 220,
"logprob": -3.6308594,
"special": false,
"text": " "
},
{
"id": 679,
"logprob": -1.7900391,
"special": false,
"text": "201"
},
{
"id": 24,
"logprob": -1.3554688,
"special": false,
"text": "9"
},
{
"id": 12,
"logprob": -2.0039062,
"special": false,
"text": "-"
},
{
"id": 2366,
"logprob": -0.4489746,
"special": false,
"text": "202"
},
{
"id": 15,
"logprob": -0.037109375,
"special": false,
"text": "0"
},
{
"id": 2978,
"logprob": -0.8100586,
"special": false,
"text": " school"
},
{
"id": 1060,
"logprob": -0.013015747,
"special": false,
"text": " year"
}
],
"top_tokens": null
},
"generated_text": " for the 2019-2020 school year"
}

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 2323,
"logprob": -9.421875,
"text": "Test"
},
{
"id": 1715,
"logprob": -10.546875,
"text": " request"
}
],
"seed": 0,
"tokens": [
{
"id": 25,
"logprob": -0.8535156,
"special": false,
"text": ":"
},
{
"id": 2209,
"logprob": -2.4804688,
"special": false,
"text": " Is"
},
{
"id": 279,
"logprob": -0.7167969,
"special": false,
"text": " the"
},
{
"id": 734,
"logprob": -2.625,
"special": false,
"text": " function"
},
{
"id": 330,
"logprob": -0.35131836,
"special": false,
"text": " \""
},
{
"id": 4110,
"logprob": -2.4101562,
"special": false,
"text": "Create"
},
{
"id": 264,
"logprob": -0.23181152,
"special": false,
"text": " a"
},
{
"id": 502,
"logprob": -0.25512695,
"special": false,
"text": " new"
},
{
"id": 1052,
"logprob": -1.2792969,
"special": false,
"text": " file"
},
{
"id": 1,
"logprob": -1.2529297,
"special": false,
"text": "\""
}
],
"top_tokens": null
},
"generated_text": "Test request: Is the function \"Create a new file\""
}

View File

@ -0,0 +1,358 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 2323,
"logprob": -9.421875,
"text": "Test"
},
{
"id": 1715,
"logprob": -10.546875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 369,
"logprob": -2.1816406,
"special": false,
"text": " for"
},
{
"id": 279,
"logprob": -2.6992188,
"special": false,
"text": " the"
},
{
"id": 220,
"logprob": -3.6308594,
"special": false,
"text": " "
},
{
"id": 679,
"logprob": -1.7988281,
"special": false,
"text": "201"
},
{
"id": 24,
"logprob": -1.3535156,
"special": false,
"text": "9"
},
{
"id": 12,
"logprob": -2.0058594,
"special": false,
"text": "-"
},
{
"id": 2366,
"logprob": -0.45410156,
"special": false,
"text": "202"
},
{
"id": 15,
"logprob": -0.037109375,
"special": false,
"text": "0"
},
{
"id": 2978,
"logprob": -0.8095703,
"special": false,
"text": " school"
},
{
"id": 1060,
"logprob": -0.013053894,
"special": false,
"text": " year"
}
],
"top_tokens": null
},
"generated_text": " for the 2019-2020 school year"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 2323,
"logprob": -9.421875,
"text": "Test"
},
{
"id": 1715,
"logprob": -10.546875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 369,
"logprob": -2.1816406,
"special": false,
"text": " for"
},
{
"id": 279,
"logprob": -2.6992188,
"special": false,
"text": " the"
},
{
"id": 220,
"logprob": -3.6308594,
"special": false,
"text": " "
},
{
"id": 679,
"logprob": -1.7988281,
"special": false,
"text": "201"
},
{
"id": 24,
"logprob": -1.3535156,
"special": false,
"text": "9"
},
{
"id": 12,
"logprob": -2.0058594,
"special": false,
"text": "-"
},
{
"id": 2366,
"logprob": -0.45410156,
"special": false,
"text": "202"
},
{
"id": 15,
"logprob": -0.037109375,
"special": false,
"text": "0"
},
{
"id": 2978,
"logprob": -0.8095703,
"special": false,
"text": " school"
},
{
"id": 1060,
"logprob": -0.013053894,
"special": false,
"text": " year"
}
],
"top_tokens": null
},
"generated_text": " for the 2019-2020 school year"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 2323,
"logprob": -9.421875,
"text": "Test"
},
{
"id": 1715,
"logprob": -10.546875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 369,
"logprob": -2.1816406,
"special": false,
"text": " for"
},
{
"id": 279,
"logprob": -2.6992188,
"special": false,
"text": " the"
},
{
"id": 220,
"logprob": -3.6308594,
"special": false,
"text": " "
},
{
"id": 679,
"logprob": -1.7988281,
"special": false,
"text": "201"
},
{
"id": 24,
"logprob": -1.3535156,
"special": false,
"text": "9"
},
{
"id": 12,
"logprob": -2.0058594,
"special": false,
"text": "-"
},
{
"id": 2366,
"logprob": -0.45410156,
"special": false,
"text": "202"
},
{
"id": 15,
"logprob": -0.037109375,
"special": false,
"text": "0"
},
{
"id": 2978,
"logprob": -0.8095703,
"special": false,
"text": " school"
},
{
"id": 1060,
"logprob": -0.013053894,
"special": false,
"text": " year"
}
],
"top_tokens": null
},
"generated_text": " for the 2019-2020 school year"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 2323,
"logprob": -9.421875,
"text": "Test"
},
{
"id": 1715,
"logprob": -10.546875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 369,
"logprob": -2.1816406,
"special": false,
"text": " for"
},
{
"id": 279,
"logprob": -2.6992188,
"special": false,
"text": " the"
},
{
"id": 220,
"logprob": -3.6308594,
"special": false,
"text": " "
},
{
"id": 679,
"logprob": -1.7988281,
"special": false,
"text": "201"
},
{
"id": 24,
"logprob": -1.3535156,
"special": false,
"text": "9"
},
{
"id": 12,
"logprob": -2.0058594,
"special": false,
"text": "-"
},
{
"id": 2366,
"logprob": -0.45410156,
"special": false,
"text": "202"
},
{
"id": 15,
"logprob": -0.037109375,
"special": false,
"text": "0"
},
{
"id": 2978,
"logprob": -0.8095703,
"special": false,
"text": " school"
},
{
"id": 1060,
"logprob": -0.013053894,
"special": false,
"text": " year"
}
],
"top_tokens": null
},
"generated_text": " for the 2019-2020 school year"
}
]

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.0859375,
"text": "Test"
},
{
"id": 2009,
"logprob": -16.359375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 5229,
"logprob": -2.7988281,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": -0.91259766,
"special": false,
"text": ":"
},
{
"id": 853,
"logprob": -2.8496094,
"special": false,
"text": " Un"
},
{
"id": 23765,
"logprob": -1.1894531,
"special": false,
"text": "supported"
},
{
"id": 4714,
"logprob": -1.5917969,
"special": false,
"text": " browser"
},
{
"id": 29892,
"logprob": -0.34765625,
"special": false,
"text": ","
},
{
"id": 1873,
"logprob": -1.2695312,
"special": false,
"text": " version"
},
{
"id": 470,
"logprob": -0.25170898,
"special": false,
"text": " or"
},
{
"id": 7481,
"logprob": -0.21411133,
"special": false,
"text": " platform"
},
{
"id": 13,
"logprob": -1.1162109,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": " failed: Unsupported browser, version or platform\n"
}

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.0859375,
"text": "Test"
},
{
"id": 2009,
"logprob": -16.359375,
"text": "request"
}
],
"seed": 0,
"tokens": [
{
"id": 5229,
"logprob": -0.6645508,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": 0.0,
"special": false,
"text": ":"
},
{
"id": 6527,
"logprob": -2.2324219,
"special": false,
"text": " Could"
},
{
"id": 451,
"logprob": 0.0,
"special": false,
"text": " not"
},
{
"id": 6088,
"logprob": -1.6074219,
"special": false,
"text": " parse"
},
{
"id": 1243,
"logprob": -1.6298828,
"special": false,
"text": " test"
},
{
"id": 1206,
"logprob": -0.72558594,
"special": false,
"text": " case"
},
{
"id": 1024,
"logprob": -0.40429688,
"special": false,
"text": " name"
},
{
"id": 515,
"logprob": 0.0,
"special": false,
"text": " from"
},
{
"id": 525,
"logprob": -1.2519531,
"special": false,
"text": " '"
}
],
"top_tokens": null
},
"generated_text": "Test request failed: Could not parse test case name from '"
}

View File

@ -0,0 +1,358 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.0859375,
"text": "Test"
},
{
"id": 2009,
"logprob": -16.359375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 5229,
"logprob": -2.7988281,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": -0.91259766,
"special": false,
"text": ":"
},
{
"id": 853,
"logprob": -2.8496094,
"special": false,
"text": " Un"
},
{
"id": 23765,
"logprob": -1.1894531,
"special": false,
"text": "supported"
},
{
"id": 4714,
"logprob": -1.5917969,
"special": false,
"text": " browser"
},
{
"id": 29892,
"logprob": -0.34765625,
"special": false,
"text": ","
},
{
"id": 1873,
"logprob": -1.2695312,
"special": false,
"text": " version"
},
{
"id": 470,
"logprob": -0.25170898,
"special": false,
"text": " or"
},
{
"id": 7481,
"logprob": -0.21411133,
"special": false,
"text": " platform"
},
{
"id": 13,
"logprob": -1.1162109,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": " failed: Unsupported browser, version or platform\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.0859375,
"text": "Test"
},
{
"id": 2009,
"logprob": -16.359375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 5229,
"logprob": -2.7988281,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": -0.91259766,
"special": false,
"text": ":"
},
{
"id": 853,
"logprob": -2.8496094,
"special": false,
"text": " Un"
},
{
"id": 23765,
"logprob": -1.1894531,
"special": false,
"text": "supported"
},
{
"id": 4714,
"logprob": -1.5917969,
"special": false,
"text": " browser"
},
{
"id": 29892,
"logprob": -0.34765625,
"special": false,
"text": ","
},
{
"id": 1873,
"logprob": -1.2695312,
"special": false,
"text": " version"
},
{
"id": 470,
"logprob": -0.25170898,
"special": false,
"text": " or"
},
{
"id": 7481,
"logprob": -0.21411133,
"special": false,
"text": " platform"
},
{
"id": 13,
"logprob": -1.1162109,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": " failed: Unsupported browser, version or platform\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.0859375,
"text": "Test"
},
{
"id": 2009,
"logprob": -16.359375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 5229,
"logprob": -2.7988281,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": -0.91259766,
"special": false,
"text": ":"
},
{
"id": 853,
"logprob": -2.8496094,
"special": false,
"text": " Un"
},
{
"id": 23765,
"logprob": -1.1894531,
"special": false,
"text": "supported"
},
{
"id": 4714,
"logprob": -1.5917969,
"special": false,
"text": " browser"
},
{
"id": 29892,
"logprob": -0.34765625,
"special": false,
"text": ","
},
{
"id": 1873,
"logprob": -1.2695312,
"special": false,
"text": " version"
},
{
"id": 470,
"logprob": -0.25170898,
"special": false,
"text": " or"
},
{
"id": 7481,
"logprob": -0.21411133,
"special": false,
"text": " platform"
},
{
"id": 13,
"logprob": -1.1162109,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": " failed: Unsupported browser, version or platform\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.0859375,
"text": "Test"
},
{
"id": 2009,
"logprob": -16.359375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 5229,
"logprob": -2.7988281,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": -0.91259766,
"special": false,
"text": ":"
},
{
"id": 853,
"logprob": -2.8496094,
"special": false,
"text": " Un"
},
{
"id": 23765,
"logprob": -1.1894531,
"special": false,
"text": "supported"
},
{
"id": 4714,
"logprob": -1.5917969,
"special": false,
"text": " browser"
},
{
"id": 29892,
"logprob": -0.34765625,
"special": false,
"text": ","
},
{
"id": 1873,
"logprob": -1.2695312,
"special": false,
"text": " version"
},
{
"id": 470,
"logprob": -0.25170898,
"special": false,
"text": " or"
},
{
"id": 7481,
"logprob": -0.21411133,
"special": false,
"text": " platform"
},
{
"id": 13,
"logprob": -1.1162109,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": " failed: Unsupported browser, version or platform\n"
}
]

View File

@ -0,0 +1,251 @@
{
"details": {
"finish_reason": "length",
"generated_tokens": 40,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -0.27416992,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.17016602,
"special": false,
"text": "\n"
},
{
"id": 28737,
"logprob": -2.7109375,
"special": false,
"text": "I"
},
{
"id": 28809,
"logprob": -1.5,
"special": false,
"text": ""
},
{
"id": 28719,
"logprob": -0.34204102,
"special": false,
"text": "m"
},
{
"id": 459,
"logprob": -1.6914062,
"special": false,
"text": " not"
},
{
"id": 1864,
"logprob": -0.69140625,
"special": false,
"text": " sure"
},
{
"id": 513,
"logprob": -1.6171875,
"special": false,
"text": " if"
},
{
"id": 315,
"logprob": -1.3837891,
"special": false,
"text": " I"
},
{
"id": 541,
"logprob": -1.2226562,
"special": false,
"text": " can"
},
{
"id": 1567,
"logprob": -1.8652344,
"special": false,
"text": " come"
},
{
"id": 582,
"logprob": -0.0070228577,
"special": false,
"text": " up"
},
{
"id": 395,
"logprob": -0.0054092407,
"special": false,
"text": " with"
},
{
"id": 28705,
"logprob": -0.62597656,
"special": false,
"text": " "
},
{
"id": 28770,
"logprob": -0.0035572052,
"special": false,
"text": "3"
},
{
"id": 4842,
"logprob": -0.93603516,
"special": false,
"text": " unique"
},
{
"id": 3085,
"logprob": -0.028411865,
"special": false,
"text": " words"
},
{
"id": 369,
"logprob": -1.0400391,
"special": false,
"text": " that"
},
{
"id": 6685,
"logprob": -0.09710693,
"special": false,
"text": " describe"
},
{
"id": 528,
"logprob": -0.066467285,
"special": false,
"text": " me"
},
{
"id": 28725,
"logprob": -1.0722656,
"special": false,
"text": ","
},
{
"id": 562,
"logprob": -0.33422852,
"special": false,
"text": " but"
},
{
"id": 315,
"logprob": -0.5136719,
"special": false,
"text": " I"
},
{
"id": 28809,
"logprob": -0.8989258,
"special": false,
"text": ""
},
{
"id": 584,
"logprob": -0.2076416,
"special": false,
"text": "ll"
},
{
"id": 1464,
"logprob": -0.8808594,
"special": false,
"text": " try"
},
{
"id": 28723,
"logprob": -0.88427734,
"special": false,
"text": "."
},
{
"id": 13,
"logprob": -0.91064453,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.08105469,
"special": false,
"text": "\n"
},
{
"id": 28740,
"logprob": -1.8486328,
"special": false,
"text": "1"
},
{
"id": 28723,
"logprob": -0.111572266,
"special": false,
"text": "."
},
{
"id": 23626,
"logprob": -3.15625,
"special": false,
"text": " Creative"
},
{
"id": 13,
"logprob": -0.9194336,
"special": false,
"text": "\n"
},
{
"id": 28750,
"logprob": -0.24841309,
"special": false,
"text": "2"
},
{
"id": 28723,
"logprob": -9.393692e-05,
"special": false,
"text": "."
},
{
"id": 6785,
"logprob": -3.1386719,
"special": false,
"text": " Fun"
},
{
"id": 1780,
"logprob": -0.53564453,
"special": false,
"text": "ny"
},
{
"id": 13,
"logprob": -0.09033203,
"special": false,
"text": "\n"
},
{
"id": 28770,
"logprob": -0.00466156,
"special": false,
"text": "3"
},
{
"id": 28723,
"logprob": -0.00016450882,
"special": false,
"text": "."
}
]
},
"generated_text": "\n\nIm not sure if I can come up with 3 unique words that describe me, but Ill try.\n\n1. Creative\n2. Funny\n3."
}

View File

@ -0,0 +1,53 @@
{
"details": {
"finish_reason": "eos_token",
"generated_tokens": 7,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 1,
"logprob": -0.49658203,
"special": true,
"text": "<s>"
},
{
"id": 28705,
"logprob": -0.0016384125,
"special": false,
"text": " "
},
{
"id": 1,
"logprob": -1.4931641,
"special": true,
"text": "<s>"
},
{
"id": 28705,
"logprob": -0.00075769424,
"special": false,
"text": " "
},
{
"id": 28740,
"logprob": -0.25024414,
"special": false,
"text": "1"
},
{
"id": 28740,
"logprob": -0.2631836,
"special": false,
"text": "1"
},
{
"id": 2,
"logprob": -0.0003285408,
"special": true,
"text": "</s>"
}
]
},
"generated_text": " 11"
}

View File

@ -0,0 +1,251 @@
{
"details": {
"finish_reason": "length",
"generated_tokens": 40,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.0488281,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -1.0800781,
"special": false,
"text": "\n"
},
{
"id": 27332,
"logprob": -2.1152344,
"special": false,
"text": "###"
},
{
"id": 28705,
"logprob": -1.6748047,
"special": false,
"text": " "
},
{
"id": 28740,
"logprob": -0.097229004,
"special": false,
"text": "1"
},
{
"id": 28723,
"logprob": -0.16467285,
"special": false,
"text": "."
},
{
"id": 7615,
"logprob": -2.2246094,
"special": false,
"text": " News"
},
{
"id": 13,
"logprob": -1.0488281,
"special": false,
"text": "\n"
},
{
"id": 27332,
"logprob": -0.69189453,
"special": false,
"text": "###"
},
{
"id": 28705,
"logprob": -0.013343811,
"special": false,
"text": " "
},
{
"id": 28750,
"logprob": -0.011230469,
"special": false,
"text": "2"
},
{
"id": 28723,
"logprob": -0.00096845627,
"special": false,
"text": "."
},
{
"id": 21095,
"logprob": -2.5605469,
"special": false,
"text": " Blog"
},
{
"id": 13,
"logprob": -0.19458008,
"special": false,
"text": "\n"
},
{
"id": 27332,
"logprob": -0.031280518,
"special": false,
"text": "###"
},
{
"id": 28705,
"logprob": -0.0030708313,
"special": false,
"text": " "
},
{
"id": 28770,
"logprob": -0.0029277802,
"special": false,
"text": "3"
},
{
"id": 28723,
"logprob": -0.0012350082,
"special": false,
"text": "."
},
{
"id": 20108,
"logprob": -2.1582031,
"special": false,
"text": " Article"
},
{
"id": 13,
"logprob": -0.05810547,
"special": false,
"text": "\n"
},
{
"id": 27332,
"logprob": -0.35083008,
"special": false,
"text": "###"
},
{
"id": 28705,
"logprob": -0.034332275,
"special": false,
"text": " "
},
{
"id": 28781,
"logprob": -0.009666443,
"special": false,
"text": "4"
},
{
"id": 28723,
"logprob": -0.0013113022,
"special": false,
"text": "."
},
{
"id": 8349,
"logprob": -2.6191406,
"special": false,
"text": " Review"
},
{
"id": 13,
"logprob": -0.04031372,
"special": false,
"text": "\n"
},
{
"id": 27332,
"logprob": -0.45239258,
"special": false,
"text": "###"
},
{
"id": 28705,
"logprob": -0.045410156,
"special": false,
"text": " "
},
{
"id": 28782,
"logprob": -0.0041236877,
"special": false,
"text": "5"
},
{
"id": 28723,
"logprob": -0.0010223389,
"special": false,
"text": "."
},
{
"id": 5299,
"logprob": -2.8066406,
"special": false,
"text": " Other"
},
{
"id": 13,
"logprob": -0.12054443,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.44580078,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -1.4921875,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -1.3574219,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -1.0039062,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.5859375,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.43481445,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.2783203,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.20410156,
"special": false,
"text": "\n"
}
]
},
"generated_text": "\n\n### 1. News\n### 2. Blog\n### 3. Article\n### 4. Review\n### 5. Other\n\n\n\n\n\n\n\n\n"
}

View File

@ -0,0 +1,251 @@
{
"details": {
"finish_reason": "length",
"generated_tokens": 40,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -0.31347656,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.27441406,
"special": false,
"text": "\n"
},
{
"id": 28737,
"logprob": -2.2285156,
"special": false,
"text": "I"
},
{
"id": 28809,
"logprob": -1.4677734,
"special": false,
"text": ""
},
{
"id": 28719,
"logprob": -0.31762695,
"special": false,
"text": "m"
},
{
"id": 264,
"logprob": -1.6865234,
"special": false,
"text": " a"
},
{
"id": 1215,
"logprob": -3.2695312,
"special": false,
"text": " very"
},
{
"id": 20640,
"logprob": -3.1230469,
"special": false,
"text": " passionate"
},
{
"id": 1338,
"logprob": -0.48339844,
"special": false,
"text": " person"
},
{
"id": 28723,
"logprob": -0.9970703,
"special": false,
"text": "."
},
{
"id": 315,
"logprob": -0.5498047,
"special": false,
"text": " I"
},
{
"id": 28809,
"logprob": -1.1923828,
"special": false,
"text": ""
},
{
"id": 28719,
"logprob": -0.080444336,
"special": false,
"text": "m"
},
{
"id": 1215,
"logprob": -1.8271484,
"special": false,
"text": " very"
},
{
"id": 12215,
"logprob": -2.8847656,
"special": false,
"text": " driven"
},
{
"id": 28723,
"logprob": -1.0927734,
"special": false,
"text": "."
},
{
"id": 315,
"logprob": -0.4584961,
"special": false,
"text": " I"
},
{
"id": 28809,
"logprob": -0.5019531,
"special": false,
"text": ""
},
{
"id": 28719,
"logprob": -0.030715942,
"special": false,
"text": "m"
},
{
"id": 1215,
"logprob": -0.96972656,
"special": false,
"text": " very"
},
{
"id": 7798,
"logprob": -2.8847656,
"special": false,
"text": " determined"
},
{
"id": 28723,
"logprob": -0.27319336,
"special": false,
"text": "."
},
{
"id": 13,
"logprob": -0.56396484,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.011016846,
"special": false,
"text": "\n"
},
{
"id": 3195,
"logprob": -0.7163086,
"special": false,
"text": "What"
},
{
"id": 349,
"logprob": -1.1611328,
"special": false,
"text": " is"
},
{
"id": 574,
"logprob": -0.515625,
"special": false,
"text": " your"
},
{
"id": 6656,
"logprob": -1.0253906,
"special": false,
"text": " favorite"
},
{
"id": 1970,
"logprob": -2.1738281,
"special": false,
"text": " thing"
},
{
"id": 684,
"logprob": -0.48364258,
"special": false,
"text": " about"
},
{
"id": 1250,
"logprob": -1.8876953,
"special": false,
"text": " being"
},
{
"id": 264,
"logprob": -0.41967773,
"special": false,
"text": " a"
},
{
"id": 8626,
"logprob": -2.9160156,
"special": false,
"text": " teacher"
},
{
"id": 28804,
"logprob": -0.11920166,
"special": false,
"text": "?"
},
{
"id": 13,
"logprob": -0.023727417,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.010848999,
"special": false,
"text": "\n"
},
{
"id": 28737,
"logprob": -1.0566406,
"special": false,
"text": "I"
},
{
"id": 2016,
"logprob": -0.7163086,
"special": false,
"text": " love"
},
{
"id": 272,
"logprob": -1.9169922,
"special": false,
"text": " the"
},
{
"id": 1639,
"logprob": -2.03125,
"special": false,
"text": " fact"
}
]
},
"generated_text": "\n\nIm a very passionate person. Im very driven. Im very determined.\n\nWhat is your favorite thing about being a teacher?\n\nI love the fact"
}

View File

@ -100,6 +100,8 @@ async def test_flash_llama_completion_many_prompts_stream(
chunk = [c.replace("data:", "") for c in chunk] chunk = [c.replace("data:", "") for c in chunk]
# remove empty strings # remove empty strings
chunk = [c for c in chunk if c] chunk = [c for c in chunk if c]
# remove completion marking chunk
chunk = [c for c in chunk if c != " [DONE]"]
# parse json # parse json
chunk = [json.loads(c) for c in chunk] chunk = [json.loads(c) for c in chunk]

View File

@ -0,0 +1,63 @@
import pytest
@pytest.fixture(scope="module")
def flash_deepseek_v2_handle(launcher):
with launcher("deepseek-ai/DeepSeek-V2-Lite", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_deepseek_v2(flash_deepseek_v2_handle):
await flash_deepseek_v2_handle.health(300)
return flash_deepseek_v2_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_deepseek_v2(flash_deepseek_v2, response_snapshot):
response = await flash_deepseek_v2.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_deepseek_v2_all_params(flash_deepseek_v2, response_snapshot):
response = await flash_deepseek_v2.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_deepseek_v2_load(
flash_deepseek_v2, generate_load, response_snapshot
):
responses = await generate_load(
flash_deepseek_v2, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -0,0 +1,62 @@
import pytest
@pytest.fixture(scope="module")
def flash_llama_fp8_handle(launcher):
with launcher("meta-llama/Meta-Llama-3-8B", num_shard=2, quantize="fp8") as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_fp8(flash_llama_fp8_handle):
await flash_llama_fp8_handle.health(300)
return flash_llama_fp8_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_fp8(flash_llama_fp8, response_snapshot):
response = await flash_llama_fp8.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_fp8_all_params(flash_llama_fp8, response_snapshot):
response = await flash_llama_fp8.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_fp8_load(flash_llama_fp8, generate_load, response_snapshot):
responses = await generate_load(
flash_llama_fp8, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -0,0 +1,66 @@
import pytest
@pytest.fixture(scope="module")
def flash_llama_marlin24_handle(launcher):
with launcher(
"nm-testing/Llama-2-7b-pruned2.4-Marlin_24", quantize="marlin"
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_marlin(flash_llama_marlin24_handle):
await flash_llama_marlin24_handle.health(300)
return flash_llama_marlin24_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
response = await flash_llama_marlin.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin24_all_params(flash_llama_marlin, response_snapshot):
response = await flash_llama_marlin.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin24_load(
flash_llama_marlin, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_marlin, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -0,0 +1,134 @@
import pytest
import requests
@pytest.fixture(scope="module")
def lora_mistral_handle(launcher):
with launcher(
"mistralai/Mistral-7B-v0.1",
lora_adapters=[
"predibase/dbpedia",
"predibase/customer_support",
],
cuda_graphs=[0],
) as handle:
yield handle
@pytest.fixture(scope="module")
async def lora_mistral(lora_mistral_handle):
await lora_mistral_handle.health(300)
return lora_mistral_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_lora_mistral(lora_mistral, response_snapshot):
response = await lora_mistral.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
classification_prompt = """You are given the title and the body of an article below. Please determine the type of the article.\n### Title: Great White Whale\n\n### Body: Great White Whale is the debut album by the Canadian rock band Secret and Whisper. The album was in the works for about a year and was released on February 12 2008. A music video was shot in Pittsburgh for the album's first single XOXOXO. The album reached number 17 on iTunes's top 100 albums in its first week on sale.\n\n### Article Type:"""
@pytest.mark.asyncio
@pytest.mark.private
async def test_lora_mistral_without_adapter(lora_mistral, response_snapshot):
response = requests.post(
f"{lora_mistral.base_url}/generate",
headers=lora_mistral.headers,
json={
"inputs": classification_prompt,
"parameters": {
"max_new_tokens": 40,
"details": True,
},
},
)
assert response.status_code == 200
data = response.json()
assert (
data["generated_text"]
== "\n\n### 1. News\n### 2. Blog\n### 3. Article\n### 4. Review\n### 5. Other\n\n\n\n\n\n\n\n\n"
)
assert data == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_lora_mistral_with_dbpedia_adapter(lora_mistral, response_snapshot):
response = requests.post(
f"{lora_mistral.base_url}/generate",
headers=lora_mistral.headers,
json={
"inputs": classification_prompt,
"parameters": {
"max_new_tokens": 40,
"adapter_id": "predibase/dbpedia",
"details": True,
},
},
)
assert response.status_code == 200
data = response.json()
assert data["generated_text"] == " 11"
assert data == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_lora_mistral_with_customer_support_adapter(
lora_mistral, response_snapshot
):
print(lora_mistral.base_url)
print(lora_mistral.headers)
response = requests.post(
f"{lora_mistral.base_url}/generate",
headers=lora_mistral.headers,
json={
"inputs": "What are 3 unique words that describe you?",
"parameters": {
"max_new_tokens": 40,
"adapter_id": "predibase/customer_support",
"details": True,
},
},
)
assert response.status_code == 200
data = response.json()
assert (
data["generated_text"]
== "\n\nIm not sure if I can come up with 3 unique words that describe me, but Ill try.\n\n1. Creative\n2. Funny\n3."
)
assert data == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_lora_mistral_without_customer_support_adapter(
lora_mistral, response_snapshot
):
response = requests.post(
f"{lora_mistral.base_url}/generate",
headers=lora_mistral.headers,
json={
"inputs": "What are 3 unique words that describe you?",
"parameters": {
"max_new_tokens": 40,
"details": True,
},
},
)
assert response.status_code == 200
data = response.json()
assert (
data["generated_text"]
== "\n\nIm a very passionate person. Im very driven. Im very determined.\n\nWhat is your favorite thing about being a teacher?\n\nI love the fact"
)
assert data == response_snapshot

View File

@ -457,6 +457,14 @@ struct Args {
/// startup that will be available to callers via the `adapter_id` field in a request. /// startup that will be available to callers via the `adapter_id` field in a request.
#[clap(long, env)] #[clap(long, env)]
lora_adapters: Option<String>, lora_adapters: Option<String>,
/// Disable sending of all usage statistics
#[clap(default_value = "false", long, env)]
disable_usage_stats: bool,
/// Disable sending of crash reports, but allow anonymous usage statistics
#[clap(default_value = "false", long, env)]
disable_crash_reports: bool,
} }
#[derive(Debug)] #[derive(Debug)]
@ -1201,6 +1209,14 @@ fn spawn_webserver(
args.model_id, args.model_id,
]; ];
// Pass usage stats flags to router
if args.disable_usage_stats {
router_args.push("--disable-usage-stats".to_string());
}
if args.disable_crash_reports {
router_args.push("--disable-crash-reports".to_string());
}
// Grammar support // Grammar support
if args.disable_grammar_support { if args.disable_grammar_support {
router_args.push("--disable-grammar-support".to_string()); router_args.push("--disable-grammar-support".to_string());

View File

@ -24,7 +24,7 @@ futures = "0.3.28"
hf-hub = { workspace = true } hf-hub = { workspace = true }
itertools = "0.10" itertools = "0.10"
jsonschema = { version = "0.17.1", features = ["draft202012"] } jsonschema = { version = "0.17.1", features = ["draft202012"] }
metrics = "0.21.1" metrics = "0.23.0"
metrics-exporter-prometheus = { version = "0.15.1", features = [] } metrics-exporter-prometheus = { version = "0.15.1", features = [] }
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
@ -52,6 +52,10 @@ regex = "1.10.3"
once_cell = "1.19.0" once_cell = "1.19.0"
image = "0.25.1" image = "0.25.1"
base64 = { workspace = true } base64 = { workspace = true }
sysinfo = "0.30.13"
uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] }
csv = "1.3.0"
[build-dependencies] [build-dependencies]
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }

View File

@ -7,7 +7,7 @@ pub(crate) use health::HealthCheck;
use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
use crate::{ use crate::{
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, ToolChoice,
}; };
use crate::{ use crate::{
FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools, FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools,
@ -91,14 +91,14 @@ impl Infer {
.limit_concurrent_requests .limit_concurrent_requests
.try_acquire_owned() .try_acquire_owned()
.map_err(|err| { .map_err(|err| {
metrics::increment_counter!("tgi_request_failure", "err" => "overloaded"); metrics::counter!("tgi_request_failure", "err" => "overloaded").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
err err
})?; })?;
// Validate request // Validate request
let valid_request = self.validation.validate(request).await.map_err(|err| { let valid_request = self.validation.validate(request).await.map_err(|err| {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
err err
})?; })?;
@ -140,7 +140,7 @@ impl Infer {
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.apply(messages, grammar_with_prompt) .apply(messages, grammar_with_prompt)
.map_err(|e| { .map_err(|e| {
metrics::increment_counter!("tgi_request_failure", "err" => "template"); metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
tracing::error!("{e}"); tracing::error!("{e}");
e e
}) })
@ -214,7 +214,7 @@ impl Infer {
}) })
} else { } else {
let err = InferError::IncompleteGeneration; let err = InferError::IncompleteGeneration;
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
Err(err) Err(err)
} }
@ -332,29 +332,37 @@ impl ChatTemplate {
pub struct ToolGrammar {} pub struct ToolGrammar {}
impl ToolGrammar { impl ToolGrammar {
// find a tool by name
fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
tools
.iter()
.find(|tool| tool.function.name == name)
.cloned()
.ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
}
pub fn apply( pub fn apply(
tools: Option<Vec<Tool>>, tools: Option<Vec<Tool>>,
tool_choice: Option<ToolType>, tool_choice: ToolChoice,
) -> Result<Option<Tools>, InferError> { ) -> Result<Option<Tools>, InferError> {
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) { // if no tools are provided, we return None
// let tool_prompt = tool_prompt.unwrap_or_default(); let tools = match tools {
Some(tools) if !tools.is_empty() => tools,
_ => return Ok(None),
};
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
// if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice { let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => { ToolType::FunctionName(name) => {
vec![req_tools vec![Self::find_tool_by_name(&tools, &name)?]
.iter()
.find(|tool| tool.function.name == *name)
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
.clone()]
} }
ToolType::Function { function } => { ToolType::Function { function } => {
let tool = req_tools vec![Self::find_tool_by_name(&tools, &function.name)?]
.iter()
.find(|tool| tool.function.name == function.name)
.unwrap_or_else(|| panic!("Tool with name {} not found", function.name))
.clone();
vec![tool]
} }
ToolType::OneOf => req_tools.to_owned(), ToolType::OneOf => tools,
ToolType::NoTool => return Ok(None),
}; };
// adds the error notification function for LLM feedback if required // adds the error notification function for LLM feedback if required
@ -448,10 +456,7 @@ impl ToolGrammar {
}, },
}; };
return Ok(Some(tools)); Ok(Some(tools))
}
// Err(InferError::ToolError("No tools provided".to_string()))
Ok(None)
} }
} }

View File

@ -111,7 +111,7 @@ async fn queue_task(
match cmd { match cmd {
QueueCommand::Append(entry, span) => { QueueCommand::Append(entry, span) => {
span.in_scope(|| state.append(*entry)); span.in_scope(|| state.append(*entry));
metrics::increment_gauge!("tgi_queue_size", 1.0); metrics::gauge!("tgi_queue_size").increment(1.0);
} }
QueueCommand::NextBatch { QueueCommand::NextBatch {
min_size, min_size,
@ -124,7 +124,7 @@ async fn queue_task(
let next_batch = let next_batch =
state.next_batch(min_size, max_size, prefill_token_budget, token_budget); state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
response_sender.send(next_batch).unwrap(); response_sender.send(next_batch).unwrap();
metrics::gauge!("tgi_queue_size", state.entries.len() as f64); metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
}), }),
} }
} }
@ -226,7 +226,7 @@ impl State {
// Filter entries where the response receiver was dropped (== entries where the request // Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client) // was dropped by the client)
if entry.response_tx.is_closed() { if entry.response_tx.is_closed() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
tracing::debug!("Dropping entry"); tracing::debug!("Dropping entry");
continue; continue;
} }
@ -336,7 +336,7 @@ impl State {
// Increment batch id // Increment batch id
self.next_batch_id += 1; self.next_batch_id += 1;
metrics::histogram!("tgi_batch_next_size", batch.size as f64); metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
Some((batch_entries, batch, next_batch_span)) Some((batch_entries, batch, next_batch_span))
} }

View File

@ -148,8 +148,8 @@ pub(crate) async fn batching_task(
let batch_size = batch.size; let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens; let batch_max_tokens = batch.max_tokens;
let mut batches = vec![batch]; let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size", batch_size as f64); metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64); metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
let min_size = if waiting_tokens >= max_waiting_tokens { let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try // If we didn't onboard any new requests since >= max_waiting_tokens, we try
@ -170,9 +170,11 @@ pub(crate) async fn batching_task(
{ {
// Tracking metrics // Tracking metrics
if min_size.is_some() { if min_size.is_some() {
metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure"); metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
.increment(1);
} else { } else {
metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded"); metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
.increment(1);
} }
entries.iter_mut().for_each(|(_, entry)| { entries.iter_mut().for_each(|(_, entry)| {
@ -219,8 +221,8 @@ pub(crate) async fn batching_task(
.await; .await;
waiting_tokens += 1; waiting_tokens += 1;
} }
metrics::gauge!("tgi_batch_current_size", 0.0); metrics::gauge!("tgi_batch_current_size").set(0.0);
metrics::gauge!("tgi_batch_current_max_tokens", 0.0); metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
} }
} }
} }
@ -234,7 +236,7 @@ async fn prefill(
) -> Option<CachedBatch> { ) -> Option<CachedBatch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_id = batch.id; let batch_id = batch.id;
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill"); metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
match client.prefill(batch).await { match client.prefill(batch).await {
Ok((generations, next_batch, timings)) => { Ok((generations, next_batch, timings)) => {
@ -248,11 +250,15 @@ async fn prefill(
// Filter next batch and remove requests that were stopped // Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await; let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill"); metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill"); .record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill"); metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); .record(timings.decode.as_secs_f64());
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration","method" => "prefill")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
next_batch next_batch
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
@ -261,7 +267,7 @@ async fn prefill(
generation_health.store(false, Ordering::SeqCst); generation_health.store(false, Ordering::SeqCst);
let _ = client.clear_cache(Some(batch_id)).await; let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries); send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
None None
} }
} }
@ -276,7 +282,7 @@ async fn decode(
) -> Option<CachedBatch> { ) -> Option<CachedBatch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect(); let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
match client.decode(batches).await { match client.decode(batches).await {
Ok((generations, next_batch, timings)) => { Ok((generations, next_batch, timings)) => {
@ -291,13 +297,18 @@ async fn decode(
let next_batch = filter_batch(client, next_batch, entries).await; let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat { if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode"); metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
.record(concat_duration.as_secs_f64());
} }
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode"); metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode"); .record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode"); metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); .record(timings.decode.as_secs_f64());
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
next_batch next_batch
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
@ -307,7 +318,7 @@ async fn decode(
let _ = client.clear_cache(Some(id)).await; let _ = client.clear_cache(Some(id)).await;
} }
send_errors(err, entries); send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode"); metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
None None
} }
} }
@ -365,7 +376,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
// request and we need to stop generating hence why we unwrap_or(true) // request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| { let stopped = send_responses(generation, entry).map_err(|err| {
tracing::error!("Entry response channel error."); tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
err err
}).unwrap_or(true); }).unwrap_or(true);
if stopped { if stopped {
@ -381,7 +392,7 @@ fn send_responses(
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> { ) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected // Return directly if the channel is disconnected
if entry.response_tx.is_closed() { if entry.response_tx.is_closed() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
return Ok(true); return Ok(true);
} }
@ -407,7 +418,7 @@ fn send_responses(
// Create last Token // Create last Token
let tokens_ = generation.tokens.expect("Non empty tokens in generation"); let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len(); let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64); metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
let mut iterator = tokens_ let mut iterator = tokens_
.ids .ids
.into_iter() .into_iter()
@ -472,7 +483,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
// Create and enter a span to link this function back to the entry // Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string()); let err = InferError::GenerationError(error.to_string());
metrics::increment_counter!("tgi_request_failure", "err" => "generation"); metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone. // unwrap_or is valid here as we don't care if the receiver is gone.

View File

@ -126,7 +126,7 @@ async fn queue_task(
match cmd { match cmd {
QueueCommand::Append(entry, span) => { QueueCommand::Append(entry, span) => {
span.in_scope(|| state.append(*entry)); span.in_scope(|| state.append(*entry));
metrics::increment_gauge!("tgi_queue_size", 1.0); metrics::gauge!("tgi_queue_size").increment(1.0);
} }
QueueCommand::NextBatch { QueueCommand::NextBatch {
min_size, min_size,
@ -141,7 +141,7 @@ async fn queue_task(
.instrument(span) .instrument(span)
.await; .await;
response_sender.send(next_batch).unwrap(); response_sender.send(next_batch).unwrap();
metrics::gauge!("tgi_queue_size", state.entries.len() as f64); metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
} }
} }
} }
@ -248,7 +248,7 @@ impl State {
// Filter entries where the response receiver was dropped (== entries where the request // Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client) // was dropped by the client)
if entry.response_tx.is_closed() { if entry.response_tx.is_closed() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
tracing::debug!("Dropping entry"); tracing::debug!("Dropping entry");
continue; continue;
} }
@ -399,7 +399,7 @@ impl State {
// Increment batch id // Increment batch id
self.next_batch_id += 1; self.next_batch_id += 1;
metrics::histogram!("tgi_batch_next_size", batch.size as f64); metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
Some((batch_entries, batch, next_batch_span)) Some((batch_entries, batch, next_batch_span))
} }

View File

@ -154,8 +154,8 @@ pub(crate) async fn batching_task(
let batch_size = batch.size; let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens; let batch_max_tokens = batch.max_tokens;
let mut batches = vec![batch]; let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size", batch_size as f64); metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64); metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
let min_size = if waiting_tokens >= max_waiting_tokens { let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try // If we didn't onboard any new requests since >= max_waiting_tokens, we try
@ -176,9 +176,11 @@ pub(crate) async fn batching_task(
{ {
// Tracking metrics // Tracking metrics
if min_size.is_some() { if min_size.is_some() {
metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure"); metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
.increment(1);
} else { } else {
metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded"); metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
.increment(1);
} }
entries.iter_mut().for_each(|(_, entry)| { entries.iter_mut().for_each(|(_, entry)| {
@ -225,8 +227,8 @@ pub(crate) async fn batching_task(
.await; .await;
waiting_tokens += 1; waiting_tokens += 1;
} }
metrics::gauge!("tgi_batch_current_size", 0.0); metrics::gauge!("tgi_batch_current_size").set(0.0);
metrics::gauge!("tgi_batch_current_max_tokens", 0.0); metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
} }
} }
} }
@ -240,7 +242,7 @@ async fn prefill(
) -> Option<CachedBatch> { ) -> Option<CachedBatch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_id = batch.id; let batch_id = batch.id;
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill"); metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
match client.prefill(batch).await { match client.prefill(batch).await {
Ok((generations, next_batch, timings)) => { Ok((generations, next_batch, timings)) => {
@ -254,11 +256,15 @@ async fn prefill(
// Filter next batch and remove requests that were stopped // Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await; let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill"); metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill"); .record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill"); metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); .record(timings.decode.as_secs_f64());
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
next_batch next_batch
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
@ -267,7 +273,7 @@ async fn prefill(
generation_health.store(false, Ordering::SeqCst); generation_health.store(false, Ordering::SeqCst);
let _ = client.clear_cache(Some(batch_id)).await; let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries); send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
None None
} }
} }
@ -282,7 +288,7 @@ async fn decode(
) -> Option<CachedBatch> { ) -> Option<CachedBatch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect(); let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
match client.decode(batches).await { match client.decode(batches).await {
Ok((generations, next_batch, timings)) => { Ok((generations, next_batch, timings)) => {
@ -297,13 +303,18 @@ async fn decode(
let next_batch = filter_batch(client, next_batch, entries).await; let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat { if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode"); metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
.record(concat_duration.as_secs_f64());
} }
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode"); metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode"); .record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode"); metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); .record(timings.decode.as_secs_f64());
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
next_batch next_batch
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
@ -313,7 +324,7 @@ async fn decode(
let _ = client.clear_cache(Some(id)).await; let _ = client.clear_cache(Some(id)).await;
} }
send_errors(err, entries); send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode"); metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
None None
} }
} }
@ -371,7 +382,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
// request and we need to stop generating hence why we unwrap_or(true) // request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| { let stopped = send_responses(generation, entry).map_err(|err| {
tracing::error!("Entry response channel error."); tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
err err
}).unwrap_or(true); }).unwrap_or(true);
if stopped { if stopped {
@ -387,7 +398,7 @@ fn send_responses(
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> { ) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected // Return directly if the channel is disconnected
if entry.response_tx.is_closed() { if entry.response_tx.is_closed() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
return Ok(true); return Ok(true);
} }
@ -413,7 +424,7 @@ fn send_responses(
// Create last Token // Create last Token
let tokens_ = generation.tokens.expect("Non empty tokens in generation"); let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len(); let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64); metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
let mut iterator = tokens_ let mut iterator = tokens_
.ids .ids
.into_iter() .into_iter()
@ -478,7 +489,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
// Create and enter a span to link this function back to the entry // Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string()); let err = InferError::GenerationError(error.to_string());
metrics::increment_counter!("tgi_request_failure", "err" => "generation"); metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone. // unwrap_or is valid here as we don't care if the receiver is gone.

View File

@ -7,6 +7,8 @@ mod validation;
#[cfg(feature = "kserve")] #[cfg(feature = "kserve")]
mod kserve; mod kserve;
pub mod usage_stats;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tracing::warn; use tracing::warn;
use utoipa::ToSchema; use utoipa::ToSchema;
@ -40,13 +42,13 @@ pub struct HubModelInfo {
pub pipeline_tag: Option<String>, pub pipeline_tag: Option<String>,
} }
#[derive(Debug, Clone, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ChatTemplate { pub struct ChatTemplate {
name: String, name: String,
template: String, template: String,
} }
#[derive(Debug, Clone, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)] #[serde(untagged)]
pub enum ChatTemplateVersions { pub enum ChatTemplateVersions {
Single(String), Single(String),
@ -55,7 +57,7 @@ pub enum ChatTemplateVersions {
use std::path::Path; use std::path::Path;
#[derive(Debug, Clone, Deserialize, Default)] #[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct HubTokenizerConfig { pub struct HubTokenizerConfig {
pub chat_template: Option<ChatTemplateVersions>, pub chat_template: Option<ChatTemplateVersions>,
pub completion_template: Option<String>, pub completion_template: Option<String>,
@ -384,7 +386,7 @@ pub struct CompletionRequest {
/// UNUSED /// UNUSED
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. /// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String, pub model: Option<String>,
/// The prompt to generate completions for. /// The prompt to generate completions for.
#[schema(example = "What is Deep Learning?")] #[schema(example = "What is Deep Learning?")]
@ -731,7 +733,7 @@ impl ChatCompletionChunk {
pub(crate) struct ChatRequest { pub(crate) struct ChatRequest {
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. /// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String, pub model: Option<String>,
/// A list of messages comprising the conversation so far. /// A list of messages comprising the conversation so far.
#[schema(example = "[{\"role\": \"user\", \"content\": \"What is Deep Learning?\"}]")] #[schema(example = "[{\"role\": \"user\", \"content\": \"What is Deep Learning?\"}]")]
@ -824,7 +826,7 @@ pub(crate) struct ChatRequest {
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
#[serde(default)] #[serde(default)]
#[schema(nullable = true, example = "null")] #[schema(nullable = true, example = "null")]
pub tool_choice: Option<ToolType>, pub tool_choice: ToolChoice,
/// Response format constraints for the generation. /// Response format constraints for the generation.
/// ///
@ -846,34 +848,34 @@ pub enum ToolType {
OneOf, OneOf,
FunctionName(String), FunctionName(String),
Function { function: FunctionName }, Function { function: FunctionName },
NoTool,
} }
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)]
pub struct FunctionName { pub struct FunctionName {
pub name: String, pub name: String,
} }
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default, ToSchema)]
#[serde(from = "ToolTypeDeserializer")] #[serde(from = "ToolTypeDeserializer")]
pub struct ToolChoice(pub Option<ToolType>); pub struct ToolChoice(pub Option<ToolType>);
#[derive(Deserialize)] #[derive(Deserialize)]
#[serde(untagged)] #[serde(untagged)]
enum ToolTypeDeserializer { enum ToolTypeDeserializer {
None(Option<String>), String(String),
Some(ToolType), ToolType(ToolType),
} }
impl From<ToolTypeDeserializer> for ToolChoice { impl From<ToolTypeDeserializer> for ToolChoice {
fn from(value: ToolTypeDeserializer) -> Self { fn from(value: ToolTypeDeserializer) -> Self {
match value { match value {
ToolTypeDeserializer::None(opt) => match opt.as_deref() { ToolTypeDeserializer::String(s) => match s.as_str() {
Some("none") => ToolChoice(None), "none" => ToolChoice(Some(ToolType::NoTool)),
Some("auto") => ToolChoice(Some(ToolType::OneOf)), "auto" => ToolChoice(Some(ToolType::OneOf)),
Some(s) => ToolChoice(Some(ToolType::FunctionName(s.to_string()))), _ => ToolChoice(Some(ToolType::FunctionName(s))),
None => ToolChoice(Some(ToolType::OneOf)),
}, },
ToolTypeDeserializer::Some(tool_type) => ToolChoice(Some(tool_type)), ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)),
} }
} }
} }

View File

@ -14,6 +14,7 @@ use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use text_generation_router::config::Config; use text_generation_router::config::Config;
use text_generation_router::usage_stats;
use text_generation_router::{ use text_generation_router::{
server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig, server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
}; };
@ -87,6 +88,10 @@ struct Args {
disable_grammar_support: bool, disable_grammar_support: bool,
#[clap(default_value = "4", long, env)] #[clap(default_value = "4", long, env)]
max_client_batch_size: usize, max_client_batch_size: usize,
#[clap(long, env, default_value_t)]
disable_usage_stats: bool,
#[clap(long, env, default_value_t)]
disable_crash_reports: bool,
} }
#[derive(Debug, Subcommand)] #[derive(Debug, Subcommand)]
@ -128,6 +133,8 @@ async fn main() -> Result<(), RouterError> {
messages_api_enabled, messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
disable_usage_stats,
disable_crash_reports,
command, command,
} = args; } = args;
@ -210,7 +217,11 @@ async fn main() -> Result<(), RouterError> {
} }
let api = if use_api { let api = if use_api {
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) { if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
let cache = Cache::default(); let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
.map_err(|_| ())
.map(|cache_dir| Cache::new(cache_dir.into()))
.unwrap_or_else(|_| Cache::default());
tracing::warn!("Offline mode active using cache defaults"); tracing::warn!("Offline mode active using cache defaults");
Type::Cache(cache) Type::Cache(cache)
} else { } else {
@ -320,6 +331,7 @@ async fn main() -> Result<(), RouterError> {
tracing::warn!("Could not find tokenizer config locally and no API specified"); tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default() HubTokenizerConfig::default()
}); });
let tokenizer_class = tokenizer_config.tokenizer_class.clone();
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| { let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
let mut tokenizer = Tokenizer::from_file(filename).ok(); let mut tokenizer = Tokenizer::from_file(filename).ok();
@ -374,8 +386,47 @@ async fn main() -> Result<(), RouterError> {
} }
}; };
// Only send usage stats when TGI is run in container and the function returns Some
let is_container = matches!(usage_stats::is_container(), Ok(true));
let user_agent = if !disable_usage_stats && is_container {
let reduced_args = usage_stats::Args::new(
config.clone(),
tokenizer_class,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
revision,
validation_workers,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
disable_usage_stats,
disable_crash_reports,
);
Some(usage_stats::UserAgent::new(reduced_args))
} else {
None
};
if let Some(ref ua) = user_agent {
let start_event =
usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None);
tokio::spawn(async move {
start_event.send().await;
});
};
// Run server // Run server
server::run( let result = server::run(
master_shard_uds_path, master_shard_uds_path,
model_info, model_info,
compat_return_full_text, compat_return_full_text,
@ -406,9 +457,42 @@ async fn main() -> Result<(), RouterError> {
max_client_batch_size, max_client_batch_size,
print_schema_command, print_schema_command,
) )
.await?; .await;
match result {
Ok(_) => {
if let Some(ref ua) = user_agent {
let stop_event = usage_stats::UsageStatsEvent::new(
ua.clone(),
usage_stats::EventType::Stop,
None,
);
stop_event.send().await;
};
Ok(()) Ok(())
} }
Err(e) => {
if let Some(ref ua) = user_agent {
if !disable_crash_reports {
let error_event = usage_stats::UsageStatsEvent::new(
ua.clone(),
usage_stats::EventType::Error,
Some(e.to_string()),
);
error_event.send().await;
} else {
let unknow_error_event = usage_stats::UsageStatsEvent::new(
ua.clone(),
usage_stats::EventType::Error,
Some("unknow_error".to_string()),
);
unknow_error_event.send().await;
}
};
Err(RouterError::WebServer(e))
}
}
}
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT: /// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
/// - otlp_endpoint is an optional URL to an Open Telemetry collector /// - otlp_endpoint is an optional URL to an Open Telemetry collector

View File

@ -11,10 +11,11 @@ use crate::kserve::{
}; };
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters,
GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig,
Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent, OutputMessage, PrefillToken,
Usage, Validation, SimpleToken, StreamDetails, StreamResponse, TextMessage, Token, TokenizeResponse,
ToolCallDelta, ToolCallMessage, Url, Usage, Validation,
}; };
use crate::{ use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
@ -23,7 +24,7 @@ use crate::{
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest, CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
VertexResponse, VertexResponse,
}; };
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolType}; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
use async_stream::__private::AsyncStream; use async_stream::__private::AsyncStream;
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
@ -185,7 +186,7 @@ pub(crate) async fn generate_internal(
span: tracing::Span, span: tracing::Span,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> { ) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let start_time = Instant::now(); let start_time = Instant::now();
metrics::increment_counter!("tgi_request_count"); metrics::counter!("tgi_request_count").increment(1);
// Do not long ultra long inputs, like image payloads. // Do not long ultra long inputs, like image payloads.
tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]); tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]);
@ -301,25 +302,15 @@ pub(crate) async fn generate_internal(
); );
// Metrics // Metrics
metrics::increment_counter!("tgi_request_success"); metrics::counter!("tgi_request_success").increment(1);
metrics::histogram!("tgi_request_duration", total_time.as_secs_f64()); metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64());
metrics::histogram!( metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64());
"tgi_request_validation_duration", metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64());
validation_time.as_secs_f64() metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64());
); metrics::histogram!("tgi_request_mean_time_per_token_duration")
metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64()); .record(time_per_token.as_secs_f64());
metrics::histogram!( metrics::histogram!("tgi_request_generated_tokens")
"tgi_request_inference_duration", .record(response.generated_text.generated_tokens as f64);
inference_time.as_secs_f64()
);
metrics::histogram!(
"tgi_request_mean_time_per_token_duration",
time_per_token.as_secs_f64()
);
metrics::histogram!(
"tgi_request_generated_tokens",
response.generated_text.generated_tokens as f64
);
// Send response // Send response
let mut output_text = response.generated_text.text; let mut output_text = response.generated_text.text;
@ -399,7 +390,7 @@ async fn generate_stream_internal(
span: tracing::Span, span: tracing::Span,
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) { ) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
let start_time = Instant::now(); let start_time = Instant::now();
metrics::increment_counter!("tgi_request_count"); metrics::counter!("tgi_request_count").increment(1);
tracing::debug!("Input: {}", req.inputs); tracing::debug!("Input: {}", req.inputs);
@ -427,12 +418,12 @@ async fn generate_stream_internal(
let best_of = req.parameters.best_of.unwrap_or(1); let best_of = req.parameters.best_of.unwrap_or(1);
if best_of != 1 { if best_of != 1 {
let err = InferError::from(ValidationError::BestOfStream); let err = InferError::from(ValidationError::BestOfStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)); yield Ok(Event::from(err));
} else if req.parameters.decoder_input_details { } else if req.parameters.decoder_input_details {
let err = InferError::from(ValidationError::PrefillDetailsStream); let err = InferError::from(ValidationError::PrefillDetailsStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)); yield Ok(Event::from(err));
} else { } else {
@ -500,13 +491,13 @@ async fn generate_stream_internal(
span.record("seed", format!("{:?}", generated_text.seed)); span.record("seed", format!("{:?}", generated_text.seed));
// Metrics // Metrics
metrics::increment_counter!("tgi_request_success"); metrics::counter!("tgi_request_success").increment(1);
metrics::histogram!("tgi_request_duration", total_time.as_secs_f64()); metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64());
metrics::histogram!("tgi_request_validation_duration", validation_time.as_secs_f64()); metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64());
metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64()); metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64());
metrics::histogram!("tgi_request_inference_duration", inference_time.as_secs_f64()); metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64());
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token.as_secs_f64()); metrics::histogram!("tgi_request_mean_time_per_token_duration").record(time_per_token.as_secs_f64());
metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64); metrics::histogram!("tgi_request_generated_tokens").record(generated_text.generated_tokens as f64);
// StreamResponse // StreamResponse
end_reached = true; end_reached = true;
@ -553,7 +544,7 @@ async fn generate_stream_internal(
// Skip if we already sent an error // Skip if we already sent an error
if !end_reached && !error { if !end_reached && !error {
let err = InferError::IncompleteGeneration; let err = InferError::IncompleteGeneration;
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)); yield Ok(Event::from(err));
} }
@ -572,8 +563,8 @@ request_body = CompletionRequest,
responses( responses(
(status = 200, description = "Generated Chat Completion", (status = 200, description = "Generated Chat Completion",
content( content(
("application/json" = Completion), ("application/json" = CompletionFinal),
("text/event-stream" = CompletionCompleteChunk), ("text/event-stream" = Chunk),
)), )),
(status = 424, description = "Generation Error", body = ErrorResponse, (status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})), example = json ! ({"error": "Request failed during generation"})),
@ -604,9 +595,10 @@ async fn completions(
Json(req): Json<CompletionRequest>, Json(req): Json<CompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count"); metrics::counter!("tgi_request_count").increment(1);
let CompletionRequest { let CompletionRequest {
model,
max_tokens, max_tokens,
seed, seed,
stop, stop,
@ -625,7 +617,7 @@ async fn completions(
// if suffix is present throw an error // if suffix is present throw an error
if req.suffix.is_some() { if req.suffix.is_some() {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
return Err(( return Err((
StatusCode::UNPROCESSABLE_ENTITY, StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse { Json(ErrorResponse {
@ -637,7 +629,7 @@ async fn completions(
} }
if req.prompt.0.len() > info.max_client_batch_size { if req.prompt.0.len() > info.max_client_batch_size {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
return Err(( return Err((
StatusCode::UNPROCESSABLE_ENTITY, StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse { Json(ErrorResponse {
@ -675,7 +667,7 @@ async fn completions(
seed, seed,
top_n_tokens: None, top_n_tokens: None,
grammar: None, grammar: None,
..Default::default() adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
}, },
}) })
.collect(); .collect();
@ -820,6 +812,10 @@ async fn completions(
} }
}; };
let stream = stream.chain(futures::stream::once(async {
Ok(Event::default().data("[DONE]"))
}));
let sse = Sse::new(stream).keep_alive(KeepAlive::default()); let sse = Sse::new(stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response()) Ok((headers, sse).into_response())
} else { } else {
@ -1009,8 +1005,9 @@ async fn chat_completions(
Json(req): Json<ChatRequest>, Json(req): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count"); metrics::counter!("tgi_request_count").increment(1);
let ChatRequest { let ChatRequest {
model,
logprobs, logprobs,
max_tokens, max_tokens,
messages, messages,
@ -1039,7 +1036,7 @@ async fn chat_completions(
// response_format and tools are mutually exclusive // response_format and tools are mutually exclusive
if response_format.is_some() && tools.as_ref().is_some() { if response_format.is_some() && tools.as_ref().is_some() {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
return Err(( return Err((
StatusCode::UNPROCESSABLE_ENTITY, StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse { Json(ErrorResponse {
@ -1053,7 +1050,7 @@ async fn chat_completions(
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) { let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
Ok(grammar) => grammar, Ok(grammar) => grammar,
Err(err) => { Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
return Err(( return Err((
StatusCode::UNPROCESSABLE_ENTITY, StatusCode::UNPROCESSABLE_ENTITY,
@ -1082,7 +1079,7 @@ async fn chat_completions(
let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) { let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) {
Ok(inputs) => inputs, Ok(inputs) => inputs,
Err(err) => { Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
return Err(( return Err((
StatusCode::UNPROCESSABLE_ENTITY, StatusCode::UNPROCESSABLE_ENTITY,
@ -1116,7 +1113,7 @@ async fn chat_completions(
seed, seed,
top_n_tokens: req.top_logprobs, top_n_tokens: req.top_logprobs,
grammar, grammar,
..Default::default() adapter_id: model.filter(|m| *m != "tgi").map(String::from),
}, },
}; };
@ -1178,6 +1175,11 @@ async fn chat_completions(
span, span,
) )
.await; .await;
let response_stream = response_stream.chain(futures::stream::once(async {
Ok(Event::default().data("[DONE]"))
}));
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response()) Ok((headers, sse).into_response())
} else { } else {
@ -1190,39 +1192,33 @@ async fn chat_completions(
.as_secs(); .as_secs();
let (tool_calls, output) = if tool_grammar.is_some() { let (tool_calls, output) = if tool_grammar.is_some() {
// gen_text should be valid json let gen_text_value: Value = serde_json::from_str(&generation.generated_text)
let gen_text_value: Value = .map_err(|e| InferError::ToolError(e.to_string()))?;
serde_json::from_str(&generation.generated_text).map_err(|e| {
( let function = gen_text_value.get("function").ok_or(InferError::ToolError(
StatusCode::UNPROCESSABLE_ENTITY, "No function found in generated text".to_string(),
Json(ErrorResponse { ))?;
error: e.to_string(),
error_type: "Input validation error".to_string(), let name = function
}), .get("_name")
) .and_then(Value::as_str)
})?; .ok_or(InferError::ToolError(
"No _name found in generated text".to_string(),
))?
.to_string();
let mut arguments = function.clone();
if let Value::Object(ref mut props) = arguments {
props.remove("_name");
}
let tool_calls = vec![ToolCall { let tool_calls = vec![ToolCall {
id: "0".to_string(), id: "0".to_string(),
r#type: "function".to_string(), r#type: "function".to_string(),
function: FunctionDefinition { function: FunctionDefinition {
description: None, description: None,
name: gen_text_value name,
.get("function") arguments,
.and_then(|f| f.get("_name"))
.and_then(|name| name.as_str())
.unwrap_or("default_function_name")
.to_string(),
// Serialize the JSON object obtained from "function" to an escaped JSON string
arguments: gen_text_value
.get("function")
.map(|f| {
let mut f_cloned = f.clone();
if let Value::Object(ref mut props) = f_cloned {
props.remove("_name");
}
f_cloned
})
.unwrap_or_default(),
}, },
}]; }];
(Some(tool_calls), None) (Some(tool_calls), None)
@ -1280,7 +1276,7 @@ async fn vertex_compatibility(
Json(req): Json<VertexRequest>, Json(req): Json<VertexRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count"); metrics::counter!("tgi_request_count").increment(1);
// check that theres at least one instance // check that theres at least one instance
if req.instances.is_empty() { if req.instances.is_empty() {
@ -1454,6 +1450,14 @@ pub async fn run(
GrammarType, GrammarType,
ChatRequest, ChatRequest,
Message, Message,
MessageContent,
MessageChunk,
Url,
FunctionName,
OutputMessage,
TextMessage,
ToolCallMessage,
ToolCallDelta,
ChatCompletionComplete, ChatCompletionComplete,
ChatCompletionChoice, ChatCompletionChoice,
ChatCompletionDelta, ChatCompletionDelta,
@ -1488,6 +1492,7 @@ pub async fn run(
ToolCall, ToolCall,
Function, Function,
FunctionDefinition, FunctionDefinition,
ToolChoice,
) )
), ),
tags( tags(

355
router/src/usage_stats.rs Normal file
View File

@ -0,0 +1,355 @@
use crate::config::Config;
use csv::ReaderBuilder;
use reqwest::header::HeaderMap;
use serde::Serialize;
use std::{
fs::File,
io::{self, BufRead},
path::Path,
process::Command,
time::Duration,
};
use uuid::Uuid;
const TELEMETRY_URL: &str = "https://huggingface.co/api/telemetry/tgi";
#[derive(Debug, Clone, Serialize)]
pub struct UserAgent {
pub uid: String,
pub args: Args,
pub env: Env,
}
impl UserAgent {
pub fn new(reduced_args: Args) -> Self {
Self {
uid: Uuid::new_v4().to_string(),
args: reduced_args,
env: Env::new(),
}
}
}
#[derive(Serialize, Debug)]
pub enum EventType {
Start,
Stop,
Error,
}
#[derive(Debug, Serialize)]
pub struct UsageStatsEvent {
user_agent: UserAgent,
event_type: EventType,
#[serde(skip_serializing_if = "Option::is_none")]
error_reason: Option<String>,
}
impl UsageStatsEvent {
pub fn new(user_agent: UserAgent, event_type: EventType, error_reason: Option<String>) -> Self {
Self {
user_agent,
event_type,
error_reason,
}
}
pub async fn send(&self) {
let mut headers = HeaderMap::new();
headers.insert("Content-Type", "application/json".parse().unwrap());
let body = serde_json::to_string(&self).unwrap();
let client = reqwest::Client::new();
let _ = client
.post(TELEMETRY_URL)
.headers(headers)
.body(body)
.timeout(Duration::from_secs(5))
.send()
.await;
}
}
#[derive(Debug, Clone, Serialize)]
pub struct Args {
model_config: Option<Config>,
tokenizer_config: Option<String>,
max_concurrent_requests: usize,
max_best_of: usize,
max_stop_sequences: usize,
max_top_n_tokens: u32,
max_input_tokens: usize,
max_total_tokens: usize,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
revision: Option<String>,
validation_workers: usize,
messages_api_enabled: bool,
disable_grammar_support: bool,
max_client_batch_size: usize,
disable_usage_stats: bool,
disable_crash_reports: bool,
}
impl Args {
#[allow(clippy::too_many_arguments)]
pub fn new(
model_config: Option<Config>,
tokenizer_config: Option<String>,
max_concurrent_requests: usize,
max_best_of: usize,
max_stop_sequences: usize,
max_top_n_tokens: u32,
max_input_tokens: usize,
max_total_tokens: usize,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
revision: Option<String>,
validation_workers: usize,
messages_api_enabled: bool,
disable_grammar_support: bool,
max_client_batch_size: usize,
disable_usage_stats: bool,
disable_crash_reports: bool,
) -> Self {
Self {
model_config,
tokenizer_config,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
revision,
validation_workers,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
disable_usage_stats,
disable_crash_reports,
}
}
}
/// This is more or less a copy of the code from the `text-generation-launcher` crate to avoid a dependency
#[derive(Serialize, Debug, Clone)]
pub struct Env {
git_sha: &'static str,
docker_label: &'static str,
nvidia_info: Option<Vec<NvidiaSmiInfo>>,
xpu_info: Option<Vec<XpuSmiInfo>>,
system_env: SystemInfo,
}
#[derive(Debug, Serialize, Clone)]
struct NvidiaSmiInfo {
name: String,
pci_bus_id: String,
driver_version: String,
pstate: String,
pcie_link_gen_max: String,
pcie_link_gen_current: String,
temperature_gpu: String,
utilization_gpu: String,
utilization_memory: String,
memory_total: String,
memory_free: String,
memory_used: String,
reset_status_reset_required: String,
reset_status_drain_and_reset_recommended: String,
compute_cap: String,
ecc_errors_corrected_volatile_total: String,
mig_mode_current: String,
power_draw_instant: String,
power_limit: String,
}
impl NvidiaSmiInfo {
fn new() -> Option<Vec<NvidiaSmiInfo>> {
let output = Command::new("nvidia-smi")
.args([
"--query-gpu=name,pci.bus_id,driver_version,pstate,pcie.link.gen.max,pcie.link.gen.gpucurrent,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,reset_status.reset_required,reset_status.drain_and_reset_recommended,compute_cap,ecc.errors.corrected.volatile.total,mig.mode.current,power.draw.instant,power.limit",
"--format=csv"
])
.output()
.ok()?;
if !output.status.success() {
return None;
}
let stdout = String::from_utf8(output.stdout).ok()?;
let mut rdr = ReaderBuilder::new()
.has_headers(true)
.from_reader(stdout.as_bytes());
let mut infos = Vec::new();
for result in rdr.records() {
let record = result.ok()?;
infos.push(NvidiaSmiInfo {
name: record[0].to_string(),
pci_bus_id: record[1].to_string(),
driver_version: record[2].to_string(),
pstate: record[3].to_string(),
pcie_link_gen_max: record[4].to_string(),
pcie_link_gen_current: record[5].to_string(),
temperature_gpu: record[6].to_string(),
utilization_gpu: record[7].to_string(),
utilization_memory: record[8].to_string(),
memory_total: record[9].to_string(),
memory_free: record[10].to_string(),
memory_used: record[11].to_string(),
reset_status_reset_required: record[12].to_string(),
reset_status_drain_and_reset_recommended: record[13].to_string(),
compute_cap: record[14].to_string(),
ecc_errors_corrected_volatile_total: record[15].to_string(),
mig_mode_current: record[16].to_string(),
power_draw_instant: record[17].to_string(),
power_limit: record[18].to_string(),
});
}
Some(infos)
}
}
#[derive(Debug, Serialize, Clone)]
struct XpuSmiInfo {
device_id: usize,
gpu_utilization: f32,
gpu_power: f32,
gpu_core_temperature: f32,
gpu_memory_bandwidth_utilization: f32,
}
impl XpuSmiInfo {
/// based on this https://github.com/intel/xpumanager/blob/master/doc/smi_user_guide.md#dump-the-device-statistics-in-csv-format
fn new() -> Option<Vec<XpuSmiInfo>> {
let output = Command::new("xpu-smi")
.args([
"dump", "-d", "-1", "-m",
"0,1,3,17", // Metrics IDs: GPU Utilization, GPU Power, GPU Core Temperature, GPU Memory Bandwidth Utilization
"-n", "1", "-j",
])
.output()
.ok()?;
if !output.status.success() {
return None;
}
let stdout = String::from_utf8(output.stdout).ok()?;
let mut infos = Vec::new();
let json_data: serde_json::Value = match serde_json::from_str(&stdout) {
Ok(data) => data,
Err(_) => return None,
};
if let Some(metrics_data) = json_data.as_array() {
for entry in metrics_data {
let device_id = entry["deviceId"].as_u64()? as usize;
let gpu_utilization = entry["metrics"][0].as_f64()? as f32;
let gpu_power = entry["metrics"][1].as_f64()? as f32;
let gpu_core_temperature = entry["metrics"][2].as_f64()? as f32;
let gpu_memory_bandwidth_utilization = entry["metrics"][3].as_f64()? as f32;
infos.push(XpuSmiInfo {
device_id,
gpu_utilization,
gpu_power,
gpu_core_temperature,
gpu_memory_bandwidth_utilization,
});
}
}
Some(infos)
}
}
#[derive(Serialize, Debug, Clone)]
pub struct SystemInfo {
cpu_count: usize,
cpu_type: String,
total_memory: u64,
architecture: String,
platform: String,
}
impl SystemInfo {
fn new() -> Self {
let mut system = sysinfo::System::new_all();
system.refresh_all();
let cpu_count = system.cpus().len();
let cpu_type = system.cpus()[0].brand().to_string();
let total_memory = system.total_memory();
let architecture = std::env::consts::ARCH.to_string();
let platform = format!(
"{}-{}-{}",
std::env::consts::OS,
std::env::consts::FAMILY,
std::env::consts::ARCH
);
Self {
cpu_count,
cpu_type,
total_memory,
architecture,
platform,
}
}
}
impl Default for Env {
fn default() -> Self {
Self::new()
}
}
impl Env {
pub fn new() -> Self {
Self {
system_env: SystemInfo::new(),
nvidia_info: NvidiaSmiInfo::new(),
xpu_info: XpuSmiInfo::new(),
git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"),
docker_label: option_env!("DOCKER_LABEL").unwrap_or("N/A"),
}
}
}
pub fn is_container() -> io::Result<bool> {
let path = Path::new("/proc/self/cgroup");
let file = File::open(path)?;
let reader = io::BufReader::new(file);
for line in reader.lines() {
let line = line?;
// Check for common container runtimes
if line.contains("/docker/")
|| line.contains("/docker-")
|| line.contains("/kubepods/")
|| line.contains("/kubepods-")
|| line.contains("containerd")
|| line.contains("crio")
|| line.contains("podman")
{
return Ok(true);
}
}
Ok(false)
}

View File

@ -157,7 +157,7 @@ impl Validation {
)); ));
} }
metrics::histogram!("tgi_request_input_length", input_length as f64); metrics::histogram!("tgi_request_input_length").record(input_length as f64);
Ok((inputs, input_length, max_new_tokens)) Ok((inputs, input_length, max_new_tokens))
} }
// Return inputs without validation // Return inputs without validation
@ -384,7 +384,7 @@ impl Validation {
ignore_eos_token: false, ignore_eos_token: false,
}; };
metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64); metrics::histogram!("tgi_request_max_new_tokens").record(max_new_tokens as f64);
Ok(ValidGenerateRequest { Ok(ValidGenerateRequest {
inputs, inputs,

View File

@ -5,6 +5,7 @@ include Makefile-awq
include Makefile-eetq include Makefile-eetq
include Makefile-selective-scan include Makefile-selective-scan
include Makefile-lorax-punica include Makefile-lorax-punica
include Makefile-fbgemm
unit-tests: unit-tests:
pytest -s -vv -m "not private" tests pytest -s -vv -m "not private" tests
@ -21,13 +22,15 @@ gen-server:
install-server: gen-server install-server: gen-server
pip install pip --upgrade pip install pip --upgrade
pip install -r requirements_cuda.txt pip install -r requirements_cuda.txt
pip install -e ".[bnb, accelerate, quantize, peft, outlines]" pip install -e ".[accelerate, quantize, peft, outlines]"
install: install-cuda install: install-cuda
echo "Installed server" echo "Installed server"
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm
pip install -e ".[bnb]"
pip install nvidia-nccl-cu12==2.22.3
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
@ -37,3 +40,4 @@ run-dev:
export-requirements: export-requirements:
poetry export -o requirements_cuda.txt --without-hashes poetry export -o requirements_cuda.txt --without-hashes
poetry export -o requirements_rocm.txt --without-hashes poetry export -o requirements_rocm.txt --without-hashes
poetry export -o requirements_intel.txt --without-hashes

15
server/Makefile-fbgemm Normal file
View File

@ -0,0 +1,15 @@
fbgemm_commit := 9cf0429b726931cfab72b8264730bea682f32fca
build-fbgemm:
chmod +x fix_torch90a.sh && ./fix_torch90a.sh && \
git clone https://github.com/pytorch/FBGEMM.git fbgemm && \
cp fbgemm_remove_unused.patch fbgemm && \
cd fbgemm && git fetch && git checkout $(fbgemm_commit) && git apply fbgemm_remove_unused.patch && \
git submodule update --init --recursive && \
cd fbgemm_gpu && \
pip install -r requirements.txt && \
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai build
install-fbgemm: build-fbgemm
cd fbgemm/fbgemm_gpu && \
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai install

View File

@ -1,14 +1,14 @@
commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921 commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
build-vllm-cuda: build-vllm-cuda:
if [ ! -d 'vllm' ]; then \ if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \ pip install -U ninja packaging --no-cache-dir && \
git clone https://github.com/Narsil/vllm.git vllm; \ git clone https://github.com/Narsil/vllm.git vllm; \
fi fi
cd vllm && git fetch && git checkout $(commit_cuda) && python setup.py build cd vllm && git fetch origin && git checkout $(commit_cuda) && python setup.py build
install-vllm-cuda: build-vllm-cuda install-vllm-cuda: build-vllm-cuda
cd vllm && git fetch && git checkout $(commit_cuda) && pip install -e . cd vllm && git fetch origin && git checkout $(commit_cuda) && pip install -e .
build-vllm-rocm: build-vllm-rocm:
if [ ! -d 'vllm' ]; then \ if [ ! -d 'vllm' ]; then \

View File

@ -0,0 +1,306 @@
diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt
index 2244ea6f..96265a48 100644
--- a/fbgemm_gpu/CMakeLists.txt
+++ b/fbgemm_gpu/CMakeLists.txt
@@ -94,14 +94,14 @@ endif()
# Build Experimental Modules
################################################################################
-if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
- # TODO: Figure out NCCL/RCCL integration with ROCm
- add_subdirectory(experimental/example)
-endif()
-
-if(NOT FBGEMM_CPU_ONLY)
- add_subdirectory(experimental/gemm)
-endif()
+# if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
+# # TODO: Figure out NCCL/RCCL integration with ROCm
+# add_subdirectory(experimental/example)
+# endif()
+
+# if(NOT FBGEMM_CPU_ONLY)
+# add_subdirectory(experimental/gemm)
+# endif()
if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM)
# CUTLASS currently doesn't build on ROCm and CK hasnt yet been added:
diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake
index c56773fe..0c0d349e 100644
--- a/fbgemm_gpu/FbgemmGpu.cmake
+++ b/fbgemm_gpu/FbgemmGpu.cmake
@@ -446,53 +446,55 @@ set_source_files_properties(${fbgemm_sources}
################################################################################
set(fbgemm_gpu_sources_static_cpu
- codegen/training/forward/embedding_forward_split_cpu.cpp
- codegen/inference/embedding_forward_quantized_host_cpu.cpp
- codegen/training/backward/embedding_backward_dense_host_cpu.cpp
- codegen/utils/embedding_bounds_check_host_cpu.cpp
- src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp
- src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp
- src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp
- src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp
- src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp
- src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp
- src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
- src/input_combine_ops/input_combine_cpu.cpp
- src/layout_transform_ops/layout_transform_ops_cpu.cpp
+ # codegen/training/forward/embedding_forward_split_cpu.cpp
+ # codegen/inference/embedding_forward_quantized_host_cpu.cpp
+ # codegen/training/backward/embedding_backward_dense_host_cpu.cpp
+ # codegen/utils/embedding_bounds_check_host_cpu.cpp
+ # src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp
+ # src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp
+ # src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp
+ # src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp
+ # src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp
+ # src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
+ # src/input_combine_ops/input_combine_cpu.cpp
+ # src/layout_transform_ops/layout_transform_ops_cpu.cpp
src/quantize_ops/quantize_ops_cpu.cpp
src/quantize_ops/quantize_ops_meta.cpp
- src/sparse_ops/sparse_ops_cpu.cpp
- src/sparse_ops/sparse_ops_meta.cpp
- src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp
- src/split_embeddings_cache/linearize_cache_indices.cpp
- src/split_embeddings_cache/lfu_cache_populate_byte.cpp
- src/split_embeddings_cache/lru_cache_populate_byte.cpp
- src/split_embeddings_cache/lxu_cache.cpp
- src/split_embeddings_cache/split_embeddings_cache_ops.cpp
- codegen/training/index_select/batch_index_select_dim0_ops.cpp
- codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp)
+ # src/sparse_ops/sparse_ops_cpu.cpp
+ # src/sparse_ops/sparse_ops_meta.cpp
+ # src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp
+ # src/split_embeddings_cache/linearize_cache_indices.cpp
+ # src/split_embeddings_cache/lfu_cache_populate_byte.cpp
+ # src/split_embeddings_cache/lru_cache_populate_byte.cpp
+ # src/split_embeddings_cache/lxu_cache.cpp
+ # src/split_embeddings_cache/split_embeddings_cache_ops.cpp
+ # codegen/training/index_select/batch_index_select_dim0_ops.cpp
+ # codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp)
+)
if(NOT FBGEMM_CPU_ONLY)
list(APPEND fbgemm_gpu_sources_static_cpu
- codegen/inference/embedding_forward_quantized_host.cpp
- codegen/utils/embedding_bounds_check_host.cpp
- src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp
- src/layout_transform_ops/layout_transform_ops_gpu.cpp
- src/memory_utils/memory_utils.cpp
- src/memory_utils/memory_utils_ops.cpp
- src/memory_utils/memory_utils_ops_cpu.cpp
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp
+ # codegen/inference/embedding_forward_quantized_host.cpp
+ # codegen/utils/embedding_bounds_check_host.cpp
+ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp
+ # src/layout_transform_ops/layout_transform_ops_gpu.cpp
+ # src/memory_utils/memory_utils.cpp
+ # src/memory_utils/memory_utils_ops.cpp
+ # src/memory_utils/memory_utils_ops_cpu.cpp
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp
src/quantize_ops/quantize_ops_gpu.cpp
- src/sparse_ops/sparse_ops_gpu.cpp
- src/split_embeddings_utils/split_embeddings_utils.cpp
- src/split_embeddings_cache/split_embeddings_cache_ops.cu
- src/metric_ops/metric_ops_host.cpp
- src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp
- src/input_combine_ops/input_combine_gpu.cpp
- codegen/training/index_select/batch_index_select_dim0_host.cpp)
+ # src/sparse_ops/sparse_ops_gpu.cpp
+ # src/split_embeddings_utils/split_embeddings_utils.cpp
+ # src/split_embeddings_cache/split_embeddings_cache_ops.cu
+ # src/metric_ops/metric_ops_host.cpp
+ # src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp
+ # src/input_combine_ops/input_combine_gpu.cpp
+ # codegen/training/index_select/batch_index_select_dim0_host.cpp)
+ )
if(NVML_LIB_PATH OR USE_ROCM)
message(STATUS "Adding merge_pooled_embeddings sources")
@@ -516,36 +518,36 @@ endif()
if(NOT FBGEMM_CPU_ONLY)
set(fbgemm_gpu_sources_static_gpu
- codegen/utils/embedding_bounds_check.cu
- codegen/inference/embedding_forward_quantized_split_lookup.cu
- src/embedding_inplace_ops/embedding_inplace_update.cu
- src/histogram_binning_calibration_ops.cu
- src/input_combine_ops/input_combine.cu
- src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu
- src/memory_utils/memory_utils.cu
- src/memory_utils/memory_utils_ops.cu
- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu
- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu
- src/jagged_tensor_ops/dense_to_jagged_forward.cu
- src/jagged_tensor_ops/jagged_dense_bmm_forward.cu
- src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu
- src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu
- src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu
- src/jagged_tensor_ops/jagged_index_add_2d_forward.cu
- src/jagged_tensor_ops/jagged_index_select_2d_forward.cu
- src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu
- src/jagged_tensor_ops/jagged_softmax_backward.cu
- src/jagged_tensor_ops/jagged_softmax_forward.cu
- src/jagged_tensor_ops/jagged_tensor_ops.cu
- src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu
- src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu
- src/jagged_tensor_ops/jagged_unique_indices.cu
- src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu
- src/layout_transform_ops/layout_transform_ops.cu
- src/metric_ops/metric_ops.cu
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu
- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu
- src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu
+ # codegen/utils/embedding_bounds_check.cu
+ # codegen/inference/embedding_forward_quantized_split_lookup.cu
+ # src/embedding_inplace_ops/embedding_inplace_update.cu
+ # src/histogram_binning_calibration_ops.cu
+ # src/input_combine_ops/input_combine.cu
+ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu
+ # src/memory_utils/memory_utils.cu
+ # src/memory_utils/memory_utils_ops.cu
+ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu
+ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu
+ # src/jagged_tensor_ops/dense_to_jagged_forward.cu
+ # src/jagged_tensor_ops/jagged_dense_bmm_forward.cu
+ # src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu
+ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu
+ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu
+ # src/jagged_tensor_ops/jagged_index_add_2d_forward.cu
+ # src/jagged_tensor_ops/jagged_index_select_2d_forward.cu
+ # src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu
+ # src/jagged_tensor_ops/jagged_softmax_backward.cu
+ # src/jagged_tensor_ops/jagged_softmax_forward.cu
+ # src/jagged_tensor_ops/jagged_tensor_ops.cu
+ # src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu
+ # src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu
+ # src/jagged_tensor_ops/jagged_unique_indices.cu
+ # src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu
+ # src/layout_transform_ops/layout_transform_ops.cu
+ # src/metric_ops/metric_ops.cu
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu
+ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu
+ # src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu
src/quantize_ops/quantize_bfloat16.cu
src/quantize_ops/quantize_fp8_rowwise.cu
src/quantize_ops/quantize_fused_8bit_rowwise.cu
@@ -554,39 +556,40 @@ if(NOT FBGEMM_CPU_ONLY)
src/quantize_ops/quantize_msfp.cu
src/quantize_ops/quantize_padded_fp8_rowwise.cu
src/quantize_ops/quantize_mx.cu
- src/sparse_ops/sparse_async_cumsum.cu
- src/sparse_ops/sparse_block_bucketize_features.cu
- src/sparse_ops/sparse_bucketize_features.cu
- src/sparse_ops/sparse_batched_unary_embeddings.cu
- src/sparse_ops/sparse_compute_frequency_sequence.cu
- src/sparse_ops/sparse_expand_into_jagged_permute.cu
- src/sparse_ops/sparse_group_index.cu
- src/sparse_ops/sparse_index_add.cu
- src/sparse_ops/sparse_index_select.cu
- src/sparse_ops/sparse_invert_permute.cu
- src/sparse_ops/sparse_pack_segments_backward.cu
- src/sparse_ops/sparse_pack_segments_forward.cu
- src/sparse_ops/sparse_permute_1d.cu
- src/sparse_ops/sparse_permute_2d.cu
- src/sparse_ops/sparse_permute102.cu
- src/sparse_ops/sparse_permute_embeddings.cu
- src/sparse_ops/sparse_range.cu
- src/sparse_ops/sparse_reorder_batched_ad.cu
- src/sparse_ops/sparse_segment_sum_csr.cu
- src/sparse_ops/sparse_zipf.cu
- src/split_embeddings_cache/lfu_cache_find.cu
- src/split_embeddings_cache/lfu_cache_populate.cu
- src/split_embeddings_cache/lfu_cache_populate_byte.cu
- src/split_embeddings_cache/lru_cache_find.cu
- src/split_embeddings_cache/lru_cache_populate.cu
- src/split_embeddings_cache/lru_cache_populate_byte.cu
- src/split_embeddings_cache/lxu_cache.cu
- src/split_embeddings_cache/linearize_cache_indices.cu
- src/split_embeddings_cache/reset_weight_momentum.cu
- src/split_embeddings_utils/generate_vbe_metadata.cu
- src/split_embeddings_utils/get_infos_metadata.cu
- src/split_embeddings_utils/radix_sort_pairs.cu
- src/split_embeddings_utils/transpose_embedding_input.cu)
+ # src/sparse_ops/sparse_async_cumsum.cu
+ # src/sparse_ops/sparse_block_bucketize_features.cu
+ # src/sparse_ops/sparse_bucketize_features.cu
+ # src/sparse_ops/sparse_batched_unary_embeddings.cu
+ # src/sparse_ops/sparse_compute_frequency_sequence.cu
+ # src/sparse_ops/sparse_expand_into_jagged_permute.cu
+ # src/sparse_ops/sparse_group_index.cu
+ # src/sparse_ops/sparse_index_add.cu
+ # src/sparse_ops/sparse_index_select.cu
+ # src/sparse_ops/sparse_invert_permute.cu
+ # src/sparse_ops/sparse_pack_segments_backward.cu
+ # src/sparse_ops/sparse_pack_segments_forward.cu
+ # src/sparse_ops/sparse_permute_1d.cu
+ # src/sparse_ops/sparse_permute_2d.cu
+ # src/sparse_ops/sparse_permute102.cu
+ # src/sparse_ops/sparse_permute_embeddings.cu
+ # src/sparse_ops/sparse_range.cu
+ # src/sparse_ops/sparse_reorder_batched_ad.cu
+ # src/sparse_ops/sparse_segment_sum_csr.cu
+ # src/sparse_ops/sparse_zipf.cu
+ # src/split_embeddings_cache/lfu_cache_find.cu
+ # src/split_embeddings_cache/lfu_cache_populate.cu
+ # src/split_embeddings_cache/lfu_cache_populate_byte.cu
+ # src/split_embeddings_cache/lru_cache_find.cu
+ # src/split_embeddings_cache/lru_cache_populate.cu
+ # src/split_embeddings_cache/lru_cache_populate_byte.cu
+ # src/split_embeddings_cache/lxu_cache.cu
+ # src/split_embeddings_cache/linearize_cache_indices.cu
+ # src/split_embeddings_cache/reset_weight_momentum.cu
+ # src/split_embeddings_utils/generate_vbe_metadata.cu
+ # src/split_embeddings_utils/get_infos_metadata.cu
+ # src/split_embeddings_utils/radix_sort_pairs.cu
+ # src/split_embeddings_utils/transpose_embedding_input.cu)
+ )
set_source_files_properties(${fbgemm_gpu_sources_static_gpu}
PROPERTIES COMPILE_OPTIONS
diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
index 01f1d6ab..a6b8d7a8 100644
--- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
+++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt
@@ -25,23 +25,24 @@ set(fbgemm_sources_include_directories
${THIRDPARTY}/json/include
${NCCL_INCLUDE_DIRS})
-set(attention_ops_sources
- src/attention/attention.cpp
- src/attention/gqa_attn_splitk.cu)
+# set(attention_ops_sources
+# src/attention/attention.cpp
+# src/attention/gqa_attn_splitk.cu)
set(quantize_ops_sources
src/quantize/cutlass_extensions.cu
src/quantize/quantize.cu
src/quantize/quantize.cpp)
-set(comm_ops_sources
- src/comm/car.cu
- src/comm/car.cpp)
+# set(comm_ops_sources
+# src/comm/car.cu
+# src/comm/car.cpp)
set(experimental_gen_ai_cpp_source_files
- ${attention_ops_sources}
+ # ${attention_ops_sources}
${quantize_ops_sources}
- ${comm_ops_sources})
+ # ${comm_ops_sources}
+)
set_source_files_properties(${experimental_gen_ai_cpp_source_files}
PROPERTIES INCLUDE_DIRECTORIES

11
server/fix_torch90a.sh Executable file
View File

@ -0,0 +1,11 @@
#!/bin/bash
# This script is required to patch torch < 2.4
# It adds the 90a cuda target (H100)
# This target is required to build FBGEMM kernels
torch_cuda_arch=$(python -c "import torch; print(torch.__file__)" | sed 's/\/__init__.py//; s|$|/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake|')
sed -i '189s/\[0-9]\\\\\.\[0-9](/[0-9]\\\\.[0-9]a?(/' $torch_cuda_arch
sed -i '245s/\[0-9()]+\+"/[0-9()]+a?"/' $torch_cuda_arch
sed -i '246s/\[0-9]+\+"/[0-9]+a?"/' $torch_cuda_arch

View File

@ -59,3 +59,18 @@ def marlin_gemm(
Matrix multiplication using Marlin kernels. Matrix multiplication using Marlin kernels.
""" """
... ...
# fp8 marlin
def fp8_marlin_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
size_m: int,
size_n: int,
size_k: int,
) -> torch.Tensor:
return torch.ops._C.fp8_marlin_gemm(
a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
)

View File

@ -9,4 +9,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gptq_marlin_repack", &gptq_marlin_repack, m.def("gptq_marlin_repack", &gptq_marlin_repack,
"Repack GPTQ parameters for Marlin"); "Repack GPTQ parameters for Marlin");
m.def("marlin_gemm", &marlin_gemm, "Marlin gemm"); m.def("marlin_gemm", &marlin_gemm, "Marlin gemm");
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
m.def("fp8_marlin_gemm", &fp8_marlin_gemm);
} }

View File

@ -27,4 +27,9 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_scales, torch::Tensor &workspace, torch::Tensor &b_scales, torch::Tensor &workspace,
int64_t size_m, int64_t size_n, int64_t size_k); int64_t size_m, int64_t size_n, int64_t size_k);
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k);
#endif #endif

File diff suppressed because it is too large Load Diff

View File

@ -9,6 +9,7 @@ setup(
CUDAExtension( CUDAExtension(
name="marlin_kernels", name="marlin_kernels",
sources=[ sources=[
"marlin_kernels/fp8_marlin.cu",
"marlin_kernels/gptq_marlin.cu", "marlin_kernels/gptq_marlin.cu",
"marlin_kernels/gptq_marlin_repack.cu", "marlin_kernels/gptq_marlin_repack.cu",
"marlin_kernels/marlin_cuda_kernel.cu", "marlin_kernels/marlin_cuda_kernel.cu",

View File

@ -2,6 +2,7 @@ import torch
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
) )
from text_generation_server.utils.weights import DefaultWeightsLoader
class ProcessGroup: class ProcessGroup:
@ -42,7 +43,12 @@ class Weights:
def test_weight_hub_files_offline_error(): def test_weight_hub_files_offline_error():
vocab_size = 17 vocab_size = 17
weights = Weights(rank=0, world_size=1, vocab_size=vocab_size, hidden_dim=256) weights = Weights(
rank=0,
world_size=1,
vocab_size=vocab_size,
hidden_dim=256,
)
embeddings = TensorParallelEmbedding("", weights) embeddings = TensorParallelEmbedding("", weights)
input_ids = torch.arange(vocab_size) input_ids = torch.arange(vocab_size)

View File

@ -1,13 +1,48 @@
import pytest import pytest
import torch import torch
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import (
from text_generation_server.layers.gptq import GPTQWeight DefaultWeightsLoader,
from text_generation_server.layers.exl2 import Exl2Weight UnquantizedWeight,
from text_generation_server.layers.marlin import MarlinWeight Weights,
WeightsLoader,
)
from text_generation_server.layers.gptq import GPTQWeight, GPTQWeightsLoader
from text_generation_server.layers.exl2 import Exl2Weight, Exl2WeightsLoader
from text_generation_server.layers.marlin import MarlinWeight, MarlinWeightsLoader
from types import SimpleNamespace from types import SimpleNamespace
from typing import List, Optional, Dict, Union from typing import List, Optional, Dict, Union
from pathlib import Path from pathlib import Path
@pytest.fixture
def gptq_weights_loader():
return GPTQWeightsLoader(
bits=4,
groupsize=-1,
desc_act=False,
quant_method="gptq",
quantize="gptq",
sym=True,
)
@pytest.fixture
def gptq_weights_loader_awq():
return GPTQWeightsLoader(
bits=4,
groupsize=-1,
desc_act=False,
quant_method="awq",
quantize="awq",
sym=True,
)
@pytest.fixture
def marlin_weights_loader():
return MarlinWeightsLoader(bits=4, is_marlin_24=False)
dummy_file_system = { dummy_file_system = {
"test_weights": { "test_weights": {
"layer.0.weight": torch.tensor( "layer.0.weight": torch.tensor(
@ -58,7 +93,7 @@ dummy_file_system = {
dtype=torch.float32, dtype=torch.float32,
), ),
}, },
"test_get_multi_weights_row": { "test_get_weights_row": {
"weight.weight": torch.tensor( "weight.weight": torch.tensor(
[ [
[1, 2], [1, 2],
@ -101,7 +136,7 @@ dummy_file_system = {
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), "weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
}, },
"test_get_multi_weights_row_gptq": { "test_get_weights_row_gptq": {
"weight.qweight": torch.tensor( "weight.qweight": torch.tensor(
[ [
[1, 2], [1, 2],
@ -200,7 +235,7 @@ dummy_file_system = {
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
"weight.q_groups": torch.tensor([4], dtype=torch.int16), "weight.q_groups": torch.tensor([4], dtype=torch.int16),
}, },
"test_get_multi_weights_row_exl2": { "test_get_weights_row_exl2": {
"weight.q_weight": torch.tensor( "weight.q_weight": torch.tensor(
[ [
[1, 2], [1, 2],
@ -245,7 +280,7 @@ dummy_file_system = {
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
"weight.q_groups": torch.tensor([4], dtype=torch.int16), "weight.q_groups": torch.tensor([4], dtype=torch.int16),
}, },
"test_get_multi_weights_row_marlin": { "test_get_weights_row_marlin": {
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
}, },
@ -308,6 +343,7 @@ class MockWeights(Weights):
dummy_fs, dummy_fs,
aliases: Optional[Dict[str, List[str]]] = None, aliases: Optional[Dict[str, List[str]]] = None,
prefix: Optional[str] = None, prefix: Optional[str] = None,
weights_loader: Optional[WeightsLoader] = None,
): ):
routing = {} routing = {}
self.dummy_fs = dummy_fs self.dummy_fs = dummy_fs
@ -327,6 +363,12 @@ class MockWeights(Weights):
self.dtype = dtype self.dtype = dtype
self.process_group = process_group self.process_group = process_group
self.prefix = prefix self.prefix = prefix
self.weights_loader = (
# We don't need to get linear layers, so just wrap raw tensors.
DefaultWeightsLoader(lambda x: x)
if weights_loader is None
else weights_loader
)
self._handles = {} self._handles = {}
def _get_handle(self, filename: Union[Path, str]): def _get_handle(self, filename: Union[Path, str]):
@ -412,12 +454,10 @@ def test_get_weights_col_packed():
) )
prefix = "weight" prefix = "weight"
quantize = None
block_sizes = 1 block_sizes = 1
w = weights.get_weights_col_packed( w = weights.get_weights_col_packed(
prefix=prefix, prefix=prefix,
quantize=quantize,
block_sizes=block_sizes, block_sizes=block_sizes,
) )
@ -448,12 +488,10 @@ def test_get_weights_col_packed_block_size():
) )
prefix = "weight" prefix = "weight"
quantize = None
block_sizes = 2 block_sizes = 2
w = weights.get_weights_col_packed( w = weights.get_weights_col_packed(
prefix=prefix, prefix=prefix,
quantize=quantize,
block_sizes=block_sizes, block_sizes=block_sizes,
) )
@ -484,12 +522,10 @@ def test_get_weights_col_packed_block_size_arr():
) )
prefix = "weight" prefix = "weight"
quantize = None
block_sizes = [1, 1] block_sizes = [1, 1]
w = weights.get_weights_col_packed( w = weights.get_weights_col_packed(
prefix=prefix, prefix=prefix,
quantize=quantize,
block_sizes=block_sizes, block_sizes=block_sizes,
) )
@ -519,11 +555,9 @@ def test_get_multi_weights_col():
) )
prefixes = ["weight", "weight"] prefixes = ["weight", "weight"]
quantize = None
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=prefixes, prefixes=prefixes,
quantize=quantize,
dim=0, dim=0,
) )
@ -545,10 +579,10 @@ def test_get_multi_weights_col():
) )
def test_get_multi_weights_row(): def test_get_weights_row():
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_row", "test_get_weights_row",
], ],
device="cpu", device="cpu",
dtype=torch.float32, dtype=torch.float32,
@ -557,11 +591,9 @@ def test_get_multi_weights_row():
) )
prefix = "weight" prefix = "weight"
quantize = None
w = weights.get_multi_weights_row( w = weights.get_weights_row(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
assert torch.allclose( assert torch.allclose(
@ -576,7 +608,7 @@ def test_get_multi_weights_row():
# test_get_weights_col # test_get_weights_col
def test_get_weights_col_awq(): def test_get_weights_col_awq(gptq_weights_loader_awq):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_weights_col_gptq", "test_get_weights_col_gptq",
@ -585,14 +617,13 @@ def test_get_weights_col_awq():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader_awq,
) )
prefix = "weight" prefix = "weight"
quantize = "awq"
w = weights.get_weights_col( w = weights.get_weights_col(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
expected_weight = GPTQWeight( expected_weight = GPTQWeight(
@ -605,6 +636,7 @@ def test_get_weights_col_awq():
g_idx=None, g_idx=None,
bits=8.0, bits=8.0,
groupsize=2.0, groupsize=2.0,
use_awq_kernel=True,
use_exllama=False, use_exllama=False,
) )
@ -614,10 +646,11 @@ def test_get_weights_col_awq():
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch" assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_weights_col_gtpq(): def test_get_weights_col_gtpq(gptq_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_weights_col_gptq", "test_get_weights_col_gptq",
@ -626,14 +659,13 @@ def test_get_weights_col_gtpq():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader,
) )
prefix = "weight" prefix = "weight"
quantize = "gptq"
w = weights.get_weights_col( w = weights.get_weights_col(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
expected_weight = GPTQWeight( expected_weight = GPTQWeight(
@ -643,6 +675,7 @@ def test_get_weights_col_gtpq():
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
bits=8.0, bits=8.0,
groupsize=2.0, groupsize=2.0,
use_awq_kernel=False,
use_exllama=False, use_exllama=False,
) )
@ -652,6 +685,7 @@ def test_get_weights_col_gtpq():
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch" assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
@ -664,14 +698,13 @@ def test_get_weights_col_exl2():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=Exl2WeightsLoader(),
) )
prefix = "weight" prefix = "weight"
quantize = "exl2"
w = weights.get_weights_col( w = weights.get_weights_col(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
scaled_scale_max = 0.3906 * 256 scaled_scale_max = 0.3906 * 256
@ -692,7 +725,7 @@ def test_get_weights_col_exl2():
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_weights_col_marlin(): def test_get_weights_col_marlin(marlin_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_weights_col_marlin", "test_get_weights_col_marlin",
@ -701,14 +734,13 @@ def test_get_weights_col_marlin():
dtype=torch.float16, dtype=torch.float16,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=marlin_weights_loader,
) )
prefix = "weight" prefix = "weight"
quantize = "marlin"
w = weights.get_weights_col( w = weights.get_weights_col(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
expected_weight = MarlinWeight( expected_weight = MarlinWeight(
@ -723,7 +755,7 @@ def test_get_weights_col_marlin():
# test_get_weights_col_packed # test_get_weights_col_packed
def test_get_weights_col_packed_awq(): def test_get_weights_col_packed_awq(gptq_weights_loader_awq):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_weights_col_packed_gptq", "test_get_weights_col_packed_gptq",
@ -732,15 +764,14 @@ def test_get_weights_col_packed_awq():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader_awq,
) )
prefix = "weight" prefix = "weight"
quantize = "awq"
block_sizes = 1 block_sizes = 1
w = weights.get_weights_col_packed( w = weights.get_weights_col_packed(
prefix=prefix, prefix=prefix,
quantize=quantize,
block_sizes=block_sizes, block_sizes=block_sizes,
) )
@ -751,6 +782,7 @@ def test_get_weights_col_packed_awq():
g_idx=None, g_idx=None,
bits=8.0, bits=8.0,
groupsize=2.0, groupsize=2.0,
use_awq_kernel=True,
use_exllama=False, use_exllama=False,
) )
@ -760,6 +792,7 @@ def test_get_weights_col_packed_awq():
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch" assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
@ -773,15 +806,14 @@ def test_get_weights_col_packed_exl2():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=Exl2WeightsLoader(),
) )
prefix = "weight" prefix = "weight"
quantize = "exl2"
block_sizes = 1 block_sizes = 1
w = weights.get_weights_col_packed( w = weights.get_weights_col_packed(
prefix=prefix, prefix=prefix,
quantize=quantize,
block_sizes=block_sizes, block_sizes=block_sizes,
) )
@ -803,7 +835,7 @@ def test_get_weights_col_packed_exl2():
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_weights_col_packed_gptq(): def test_get_weights_col_packed_gptq(gptq_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_weights_col_packed_gptq", "test_get_weights_col_packed_gptq",
@ -812,14 +844,13 @@ def test_get_weights_col_packed_gptq():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader,
) )
prefixes = ["weight"] prefixes = ["weight"]
quantize = "gptq"
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=prefixes, prefixes=prefixes,
quantize=quantize,
dim=0, dim=0,
) )
@ -830,6 +861,7 @@ def test_get_weights_col_packed_gptq():
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
bits=8.0, bits=8.0,
groupsize=2.0, groupsize=2.0,
use_awq_kernel=False,
use_exllama=False, use_exllama=False,
) )
@ -839,10 +871,11 @@ def test_get_weights_col_packed_gptq():
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch" assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_weights_col_packed_marlin(): def test_get_weights_col_packed_marlin(marlin_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_weights_col_packed_marlin", "test_get_weights_col_packed_marlin",
@ -851,14 +884,13 @@ def test_get_weights_col_packed_marlin():
dtype=torch.float16, dtype=torch.float16,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=marlin_weights_loader,
) )
prefix = "weight" prefix = "weight"
quantize = "marlin"
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=[prefix], prefixes=[prefix],
quantize=quantize,
dim=0, dim=0,
) )
@ -876,7 +908,7 @@ def test_get_weights_col_packed_marlin():
# test_get_multi_weights_col # test_get_multi_weights_col
def test_get_multi_weights_col_awq(): def test_get_multi_weights_col_awq(gptq_weights_loader_awq):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_col_gptq", "test_get_multi_weights_col_gptq",
@ -885,14 +917,13 @@ def test_get_multi_weights_col_awq():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader_awq,
) )
prefixes = ["weight"] prefixes = ["weight"]
quantize = "awq"
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=prefixes, prefixes=prefixes,
quantize=quantize,
dim=0, dim=0,
) )
@ -903,6 +934,7 @@ def test_get_multi_weights_col_awq():
g_idx=None, g_idx=None,
bits=8.0, bits=8.0,
groupsize=2.0, groupsize=2.0,
use_awq_kernel=True,
use_exllama=False, use_exllama=False,
) )
@ -912,6 +944,7 @@ def test_get_multi_weights_col_awq():
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch" assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
@ -924,22 +957,21 @@ def test_get_multi_weights_col_exl2():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=Exl2WeightsLoader(),
) )
prefix = "weight" prefix = "weight"
quantize = "exl2"
try: try:
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=[prefix], prefixes=[prefix],
quantize=quantize,
dim=0, dim=0,
) )
except ValueError as e: except ValueError as e:
assert e.args[0] == "get_multi_weights_col is not supported for exl2" assert e.args[0] == "get_multi_weights_col is not supported for exl2"
def test_get_multi_weights_col_gptq(): def test_get_multi_weights_col_gptq(gptq_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_col_gptq", "test_get_multi_weights_col_gptq",
@ -948,14 +980,13 @@ def test_get_multi_weights_col_gptq():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader,
) )
prefixes = ["weight"] prefixes = ["weight"]
quantize = "gptq"
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=prefixes, prefixes=prefixes,
quantize=quantize,
dim=0, dim=0,
) )
@ -966,6 +997,7 @@ def test_get_multi_weights_col_gptq():
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
bits=8.0, bits=8.0,
groupsize=2.0, groupsize=2.0,
use_awq_kernel=False,
use_exllama=False, use_exllama=False,
) )
@ -975,10 +1007,11 @@ def test_get_multi_weights_col_gptq():
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch" assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_col_marlin(): def test_get_multi_weights_col_marlin(marlin_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_col_marlin", "test_get_multi_weights_col_marlin",
@ -987,14 +1020,13 @@ def test_get_multi_weights_col_marlin():
dtype=torch.float16, dtype=torch.float16,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=marlin_weights_loader,
) )
prefix = "weight" prefix = "weight"
quantize = "marlin"
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=[prefix], prefixes=[prefix],
quantize=quantize,
dim=0, dim=0,
) )
@ -1007,26 +1039,25 @@ def test_get_multi_weights_col_marlin():
assert torch.allclose(w.s, expected_weight.s), "s mismatch" assert torch.allclose(w.s, expected_weight.s), "s mismatch"
# test_get_multi_weights_row # test_get_weights_row
def test_get_multi_weights_row_awq(): def test_get_weights_row_awq(gptq_weights_loader_awq):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_row_gptq", "test_get_weights_row_gptq",
], ],
device="cpu", device="cpu",
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader_awq,
) )
prefix = "weight" prefix = "weight"
quantize = "awq"
w = weights.get_multi_weights_row( w = weights.get_weights_row(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
expected_weight = GPTQWeight( expected_weight = GPTQWeight(
@ -1036,6 +1067,7 @@ def test_get_multi_weights_row_awq():
g_idx=None, g_idx=None,
bits=8.0, bits=8.0,
groupsize=2.0, groupsize=2.0,
use_awq_kernel=True,
use_exllama=False, use_exllama=False,
) )
@ -1045,26 +1077,26 @@ def test_get_multi_weights_row_awq():
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch" assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_row_exl2(): def test_get_weights_row_exl2():
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_row_exl2", "test_get_weights_row_exl2",
], ],
device="cpu", device="cpu",
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=Exl2WeightsLoader(),
) )
prefix = "weight" prefix = "weight"
quantize = "exl2"
w = weights.get_multi_weights_row( w = weights.get_weights_row(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
print(w) print(w)
@ -1086,23 +1118,22 @@ def test_get_multi_weights_row_exl2():
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_multi_weights_row_gptq(): def test_get_weights_row_gptq(gptq_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_row_gptq", "test_get_weights_row_gptq",
], ],
device="cpu", device="cpu",
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader,
) )
prefix = "weight" prefix = "weight"
quantize = "gptq"
w = weights.get_multi_weights_row( w = weights.get_weights_row(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
expected_weight = GPTQWeight( expected_weight = GPTQWeight(
@ -1112,6 +1143,7 @@ def test_get_multi_weights_row_gptq():
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
bits=8.0, bits=8.0,
groupsize=2.0, groupsize=2.0,
use_awq_kernel=False,
use_exllama=False, use_exllama=False,
) )
@ -1121,26 +1153,26 @@ def test_get_multi_weights_row_gptq():
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch" assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_row_marlin(): def test_get_weights_row_marlin(marlin_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_row_marlin", "test_get_weights_row_marlin",
], ],
device="cpu", device="cpu",
dtype=torch.float16, dtype=torch.float16,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=marlin_weights_loader,
) )
prefix = "weight" prefix = "weight"
quantize = "marlin"
w = weights.get_multi_weights_row( w = weights.get_weights_row(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
expected_weight = MarlinWeight( expected_weight = MarlinWeight(

View File

@ -8,6 +8,7 @@ from typing import Optional
from enum import Enum from enum import Enum
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from text_generation_server.utils.log import log_master
app = typer.Typer() app = typer.Typer()
@ -87,10 +88,21 @@ def serve(
) )
if len(lora_adapter_ids) > 0: if len(lora_adapter_ids) > 0:
logger.warning( log_master(
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected." logger.warning,
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected.",
) )
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
# and warn the user
if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None:
log_master(
logger.warning,
f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.",
)
global CUDA_GRAPHS
CUDA_GRAPHS = None
# Downgrade enum into str for easier management later on # Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value quantize = None if quantize is None else quantize.value
dtype = None if dtype is None else dtype.value dtype = None if dtype is None else dtype.value
@ -332,6 +344,7 @@ def quantize(
upload_to_model_id: Optional[str] = None, upload_to_model_id: Optional[str] = None,
percdamp: float = 0.01, percdamp: float = 0.01,
act_order: bool = False, act_order: bool = False,
groupsize: int = 128,
): ):
if revision is None: if revision is None:
revision = "main" revision = "main"
@ -346,13 +359,14 @@ def quantize(
quantize( quantize(
model_id=model_id, model_id=model_id,
bits=4, bits=4,
groupsize=128, groupsize=groupsize,
output_dir=output_dir, output_dir=output_dir,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
upload_to_model_id=upload_to_model_id, upload_to_model_id=upload_to_model_id,
percdamp=percdamp, percdamp=percdamp,
act_order=act_order, act_order=act_order,
sym=True,
) )

View File

@ -3,6 +3,7 @@ import torch
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils.log import log_master
from loguru import logger from loguru import logger
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
@ -136,7 +137,10 @@ if ENGINE != "triton":
try: try:
import flash_attn_2_cuda import flash_attn_2_cuda
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.") log_master(
logger.info,
"ROCm: using Flash Attention 2 Composable Kernel implementation.",
)
except ImportError as e: except ImportError as e:
if major >= 8: if major >= 8:
architecture_suffix = f"-{SYSTEM}" architecture_suffix = f"-{SYSTEM}"

View File

@ -1,15 +1,18 @@
import torch from dataclasses import dataclass
from loguru import logger
from functools import lru_cache from functools import lru_cache
import bitsandbytes as bnb import bitsandbytes as bnb
import torch
from bitsandbytes.nn import Int8Params, Params4bit from bitsandbytes.nn import Int8Params, Params4bit
from text_generation_server.utils.weights import UnquantizedWeight
@lru_cache(1) @dataclass
def warn_deprecate_bnb(): class BNBWeight(UnquantizedWeight):
logger.warning( weight: torch.Tensor
"Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
) def get_linear(self, bias: torch.Tensor):
return Linear8bitLt(self.weight, bias, has_fp16_weights=False, threshold=6.0)
class Linear8bitLt(torch.nn.Module): class Linear8bitLt(torch.nn.Module):
@ -70,6 +73,22 @@ class Linear8bitLt(torch.nn.Module):
return out return out
@dataclass
class BNBFP4Weight(UnquantizedWeight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
return Linear4bit(self.weight, bias, quant_type="fp4")
@dataclass
class BNBNF4Weight(UnquantizedWeight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
return Linear4bit(self.weight, bias, quant_type="nf4")
class Linear4bit(torch.nn.Module): class Linear4bit(torch.nn.Module):
def __init__(self, weight, bias, quant_type): def __init__(self, weight, bias, quant_type):
super().__init__() super().__init__()

View File

@ -1,5 +1,23 @@
from dataclasses import dataclass
import torch import torch
from EETQ import quant_weights, w8_a16_gemm from EETQ import quant_weights, w8_a16_gemm
from text_generation_server.utils.weights import UnquantizedWeight
@dataclass
class EETQWeight(UnquantizedWeight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
try:
from text_generation_server.layers.eetq import EETQLinear
return EETQLinear(self.weight, bias)
except ImportError:
raise ImportError(
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
)
class EETQLinear(torch.nn.Module): class EETQLinear(torch.nn.Module):

View File

@ -1,9 +1,12 @@
import torch
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Union
import torch
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
@dataclass @dataclass
class Exl2Weight: class Exl2Weight(Weight):
""" """
Exllama2 exl2 quantized weights. Exllama2 exl2 quantized weights.
""" """
@ -21,3 +24,55 @@ class Exl2Weight:
@property @property
def device(self) -> torch.device: def device(self) -> torch.device:
return self.q_weight.device return self.q_weight.device
def get_linear(self, bias: torch.Tensor):
from text_generation_server.layers.gptq import ExllamaQuantLinear
return ExllamaQuantLinear(self, bias)
class Exl2WeightsLoader(WeightsLoader):
"""Loader for exl2-quantized weights."""
def get_weights(self, weights: "Weights", prefix: str):
"""
Get weights at the given prefix and apply without tensor paralllism.
"""
try:
q_weight = weights.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = weights.get_tensor(f"{prefix}.q_scale")
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
q_groups = weights.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
raise RuntimeError("Column-packed weights are not supported for exl")
def get_weights_col(self, weights: Weights, prefix: str):
# Sharding is not yet supported, so we return the weights as-is.
return self.get_weights(weights, prefix)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
raise ValueError("get_multi_weights_col is not supported for exl2")
def get_weights_row(self, weights: Weights, prefix: str):
# Sharding is not yet supported, so we return the weights as-is.
return self.get_weights(weights, prefix)

View File

@ -1,12 +1,58 @@
import torch import torch
from dataclasses import dataclass
from typing import Optional, Union, List
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import (
Weight,
WeightsLoader,
UnquantizedWeight,
Weights,
)
from text_generation_server.utils.log import log_master, log_once
FBGEMM_MM_AVAILABLE = False
FBGEMM_DYN_AVAILABLE = False
try:
import fbgemm_gpu.experimental.gen_ai
if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability()
FBGEMM_MM_AVAILABLE = major == 9
FBGEMM_DYN_AVAILABLE = major >= 8
except (ImportError, ModuleNotFoundError):
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
def get_fp8_linear() -> torch.nn.Module:
"""
Return an FP8 linear `Module` that is compatible with the current system.
"""
if SYSTEM == "cuda":
major, minor = torch.cuda.get_device_capability()
if major == 8 and minor < 9:
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
return GPTQMarlinFP8Linear
# On other systems let Torch decide if the hardware supports FP8.
return Fp8Linear
def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
if FBGEMM_DYN_AVAILABLE:
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
)
return qweight, scale
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
device = weight.device
# weight, scale = quant_weights(weight, torch.int8, False) # weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch.finfo(qdtype) finfo = torch.finfo(qdtype)
# Calculate the scale as dtype max divided by absmax # Calculate the scale as dtype max divided by absmax
scale = finfo.max / weight.abs().max().clamp(min=1e-12) scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
# scale and clamp the tensor to bring it to # scale and clamp the tensor to bring it to
# the representative range of float8 data type # the representative range of float8 data type
# (as default cast is unsaturated) # (as default cast is unsaturated)
@ -18,19 +64,166 @@ def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
return qweight, scale return qweight, scale
class HybridFP8UnquantLoader(WeightsLoader):
"""Weight loader that loads FP8 and unquantized Torch tensors."""
def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool):
self.activation_scale_ub = activation_scale_ub
self.to_fp8 = to_fp8
def get_weights(self, weights: "Weights", prefix: str):
w = weights.get_tensor(f"{prefix}.weight")
if w.dtype == torch.float8_e4m3fn:
# FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
w = weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
if w.dtype == torch.float8_e4m3fn:
# FP8 branch
scale = weights.get_packed_sharded(
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False
)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
w = torch.cat(w, dim=dim)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
scale = [
weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)
for p in prefixes
]
scale = torch.cat(scale, dim=0)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
def get_weights_row(self, weights: "Weights", prefix: str):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, to_dtype=False)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
@dataclass
class Fp8Weight(Weight):
weight: torch.Tensor
dtype: torch.dtype
weight_scale: Optional[torch.Tensor] = None
activation_scale_ub: Optional[float] = None
def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None:
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
return get_fp8_linear().from_fp8(
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
)
class Fp8Linear(torch.nn.Module): class Fp8Linear(torch.nn.Module):
def __init__( def __init__(
self, self,
weight, qweight,
scale,
scale_upper_bound,
bias, bias,
dtype,
) -> None: ) -> None:
super().__init__() super().__init__()
self.dtype = weight.dtype self.dtype = dtype
self.qweight, self.scale = fp8_quantize(weight) self.qweight = qweight
self.scale = scale
self.scale_upper_bound = (
torch.tensor(
[scale_upper_bound], dtype=torch.float32, device=qweight.device
)
if scale_upper_bound is not None
else None
)
self.bias = bias if bias is not None else None self.bias = bias if bias is not None else None
@classmethod
def from_unquant(cls, weight, bias, dtype):
qweight, scale = fp8_quantize(weight)
return cls(
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
)
@classmethod
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
return cls(
qweight=weight,
scale=scale,
scale_upper_bound=input_scale,
bias=bias,
dtype=dtype,
)
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
if FBGEMM_MM_AVAILABLE:
qinput, scale = fp8_quantize(
input, scale_upper_bound=self.scale_upper_bound
)
y = torch.ops.fbgemm.f8f8bf16_rowwise(
qinput,
self.qweight,
scale,
self.scale,
use_fast_accum=True,
bias=self.bias,
)
return y.to(self.dtype)
qinput, scale = fp8_quantize(input) qinput, scale = fp8_quantize(input)
output, _ = torch._scaled_mm( output, _ = torch._scaled_mm(
qinput, qinput,

View File

@ -1,30 +1,23 @@
from dataclasses import dataclass
import os import os
from typing import Optional from dataclasses import dataclass
from typing import List, Optional, Union
import torch import torch
from text_generation_server.utils.import_utils import ( from loguru import logger
SYSTEM, from text_generation_server.utils.import_utils import SYSTEM
) from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
@dataclass @dataclass
class GPTQParams: class GPTQWeight(Weight):
bits: int
checkpoint_format: Optional[str]
groupsize: int
desc_act: bool
quant_method: str
sym: bool
@dataclass
class GPTQWeight:
qweight: torch.Tensor qweight: torch.Tensor
qzeros: torch.Tensor qzeros: torch.Tensor
scales: torch.Tensor scales: torch.Tensor
g_idx: Optional[torch.Tensor] g_idx: Optional[torch.Tensor]
bits: int bits: int
groupsize: int groupsize: int
use_awq_kernel: bool
use_exllama: bool use_exllama: bool
def __post_init__(self): def __post_init__(self):
@ -35,6 +28,50 @@ class GPTQWeight:
def device(self) -> torch.device: def device(self) -> torch.device:
return self.qweight.device return self.qweight.device
def get_linear(self, bias: torch.Tensor):
if self.use_awq_kernel:
if SYSTEM == "rocm":
raise NotImplementedError(
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
"to use Exllama/GPTQ kernels for AWQ inference."
)
try:
from text_generation_server.layers.awq.quantize.qmodule import WQLinear
return WQLinear(
w_bit=self.bits,
group_size=self.groupsize,
qweight=self.qweight,
qzeros=self.qzeros,
scales=self.scales,
bias=bias,
)
except ImportError:
raise NotImplementedError(
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
)
elif self.use_exllama:
try:
from text_generation_server.layers.gptq import ExllamaQuantLinear
except ImportError:
raise NotImplementedError(
f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
)
return ExllamaQuantLinear(self, bias)
else:
from text_generation_server.layers.gptq.quant_linear import QuantLinear
return QuantLinear(
self.qweight,
self.qzeros,
self.scales,
self.g_idx,
bias,
self.bits,
self.groupsize,
)
try: try:
major, _minor = torch.cuda.get_device_capability() major, _minor = torch.cuda.get_device_capability()
@ -51,6 +88,8 @@ elif CAN_EXLLAMA:
if V2: if V2:
from text_generation_server.layers.gptq.exllamav2 import ( from text_generation_server.layers.gptq.exllamav2 import (
QuantLinear as ExllamaQuantLinear, QuantLinear as ExllamaQuantLinear,
)
from text_generation_server.layers.gptq.exllamav2 import (
create_exllama_buffers, create_exllama_buffers,
set_device, set_device,
) )
@ -59,6 +98,8 @@ elif CAN_EXLLAMA:
else: else:
from text_generation_server.layers.gptq.exllama import ( from text_generation_server.layers.gptq.exllama import (
Ex4bitLinear as ExllamaQuantLinear, Ex4bitLinear as ExllamaQuantLinear,
)
from text_generation_server.layers.gptq.exllama import (
create_exllama_buffers, create_exllama_buffers,
set_device, set_device,
) )
@ -69,3 +110,457 @@ elif CAN_EXLLAMA:
pass pass
from text_generation_server.layers.gptq.quant_linear import QuantLinear from text_generation_server.layers.gptq.quant_linear import QuantLinear
class GPTQWeightsLoader(WeightsLoader):
"""
Loader for GPTQ- and AWQ-quantized weights.
"""
def __init__(
self,
*,
bits: int,
desc_act: bool,
groupsize: int,
quant_method: str,
quantize: str,
sym: bool,
):
self.bits = bits
self.desc_act = desc_act
self.groupsize = groupsize
self.quant_method = quant_method
self.quantize = quantize
self.sym = sym
def get_weights(self, weights: Weights, prefix: str):
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
self._get_gptq_params(weights)
if can_use_gptq_marlin(
bits=self.bits,
groupsize=self.groupsize,
quant_method=self.quant_method,
quantize=self.quantize,
sym=self.sym,
):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = weights.get_tensor(f"{prefix}.qweight")
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
g_idx = weights.get_tensor(f"{prefix}.g_idx")
scales = weights.get_tensor(f"{prefix}.scales")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
sym=self.sym,
sharded_infeatures=False,
)
use_exllama = True
if self.bits != 4:
use_exllama = False
if self.desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
try:
qweight = weights.get_tensor(f"{prefix}.qweight")
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if self.quantize == "gptq" and self.quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.g_idx")
else:
g_idx = None
from text_generation_server.layers.gptq import (
HAS_EXLLAMA,
CAN_EXLLAMA,
GPTQWeight,
)
if use_exllama:
if not HAS_EXLLAMA:
if CAN_EXLLAMA:
log_once(
logger.warning,
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
)
use_exllama = False
else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
qzeros = weights.get_tensor(f"{prefix}.qzeros")
scales = weights.get_tensor(f"{prefix}.scales")
if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0]
if self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_exllama=use_exllama,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
try:
qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
)
scales = weights.get_packed_sharded(
f"{prefix}.scales", dim=1, block_sizes=block_sizes
)
scales = scales.to(dtype=weights.dtype)
self._get_gptq_params(weights)
if can_use_gptq_marlin(
bits=self.bits,
groupsize=self.groupsize,
quant_method=self.quant_method,
quantize=self.quantize,
sym=self.sym,
):
g_idx = weights.get_tensor(f"{prefix}.g_idx")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
sym=self.sym,
sharded_infeatures=False,
)
qzeros = weights.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
if self.quantize == "gptq" and self.quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.g_idx")
elif self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
else:
g_idx = None
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq",
use_exllama=False,
)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
try:
qweight = torch.cat(
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
)
scales = torch.cat(
[weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
self._get_gptq_params(weights)
if can_use_gptq_marlin(
bits=self.bits,
groupsize=self.groupsize,
quant_method=self.quant_method,
quantize=self.quantize,
sym=self.sym,
):
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
sym=self.sym,
sharded_infeatures=False,
)
qzeros = torch.cat(
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
from text_generation_server.layers.gptq import HAS_EXLLAMA
use_exllama = (
self.bits == 4
and HAS_EXLLAMA
and self.quantize == "gptq"
and not self.desc_act
)
if self.quantize == "gptq" and self.quant_method == "gptq":
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
elif self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
else:
g_idx = None
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq",
use_exllama=use_exllama,
)
def get_weights_row(self, weights: Weights, prefix: str):
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
self._get_gptq_params(weights)
if can_use_gptq_marlin(
bits=self.bits,
groupsize=self.groupsize,
quant_method=self.quant_method,
quantize=self.quantize,
sym=self.sym,
):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
if self.desc_act or self.groupsize == -1:
scales = weights.get_tensor(f"{prefix}.scales")
else:
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
sharded_in_features = weights.process_group.size() > 1
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
sym=self.sym,
sharded_infeatures=sharded_in_features,
)
use_exllama = True
if self.bits != 4:
use_exllama = False
if self.desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if self.quantize == "gptq" and self.quant_method == "gptq":
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
else:
g_idx = None
if weights.process_group.size() > 1:
if g_idx is not None:
if (
not torch.equal(
g_idx.cpu(),
torch.tensor(
[i // self.groupsize for i in range(g_idx.shape[0])],
dtype=torch.int32,
),
)
and not (g_idx == 0).all()
):
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
use_exllama = False
from text_generation_server.layers.gptq import (
CAN_EXLLAMA,
HAS_EXLLAMA,
GPTQWeight,
)
if use_exllama:
if not HAS_EXLLAMA:
if CAN_EXLLAMA:
log_once(
logger.warning,
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
)
use_exllama = False
else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
if use_exllama and self.groupsize != -1:
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
else:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
scales = weights.get_tensor(f"{prefix}.scales")
if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0]
if self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq",
use_exllama=use_exllama,
)
def _get_gptq_params(self, weights: Weights):
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
self.bits = weights.get_tensor("gptq_bits").item()
self.groupsize = weights.get_tensor("gptq_groupsize").item()
self.desc_act = False
# `server quantize` used asymmetric quantization unconditionally
# before the `gptq_sym` setting tensor was added.
self.sym = (
weights.get_tensor("gptq_sym").item()
if weights._has_tensor("gptq_sym")
else False
)
self.quant_method = "gptq"

View File

@ -9,11 +9,12 @@ from loguru import logger
from text_generation_server.layers.exl2 import Exl2Weight from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.utils.log import log_master
try: try:
from exllamav2_kernels import make_q_matrix, gemm_half_q_half from exllamav2_kernels import make_q_matrix, gemm_half_q_half
except ImportError: except ImportError:
logger.error("exllamav2_kernels not installed.") log_master(logger.warning, "exllamav2_kernels not installed.")
raise raise
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension

View File

@ -16,6 +16,8 @@ from text_generation_server.layers.gptq.quant_linear import QuantLinear
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional
from text_generation_server.utils.weights import DefaultWeightsLoader
DEV = torch.device("cuda:0") DEV = torch.device("cuda:0")
@ -869,6 +871,7 @@ def quantize(
upload_to_model_id: Optional[str], upload_to_model_id: Optional[str],
percdamp: float, percdamp: float,
act_order: bool, act_order: bool,
sym: bool,
): ):
print("loading model") print("loading model")
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
@ -891,6 +894,7 @@ def quantize(
dtype=torch.float16, dtype=torch.float16,
process_group=process_group, process_group=process_group,
aliases={"embed_tokens.weight": ["lm_head.weight"]}, aliases={"embed_tokens.weight": ["lm_head.weight"]},
weights_loader=DefaultWeightsLoader(),
) )
hooks = [] hooks = []
for name, module in model.named_modules(): for name, module in model.named_modules():
@ -943,6 +947,7 @@ def quantize(
percdamp=percdamp, percdamp=percdamp,
act_order=act_order, act_order=act_order,
hooks=hooks, hooks=hooks,
sym=sym,
) )
print(time.time() - tick) print(time.time() - tick)
@ -954,6 +959,7 @@ def quantize(
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
state_dict["gptq_bits"] = torch.LongTensor([bits]) state_dict["gptq_bits"] = torch.LongTensor([bits])
state_dict["gptq_groupsize"] = torch.LongTensor([groupsize]) state_dict["gptq_groupsize"] = torch.LongTensor([groupsize])
state_dict["gptq_sym"] = torch.BoolTensor([sym])
max_shard_size = "10GB" max_shard_size = "10GB"
shards, index = shard_checkpoint( shards, index = shard_checkpoint(

View File

@ -1,7 +1,8 @@
from typing import Optional from typing import Optional
import torch import torch
from torch.nn import functional as F
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from torch.nn import functional as F
if SYSTEM == "rocm": if SYSTEM == "rocm":
try: try:
@ -90,167 +91,14 @@ class FastLinearROCm(torch.nn.Module):
return F.linear(inp, self.weight, self.bias) return F.linear(inp, self.weight, self.bias)
def get_linear(weight, bias, quantize): def get_linear(weight, bias):
if quantize is None: # Weights that are loaded through methods that are not
# quantization-aware are still bare tensors. We may want
# to change this in the future.
if isinstance(weight, torch.Tensor):
if SYSTEM == "rocm": if SYSTEM == "rocm":
linear = FastLinearROCm(weight, bias) return FastLinearROCm(weight, bias)
else: else:
linear = FastLinear(weight, bias) return FastLinear(weight, bias)
elif quantize == "eetq":
try:
from text_generation_server.layers.eetq import EETQLinear
linear = EETQLinear(weight, bias) return weight.get_linear(bias)
except ImportError:
raise ImportError(
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
)
elif quantize == "fp8":
from text_generation_server.layers.fp8 import Fp8Linear
linear = Fp8Linear(weight, bias)
elif quantize == "bitsandbytes":
try:
from text_generation_server.layers.bnb import (
warn_deprecate_bnb,
Linear8bitLt,
)
except ImportError:
raise NotImplementedError(
f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
)
warn_deprecate_bnb()
linear = Linear8bitLt(
weight,
bias,
has_fp16_weights=False,
threshold=6.0,
)
if bias is not None:
linear.bias = nn.Parameter(bias)
elif quantize == "bitsandbytes-fp4":
try:
from text_generation_server.layers.bnb import Linear4bit
except ImportError:
raise NotImplementedError(
f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
)
linear = Linear4bit(
weight,
bias,
quant_type="fp4",
)
elif quantize == "bitsandbytes-nf4":
try:
from text_generation_server.layers.bnb import Linear4bit
except ImportError:
raise NotImplementedError(
f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
)
linear = Linear4bit(
weight,
bias,
quant_type="nf4",
)
elif quantize == "exl2":
from text_generation_server.layers.exl2 import Exl2Weight
if not isinstance(weight, Exl2Weight):
raise NotImplementedError(
f"The passed weight is not `exl2` compatible, loader needs to be updated."
)
from text_generation_server.layers.gptq import ExllamaQuantLinear
linear = ExllamaQuantLinear(weight, bias)
elif quantize == "gptq":
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
GPTQMarlinLinear,
GPTQMarlinWeight,
)
if isinstance(weight, GPTQMarlinWeight):
linear = GPTQMarlinLinear(
weight=weight,
bias=bias,
)
elif isinstance(weight, GPTQWeight):
if weight.use_exllama:
try:
from text_generation_server.layers.gptq import (
ExllamaQuantLinear,
)
except ImportError:
raise NotImplementedError(
f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
)
linear = ExllamaQuantLinear(weight, bias)
else:
from text_generation_server.layers.gptq.quant_linear import QuantLinear
linear = QuantLinear(
weight.qweight,
weight.qzeros,
weight.scales,
weight.g_idx,
bias,
weight.bits,
weight.groupsize,
)
else:
raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated."
)
elif quantize == "awq":
from text_generation_server.layers.gptq import GPTQWeight
if not isinstance(weight, GPTQWeight):
raise NotImplementedError(
f"The passed weight is not `awq` compatible, loader needs to be updated."
)
if SYSTEM == "rocm":
raise NotImplementedError(
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
"to use Exllama/GPTQ kernels for AWQ inference."
)
try:
from text_generation_server.layers.awq.quantize.qmodule import WQLinear
linear = WQLinear(
w_bit=weight.bits,
group_size=weight.groupsize,
qweight=weight.qweight,
qzeros=weight.qzeros,
scales=weight.scales,
bias=bias,
)
except ImportError:
raise NotImplementedError(
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
)
elif quantize == "marlin":
from text_generation_server.layers.marlin import (
GPTQMarlin24Linear,
GPTQMarlin24Weight,
MarlinLinear,
MarlinWeight,
)
if isinstance(weight, GPTQMarlin24Weight):
linear = GPTQMarlin24Linear(
weight=weight,
bias=bias,
)
elif isinstance(weight, MarlinWeight):
linear = MarlinLinear(weight=weight, bias=bias)
else:
raise NotImplementedError(
f"The passed weight is not `marlin` compatible, loader needs to be updated."
)
else:
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
return linear

View File

@ -1,11 +1,13 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from loguru import logger
from text_generation_server.layers.gptq import GPTQParams from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try: try:
import marlin_kernels import marlin_kernels
@ -24,16 +26,159 @@ GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
MARLIN_TILE_SIZE = 16 MARLIN_TILE_SIZE = 16
def can_use_gptq_marlin(gptq_params: GPTQParams, quantize: str) -> bool: class MarlinWeightsLoader(WeightsLoader):
"""Loader for Marlin-quantized weights."""
def __init__(self, *, bits: int, is_marlin_24: bool):
self.bits = bits
self.is_marlin_24 = is_marlin_24
def get_weights(self, weights: "Weights", prefix: str):
"""
Get weights at the given prefix and apply without tensor paralllism.
"""
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
B = weights.get_tensor(f"{prefix}.B_24")
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
)
B_meta = weights.get_tensor(f"{prefix}.B_meta")
s = weights.get_tensor(f"{prefix}.s")
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
try:
B = weights.get_tensor(f"{prefix}.B")
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized."
)
s = weights.get_tensor(f"{prefix}.s")
weight = MarlinWeight(B=B, s=s)
return weight
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
if self.is_marlin_24:
B = weights.get_packed_sharded(
f"{prefix}.B_24", dim=1, block_sizes=block_sizes
)
B_meta = weights.get_packed_sharded(
f"{prefix}.B_meta", dim=1, block_sizes=block_sizes
)
s = weights.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
B = weights.get_packed_sharded(
f"{prefix}.B", dim=1, block_sizes=block_sizes
)
s = weights.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
weight = MarlinWeight(B=B, s=s)
return weight
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
if self.is_marlin_24:
try:
B = torch.cat(
[weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `marlin` weight, make sure the model is already quantized"
)
B_meta = torch.cat(
[weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1
)
s = torch.cat(
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
try:
B = torch.cat(
[weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `marlin` weight, make sure the model is already quantized"
)
s = torch.cat(
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
weight = MarlinWeight(B=B, s=s)
return weight
def get_weights_row(self, weights: Weights, prefix: str):
if self.is_marlin_24:
try:
B = weights.get_sharded(f"{prefix}.B_24", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
)
B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0)
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = weights.get_tensor(f"{prefix}.s")
else:
s = weights.get_sharded(f"{prefix}.s", dim=0)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
try:
B = weights.get_sharded(f"{prefix}.B", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized."
)
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = weights.get_tensor(f"{prefix}.s")
else:
s = weights.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s)
return weight
def can_use_gptq_marlin(
*, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool
) -> bool:
return ( return (
SYSTEM == "cuda" SYSTEM == "cuda"
and marlin_kernels is not None and marlin_kernels is not None
and has_sm_8_0 and has_sm_8_0
and quantize == "gptq" and quantize == "gptq"
and gptq_params.quant_method == "gptq" and quant_method == "gptq"
and gptq_params.bits in GPTQ_MARLIN_BITS and bits in GPTQ_MARLIN_BITS
and gptq_params.groupsize in GPTQ_MARLIN_GROUP_SIZES and groupsize in GPTQ_MARLIN_GROUP_SIZES
and gptq_params.sym and sym
) )
@ -83,7 +228,7 @@ def permute_scales(scales: torch.Tensor):
@dataclass @dataclass
class GPTQMarlinWeight: class GPTQMarlinWeight(Weight):
""" """
Repacked GPTQ Marlin weights. Repacked GPTQ Marlin weights.
""" """
@ -101,6 +246,12 @@ class GPTQMarlinWeight:
assert self.g_idx.dtype == torch.int32 assert self.g_idx.dtype == torch.int32
assert self.perm.dtype == torch.int32 assert self.perm.dtype == torch.int32
def get_linear(self, bias: torch.Tensor):
return GPTQMarlinLinear(
weight=self,
bias=bias,
)
def repack_gptq_for_marlin( def repack_gptq_for_marlin(
*, *,
@ -258,6 +409,12 @@ class GPTQMarlin24Weight:
assert self.B_meta.dtype == torch.int16 assert self.B_meta.dtype == torch.int16
assert self.s.dtype == torch.float16 assert self.s.dtype == torch.float16
def get_linear(self, bias: torch.Tensor):
return GPTQMarlin24Linear(
weight=self,
bias=bias,
)
class GPTQMarlin24Linear(nn.Module): class GPTQMarlin24Linear(nn.Module):
def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]): def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]):
@ -339,8 +496,126 @@ class GPTQMarlin24Linear(nn.Module):
return C return C
class GPTQMarlinFP8Linear(nn.Module):
"""
FP8 GPTQ-Marlin linear layer.
"""
def __init__(
self,
qweight: torch.Tensor,
scale: torch.Tensor,
bias: Optional[torch.Tensor],
) -> None:
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
scale = scale.to(torch.float16)
qweight, scales = repack_fp8_for_marlin(qweight, scale)
in_features = qweight.shape[0] * MARLIN_TILE_SIZE
out_features = scales.shape[1]
_check_valid_shape(in_features=in_features, out_features=out_features)
self.qweight = qweight
self.scales = scales
self.bias = bias if bias is not None else None
self.workspace = torch.zeros(
out_features // 64 * 16, dtype=torch.int, device=qweight.device
)
@classmethod
def from_unquant(cls, weight, bias, _dtype):
qweight, scale = fp8_quantize(weight)
return cls(qweight=qweight, scale=scale, bias=bias)
@classmethod
def from_fp8(cls, weight, scale, _input_scale, bias, _dtype):
return cls(qweight=weight, scale=scale, bias=bias)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
A_flat = A.view(-1, A.shape[-1])
C = marlin_kernels.fp8_marlin_gemm(
A_flat,
self.qweight,
self.scales,
self.workspace,
8,
A_flat.shape[0],
self.scales.shape[1],
A_flat.shape[1],
)
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
if self.bias is not None:
C += self.bias
return C
def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
"""
Repack FP8 weights to gptq format (packed int32 elements).
"""
assert fp8_tensor.dtype == torch.float8_e4m3fn
if fp8_tensor.shape[0] % 4 != 0:
raise ValueError(
f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}"
)
# Reshape to prepare for packing
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
# Convert fp8 to uint8 (byte) representation
byte_tensor = reshaped.view(torch.uint8)
# Pack 4 uint8 values into one int32
packed = torch.zeros(
fp8_tensor.shape[0] // 4,
fp8_tensor.shape[1],
dtype=torch.int32,
device=fp8_tensor.device,
)
for i in range(4):
packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8)
return packed
def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor):
"""
Repack FP8 tensor for GPTQ-Marlin.
"""
out_features, in_features = weight.shape
# Torch linear layers weights with shape [out_features, in_features],
# GPTQ-quantized weights use [in_feateres/pack_factor, in_features],
# so transpose before packing.
qweight = pack_fp8_as_int32(weight.t())
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
repacked = marlin_kernels.gptq_marlin_repack(
qweight, perm, in_features, out_features, 8
)
scales = scale.reshape(1, 1).repeat(1, out_features)
scales = permute_scales(scales)
return repacked, scales
@dataclass @dataclass
class MarlinWeight: class MarlinWeight(Weight):
""" """
Marlin weights. Marlin weights.
@ -356,6 +631,9 @@ class MarlinWeight:
assert self.B.dtype == torch.int32 assert self.B.dtype == torch.int32
assert self.s.dtype == torch.float16 assert self.s.dtype == torch.float16
def get_linear(self, bias: torch.Tensor):
return MarlinLinear(weight=self, bias=bias)
class MarlinLinear(nn.Module): class MarlinLinear(nn.Module):
def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]): def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]):

View File

@ -1,6 +1,7 @@
import os import os
import torch import torch
from torch import nn from torch import nn
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
@ -97,18 +98,22 @@ class PositionRotaryEmbedding(nn.Module):
) )
elif rope_scaling["type"] == "yarn": elif rope_scaling["type"] == "yarn":
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
mscale = rope_scaling.get("mscale", 1.0)
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
return YarnPositionRotaryEmbedding( return YarnPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0], dim=2 * inv_freq.shape[0],
max_position_embeddings=rope_scaling[ max_position_embeddings=rope_scaling[
"original_max_position_embeddings" "original_max_position_embeddings"
], ],
base=10000.0, base=base,
device=inv_freq.device, device=inv_freq.device,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
extrapolation_factor=1, extrapolation_factor=1,
attn_factor=1, attn_factor=1,
beta_fast=32, beta_fast=32,
beta_slow=1, beta_slow=1,
mscale=mscale,
mscale_all_dim=mscale_all_dim,
) )
elif rope_scaling["type"] in ["su", "longrope"]: elif rope_scaling["type"] in ["su", "longrope"]:
short_factor = torch.tensor( short_factor = torch.tensor(
@ -181,6 +186,8 @@ class PositionRotaryEmbedding(nn.Module):
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
) )
elif rope_scaling["type"] == "yarn": elif rope_scaling["type"] == "yarn":
mscale = rope_scaling.get("mscale", 1.0)
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
return YarnPositionRotaryEmbedding( return YarnPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0], dim=2 * inv_freq.shape[0],
max_position_embeddings=rope_scaling[ max_position_embeddings=rope_scaling[
@ -193,6 +200,8 @@ class PositionRotaryEmbedding(nn.Module):
attn_factor=1, attn_factor=1,
beta_fast=32, beta_fast=32,
beta_slow=1, beta_slow=1,
mscale=mscale,
mscale_all_dim=mscale_all_dim,
) )
else: else:
raise NotImplementedError( raise NotImplementedError(
@ -346,10 +355,10 @@ def linear_ramp_mask(min, max, dim):
return ramp_func return ramp_func
def get_mscale(scale=1): def get_mscale(scale: float = 1.0, mscale: float = 1.0):
if scale <= 1: if scale <= 1:
return 1.0 return 1.0
return 0.1 * math.log(scale) + 1.0 return 0.1 * mscale * math.log(scale) + 1.0
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
@ -365,6 +374,8 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
attn_factor, attn_factor,
beta_fast, beta_fast,
beta_slow, beta_slow,
mscale: float,
mscale_all_dim: float,
): ):
inv_freq = _create_inv_freq(dim, base, device) inv_freq = _create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor) super().__init__(inv_freq, scaling_factor)
@ -375,8 +386,12 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
self.attn_factor = attn_factor self.attn_factor = attn_factor
self.beta_fast = beta_fast self.beta_fast = beta_fast
self.beta_slow = beta_slow self.beta_slow = beta_slow
self.mscale_all_dim = mscale_all_dim
self.scaling_factor = scaling_factor
self.mscale = float( self.mscale = float(
get_mscale(self.scaling_factor) * self.attn_factor get_mscale(self.scaling_factor, mscale)
/ get_mscale(self.scaling_factor, mscale_all_dim)
* self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation ) # Get n-d magnitude scaling corrected for interpolation
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
@ -387,7 +402,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
or self._cos_cached.device != device or self._cos_cached.device != device
or self._cos_cached.dtype != dtype or self._cos_cached.dtype != dtype
): ):
if seqlen > self.max_position_embeddings: if seqlen > self.max_position_embeddings or True:
inv_freq_extrapolation = _create_inv_freq( inv_freq_extrapolation = _create_inv_freq(
self.dim, self.base, self.inv_freq.device self.dim, self.base, self.inv_freq.device
) )
@ -400,6 +415,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
self.base, self.base,
self.max_position_embeddings, self.max_position_embeddings,
) )
inv_freq_mask = ( inv_freq_mask = (
1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device) 1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)
) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation ) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
@ -409,9 +425,6 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
) )
self.inv_freq = inv_freq self.inv_freq = inv_freq
self.mscale = float(
get_mscale(self.scaling_factor) * self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation
self._seq_len_cached = seqlen self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)

View File

@ -52,7 +52,7 @@ class TensorParallelHead(SuperLayer):
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
except: except:
# ...otherwise they are quantized. # ...otherwise they are quantized.
weight = weights.get_weights_col(prefix, config.quantize) weight = weights.get_weights_col(prefix)
should_gather = weights.process_group.size() > 1 should_gather = weights.process_group.size() > 1
elif weights.process_group.size() > 1: elif weights.process_group.size() > 1:
try: try:
@ -77,7 +77,7 @@ class TensorParallelHead(SuperLayer):
quantize = config.quantize quantize = config.quantize
return TensorParallelHead( return TensorParallelHead(
get_linear(weight, bias=None, quantize=quantize), get_linear(weight, bias=None),
process_group=weights.process_group, process_group=weights.process_group,
should_gather=should_gather, should_gather=should_gather,
) )
@ -129,14 +129,12 @@ class TensorParallelColumnLinear(SuperLayer):
@classmethod @classmethod
def load_gate_up(cls, config, prefix: str, weights, bias: bool): def load_gate_up(cls, config, prefix: str, weights, bias: bool):
"""Specific method when the QKV was joined after the fact""" """Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_gate_up( weight = weights.get_weights_col_packed_gate_up(prefix)
prefix, quantize=config.quantize
)
if bias: if bias:
raise NotImplementedError("packed_gate_up only implemented without bias") raise NotImplementedError("packed_gate_up only implemented without bias")
else: else:
bias = None bias = None
linear = get_linear(weight, bias, config.quantize) linear = get_linear(weight, bias)
return cls(linear) return cls(linear)
@classmethod @classmethod
@ -152,7 +150,6 @@ class TensorParallelColumnLinear(SuperLayer):
"""Specific method when the QKV was joined after the fact""" """Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_qkv( weight = weights.get_weights_col_packed_qkv(
prefix, prefix,
quantize=config.quantize,
num_heads=num_heads, num_heads=num_heads,
num_key_value_heads=num_key_value_heads, num_key_value_heads=num_key_value_heads,
) )
@ -160,17 +157,17 @@ class TensorParallelColumnLinear(SuperLayer):
raise NotImplementedError("packed_qkv only implemented for baichuan") raise NotImplementedError("packed_qkv only implemented for baichuan")
else: else:
bias = None bias = None
linear = get_linear(weight, bias, config.quantize) linear = get_linear(weight, bias)
return cls(linear) return cls(linear)
@classmethod @classmethod
def load(cls, config, prefix: str, weights, bias: bool): def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_weights_col(prefix, config.quantize) weight = weights.get_weights_col(prefix)
if bias: if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0) bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else: else:
bias = None bias = None
linear = get_linear(weight, bias, config.quantize) linear = get_linear(weight, bias)
return cls(linear) return cls(linear)
@classmethod @classmethod
@ -178,20 +175,18 @@ class TensorParallelColumnLinear(SuperLayer):
if config.quantize == "exl2": if config.quantize == "exl2":
linears = [] linears = []
for prefix in prefixes: for prefix in prefixes:
weight = weights.get_weights_col(prefix, config.quantize) weight = weights.get_weights_col(prefix)
b = weights.get_tensor(f"{prefix}.bias") if bias else None b = weights.get_tensor(f"{prefix}.bias") if bias else None
linears.append(get_linear(weight, b, config.quantize)) linears.append(get_linear(weight, b))
linear = LayerConcat(linears) linear = LayerConcat(linears)
else: else:
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(prefixes, dim=dim)
prefixes, quantize=config.quantize, dim=dim
)
if bias: if bias:
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
bias = torch.cat(b, dim=dim) bias = torch.cat(b, dim=dim)
else: else:
bias = None bias = None
linear = get_linear(weight, bias, config.quantize) linear = get_linear(weight, bias)
return cls(linear) return cls(linear)
@ -202,7 +197,7 @@ class TensorParallelRowLinear(SuperLayer):
@classmethod @classmethod
def load(cls, config, prefix: str, weights, bias: bool): def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0: if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process # Rank is only on the first rank process
@ -210,7 +205,7 @@ class TensorParallelRowLinear(SuperLayer):
else: else:
bias = None bias = None
return cls( return cls(
get_linear(weight, bias, config.quantize), get_linear(weight, bias),
process_group=weights.process_group, process_group=weights.process_group,
) )

View File

@ -34,6 +34,7 @@ from text_generation_server.models.custom_modeling.t5_modeling import (
) )
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_master
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later. # in PyTorch 1.12 and later.
@ -47,9 +48,7 @@ torch.set_grad_enabled(False)
__all__ = [ __all__ = [
"Model", "Model",
"BLOOMSharded",
"CausalLM", "CausalLM",
"GalacticaSharded",
"Seq2SeqLM", "Seq2SeqLM",
"get_model", "get_model",
] ]
@ -61,6 +60,10 @@ FLASH_ATTENTION = True
try: try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.vlm_causal_lm import VlmCausalLM from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
FlashDeepseekV2ForCausalLM,
DeepseekV2Config,
)
from text_generation_server.models.custom_modeling.flash_llama_modeling import ( from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM, FlashLlamaForCausalLM,
) )
@ -121,7 +124,7 @@ try:
) )
from text_generation_server.layers.attention import SUPPORTS_WINDOWING from text_generation_server.layers.attention import SUPPORTS_WINDOWING
except ImportError as e: except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}") log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
SUPPORTS_WINDOWING = False SUPPORTS_WINDOWING = False
FLASH_ATTENTION = False FLASH_ATTENTION = False
@ -133,7 +136,7 @@ MAMBA_AVAILABLE = True
try: try:
from text_generation_server.models.mamba import Mamba from text_generation_server.models.mamba import Mamba
except ImportError as e: except ImportError as e:
logger.warning(f"Could not import Mamba: {e}") log_master(logger.warning, f"Could not import Mamba: {e}")
MAMBA_AVAILABLE = False MAMBA_AVAILABLE = False
if MAMBA_AVAILABLE: if MAMBA_AVAILABLE:
@ -141,6 +144,11 @@ if MAMBA_AVAILABLE:
class ModelType(enum.Enum): class ModelType(enum.Enum):
DEEPSEEK_V2 = {
"type": "deepseek_v2",
"name": "Deepseek V2",
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
}
IDEFICS2 = { IDEFICS2 = {
"type": "idefics2", "type": "idefics2",
"name": "Idefics 2", "name": "Idefics 2",
@ -302,6 +310,12 @@ def get_model(
if quantize in ["awq", "exl2", "gptq", "marlin"]: if quantize in ["awq", "exl2", "gptq", "marlin"]:
# These quantizers only work with float16 params. # These quantizers only work with float16 params.
dtype = torch.float16 dtype = torch.float16
elif quantize == "fp8":
from text_generation_server.layers.fp8 import FBGEMM_MM_AVAILABLE
if FBGEMM_MM_AVAILABLE:
# fbgemm kernels are fp8xfp8->bf16
dtype = torch.bfloat16
else: else:
# Keep it as default for now and let # Keep it as default for now and let
# every model resolve their own default dtype. # every model resolve their own default dtype.
@ -424,7 +438,9 @@ def get_model(
speculate = get_speculate() speculate = get_speculate()
if speculate > 0: if speculate > 0:
logger.info(f"Using speculation {method} with {speculate} input ids.") log_master(
logger.info, f"Using speculation {method} with {speculate} input ids."
)
if model_type is None: if model_type is None:
# TODO: fix how we determine model type for Mamba # TODO: fix how we determine model type for Mamba
@ -439,10 +455,10 @@ def get_model(
if quantization_config is not None and quantize is None: if quantization_config is not None and quantize is None:
method = quantization_config.get("quant_method", None) method = quantization_config.get("quant_method", None)
if method in {"gptq", "awq", "exl2"}: if method in {"gptq", "awq", "exl2"}:
logger.info(f"Auto selecting quantization method {method}") log_master(logger.info, f"Auto selecting quantization method {method}")
quantize = method quantize = method
else: else:
logger.info(f"Unknown quantization method {method}") log_master(logger.warning, f"Unknown quantization method {method}")
if quantize == "exl2" and sharded: if quantize == "exl2" and sharded:
raise RuntimeError( raise RuntimeError(
@ -459,7 +475,40 @@ def get_model(
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
) )
if model_type == MAMBA: if model_type == DEEPSEEK_V2:
if FLASH_ATTENTION:
head_size = max(
config_dict.get("qk_nope_dim", 128)
+ config_dict.get("qk_rope_dim", 64),
config_dict.get("v_head_dim", 128),
)
return FlashCausalLM(
model_id=model_id,
model_class=FlashDeepseekV2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
default_dtype=torch.bfloat16,
dtype=dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DeepseekV2Config,
head_size=head_size,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
)
else:
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == MAMBA:
return Mamba( return Mamba(
model_id, model_id,
revision, revision,
@ -551,7 +600,7 @@ def get_model(
) )
except RuntimeError as e: except RuntimeError as e:
# Lots of legacy models with various weight names. # Lots of legacy models with various weight names.
logger.warning(f"Couldn't load flash gpt2 variant: {e}") log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
return CausalLM.fallback( return CausalLM.fallback(
model_id, model_id,
revision, revision,
@ -573,6 +622,10 @@ def get_model(
) )
elif model_type == GPT_NEOX: elif model_type == GPT_NEOX:
if FLASH_ATTENTION: if FLASH_ATTENTION:
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
GPTNeoXConfig,
)
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=FlashGPTNeoXForCausalLM, model_class=FlashGPTNeoXForCausalLM,
@ -582,6 +635,7 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
config_class=GPTNeoXConfig,
) )
elif sharded: elif sharded:
return CausalLM( return CausalLM(
@ -797,6 +851,10 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
aliases={
"lm_head.weight": ["transformer.word_embeddings.weight"],
"transformer.word_embeddings.weight": ["lm_head.weight"],
},
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
config_class=RWConfig, config_class=RWConfig,

View File

@ -20,6 +20,7 @@ from text_generation_server.utils import (
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
@ -491,7 +492,7 @@ class CausalLMBatch(Batch):
@dataclass @dataclass
class CausalLMBatchKeysLast(Batch): class CausalLMBatchKeysLast(CausalLMBatch):
keys_head_dim_last: bool = False keys_head_dim_last: bool = False
@ -543,15 +544,25 @@ class CausalLM(Model):
config.quantize = quantize config.quantize = quantize
config.speculator = speculator config.speculator = speculator
if tokenizer.pad_token_id is None: if tokenizer.pad_token_id is None:
if config.pad_token_id is not None:
tokenizer.pad_token_id = config.pad_token_id tokenizer.pad_token_id = config.pad_token_id
elif config.eos_token_id is not None:
tokenizer.pad_token_id = config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights( weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
weights_loader=weights_loader,
) )
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = "" prefix = ""
model = model_class(prefix, config, weights) model = model_class(prefix, config, weights)

View File

@ -163,7 +163,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0, dim=0,
) )
@ -187,9 +186,7 @@ def _load_gqa(config, prefix: str, weights):
else: else:
bias = None bias = None
return TensorParallelColumnLinear( return TensorParallelColumnLinear(get_linear(weight, bias=bias))
get_linear(weight, bias=bias, quantize=config.quantize)
)
class FlashCohereAttention(torch.nn.Module): class FlashCohereAttention(torch.nn.Module):
@ -260,8 +257,8 @@ class FlashCohereAttention(torch.nn.Module):
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables, block_tables,
input_lengths,
slots, slots,
input_lengths,
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)

View File

@ -105,6 +105,12 @@ class DbrxFFNConfig(PretrainedConfig):
class DbrxConfig(PretrainedConfig): class DbrxConfig(PretrainedConfig):
attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "n_heads",
"num_hidden_layers": "n_layers",
}
def __init__( def __init__(
self, self,
d_model: int = 2048, d_model: int = 2048,
@ -157,6 +163,12 @@ class DbrxConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
@property
def num_key_value_heads(self):
# We can't use the attribute map, since this the number of KV
# heads is not top-level.
return self.attn_config.kv_n_heads
def promote_scalar(x: torch.Tensor) -> torch.Tensor: def promote_scalar(x: torch.Tensor) -> torch.Tensor:
return x.view(1) if len(x.size()) == 0 else x return x.view(1) if len(x.size()) == 0 else x
@ -235,10 +247,10 @@ def _load_experts_quantized(config, prefix, weights, cls):
if cls == TensorParallelRowLinear: if cls == TensorParallelRowLinear:
expert_slice = expert_slice.t().contiguous() expert_slice = expert_slice.t().contiguous()
linear = get_linear(expert_slice, None, config.quantize) linear = get_linear(expert_slice, None)
experts.append(cls(linear, weights.process_group)) experts.append(cls(linear, weights.process_group))
else: else:
linear = get_linear(expert_slice, None, config.quantize) linear = get_linear(expert_slice, None)
experts.append(cls(linear)) experts.append(cls(linear))
return experts return experts

View File

@ -0,0 +1,980 @@
# coding=utf-8
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.distributed
from text_generation_server.layers import (
FastLinear,
SpeculativeHead,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
get_linear,
)
from text_generation_server.layers.attention import (
attention,
paged_attention,
reshape_and_cache,
)
from text_generation_server.layers.attention.common import Seqlen
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import Weights
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
class DeepseekV2Config(PretrainedConfig):
def __init__(
self,
vocab_size=102400,
hidden_size=4096,
intermediate_size=11008,
moe_intermediate_size=1407,
num_hidden_layers=30,
num_attention_heads=32,
num_key_value_heads=32,
n_shared_experts=2,
n_routed_experts=160,
ep_size=1,
routed_scaling_factor=1.0,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
topk_method="gready",
n_group=8,
topk_group=3,
num_experts_per_tok=6,
moe_layer_freq=1,
first_k_dense_replace=0,
norm_topk_prob=False,
scoring_func="softmax",
aux_loss_alpha=0.001,
seq_aux=True,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=100000,
eos_token_id=100001,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.topk_method = topk_method
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
self.aux_loss_alpha = aux_loss_alpha
self.seq_aux = seq_aux
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
if tie_word_embeddings:
raise ValueError(
"tie_word_embeddings is not supported for Deepseek V2 models."
)
if ep_size != 1:
raise ValueError(
f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}"
)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _load_experts(config, prefix: str, mat: str, weights: Weights):
if config.quantize is not None:
raise NotImplementedError(
"Deepseek V2 does not support weight quantization yet."
)
assert mat in ["gate_proj", "up_proj", "down_proj"]
world_size = weights.process_group.size()
rank = weights.process_group.rank()
assert (
config.moe_intermediate_size % world_size == 0
), f"The chosen size {config.moe_intermediate_size} is not compatible with sharding on {world_size} shards"
block_size = config.moe_intermediate_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = torch.empty(
(config.n_routed_experts * block_size, config.hidden_size),
dtype=weights.dtype,
device=weights.device,
)
for i in range(config.n_routed_experts):
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
if mat == "down_proj":
expert_slice = slice_[:, start:stop].t().contiguous()
else:
expert_slice = slice_[start:stop]
tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
dtype=weights.dtype
).to(device=weights.device)
return tensor
class DeepseekV2Attention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights: Weights,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.kv_lora_rank = config.kv_lora_rank
self.q_lora_rank = config.q_lora_rank
self.qk_nope_head_dim = config.qk_nope_head_dim
self.qk_rope_head_dim = config.qk_rope_head_dim
self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
self.value_head_size = config.v_head_dim
self.head_pad_size = max(self.head_size, self.value_head_size)
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.qk_rope_head_dim,
base=config.rope_theta,
device=weights.device,
)
mscale = get_mscale(
self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
)
self.softmax_scale = self.head_size**-0.5 * mscale * mscale
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
if self.q_lora_rank is None:
self.q_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_proj",
weights=weights,
bias=config.attention_bias,
)
else:
self.q_a_proj = get_linear(
weight=weights.get_weights(f"{prefix}.q_a_proj"),
bias=(
weights.get_tensor(f"{prefix}.q_a_proj.bias")
if config.attention_bias
else None
),
)
self.q_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.q_a_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.q_b_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_b_proj",
weights=weights,
bias=config.attention_bias,
)
self.kv_a_proj_with_mqa = get_linear(
weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"),
bias=(
weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias")
if config.attention_bias
else None
),
)
self.kv_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.kv_b_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.kv_b_proj",
weights=weights,
bias=config.attention_bias,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: Seqlen,
max_s: int,
):
if self.q_lora_rank is None:
query = self.q_proj(hidden_states)
else:
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
query = query.view(-1, self.num_heads, self.head_size)
_, query_pe = torch.split(
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, key_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
-1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
)
key_nope, value = torch.split(
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
)
batch_size, heads, head_dim = query_pe.shape
query_pe = (
query_pe.view(batch_size, heads, head_dim // 2, 2)
.transpose(2, 3)
.reshape(batch_size, heads, head_dim)
)
batch_size, heads, head_dim = key_pe.shape
key_pe = (
key_pe.view(batch_size, heads, head_dim // 2, 2)
.transpose(2, 3)
.reshape(batch_size, heads, head_dim)
)
self.rotary_emb(query_pe, key_pe, cos, sin)
query[..., self.qk_nope_head_dim :] = query_pe
key = torch.empty_like(query)
key[..., : self.qk_nope_head_dim] = key_nope
key[..., self.qk_nope_head_dim :] = key_pe
# We need to pad the heads because Flash Attention does not support
# qk and v with different head sizes.
query = torch.nn.functional.pad(
query, (0, self.head_pad_size - self.head_size), value=0
)
key = torch.nn.functional.pad(
key, (0, self.head_pad_size - self.head_size), value=0
)
value = torch.nn.functional.pad(
value, (0, self.head_pad_size - self.value_head_size), value=0
)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# Output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
query,
key,
value,
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
# Remove padding.
attn_output = attn_output[..., : self.value_head_size]
return self.o_proj(
attn_output.reshape(-1, self.num_heads * self.value_head_size)
)
class DeepseekV2MLP(nn.Module):
def __init__(self, prefix: str, config, weights, intermediate_size: int):
super().__init__()
self.hidden_act = config.hidden_act
if self.hidden_act != "silu":
# Bail out because MoE only supports silu.
raise NotImplementedError(
"Currently only `silu` is supported as an activation for Deepseek V2."
)
self.act = ACT2FN[self.hidden_act]
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.intermediate_size = intermediate_size // weights.process_group.size()
# TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
and hidden_states.shape[0] == 1
and not self.quantize
):
out = torch.empty(
hidden_states.shape[0],
self.intermediate_size,
dtype=hidden_states.dtype,
device="cuda",
)
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
return self.down_proj(out, reduce=reduce)
else:
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
)
class BlockSparseMoE(nn.Module):
def __init__(self, prefix, config: DeepseekV2Config, weights):
super().__init__()
self.hidden_dim = config.hidden_size
self.moe_intermediate_size = (
config.moe_intermediate_size // weights.process_group.size()
)
self.n_routed_experts = config.n_routed_experts
self.n_expert_group = config.n_group
self.topk_group = config.topk_group
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.routed_scaling_factor = config.routed_scaling_factor
gate_proj = _load_experts(
config, f"{prefix}.experts", "gate_proj", weights
).view(self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim)
up_proj = _load_experts(config, f"{prefix}.experts", "up_proj", weights).view(
self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim
)
self.gate_up_proj = torch.cat([gate_proj, up_proj], dim=1)
self.down_proj = (
_load_experts(config, f"{prefix}.experts", "down_proj", weights)
.view(self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim)
.transpose(1, 2)
.contiguous()
)
# Gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
if config.n_shared_experts is not None:
self.shared_experts = DeepseekV2MLP(
prefix=f"{prefix}.shared_experts",
config=config,
weights=weights,
intermediate_size=config.moe_intermediate_size
* config.n_shared_experts,
)
else:
self.shared_experts = None
self.process_group = weights.process_group
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.shared_experts is not None:
shared_output = self.shared_experts(x, reduce=False)
else:
shared_output = None
router_logits = self.gate(x)
topk_weights, topk_ids = grouped_topk(
x,
router_logits,
self.top_k,
renormalize=self.norm_topk_prob,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
out = (
fused_experts(
x,
self.gate_up_proj,
self.down_proj,
topk_weights,
topk_ids,
inplace=True,
)
* self.routed_scaling_factor
)
if shared_output is not None:
out = out + shared_output
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out.view(*x.shape)
class DenseMoE(nn.Module):
def __init__(self, prefix: str, config: DeepseekV2Config, weights: Weights):
super().__init__()
self.hidden_dim = config.hidden_size
self.moe_intermediate_size = config.moe_intermediate_size
self.n_routed_experts = config.n_routed_experts
self.n_expert_group = config.n_group
self.topk_group = config.topk_group
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.routed_scaling_factor = config.routed_scaling_factor
# Gating
#
# Seems like no one quantizes the gate.
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
self.experts = [
DeepseekV2MLP(
f"{prefix}.experts.{i}", config, weights, self.moe_intermediate_size
)
for i in range(self.n_routed_experts)
]
if config.n_shared_experts is not None:
self.shared_experts = DeepseekV2MLP(
prefix=f"{prefix}.shared_experts",
config=config,
weights=weights,
intermediate_size=config.moe_intermediate_size
* config.n_shared_experts,
)
else:
self.shared_experts = None
self.process_group = weights.process_group
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
if self.shared_experts is not None:
shared_output = self.shared_experts(x, reduce=False)
else:
shared_output = None
# gate_logits: (sequence_length, n_experts)
router_logits = self.gate(x)
topk_weights, topk_ids = grouped_topk(
x,
router_logits,
self.top_k,
renormalize=self.norm_topk_prob,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
out = self.moe_infer_gpu(x, topk_ids, topk_weights) * self.routed_scaling_factor
if shared_output is not None:
out = out + shared_output
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out
def moe_infer_gpu(
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
):
weights = torch.zeros(
topk_ids.shape[0], len(self.experts), dtype=x.dtype, device=x.device
)
weights.scatter_(1, topk_ids, topk_weight)
out = x.new_zeros(x.shape[0], self.hidden_dim)
for i, expert in enumerate(self.experts):
# Add expert output to out with masking
out += expert(x, reduce=False) * weights[:, i].view(-1, 1)
return out
class DeepseekV2Layer(nn.Module):
def __init__(self, prefix, layer_id, config, weights):
super().__init__()
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = DeepseekV2Attention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
)
if (
config.n_routed_experts is not None
and layer_id >= config.first_k_dense_replace
and layer_id % config.moe_layer_freq == 0
):
moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE
self.mlp = moe_cls(f"{prefix}.mlp", config, weights)
else:
self.mlp = DeepseekV2MLP(
prefix=f"{prefix}.mlp",
config=config,
weights=weights,
intermediate_size=config.intermediate_size,
)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache,
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: Seqlen,
max_s: int,
):
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
)
# faster post attention rms norm
normed_attn_res_output, residual = self.post_attention_layernorm(
attn_output, residual
)
output = self.mlp(normed_attn_res_output)
return output, residual
class DeepseekV2Model(torch.nn.Module):
def __init__(self, prefix: str, config, weights: Weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
DeepseekV2Layer(
prefix,
layer_id,
config,
weights,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashDeepseekV2ForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights: Weights):
super().__init__()
self.model = DeepseekV2Model(
"model" if not prefix else f"{prefix}.model", config, weights
)
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits
# Functions below are from vLLM:
#
# https://github.com/vllm-project/vllm/blob/f7160d946a0a07703e72d81ba9ecf3913f192605/vllm/model_executor/layers/fused_moe/fused_moe.py#L397
#
# Remove after we have synced our version with upstream.
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def get_default_config(
M: int,
E: int,
N: int,
K: int,
topk: int,
dtype: Optional[str],
) -> Dict[str, int]:
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
}
if M <= E:
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
return config
def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
):
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
import triton.language as tl
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_moe_configs,
invoke_fused_moe_kernel,
moe_align_block_size,
)
M, _ = hidden_states.shape
E, N, _ = w1.shape
if override_config:
config = override_config
else:
# First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config
config = get_default_config(
M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None
)
intermediate_cache1 = torch.empty(
(M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache3 = torch.empty(
(M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config["BLOCK_SIZE_M"], E
)
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
invoke_fused_moe_kernel(
hidden_states,
w1,
intermediate_cache1,
a1_scale,
w1_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
topk_ids.shape[1],
config,
compute_type=compute_type,
use_fp8=use_fp8,
)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(
intermediate_cache2,
w2,
intermediate_cache3,
a2_scale,
w2_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
config,
compute_type=compute_type,
use_fp8=use_fp8,
)
if inplace:
return torch.sum(
intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=hidden_states,
)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)

View File

@ -42,6 +42,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.utils.weights import UnquantizedWeight
class Gemma2Config(PretrainedConfig): class Gemma2Config(PretrainedConfig):
@ -141,24 +142,21 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0, dim=0,
) )
if config.quantize not in ["gptq", "awq", "marlin"]: if isinstance(weight, UnquantizedWeight):
weight = weight.to(dtype=weights.dtype).to(device=weights.device) weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.head_dim head_size = config.head_dim
num_heads = config.num_attention_heads // weights.process_group.size() num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [ assert list(weight.weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size, (num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size, config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear( return TensorParallelColumnLinear(get_linear(weight, bias=None))
get_linear(weight, bias=None, quantize=config.quantize)
)
class FlashGemma2Attention(torch.nn.Module): class FlashGemma2Attention(torch.nn.Module):

View File

@ -42,6 +42,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.utils.weights import UnquantizedWeight
class GemmaConfig(PretrainedConfig): class GemmaConfig(PretrainedConfig):
@ -141,24 +142,21 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0, dim=0,
) )
if config.quantize not in ["gptq", "awq", "marlin"]: if isinstance(weight, UnquantizedWeight):
weight = weight.to(dtype=weights.dtype).to(device=weights.device) weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.head_dim head_size = config.head_dim
num_heads = config.num_attention_heads // weights.process_group.size() num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [ assert list(weight.weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size, (num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size, config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear( return TensorParallelColumnLinear(get_linear(weight, bias=None))
get_linear(weight, bias=None, quantize=config.quantize)
)
class FlashGemmaAttention(torch.nn.Module): class FlashGemmaAttention(torch.nn.Module):

View File

@ -61,7 +61,6 @@ def _load_qkv_gptq(config, prefix: str, weights):
# Weights # Weights
weight = weights.get_weights_col_packed_qkv( weight = weights.get_weights_col_packed_qkv(
f"{prefix}.c_attn", f"{prefix}.c_attn",
config.quantize,
config.num_attention_heads, config.num_attention_heads,
config.num_attention_heads, config.num_attention_heads,
) )
@ -83,7 +82,7 @@ def _load_qkv_gptq(config, prefix: str, weights):
bias = torch.cat(tensors, dim=0) bias = torch.cat(tensors, dim=0)
bias = bias.to(device=weights.device) bias = bias.to(device=weights.device)
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) return TensorParallelColumnLinear(get_linear(weight, bias))
def _load_qkv(config, prefix: str, weights, head_size, num_heads): def _load_qkv(config, prefix: str, weights, head_size, num_heads):
@ -130,14 +129,14 @@ def _load_qkv(config, prefix: str, weights, head_size, num_heads):
3 * num_heads * head_size 3 * num_heads * head_size
], f"{weight.shape} != {[3 * num_heads * head_size]}" ], f"{weight.shape} != {[3 * num_heads * head_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) return TensorParallelColumnLinear(get_linear(weight, bias))
def load_row(config, prefix: str, weights, bias: bool): def load_row(config, prefix: str, weights, bias: bool):
"""load_row, but with transposed weight matrices.""" """load_row, but with transposed weight matrices."""
if config.quantize == "gptq": if config.quantize == "gptq":
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) weight = weights.get_weights_row(prefix)
else: else:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
@ -148,16 +147,14 @@ def load_row(config, prefix: str, weights, bias: bool):
bias = None bias = None
return TensorParallelRowLinear( return TensorParallelRowLinear(
get_linear(weight, bias, config.quantize), process_group=weights.process_group get_linear(weight, bias), process_group=weights.process_group
) )
def load_col(config, prefix: str, weights, bias: bool): def load_col(config, prefix: str, weights, bias: bool):
"""load_col, but with transposed weight matrices.""" """load_col, but with transposed weight matrices."""
if config.quantize == "gptq": if config.quantize == "gptq":
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col([prefix], dim=1)
[prefix], quantize=config.quantize, dim=1
)
else: else:
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
@ -166,7 +163,7 @@ def load_col(config, prefix: str, weights, bias: bool):
else: else:
bias = None bias = None
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) return TensorParallelColumnLinear(get_linear(weight, bias))
class FlashGPT2Attention(torch.nn.Module): class FlashGPT2Attention(torch.nn.Module):

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from contextlib import contextmanager
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
@ -25,7 +26,6 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
@ -33,7 +33,6 @@ from text_generation_server.layers.attention import (
attention, attention,
reshape_and_cache, reshape_and_cache,
) )
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -46,6 +45,11 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.utils.weights import (
UnquantizedWeight,
Weights,
)
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
if SYSTEM == "rocm": if SYSTEM == "rocm":
try: try:
@ -105,6 +109,19 @@ def load_attention(config, prefix: str, weights, layer_id):
) )
@contextmanager
def no_fp8(weights: Weights):
"""De-activate fp8 auto conversion for the duration of this context manager"""
weights_loader = weights.weights_loader
if isinstance(weights_loader, HybridFP8UnquantLoader) and weights_loader.to_fp8:
weights_loader = HybridFP8UnquantLoader(
weights_loader.activation_scale_ub, to_fp8=False
)
with weights.use_loader(weights_loader):
yield
class FlashLlamaAttention(torch.nn.Module): class FlashLlamaAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
@ -330,12 +347,15 @@ class LlamaMLP(nn.Module):
class FlashLlamaLayer(nn.Module): class FlashLlamaLayer(nn.Module):
def __init__(self, index, prefix, config, weights): def __init__(self, index, prefix, config, weights):
super().__init__() super().__init__()
with no_fp8(weights):
self.self_attn = FlashLlamaAttention( self.self_attn = FlashLlamaAttention(
index=index, index=index,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
config=config, config=config,
weights=weights, weights=weights,
) )
self.mlp = LlamaMLP( self.mlp = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
) )
@ -396,7 +416,22 @@ class FlashLlamaModel(torch.nn.Module):
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.layers = nn.ModuleList(
# Skip fp8 quant for first and last layers
self.layers = nn.ModuleList()
with no_fp8(weights):
self.layers.append(
FlashLlamaLayer(
index=0,
prefix=(
"model.layers.0" if not prefix else "{prefix}.model.layers.0"
),
config=config,
weights=weights,
)
)
self.layers.extend(
[ [
FlashLlamaLayer( FlashLlamaLayer(
index=layer_id, index=layer_id,
@ -408,9 +443,26 @@ class FlashLlamaModel(torch.nn.Module):
config=config, config=config,
weights=weights, weights=weights,
) )
for layer_id in range(config.num_hidden_layers) # Skip first and last layers
for layer_id in range(1, config.num_hidden_layers - 1)
] ]
) )
with no_fp8(weights):
last_layer_id = config.num_hidden_layers - 1
self.layers.append(
FlashLlamaLayer(
index=last_layer_id,
prefix=(
f"model.layers.{last_layer_id}"
if not prefix
else f"{prefix}.model.layers.{last_layer_id}"
),
config=config,
weights=weights,
)
)
self.norm = FastRMSNorm.load( self.norm = FastRMSNorm.load(
prefix="model.norm" if not prefix else f"{prefix}.model.norm", prefix="model.norm" if not prefix else f"{prefix}.model.norm",
weights=weights, weights=weights,
@ -470,9 +522,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
with no_fp8(weights):
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix=( prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens" "model.embed_tokens"
if not prefix
else f"{prefix}.model.embed_tokens"
), ),
weights=weights, weights=weights,
) )
@ -482,6 +537,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
else: else:
suffix = "lm_head" suffix = "lm_head"
with no_fp8(weights):
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix=suffix if not prefix else f"{prefix}.{suffix}", prefix=suffix if not prefix else f"{prefix}.{suffix}",

View File

@ -135,7 +135,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0, dim=0,
) )
@ -150,9 +149,7 @@ def _load_gqa(config, prefix: str, weights):
config.hidden_size, config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear( return TensorParallelColumnLinear(get_linear(weight, bias=None))
get_linear(weight, bias=None, quantize=config.quantize)
)
def _load_experts(config, prefix: str, mat, weights): def _load_experts(config, prefix: str, mat, weights):

View File

@ -24,7 +24,7 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
@ -45,10 +45,17 @@ from text_generation_server.layers.layernorm import (
from text_generation_server.layers.rotary import ( from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.weights import UnquantizedWeight
class GPTNeoXConfig(TransformersGPTNeoXConfig):
attribute_map = {
"num_key_value_heads": "num_attention_heads",
}
def load_row(config, prefix: str, weights, bias: bool): def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0: if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process # Rank is only on the first rank process
@ -56,7 +63,7 @@ def load_row(config, prefix: str, weights, bias: bool):
else: else:
bias = None bias = None
linear = get_linear(weight, bias, config.quantize) linear = get_linear(weight, bias)
if config.use_parallel_residual: if config.use_parallel_residual:
return linear return linear
else: else:
@ -64,11 +71,11 @@ def load_row(config, prefix: str, weights, bias: bool):
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
weight = weights.get_multi_weights_col([prefix], quantize=config.quantize, dim=0) weight = weights.get_multi_weights_col([prefix], dim=0)
if isinstance(weight, torch.Tensor): if isinstance(weight, UnquantizedWeight):
# Only on non quantized versions # Only on non quantized versions
weight = ( weight.weight = (
weight.view( weight.weight.view(
num_heads, num_heads,
3, 3,
head_size, head_size,
@ -81,7 +88,7 @@ def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
bias = weights.get_sharded(f"{prefix}.bias", dim=0) bias = weights.get_sharded(f"{prefix}.bias", dim=0)
bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1) bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)
linear = get_linear(weight, bias, config.quantize) linear = get_linear(weight, bias)
if config.use_parallel_residual: if config.use_parallel_residual:
return linear return linear
else: else:

View File

@ -85,7 +85,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0, dim=0,
) )
@ -101,9 +100,7 @@ def _load_gqa(config, prefix: str, weights):
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
# this is the same as llama except for Phi uses bias=True # this is the same as llama except for Phi uses bias=True
return TensorParallelColumnLinear( return TensorParallelColumnLinear(get_linear(weight, bias=True))
get_linear(weight, bias=True, quantize=config.quantize)
)
class FlashPhiAttention(torch.nn.Module): class FlashPhiAttention(torch.nn.Module):

View File

@ -23,7 +23,7 @@ from text_generation_server.layers.attention import (
def load_row(config, prefix: str, weights, bias: bool): def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0: if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process # Rank is only on the first rank process
@ -31,7 +31,7 @@ def load_row(config, prefix: str, weights, bias: bool):
else: else:
bias = None bias = None
linear = get_linear(weight, bias, config.quantize) linear = get_linear(weight, bias)
if config.parallel_attn: if config.parallel_attn:
return linear return linear
else: else:
@ -42,6 +42,7 @@ class RWConfig(PretrainedConfig):
attribute_map = { attribute_map = {
"num_hidden_layers": "n_layer", "num_hidden_layers": "n_layer",
"num_attention_heads": "n_head", "num_attention_heads": "n_head",
"num_key_value_heads": "n_head_kv",
} }
def __init__( def __init__(

View File

@ -17,6 +17,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
get_linear, get_linear,
) )
from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
@ -81,11 +82,13 @@ def _load_multi_mqa_gptq(
qzeros = torch.cat([q_tensor, kv_tensor], dim=1) qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
qzeros = qzeros.to(device=weights.device) qzeros = qzeros.to(device=weights.device)
gptq_params = weights._get_gptq_params() loader = weights.weights_loader
if gptq_params.quant_method == "gptq": assert isinstance(loader, GPTQWeightsLoader)
loader._get_gptq_params(weights)
if loader.quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device) g_idx = g_idx.to(device=weights.device)
elif gptq_params.quant_method == "awq": elif loader.quant_method == "awq":
g_idx = None g_idx = None
from text_generation_server.layers.awq.conversion_utils import ( from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq, fast_awq_to_gptq,
@ -100,8 +103,9 @@ def _load_multi_mqa_gptq(
qzeros=qzeros, qzeros=qzeros,
scales=scales, scales=scales,
g_idx=g_idx, g_idx=g_idx,
bits=gptq_params.bits, bits=loader.bits,
groupsize=gptq_params.groupsize, groupsize=loader.groupsize,
use_awq_kernel=loader.quantize == "awq",
use_exllama=HAS_EXLLAMA, use_exllama=HAS_EXLLAMA,
) )
@ -118,7 +122,7 @@ def _load_multi_mqa_gptq(
bias = torch.cat([q_tensor, kv_tensor], dim=0) bias = torch.cat([q_tensor, kv_tensor], dim=0)
bias = bias.to(device=weights.device) bias = bias.to(device=weights.device)
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) return TensorParallelColumnLinear(get_linear(weight, bias))
else: else:
raise NotImplementedError("Gptq loading with santacoder is not implemented") raise NotImplementedError("Gptq loading with santacoder is not implemented")
@ -190,29 +194,27 @@ def _load_multi_mqa(
assert list(bias.shape) == [ assert list(bias.shape) == [
(num_heads + 2) * head_size (num_heads + 2) * head_size
], f"{weight.shape} != {[(num_heads + 2) * head_size]}" ], f"{weight.shape} != {[(num_heads + 2) * head_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) return TensorParallelColumnLinear(get_linear(weight, bias))
def load_col(config, prefix: str, weights, bias: bool): def load_col(config, prefix: str, weights, bias: bool):
if config.transpose: if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
else: else:
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col([prefix], dim=0)
[prefix], quantize=config.quantize, dim=0
)
if bias: if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0) bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else: else:
bias = None bias = None
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) return TensorParallelColumnLinear(get_linear(weight, bias))
def load_row(config, prefix: str, weights, bias: bool): def load_row(config, prefix: str, weights, bias: bool):
if config.transpose: if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
else: else:
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0: if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process # Rank is only on the first rank process
@ -220,7 +222,7 @@ def load_row(config, prefix: str, weights, bias: bool):
else: else:
bias = None bias = None
return TensorParallelRowLinear( return TensorParallelRowLinear(
get_linear(weight, bias, config.quantize), process_group=weights.process_group get_linear(weight, bias), process_group=weights.process_group
) )

View File

@ -45,6 +45,7 @@ from text_generation_server.layers.layernorm import (
from text_generation_server.layers.rotary import ( from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.weights import UnquantizedWeight
class Starcoder2Config(PretrainedConfig): class Starcoder2Config(PretrainedConfig):
@ -126,20 +127,19 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0, dim=0,
) )
if config.quantize not in ["gptq", "awq", "marlin"]: if isinstance(weight, UnquantizedWeight):
weight = weight.to(dtype=weights.dtype).to(device=weights.device) weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size() num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [ assert list(weight.weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size, (num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size, config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
if config.use_bias: if config.use_bias:
w = [ w = [
@ -150,9 +150,7 @@ def _load_gqa(config, prefix: str, weights):
else: else:
bias = None bias = None
return TensorParallelColumnLinear( return TensorParallelColumnLinear(get_linear(weight, bias=bias))
get_linear(weight, bias=bias, quantize=config.quantize)
)
class Starcoder2Attention(torch.nn.Module): class Starcoder2Attention(torch.nn.Module):

View File

@ -34,6 +34,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@ -682,7 +683,7 @@ class Idefics2Connector(nn.Module):
class Idefics2ForConditionalGeneration(nn.Module): class Idefics2ForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
config.vision_config.quantize = config.quantize config.vision_config.quantize = None
config.vision_config.speculator = config.speculator config.vision_config.speculator = config.speculator
config.text_config.quantize = config.quantize config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator config.text_config.speculator = config.speculator
@ -695,16 +696,24 @@ class Idefics2ForConditionalGeneration(nn.Module):
name="text_model", name="text_model",
) )
self.dtype = weights.dtype self.dtype = weights.dtype
# The vision and connector models are not quantized.
with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):
self.vision_model = Idefics2VisionTransformer( self.vision_model = Idefics2VisionTransformer(
prefix=f"{prefix}.model.vision_model" if prefix else "model.vision_model", prefix=(
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
),
config=vision_config, config=vision_config,
weights=weights, weights=weights,
) )
config.quantize = None
self.connector = Idefics2Connector( self.connector = Idefics2Connector(
prefix=f"{prefix}.model.connector" if prefix else "model.connector", prefix=f"{prefix}.model.connector" if prefix else "model.connector",
config=config, config=config,
weights=weights, weights=weights,
) )
self.config = config self.config = config
self.image_seq_len = config.perceiver_config.resampler_n_latents self.image_seq_len = config.perceiver_config.resampler_n_latents
self.image_token_id = config.image_token_id self.image_token_id = config.image_token_id

View File

@ -75,7 +75,7 @@ def load_col(config, prefix, weights, bias):
bias = bias.to(device=weights.device) bias = bias.to(device=weights.device)
else: else:
bias = None bias = None
linear = get_linear(weight, bias, config.quantize) linear = get_linear(weight, bias)
return TensorParallelColumnLinear(linear) return TensorParallelColumnLinear(linear)
@ -337,17 +337,17 @@ class MultiheadAttention(nn.Module):
weights, weights,
): ):
super().__init__() super().__init__()
attn_impl = config.attn_config["attn_impl"] attn_impl = config.attn_config.attn_impl
self.attn_impl = config.attn_config["attn_impl"] self.attn_impl = config.attn_config.attn_impl
self.clip_qkv = config.attn_config["clip_qkv"] self.clip_qkv = config.attn_config.clip_qkv
self.qk_ln = config.attn_config["qk_ln"] self.qk_ln = config.attn_config.qk_ln
self.d_model = config.d_model self.d_model = config.d_model
d_model = config.d_model d_model = config.d_model
self.n_heads = config.n_heads self.n_heads = config.n_heads
self.softmax_scale = config.attn_config["softmax_scale"] self.softmax_scale = config.attn_config.softmax_scale
if self.softmax_scale is None: if self.softmax_scale is None:
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = config.attn_config["attn_pdrop"] self.attn_dropout_p = config.attn_config.attn_pdrop
if self.n_heads % weights.process_group.size() != 0: if self.n_heads % weights.process_group.size() != 0:
raise ValueError( raise ValueError(
@ -430,17 +430,17 @@ class MultiQueryAttention(nn.Module):
def __init__(self, config, prefix, weights): def __init__(self, config, prefix, weights):
super().__init__() super().__init__()
attn_impl = config.attn_config["attn_impl"] attn_impl = config.attn_config.attn_impl
self.attn_impl = config.attn_config["attn_impl"] self.attn_impl = config.attn_config.attn_impl
self.clip_qkv = config.attn_config["clip_qkv"] self.clip_qkv = config.attn_config.clip_qkv
self.qk_ln = config.attn_config["qk_ln"] self.qk_ln = config.attn_config.qk_ln
self.d_model = config.d_model self.d_model = config.d_model
d_model = config.d_model d_model = config.d_model
self.n_heads = config.n_heads self.n_heads = config.n_heads
self.softmax_scale = config.attn_config["softmax_scale"] self.softmax_scale = config.attn_config.softmax_scale
if self.softmax_scale is None: if self.softmax_scale is None:
self.softmax_scale = 1 / math.sqrt(self.head_dim) self.softmax_scale = 1 / math.sqrt(self.head_dim)
self.attn_dropout_p = config.attn_config["attn_pdrop"] self.attn_dropout_p = config.attn_config.attn_pdrop
# self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device) # self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
self.Wqkv = TensorParallelColumnLinear.load( self.Wqkv = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
@ -614,9 +614,9 @@ class MPTBlock(nn.Module):
def __init__(self, config, prefix, weights): def __init__(self, config, prefix, weights):
super().__init__() super().__init__()
self.prefix = prefix self.prefix = prefix
if config.attn_config["attn_type"] != "multihead_attention": if config.attn_config.attn_type != "multihead_attention":
raise NotImplementedError( raise NotImplementedError(
f"""Not implemented attn {config.attn_config["attn_type"]}""" f"""Not implemented attn {config.attn_config.attn_type}"""
) )
resid_pdrop = config.resid_pdrop resid_pdrop = config.resid_pdrop
if config.no_bias: if config.no_bias:
@ -789,11 +789,11 @@ class MPTModel(MPTPreTrainedModel):
self.world_size = weights.process_group.size() self.world_size = weights.process_group.size()
self.rank = weights.process_group.rank() self.rank = weights.process_group.rank()
self.n_heads = config.n_heads self.n_heads = config.n_heads
self.attn_impl = config.attn_config["attn_impl"] self.attn_impl = config.attn_config.attn_impl
self.prefix_lm = config.attn_config["prefix_lm"] self.prefix_lm = config.attn_config.prefix_lm
self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"] self.attn_uses_sequence_id = config.attn_config.attn_uses_sequence_id
self.alibi = config.attn_config["alibi"] self.alibi = config.attn_config.alibi
self.alibi_bias_max = config.attn_config["alibi_bias_max"] self.alibi_bias_max = config.attn_config.alibi_bias_max
if config.init_device == "mixed": if config.init_device == "mixed":
if dist.get_local_rank() == 0: if dist.get_local_rank() == 0:
config.init_device = "cpu" config.init_device = "cpu"

View File

@ -23,14 +23,13 @@ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.log import log_master
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.dist import RANK
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights, Weights,
hub,
) )
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
@ -50,6 +49,7 @@ from text_generation_server.models.globals import (
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.dist import MEMORY_FRACTION
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
from text_generation_server.utils.import_utils import ( from text_generation_server.utils.import_utils import (
@ -838,7 +838,9 @@ class FlashCausalLM(Model):
default_dtype=torch.float16, default_dtype=torch.float16,
aliases=None, aliases=None,
# Used for Santacoder override of config # Used for Santacoder override of config
num_kv_heads=None, num_kv_heads: Optional[int] = None,
# Deepseek V2 uses different QK and V dims.
head_size: Optional[int] = None,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
@ -881,12 +883,16 @@ class FlashCausalLM(Model):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
weights_loader = get_loader(quantize, model_id, revision)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights( weights = Weights(
filenames, device, dtype, process_group=self.process_group, aliases=aliases filenames,
device,
dtype,
process_group=self.process_group,
aliases=aliases,
weights_loader=weights_loader,
) )
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = "" prefix = ""
model = model_class(prefix, config, weights) model = model_class(prefix, config, weights)
@ -905,15 +911,23 @@ class FlashCausalLM(Model):
self.num_layers = config.num_hidden_layers self.num_layers = config.num_hidden_layers
# Validation is done in the model itself # Validation is done in the model itself
if num_kv_heads is None: if num_kv_heads is None:
# Order is important here. num_kv_heads = getattr(config, "num_key_value_heads", None)
for attr in ["num_key_value_heads", "num_key_value_heads", "n_head"]: # GPT-2 workaround
num_kv_heads = getattr(config, "num_attention_heads", None) if num_kv_heads is None:
if num_kv_heads is not None: num_kv_heads = getattr(config, "n_head", None)
break
if num_kv_heads is None: if num_kv_heads is None:
raise ValueError("Cannot get the number of key/value heads") raise ValueError("Cannot get the number of key/value heads")
self.num_kv_heads = num_kv_heads // self.process_group.size() self.num_kv_heads = (
num_kv_heads // self.process_group.size()
if num_kv_heads > 1
else num_kv_heads
)
assert self.num_kv_heads > 0
if head_size is None:
self.head_size = config.hidden_size // config.num_attention_heads self.head_size = config.hidden_size // config.num_attention_heads
else:
self.head_size = head_size
self.cuda_graphs = {} self.cuda_graphs = {}
self.kv_cache = [] self.kv_cache = []
@ -1141,31 +1155,36 @@ class FlashCausalLM(Model):
f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
) )
logger.info( log_master(
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`." logger.info,
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
) )
if os.path.isfile(tunableop_filepath): if os.path.isfile(tunableop_filepath):
logger.info( log_master(
f"The file {tunableop_filepath} already exists and will be reused." logger.info,
f"The file {tunableop_filepath} already exists and will be reused.",
) )
torch.cuda.tunable.read_file(tunableop_filepath) torch.cuda.tunable.read_file(tunableop_filepath)
os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True) os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)
for seqlen in tuning_sequences: for seqlen in tuning_sequences:
logger.info(f"Warming up TunableOp for seqlen={seqlen}") log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
self.tunableop_warmup(seqlen) self.tunableop_warmup(seqlen)
torch.cuda.tunable.write_file(tunableop_filepath) torch.cuda.tunable.write_file(tunableop_filepath)
torch.cuda.tunable.tuning_enable(False) torch.cuda.tunable.tuning_enable(False)
else: else:
logger.info( log_master(
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp." logger.info,
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.",
) )
if CUDA_GRAPHS: if CUDA_GRAPHS:
try: try:
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}") log_master(
logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}"
)
# Warmup cuda graphs # Warmup cuda graphs
for bs in CUDA_GRAPHS: for bs in CUDA_GRAPHS:
if self.speculate is None or self.speculate + 1 <= bs: if self.speculate is None or self.speculate + 1 <= bs:
@ -1173,7 +1192,9 @@ class FlashCausalLM(Model):
except torch.cuda.OutOfMemoryError: except torch.cuda.OutOfMemoryError:
logger.exception(f"Decode cuda graph warmup failed") logger.exception(f"Decode cuda graph warmup failed")
else: else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") log_master(
logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
)
return int(num_blocks * BLOCK_SIZE) return int(num_blocks * BLOCK_SIZE)
@ -1525,8 +1546,7 @@ class FlashCausalLM(Model):
left = 0 left = 0
if n_accepted_ids > 1: if n_accepted_ids > 1:
if RANK == 0: log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}")
logger.debug(f"Speculated ids {n_accepted_ids - 1}")
current_stopped = False current_stopped = False
for j in range(index, index + n_accepted_ids): for j in range(index, index + n_accepted_ids):

View File

@ -1,15 +1,16 @@
import torch import torch
import os import os
from loguru import logger from loguru import logger
from typing import Dict from typing import Dict, Optional
from text_generation_server.utils.log import log_master
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli # This is overridden by the cli
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
if FLASH_DECODING: if FLASH_DECODING:
logger.info("Using FLASH_DECODING") log_master(logger.info, "Using FLASH_DECODING")
cuda_graphs = os.getenv("CUDA_GRAPHS") cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None: if cuda_graphs is not None:
@ -26,11 +27,9 @@ else:
if cuda_graphs is not None: if cuda_graphs is not None:
cuda_graphs.sort(reverse=True) cuda_graphs.sort(reverse=True)
CUDA_GRAPHS = cuda_graphs CUDA_GRAPHS = cuda_graphs
# This is overridden at model loading. # This is overridden at model loading.
global MODEL_ID
MODEL_ID = None MODEL_ID = None
@ -41,8 +40,7 @@ def set_model_id(model_id: str):
# NOTE: eventually we should move this into the router and pass back the # NOTE: eventually we should move this into the router and pass back the
# index in all cases. # index in all cases.
global ADAPTER_TO_INDEX ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
ADAPTER_TO_INDEX: Dict[str, int] = None
def set_adapter_to_index(adapter_to_index: Dict[str, int]): def set_adapter_to_index(adapter_to_index: Dict[str, int]):

View File

@ -23,6 +23,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.quantization import get_loader
class IDEFICSSharded(IdeficsCausalLM): class IDEFICSSharded(IdeficsCausalLM):
@ -70,6 +71,9 @@ class IDEFICSSharded(IdeficsCausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights( weights = Weights(
@ -77,6 +81,7 @@ class IDEFICSSharded(IdeficsCausalLM):
device=device, device=device,
dtype=dtype, dtype=dtype,
process_group=self.process_group, process_group=self.process_group,
weights_loader=weights_loader,
) )
model = IdeficsForVisionText2Text(config, weights) model = IdeficsForVisionText2Text(config, weights)

View File

@ -28,6 +28,7 @@ from text_generation_server.models.types import (
GeneratedText, GeneratedText,
) )
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.tokens import batch_top_tokens, Sampling from text_generation_server.utils.tokens import batch_top_tokens, Sampling
from dataclasses import dataclass from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -448,8 +449,17 @@ class Mamba(Model):
config.quantize = quantize config.quantize = quantize
config.speculator = speculator config.speculator = speculator
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
weights_loader=weights_loader,
)
model = MambaModel(config, weights) model = MambaModel(config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(Mamba, self).__init__( super(Mamba, self).__init__(

View File

@ -15,6 +15,7 @@ from text_generation_server.utils.adapter import (
AdapterParameters, AdapterParameters,
AdapterSource, AdapterSource,
) )
from text_generation_server.utils.log import log_master
from loguru import logger from loguru import logger
@ -204,8 +205,9 @@ class Model(ABC):
f"order to use the dynamic adapter loading feature." f"order to use the dynamic adapter loading feature."
) )
logger.info( log_master(
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}" logger.info,
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}",
) )
weight_names = tuple([v[0] for v in self.target_to_layer.values()]) weight_names = tuple([v[0] for v in self.target_to_layer.values()])
( (
@ -240,8 +242,9 @@ class Model(ABC):
layer_weights.add_adapter(adapter_index, adapter_weights) layer_weights.add_adapter(adapter_index, adapter_weights)
if len(unused_weight_names) > 0: if len(unused_weight_names) > 0:
logger.warning( log_master(
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}" logger.warning,
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}",
) )
if adapter_tokenizer is not None: if adapter_tokenizer is not None:

View File

@ -18,6 +18,7 @@ from text_generation_server.utils import (
Weights, Weights,
) )
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.models.types import ( from text_generation_server.models.types import (
@ -586,6 +587,9 @@ class Seq2SeqLM(Model):
) )
tokenizer.bos_token_id = config.decoder_start_token_id tokenizer.bos_token_id = config.decoder_start_token_id
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights( weights = Weights(
@ -594,6 +598,7 @@ class Seq2SeqLM(Model):
dtype=dtype, dtype=dtype,
process_group=self.process_group, process_group=self.process_group,
aliases=aliases, aliases=aliases,
weights_loader=weights_loader,
) )
if config.quantize in ["awq", "exl2", "gptq", "marlin"]: if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)

View File

@ -1,4 +1,3 @@
from itertools import repeat
import torch import torch
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
@ -13,6 +12,7 @@ from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch, FlashCausalLMBatch,
FlashCausalLM, FlashCausalLM,
) )
from text_generation_server.utils.log import log_master
from transformers import AutoProcessor from transformers import AutoProcessor
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -56,8 +56,9 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
num_features = get_number_of_features(height, width, config) num_features = get_number_of_features(height, width, config)
from loguru import logger from loguru import logger
logger.info( log_master(
f"Found {num_features} features in image of resolution {height}x{width}" logger.info,
f"Found {num_features} features in image of resolution {height}x{width}",
) )
return "<image>" * num_features return "<image>" * num_features
@ -261,7 +262,12 @@ class VlmCausalLM(FlashCausalLM):
**processor_kwargs, **processor_kwargs,
) )
self.batch_class = batch_class self.batch_class = batch_class
super().__init__(model_id=model_id, **kwargs) super().__init__(
model_id=model_id,
revision=revision,
trust_remote_code=trust_remote_code,
**kwargs,
)
@property @property
def batch_type(self) -> Type[VlmCausalLMBatch]: def batch_type(self) -> Type[VlmCausalLMBatch]:

View File

@ -56,7 +56,7 @@ def initialize_torch_distributed():
backend = "nccl" backend = "nccl"
options = ProcessGroupNCCL.Options() options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60) options._timeout = timedelta(seconds=120)
else: else:
backend = "gloo" backend = "gloo"
options = None options = None
@ -76,7 +76,7 @@ def initialize_torch_distributed():
backend="ccl", backend="ccl",
world_size=WORLD_SIZE, world_size=WORLD_SIZE,
rank=RANK, rank=RANK,
timeout=timedelta(seconds=60), timeout=timedelta(seconds=120),
pg_options=options, pg_options=options,
) )
else: else:
@ -84,7 +84,7 @@ def initialize_torch_distributed():
backend=backend, backend=backend,
world_size=WORLD_SIZE, world_size=WORLD_SIZE,
rank=RANK, rank=RANK,
timeout=timedelta(seconds=60), timeout=timedelta(seconds=120),
pg_options=options, pg_options=options,
) )
else: else:

View File

@ -1,6 +1,15 @@
from functools import lru_cache from functools import lru_cache
from text_generation_server.utils.dist import RANK
@lru_cache(10) @lru_cache(10)
def log_once(log, msg: str): def log_once(log, msg: str, master=True):
if master:
log_master(log, msg)
else:
log(msg)
def log_master(log, msg: str):
if RANK == 0:
log(msg) log(msg)

View File

@ -0,0 +1,173 @@
import json
import os
from dataclasses import dataclass
from typing import Optional
from huggingface_hub import hf_hub_download
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
WeightsLoader,
)
# TODO: Split this config to have a single config type per quant method
@dataclass
class _QuantizerConfig:
bits: int
checkpoint_format: Optional[str]
desc_act: bool
groupsize: int
quant_method: str
sym: bool
@dataclass
class _FP8QuantizerConfig:
activation_scale_ub: float
# We should probably do this with Pytantic JSON deserialization,
# but for now we'll stay close to the old _set_gptq_params.
def _get_quantizer_config(model_id, revision):
bits = 4
groupsize = -1
quant_method = "gptq"
checkpoint_format = None
sym = True
desc_act = False
filename = "config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(model_id, filename=filename, revision=revision)
with open(filename, "r") as f:
data = json.load(f)
# FP8 config
if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
return _FP8QuantizerConfig(
activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
)
bits = data["quantization_config"]["bits"]
groupsize = data["quantization_config"]["group_size"]
# Order is important here, desc_act is missing on some real models
quant_method = data["quantization_config"]["quant_method"]
checkpoint_format = data["quantization_config"].get("checkpoint_format")
sym = data["quantization_config"]["sym"]
desc_act = data["quantization_config"]["desc_act"]
except Exception:
filename = "quantize_config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
bits = data["bits"]
groupsize = data["group_size"]
sym = data["sym"]
desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM":
quant_method = "awq"
except Exception:
filename = "quant_config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
bits = data["w_bit"]
groupsize = data["q_group_size"]
desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM":
quant_method = "awq"
except Exception:
pass
return _QuantizerConfig(
bits=bits,
groupsize=groupsize,
quant_method=quant_method,
checkpoint_format=checkpoint_format,
sym=sym,
desc_act=desc_act,
)
def get_loader(
quantize: Optional[str], model_id: str, revision: Optional[str]
) -> WeightsLoader:
quantizer_config = _get_quantizer_config(model_id, revision)
if quantize in {"awq", "gptq"}:
from text_generation_server.layers.gptq import GPTQWeightsLoader
# TODO: improve check once we have one config type per quantize value
if not isinstance(quantizer_config, _QuantizerConfig):
raise ValueError(
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
)
return GPTQWeightsLoader(
bits=quantizer_config.bits,
desc_act=quantizer_config.desc_act,
groupsize=quantizer_config.groupsize,
quant_method=quantizer_config.quant_method,
quantize=quantize,
sym=quantizer_config.sym,
)
elif quantize == "bitsandbytes":
from text_generation_server.layers.bnb import BNBWeight
return DefaultWeightsLoader(BNBWeight)
elif quantize == "bitsandbytes-fp4":
from text_generation_server.layers.bnb import BNBFP4Weight
return DefaultWeightsLoader(BNBFP4Weight)
elif quantize == "bitsandbytes-nf4":
from text_generation_server.layers.bnb import BNBNF4Weight
return DefaultWeightsLoader(BNBNF4Weight)
elif quantize == "eetq":
from text_generation_server.layers.eetq import EETQWeight
return DefaultWeightsLoader(EETQWeight)
elif quantize == "exl2":
from text_generation_server.layers.exl2 import Exl2WeightsLoader
return Exl2WeightsLoader()
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeightsLoader
# TODO: improve check once we have one config type per quantize value
if not isinstance(quantizer_config, _QuantizerConfig):
raise ValueError(
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
)
return MarlinWeightsLoader(
bits=quantizer_config.bits,
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
)
elif quantize == "fp8" or quantize is None:
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
# Since the default for the quantize config is _QuantizerConfig,
# we need to add this check to not get an attribute error
activation_scale_ub = None
if isinstance(quantizer_config, _FP8QuantizerConfig):
activation_scale_ub = quantizer_config.activation_scale_ub
return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8")
else:
raise ValueError(f"Unknown quantization method: {quantize}")

View File

@ -1,13 +1,140 @@
import os
from pathlib import Path
from typing import Dict, List, Optional, Union
from safetensors import safe_open, SafetensorError
import torch import torch
from loguru import logger
from huggingface_hub import hf_hub_download from abc import ABC, abstractmethod
import json from contextlib import contextmanager
from text_generation_server.layers.gptq import GPTQParams from pathlib import Path
from text_generation_server.utils.log import log_once from typing import Dict, List, Optional, Union, Type
from safetensors import safe_open
from dataclasses import dataclass
from text_generation_server.utils.import_utils import SYSTEM
class WeightsLoader(ABC):
"""
Instances of this type implement higher-level weight loading.
At a low-level, every weight is stored in the Safetensors format.
The interpretation of weights may be different however, for instance
could be packed, quantized weights. Loaders are responsible for
interpreting the raw tensors, sharding tensors in a manner compatible
with the format, etc.
"""
@abstractmethod
def get_weights(self, weights: "Weights", prefix: str):
"""
Get weights at the given prefix and apply without tensor paralllism.
"""
...
@abstractmethod
def get_weights_col_packed(
self,
weights: "Weights",
prefix: str,
block_sizes: Union[int, List[int]],
):
"""
Get the packed weights at the given prefix with column-splitting for
tensor parallelism. This method should be used when multiple different
weights are packed into a tensor, for instance, query/key/value
weights or a gate/up projection.
The `block_sizes` determines the proportions of the packed tensors.
The columns are split in equally sized blocks when `block_sizes` is an
`int`, or in blocks proportional given to the sizes. For instance
`[2, 1, 1]` will divide an input with dimensionality `1024` in
`[512, 256, 256]`.
"""
...
def get_weights_col(self, weights: "Weights", prefix: str):
"""
Get weights at the given prefix and apply column-splitting for tensor
paralllism.
"""
return weights.get_multi_weights_col([prefix], 0)
@abstractmethod
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
"""
Get the weights at the given prefixes, column-split them for tensor
parallelim, and then concatenate the weights along the given dimension.
"""
...
@abstractmethod
def get_weights_row(self, weights: "Weights", prefix: str):
"""
Get the weights at the given prefix and apply row-splitting for tensor
parallism.
"""
...
class Weight(ABC):
"""Instances of this type implement unquantized/quantized/to-be
quantized weights."""
@abstractmethod
def get_linear(self, bias: torch.Tensor):
"""Create a linear layer from this weight."""
...
@dataclass
class UnquantizedWeight(Weight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
from text_generation_server.layers.linear import FastLinear, FastLinearROCm
if SYSTEM == "rocm":
return FastLinearROCm(self.weight, bias)
else:
return FastLinear(self.weight, bias)
class DefaultWeightsLoader(WeightsLoader):
"""Weight loader that loads (unquantized) Torch tensors."""
def __init__(self, weight_class: Type[UnquantizedWeight]):
"""Create a loader. Weights will be wrapped using the given `weights_class`,
normally this will be `UnquantizedWeight`, but a quantizer-specific class
such as `Fp8Weight` can be used to quantize the weights during loading.
"""
self.weight_class = weight_class
"""
Loader that uses tensors as-is with the exception of applying sharding
and/or concatenation.
"""
def get_weights(self, weights: "Weights", prefix: str):
return weights.get_tensor(f"{prefix}.weight")
def get_weights_col_packed(
self,
weights: "Weights",
prefix: str,
block_sizes: Union[int, List[int]],
):
return self.weight_class(
weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
),
)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
return self.weight_class(torch.cat(w, dim=dim))
def get_weights_row(self, weights: "Weights", prefix: str):
return self.weight_class(
weights.get_sharded(f"{prefix}.weight", dim=1),
)
class Weights: class Weights:
@ -17,6 +144,7 @@ class Weights:
device, device,
dtype, dtype,
process_group, process_group,
weights_loader: WeightsLoader,
aliases: Optional[Dict[str, List[str]]] = None, aliases: Optional[Dict[str, List[str]]] = None,
prefix: Optional[str] = None, prefix: Optional[str] = None,
): ):
@ -37,6 +165,7 @@ class Weights:
self.dtype = dtype self.dtype = dtype
self.process_group = process_group self.process_group = process_group
self.prefix = prefix self.prefix = prefix
self.weights_loader = weights_loader
self._handles = {} self._handles = {}
def _get_handle(self, filename): def _get_handle(self, filename):
@ -69,23 +198,39 @@ class Weights:
slice_ = f.get_slice(tensor_name) slice_ = f.get_slice(tensor_name)
return slice_ return slice_
def _has_tensor(self, tensor_name: str):
try:
self.get_filename(tensor_name)
except Exception:
return False
return True
def get_shape(self, tensor_name: str): def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape() return self._get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str, to_device=True): def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True):
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name) tensor = f.get_tensor(tensor_name)
# Special case for gptq which shouldn't convert # Special case for gptq which shouldn't convert
# u4 which are disguised as int32. Exl2 uses int16 # u4 which are disguised as int32. Exl2 uses int16
# as well. # as well. FP8 uses torch.float8_e4m3fn
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]: if (
tensor.dtype
not in [
torch.float8_e4m3fn,
torch.int16,
torch.int32,
torch.int64,
]
and to_dtype
):
tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(dtype=self.dtype)
if to_device: if to_device:
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor
def get_partial_sharded(self, tensor_name: str, dim: int): def get_partial_sharded(self, tensor_name: str, dim: int, to_dtype=True):
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name) slice_ = f.get_slice(tensor_name)
@ -105,12 +250,16 @@ class Weights:
raise NotImplementedError("Let's make that generic when needed") raise NotImplementedError("Let's make that generic when needed")
# Special case for gptq which shouldn't convert # Special case for gptq which shouldn't convert
# u4 which are disguised as int32. exl2 uses int16. # u4 which are disguised as int32. exl2 uses int16.
if tensor.dtype not in (torch.int16, torch.int32): # FP8 uses torch.float8_e4m3fn.
if (
tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32)
and to_dtype
):
tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor
def get_sharded(self, tensor_name: str, dim: int): def get_sharded(self, tensor_name: str, dim: int, to_dtype=True):
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name) slice_ = f.get_slice(tensor_name)
@ -119,10 +268,14 @@ class Weights:
assert ( assert (
size % world_size == 0 size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards" ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim) return self.get_partial_sharded(tensor_name, dim, to_dtype=to_dtype)
def get_packed_sharded( def get_packed_sharded(
self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]] self,
tensor_name: str,
dim: int,
block_sizes: Union[int, List[int]],
to_dtype=True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Get a shard from a tensor that packs multiple tensors. Get a shard from a tensor that packs multiple tensors.
@ -168,308 +321,51 @@ class Weights:
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
# Avoid casting quantizer dtypes. # Avoid casting quantizer dtypes.
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]: if (
tensor.dtype
not in [
torch.float8_e4m3fn,
torch.int16,
torch.int32,
torch.int64,
]
and to_dtype
):
tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(dtype=self.dtype)
return tensor return tensor
def get_weights(self, prefix: str):
return self.weights_loader.get_weights(self, prefix)
def get_weights_col_packed_qkv( def get_weights_col_packed_qkv(
self, self,
prefix: str, prefix: str,
quantize: str,
num_heads: int, num_heads: int,
num_key_value_heads: int, num_key_value_heads: int,
): ):
return self.get_weights_col_packed( return self.get_weights_col_packed(
prefix, quantize, [num_heads, num_key_value_heads, num_key_value_heads] prefix, [num_heads, num_key_value_heads, num_key_value_heads]
) )
def get_weights_col_packed_gate_up(self, prefix: str, quantize: str): def get_weights_col_packed_gate_up(self, prefix: str):
return self.get_weights_col_packed(prefix, quantize, 2) return self.get_weights_col_packed(prefix, 2)
def get_weights_col_packed( def get_weights_col_packed(self, prefix: str, block_sizes: Union[int, List[int]]):
self, prefix: str, quantize: str, block_sizes: Union[int, List[int]]
):
""" """
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
already alternating Q,K,V within the main tensor.
The columns are split in equally sized blocks when blocks is an `int`, or The columns are split in equally sized blocks when blocks is an `int`, or
in blocks proportional given to the sizes. For instance `[2, 1, 1]` will in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
divide an input with dimensionality `1024` in `[512, 256, 256]`. This is divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
convenient for e.g. splitting QKV without knowing the storage details of convenient for e.g. splitting QKV without knowing the storage details of
quantized weights. quantized weights.
""" """
if quantize in ["gptq", "awq"]: return self.weights_loader.get_weights_col_packed(self, prefix, block_sizes)
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
try: def get_weights_col(self, prefix: str):
qweight = self.get_packed_sharded( return self.weights_loader.get_weights_col(self, prefix)
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
)
scales = self.get_packed_sharded(
f"{prefix}.scales", dim=1, block_sizes=block_sizes
)
scales = scales.to(dtype=self.dtype)
gptq_params = self._get_gptq_params() def get_multi_weights_col(self, prefixes: List[str], dim: int):
if can_use_gptq_marlin(gptq_params, quantize): return self.weights_loader.get_multi_weights_col(self, prefixes, dim)
g_idx = self.get_tensor(f"{prefix}.g_idx")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
qzeros = self.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
if quantize == "gptq" and gptq_params.quant_method == "gptq":
g_idx = self.get_tensor(f"{prefix}.g_idx")
elif quantize == "gptq" and gptq_params.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
g_idx = (
torch.arange(
qweight.shape[0] * (32 // gptq_params.bits),
device=qweight.device,
)
// gptq_params.groupsize
).to(dtype=torch.int32)
else:
g_idx = None
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
groupsize=gptq_params.groupsize,
use_exllama=False,
)
elif quantize == "marlin":
from text_generation_server.layers.marlin import (
GPTQMarlin24Weight,
MarlinWeight,
repack_gptq_for_marlin,
)
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
B = self.get_packed_sharded(
f"{prefix}.B_24", dim=1, block_sizes=block_sizes
)
B_meta = self.get_packed_sharded(
f"{prefix}.B_meta", dim=1, block_sizes=block_sizes
)
s = self.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
gptq_params = self._get_gptq_params()
weight = GPTQMarlin24Weight(
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
)
else:
B = self.get_packed_sharded(
f"{prefix}.B", dim=1, block_sizes=block_sizes
)
s = self.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
weight = MarlinWeight(B=B, s=s)
else:
weight = self.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
return weight
def get_weights_col(self, prefix: str, quantize: str):
if quantize == "exl2":
from text_generation_server.layers.exl2 import Exl2Weight
try:
q_weight = self.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = self.get_tensor(f"{prefix}.q_scale")
q_invperm = self.get_tensor(f"{prefix}.q_invperm")
q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
q_groups = self.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
return self.get_multi_weights_col([prefix], quantize, 0)
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
if quantize == "exl2":
raise ValueError("get_multi_weights_col is not supported for exl2")
elif quantize in ["gptq", "awq"]:
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
try:
qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
)
scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
gptq_params = self._get_gptq_params()
if can_use_gptq_marlin(gptq_params, quantize):
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
qzeros = torch.cat(
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
from text_generation_server.layers.gptq import HAS_EXLLAMA
use_exllama = (
gptq_params.bits == 4
and HAS_EXLLAMA
and quantize == "gptq"
and not gptq_params.desc_act
)
if quantize == "gptq" and gptq_params.quant_method == "gptq":
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
elif quantize == "gptq" and gptq_params.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // gptq_params.bits),
device=qweight.device,
)
// gptq_params.groupsize
).to(dtype=torch.int32)
else:
g_idx = None
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
groupsize=gptq_params.groupsize,
use_exllama=use_exllama,
)
elif quantize == "marlin":
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
GPTQMarlin24Weight,
MarlinWeight,
)
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
B = torch.cat(
[self.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
)
B_meta = torch.cat(
[self.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1
)
s = torch.cat(
[self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
gptq_params = self._get_gptq_params()
weight = GPTQMarlin24Weight(
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
)
else:
try:
B = torch.cat(
[self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
)
s = torch.cat(
[self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
weight = MarlinWeight(B=B, s=s)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim)
return weight
def get_tensor_shard(self, var, dim): def get_tensor_shard(self, var, dim):
world_size = self.process_group.size() world_size = self.process_group.size()
@ -487,318 +383,22 @@ class Weights:
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor
def get_multi_weights_row(self, prefix: str, quantize: str): def get_weights_row(self, prefix: str):
if quantize == "exl2": return self.weights_loader.get_weights_row(self, prefix)
from text_generation_server.layers.exl2 import Exl2Weight
@contextmanager
def use_loader(self, weights_loader: WeightsLoader):
"""
This method is a context manager that can be used to use `Weights` with
a different loader for the duration of the context.
"""
old_loader = self.weights_loader
self.weights_loader = weights_loader
try: try:
q_weight = self.get_tensor(f"{prefix}.q_weight") yield
except RuntimeError: finally:
raise RuntimeError( self.weights_loader = old_loader
f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = self.get_tensor(f"{prefix}.q_scale")
q_invperm = self.get_tensor(f"{prefix}.q_invperm")
q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
q_groups = self.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
elif quantize == "gptq":
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
gptq_params = self._get_gptq_params()
if can_use_gptq_marlin(gptq_params, quantize):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
if gptq_params.desc_act or gptq_params.groupsize == -1:
scales = self.get_tensor(f"{prefix}.scales")
else:
scales = self.get_sharded(f"{prefix}.scales", dim=0)
sharded_in_features = self.process_group.size() > 1
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=sharded_in_features,
)
use_exllama = True
if gptq_params.bits != 4:
use_exllama = False
if gptq_params.desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if gptq_params.quant_method == "gptq":
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
elif gptq_params.quant_method == "awq":
g_idx = None
if self.process_group.size() > 1:
if g_idx is not None:
if (
not torch.equal(
g_idx.cpu(),
torch.tensor(
[
i // gptq_params.groupsize
for i in range(g_idx.shape[0])
],
dtype=torch.int32,
),
)
and not (g_idx == 0).all()
):
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
use_exllama = False
from text_generation_server.layers.gptq import (
HAS_EXLLAMA,
CAN_EXLLAMA,
GPTQWeight,
)
if use_exllama:
if not HAS_EXLLAMA:
if CAN_EXLLAMA:
log_once(
logger.warning,
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
)
use_exllama = False
else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
if use_exllama and gptq_params.groupsize != -1:
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0)
else:
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0]
if gptq_params.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // gptq_params.bits),
device=qweight.device,
)
// gptq_params.groupsize
).to(dtype=torch.int32)
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
groupsize=gptq_params.groupsize,
use_exllama=use_exllama,
)
elif quantize == "awq":
from text_generation_server.layers.gptq import GPTQWeight
gptq_params = self._get_gptq_params()
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `awq` weight, make sure the model is already quantized"
)
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0)
g_idx = None
use_exllama = False
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
groupsize=gptq_params.groupsize,
use_exllama=use_exllama,
)
elif quantize == "marlin":
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
GPTQMarlin24Weight,
MarlinWeight,
)
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
B = self.get_sharded(f"{prefix}.B_24", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
)
B_meta = self.get_sharded(f"{prefix}.B_meta", dim=0)
num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = self.get_tensor(f"{prefix}.s")
else:
s = self.get_sharded(f"{prefix}.s", dim=0)
gptq_params = self._get_gptq_params()
weight = GPTQMarlin24Weight(
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
)
else:
try:
B = self.get_sharded(f"{prefix}.B", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized."
)
num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = self.get_tensor(f"{prefix}.s")
else:
s = self.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight
def _get_gptq_params(self) -> GPTQParams:
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
desc_act = False
sym = False
quant_method = "gptq"
except (SafetensorError, RuntimeError) as e:
try:
bits = self.gptq_bits
groupsize = self.gptq_groupsize
checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
desc_act = getattr(self, "gptq_desc_act", False)
quant_method = getattr(self, "quant_method", "gptq")
sym = getattr(self, "sym", True)
except Exception:
raise e
return GPTQParams(
bits=bits,
checkpoint_format=checkpoint_format,
desc_act=desc_act,
groupsize=groupsize,
quant_method=quant_method,
sym=sym,
)
def _set_gptq_params(self, model_id, revision):
filename = "config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
self.gptq_bits = data["quantization_config"]["bits"]
self.gptq_groupsize = data["quantization_config"]["group_size"]
# Order is important here, desc_act is missing on some real models
self.quant_method = data["quantization_config"]["quant_method"]
self.gptq_checkpoint_format = data["quantization_config"].get(
"checkpoint_format"
)
self.gptq_sym = data["quantization_config"]["sym"]
self.gptq_desc_act = data["quantization_config"]["desc_act"]
except Exception:
filename = "quantize_config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
self.gptq_bits = data["bits"]
self.gptq_groupsize = data["group_size"]
self.gptq_sym = data["sym"]
self.gptq_desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM":
self.quant_method = "awq"
except Exception:
filename = "quant_config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
self.gptq_bits = data["w_bit"]
self.gptq_groupsize = data["q_group_size"]
self.gptq_desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM":
self.quant_method = "awq"
except Exception:
pass
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]: def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:

View File

@ -155,7 +155,7 @@ def check_openapi(check: bool):
filename, filename,
], ],
capture_output=True, capture_output=True,
).stdout.decode() ).stdout.decode("utf-8")
os.remove(tmp_filename) os.remove(tmp_filename)
if diff: if diff:
@ -164,10 +164,26 @@ def check_openapi(check: bool):
"OpenAPI documentation is not up-to-date, run `python update_doc.py` in order to update it" "OpenAPI documentation is not up-to-date, run `python update_doc.py` in order to update it"
) )
return True
else: else:
os.rename(tmp_filename, filename) os.rename(tmp_filename, filename)
print("OpenAPI documentation updated.") print("OpenAPI documentation updated.")
errors = subprocess.run(
[
"swagger-cli",
# allow for trailing whitespace since it's not significant
# and the precommit hook will remove it
"validate",
filename,
],
capture_output=True,
).stderr.decode("utf-8")
# The openapi specs fails on `exclusive_minimum` which is expected to be a boolean where
# utoipa outputs a value instead: https://github.com/juhaku/utoipa/issues/969
if not errors.startswith("Swagger schema validation failed."):
print(errors)
raise Exception(
f"OpenAPI documentation is invalid, `swagger-cli validate` showed some error:\n {errors}"
)
return True return True