diff --git a/.github/workflows/autodocs.yaml b/.github/workflows/autodocs.yaml index 8af0b95d..e10b232c 100644 --- a/.github/workflows/autodocs.yaml +++ b/.github/workflows/autodocs.yaml @@ -30,6 +30,10 @@ jobs: id: install-router run: cargo install --path router/ + - uses: actions/setup-node@v4 + with: + node-version: 22 + - name: Set up Python uses: actions/setup-python@v2 with: @@ -37,4 +41,5 @@ jobs: - name: Check that documentation is up-to-date run: | + npm install -g swagger-cli python update_doc.py --check diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 8213887f..6c968053 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -27,8 +27,8 @@ jobs: concurrency: group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true - # TODO see with @Glegendre to get CPU runner here instead - runs-on: [self-hosted, nvidia-gpu , multi-gpu, 4-a10, ci] + runs-on: + group: aws-r7i-8xlarge-priv permissions: contents: write packages: write @@ -49,7 +49,7 @@ jobs: export dockerfile="Dockerfile" export label_extension="" export docker_devices="" - export runs_on="nvidia-gpu" + export runs_on="aws-g5-12xlarge" ;; rocm) export dockerfile="Dockerfile_amd" @@ -79,9 +79,15 @@ jobs: uses: docker/setup-buildx-action@v3 with: install: true - config-inline: | + buildkitd-config-inline: | [registry."docker.io"] - mirrors = ["registry.github-runners.huggingface.tech"] + mirrors = ["registry-us-east-1-mirror.prod.aws.ci.huggingface.tech"] + - name: Login to internal Container Registry + uses: docker/login-action@v3 + with: + username: ${{ secrets.REGISTRY_USERNAME }} + password: ${{ secrets.REGISTRY_PASSWORD }} + registry: registry.internal.huggingface.tech - name: Login to GitHub Container Registry if: github.event_name != 'pull_request' uses: docker/login-action@v3 @@ -103,7 +109,8 @@ jobs: uses: docker/metadata-action@v5 with: images: | - registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference + registry-us-east-1.prod.aws.ci.huggingface.tech/api-inference/community/text-generation-inference + registry.internal.huggingface.tech/api-inference/community/text-generation-inference tags: | type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} # If main, release or tag @@ -115,7 +122,8 @@ jobs: flavor: | latest=auto images: | - registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference + registry-us-east-1.prod.aws.ci.huggingface.tech/api-inference/community/text-generation-inference + registry.internal.huggingface.tech/api-inference/community/text-generation-inferenceca ghcr.io/huggingface/text-generation-inference db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference tags: | @@ -136,12 +144,12 @@ jobs: DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }} labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }} - cache-from: type=registry,ref=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min - cache-to: type=registry,ref=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min + cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min + cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min - name: Final id: final run: | - echo "docker_image=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT" + echo "docker_image=registry-us-east-1.prod.aws.ci.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT" echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT" echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT" echo "label=${{ env.LABEL }}" >> "$GITHUB_OUTPUT" @@ -150,7 +158,8 @@ jobs: group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true needs: build-and-push - runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"] + runs-on: + group: ${{ needs.build-and-push.outputs.runs_on }} if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest' env: PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '' }} diff --git a/.github/workflows/load_test.yaml b/.github/workflows/load_test.yaml index 637df472..ecfe0fda 100644 --- a/.github/workflows/load_test.yaml +++ b/.github/workflows/load_test.yaml @@ -15,7 +15,8 @@ jobs: concurrency: group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true - runs-on: [self-hosted, nvidia-gpu , multi-gpu, 4-a10, ci] + runs-on: + group: aws-g5-12xlarge env: DOCKER_VOLUME: /cache steps: diff --git a/Cargo.lock b/Cargo.lock index 090e2e80..923b5cbe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -801,6 +801,27 @@ dependencies = [ "typenum", ] +[[package]] +name = "csv" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" +dependencies = [ + "memchr", +] + [[package]] name = "ctrlc" version = "3.4.4" @@ -1935,17 +1956,6 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" -[[package]] -name = "metrics" -version = "0.21.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fde3af1a009ed76a778cb84fdef9e7dbbdf5775ae3e4cc1f434a6a307f6f76c5" -dependencies = [ - "ahash", - "metrics-macros", - "portable-atomic", -] - [[package]] name = "metrics" version = "0.23.0" @@ -1969,7 +1979,7 @@ dependencies = [ "hyper-util", "indexmap 2.2.6", "ipnet", - "metrics 0.23.0", + "metrics", "metrics-util", "quanta", "thiserror", @@ -1977,17 +1987,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "metrics-macros" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38b4faf00617defe497754acde3024865bc143d44a86799b24e191ecff91354f" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.68", -] - [[package]] name = "metrics-util" version = "0.17.0" @@ -1997,7 +1996,7 @@ dependencies = [ "crossbeam-epoch", "crossbeam-utils", "hashbrown 0.14.5", - "metrics 0.23.0", + "metrics", "num_cpus", "quanta", "sketches-ddsketch", @@ -3424,9 +3423,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.118" +version = "1.0.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d947f6b3163d8857ea16c4fa0dd4840d52f3041039a85decd46867eb1abef2e4" +checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5" dependencies = [ "itoa", "ryu", @@ -3672,15 +3671,16 @@ checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" [[package]] name = "sysinfo" -version = "0.30.12" +version = "0.30.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "732ffa00f53e6b2af46208fba5718d9662a421049204e156328b66791ffa15ae" +checksum = "0a5b4ddaee55fb2bea2bf0e5000747e5f5c0de765e5a5ff87f4cd106439f4bb3" dependencies = [ "cfg-if", "core-foundation-sys", "libc", "ntapi", "once_cell", + "rayon", "windows", ] @@ -3762,7 +3762,7 @@ dependencies = [ [[package]] name = "text-generation-benchmark" -version = "2.1.1-dev0" +version = "2.1.2-dev0" dependencies = [ "average", "clap", @@ -3783,7 +3783,7 @@ dependencies = [ [[package]] name = "text-generation-client" -version = "2.1.1-dev0" +version = "2.1.2-dev0" dependencies = [ "async-trait", "base64 0.22.1", @@ -3801,7 +3801,7 @@ dependencies = [ [[package]] name = "text-generation-launcher" -version = "2.1.1-dev0" +version = "2.1.2-dev0" dependencies = [ "clap", "ctrlc", @@ -3820,13 +3820,14 @@ dependencies = [ [[package]] name = "text-generation-router" -version = "2.1.1-dev0" +version = "2.1.2-dev0" dependencies = [ "async-stream", "axum 0.7.5", "axum-tracing-opentelemetry", "base64 0.22.1", "clap", + "csv", "futures", "futures-util", "hf-hub", @@ -3834,7 +3835,7 @@ dependencies = [ "init-tracing-opentelemetry", "itertools 0.10.5", "jsonschema", - "metrics 0.21.1", + "metrics", "metrics-exporter-prometheus", "minijinja", "minijinja-contrib", @@ -3848,6 +3849,7 @@ dependencies = [ "reqwest", "serde", "serde_json", + "sysinfo", "text-generation-client", "thiserror", "tokenizers", @@ -3859,6 +3861,7 @@ dependencies = [ "tracing-subscriber", "utoipa", "utoipa-swagger-ui", + "uuid", "vergen", ] @@ -4530,9 +4533,25 @@ dependencies = [ [[package]] name = "uuid" -version = "1.9.1" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5de17fd2f7da591098415cff336e12965a28061ddace43b59cb3c430179c9439" +checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +dependencies = [ + "getrandom", + "rand", + "uuid-macro-internal", +] + +[[package]] +name = "uuid-macro-internal" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee1cd046f83ea2c4e920d6ee9f7c3537ef928d75dce5d84a87c2c5d6b3999a3a" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.68", +] [[package]] name = "v_frame" diff --git a/Dockerfile b/Dockerfile index d4772b4a..54ddd5ef 100644 --- a/Dockerfile +++ b/Dockerfile @@ -40,7 +40,9 @@ RUN cargo build --profile release-opt # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS pytorch-install +# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099 ARG PYTORCH_VERSION=2.3.0 + ARG PYTHON_VERSION=3.10 # Keep in sync with `server/pyproject.toml ARG CUDA_VERSION=12.1 @@ -159,6 +161,17 @@ COPY server/custom_kernels/ . # Build specific version of transformers RUN python setup.py build +# Build FBGEMM CUDA kernels +FROM kernel-builder AS fbgemm-builder + +WORKDIR /usr/src + +COPY server/Makefile-fbgemm Makefile +COPY server/fbgemm_remove_unused.patch fbgemm_remove_unused.patch +COPY server/fix_torch90a.sh fix_torch90a.sh + +RUN make build-fbgemm + # Build vllm CUDA kernels FROM kernel-builder AS vllm-builder @@ -223,10 +236,10 @@ COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-31 # Copy build artifacts from marlin kernels builder COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages - -# Copy builds artifacts from vllm builder +# Copy build artifacts from fbgemm builder +COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from vllm builder COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages - # Copy build artifacts from mamba builder COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages @@ -241,7 +254,10 @@ COPY server/Makefile server/Makefile RUN cd server && \ make gen-server && \ pip install -r requirements_cuda.txt && \ - pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir + pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir && \ + pip install nvidia-nccl-cu12==2.22.3 + +ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2 # Deps before the binaries # The binaries change on every build given we burn the SHA into them diff --git a/README.md b/README.md index 4c1c1e29..4287c119 100644 --- a/README.md +++ b/README.md @@ -20,19 +20,20 @@ to power Hugging Chat, the Inference API and Inference Endpoint. ## Table of contents -- [Get Started](#get-started) - - [API Documentation](#api-documentation) - - [Using a private or gated model](#using-a-private-or-gated-model) - - [A note on Shared Memory](#a-note-on-shared-memory-shm) - - [Distributed Tracing](#distributed-tracing) - - [Local Install](#local-install) - - [CUDA Kernels](#cuda-kernels) -- [Optimized architectures](#optimized-architectures) -- [Run Mistral](#run-a-model) - - [Run](#run) - - [Quantization](#quantization) -- [Develop](#develop) -- [Testing](#testing) + - [Get Started](#get-started) + - [Docker](#docker) + - [API documentation](#api-documentation) + - [Using a private or gated model](#using-a-private-or-gated-model) + - [A note on Shared Memory (shm)](#a-note-on-shared-memory-shm) + - [Distributed Tracing](#distributed-tracing) + - [Architecture](#architecture) + - [Local install](#local-install) + - [Optimized architectures](#optimized-architectures) + - [Run locally](#run-locally) + - [Run](#run) + - [Quantization](#quantization) + - [Develop](#develop) + - [Testing](#testing) Text Generation Inference (TGI) is a toolkit for deploying and serving Large Language Models (LLMs). TGI enables high-performance text generation for the most popular open-source LLMs, including Llama, Falcon, StarCoder, BLOOM, GPT-NeoX, and [more](https://huggingface.co/docs/text-generation-inference/supported_models). TGI implements many features, such as: diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index a56edaca..e36dd470 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -61,7 +61,7 @@ class ChoiceDeltaToolCall(BaseModel): class ChoiceDelta(BaseModel): role: str content: Optional[str] = None - tool_calls: Optional[ChoiceDeltaToolCall] + tool_calls: Optional[ChoiceDeltaToolCall] = None class Choice(BaseModel): diff --git a/docs/openapi.json b/docs/openapi.json index 9c9a8b1a..7000c7b7 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -492,12 +492,12 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/Completion" + "$ref": "#/components/schemas/CompletionFinal" } }, "text/event-stream": { "schema": { - "$ref": "#/components/schemas/CompletionCompleteChunk" + "$ref": "#/components/schemas/Chunk" } } } @@ -809,7 +809,6 @@ "ChatRequest": { "type": "object", "required": [ - "model", "messages" ], "properties": { @@ -854,7 +853,8 @@ "model": { "type": "string", "description": "[UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", - "example": "mistralai/Mistral-7B-Instruct-v0.2" + "example": "mistralai/Mistral-7B-Instruct-v0.2", + "nullable": true }, "n": { "type": "integer", @@ -909,7 +909,7 @@ "tool_choice": { "allOf": [ { - "$ref": "#/components/schemas/ToolType" + "$ref": "#/components/schemas/ToolChoice" } ], "nullable": true @@ -1116,7 +1116,6 @@ "CompletionRequest": { "type": "object", "required": [ - "model", "prompt" ], "properties": { @@ -1138,7 +1137,8 @@ "model": { "type": "string", "description": "UNUSED\nID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", - "example": "mistralai/Mistral-7B-Instruct-v0.2" + "example": "mistralai/Mistral-7B-Instruct-v0.2", + "nullable": true }, "prompt": { "$ref": "#/components/schemas/Prompt" @@ -1324,6 +1324,17 @@ } } }, + "FunctionName": { + "type": "object", + "required": [ + "name" + ], + "properties": { + "name": { + "type": "string" + } + } + }, "GenerateParameters": { "type": "object", "properties": { @@ -1708,6 +1719,72 @@ } } }, + "MessageChunk": { + "oneOf": [ + { + "type": "object", + "required": [ + "text", + "type" + ], + "properties": { + "text": { + "type": "string" + }, + "type": { + "type": "string", + "enum": [ + "text" + ] + } + } + }, + { + "type": "object", + "required": [ + "image_url", + "type" + ], + "properties": { + "image_url": { + "$ref": "#/components/schemas/Url" + }, + "type": { + "type": "string", + "enum": [ + "image_url" + ] + } + } + } + ], + "discriminator": { + "propertyName": "type" + } + }, + "MessageContent": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "$ref": "#/components/schemas/MessageChunk" + } + } + ] + }, + "OutputMessage": { + "oneOf": [ + { + "$ref": "#/components/schemas/TextMessage" + }, + { + "$ref": "#/components/schemas/ToolCallMessage" + } + ] + }, "PrefillToken": { "type": "object", "required": [ @@ -1834,6 +1911,23 @@ } } }, + "TextMessage": { + "type": "object", + "required": [ + "role", + "content" + ], + "properties": { + "content": { + "type": "string", + "example": "My name is David and I" + }, + "role": { + "type": "string", + "example": "user" + } + } + }, "Token": { "type": "object", "required": [ @@ -1906,6 +2000,49 @@ } } }, + "ToolCallDelta": { + "type": "object", + "required": [ + "role", + "tool_calls" + ], + "properties": { + "role": { + "type": "string", + "example": "assistant" + }, + "tool_calls": { + "$ref": "#/components/schemas/DeltaToolCall" + } + } + }, + "ToolCallMessage": { + "type": "object", + "required": [ + "role", + "tool_calls" + ], + "properties": { + "role": { + "type": "string", + "example": "assistant" + }, + "tool_calls": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolCall" + } + } + } + }, + "ToolChoice": { + "allOf": [ + { + "$ref": "#/components/schemas/ToolType" + } + ], + "nullable": true + }, "ToolType": { "oneOf": [ { @@ -1926,9 +2063,25 @@ "$ref": "#/components/schemas/FunctionName" } } + }, + { + "type": "object", + "default": null, + "nullable": true } ] }, + "Url": { + "type": "object", + "required": [ + "url" + ], + "properties": { + "url": { + "type": "string" + } + } + }, "Usage": { "type": "object", "required": [ diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index c9b4efd9..e97c00aa 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -11,6 +11,8 @@ title: Using TGI with Intel Gaudi - local: installation_inferentia title: Using TGI with AWS Inferentia + - local: installation_intel + title: Using TGI with Intel GPUs - local: installation title: Installation from source - local: supported_models @@ -19,6 +21,8 @@ title: Messages API - local: architecture title: Internal Architecture + - local: usage_statistics + title: Usage Statistics title: Getting started - sections: - local: basic_tutorials/consuming_tgi diff --git a/docs/source/architecture.md b/docs/source/architecture.md index a8418817..28c84f62 100644 --- a/docs/source/architecture.md +++ b/docs/source/architecture.md @@ -103,6 +103,7 @@ Several variants of the model server exist that are actively supported by Huggin - By default, the model server will attempt building [a server optimized for Nvidia GPUs with CUDA](https://huggingface.co/docs/text-generation-inference/installation_nvidia). The code for this version is hosted in the [main TGI repository](https://github.com/huggingface/text-generation-inference). - A [version optimized for AMD with ROCm](https://huggingface.co/docs/text-generation-inference/installation_amd) is hosted in the main TGI repository. Some model features differ. +- A [version optimized for Intel GPUs](https://huggingface.co/docs/text-generation-inference/installation_intel) is hosted in the main TGI repository. Some model features differ. - The [version for Intel Gaudi](https://huggingface.co/docs/text-generation-inference/installation_gaudi) is maintained on a forked repository, often resynchronized with the main [TGI repository](https://github.com/huggingface/tgi-gaudi). - A [version for Neuron (AWS Inferentia2)](https://huggingface.co/docs/text-generation-inference/installation_inferentia) is maintained as part of [Optimum Neuron](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference). - A version for Google TPUs is maintained as part of [Optimum TPU](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference). diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 5e40146f..77f88490 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -424,6 +424,22 @@ Options: [env: LORA_ADAPTERS=] +``` +## DISABLE_USAGE_STATS +```shell + --disable-usage-stats + Disable sending of all usage statistics + + [env: DISABLE_USAGE_STATS=] + +``` +## DISABLE_CRASH_REPORTS +```shell + --disable-crash-reports + Disable sending of crash reports, but allow anonymous usage statistics + + [env: DISABLE_CRASH_REPORTS=] + ``` ## HELP ```shell diff --git a/docs/source/installation_intel.md b/docs/source/installation_intel.md new file mode 100644 index 00000000..f9fda863 --- /dev/null +++ b/docs/source/installation_intel.md @@ -0,0 +1,19 @@ +# Using TGI with Intel GPUs + +TGI optimized models are supported on Intel Data Center GPU [Max1100](https://www.intel.com/content/www/us/en/products/sku/232876/intel-data-center-gpu-max-1100/specifications.html), [Max1550](https://www.intel.com/content/www/us/en/products/sku/232873/intel-data-center-gpu-max-1550/specifications.html), the recommended usage is through Docker. + + +On a server powered by Intel GPUs, TGI can be launched with the following command: + +```bash +model=teknium/OpenHermes-2.5-Mistral-7B +volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run + +docker run --rm --privileged --cap-add=sys_nice \ + --device=/dev/dri \ + --ipc=host --shm-size 1g --net host -v $volume:/data \ + ghcr.io/huggingface/text-generation-inference:latest-intel \ + --model-id $model --cuda-graphs 0 +``` + +The launched TGI server can then be queried from clients, make sure to check out the [Consuming TGI](./basic_tutorials/consuming_tgi) guide. diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md index c546bc03..f056baad 100644 --- a/docs/source/quicktour.md +++ b/docs/source/quicktour.md @@ -17,7 +17,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ ### Supported hardware -TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on. +TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Intel GPUs](./installation_intel), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on. ## Consuming TGI diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index 2bdd00de..bc124f31 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -5,6 +5,7 @@ Text Generation Inference enables serving optimized models on specific hardware ## Supported Models +- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2) - [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal) - [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal) - [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) diff --git a/docs/source/usage_statistics.md b/docs/source/usage_statistics.md new file mode 100644 index 00000000..adf0d70f --- /dev/null +++ b/docs/source/usage_statistics.md @@ -0,0 +1,73 @@ + +# Collection of Usage Statistics + +Text Generation Inference collects anonymous usage statistics to help us improve the service. The collected data is used to improve TGI and to understand what causes failures. The data is collected transparently and any sensitive information is omitted. + +Data is sent twice, once on server startup and once when server stops. Also, usage statistics are only enabled when TGI is running in docker to avoid collecting data then TGI runs directly on the host machine. + +## What data is collected + +The code that collects the data is available [here](https://github.com/huggingface/text-generation-inference/blob/main/router/src/usage_stats.rs). +As of release 2.1.2 this is an example of the data collected: + +- From the TGI configuration: +```json +{ + "event_type": "start", + "disable_grammar_support": false, + "max_batch_prefill_tokens": 4096, + "max_batch_size": null, + "max_batch_total_tokens": null, + "max_best_of": 2, + "max_client_batch_size": 4, + "max_concurrent_requests": 128, + "max_input_tokens": 1024, + "max_stop_sequences": 4, + "max_top_n_tokens": 5, + "max_total_tokens": 2048, + "max_waiting_tokens": 20, + "messages_api_enabled": false, + "model_config": { + "model_type": "Bloom" + }, + "revision": null, + "tokenizer_class": "BloomTokenizerFast", + "validation_workers": 2, + "waiting_served_ratio": 1.2, + "docker_label": "latest", + "git_sha": "cfc118704880453d29bcbe4fbbd91dda501cf5fe", + "nvidia_env": { + "name": "NVIDIA A10G", + "pci_bus_id": "00000000:00:1E.0", + "driver_version": "535.183.01", + "pstate": "P8", + "pcie_link_gen_max": "4", + "pcie_link_gen_current": "1", + "temperature_gpu": "31", + "utilization_gpu": "0 %", + "utilization_memory": "0 %", + "memory_total": "23028 MiB", + "memory_free": "22515 MiB", + "memory_used": "0 MiB", + "reset_status_reset_required": "No", + "reset_status_drain_and_reset_recommended": "No", + "compute_cap": "8.6", + "ecc_errors_corrected_volatile_total": "0", + "mig_mode_current": "[N/A]", + "power_draw_instant": "10.86 W", + "power_limit": "300.00 W" + }, + "system_env": { + "cpu_count": 16, + "cpu_type": "AMD EPYC 7R32", + "total_memory": 66681196544, + "architecture": "x86_64", + "platform": "linux-unix-x86_64" + } +} + +``` + +## How to opt-out + +You can easily opt out by passing the `--disable-usage-stats` to the text-generation-launcher command. This will disable all usage statistics. You can also pass `--disable-crash-reports` which disables sending specific crash reports, but allows anonymous usage statistics. diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index f5f38ac6..60146ad1 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -333,6 +333,8 @@ def launcher(event_loop): max_input_length: Optional[int] = None, max_batch_prefill_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None, + lora_adapters: Optional[List[str]] = None, + cuda_graphs: Optional[List[int]] = None, ): port = random.randint(8000, 10_000) master_port = random.randint(10_000, 20_000) @@ -379,6 +381,14 @@ def launcher(event_loop): if max_total_tokens: args.append("--max-total-tokens") args.append(str(max_total_tokens)) + if lora_adapters: + args.append("--lora-adapters") + args.append(",".join(lora_adapters)) + if cuda_graphs: + args.append("--cuda-graphs") + args.append(",".join(map(str, cuda_graphs))) + + print(" ".join(args), file=sys.stderr) env["LOG_LEVEL"] = "info,text_generation_router=debug" @@ -418,6 +428,8 @@ def launcher(event_loop): max_input_length: Optional[int] = None, max_batch_prefill_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None, + lora_adapters: Optional[List[str]] = None, + cuda_graphs: Optional[List[int]] = None, ): port = random.randint(8000, 10_000) @@ -447,6 +459,12 @@ def launcher(event_loop): if max_total_tokens: args.append("--max-total-tokens") args.append(str(max_total_tokens)) + if lora_adapters: + args.append("--lora-adapters") + args.append(",".join(lora_adapters)) + if cuda_graphs: + args.append("--cuda-graphs") + args.append(",".join(map(str, cuda_graphs))) client = docker.from_env() diff --git a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json new file mode 100644 index 00000000..03f90367 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.1875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.5546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.84375, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.34375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.8359375, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.0859375, + "special": false, + "text": " is" + }, + { + "id": 254, + "logprob": -1.5390625, + "special": false, + "text": " the" + }, + { + "id": 1022, + "logprob": -1.1875, + "special": false, + "text": " first" + }, + { + "id": 3458, + "logprob": -0.35546875, + "special": false, + "text": " step" + }, + { + "id": 279, + "logprob": -0.8828125, + "special": false, + "text": " in" + }, + { + "id": 254, + "logprob": -0.71484375, + "special": false, + "text": " the" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is the first step in the" +} diff --git a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json new file mode 100644 index 00000000..e84135cf --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.1875, + "text": " request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 2143, + "logprob": -1.828125, + "special": false, + "text": " sent" + }, + { + "id": 10081, + "logprob": -0.36914062, + "special": false, + "text": " successfully" + }, + { + "id": 13, + "logprob": 0.0, + "special": false, + "text": "." + }, + { + "id": 185, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 1380, + "logprob": -0.38671875, + "special": false, + "text": "We" + }, + { + "id": 543, + "logprob": -0.12695312, + "special": false, + "text": " will" + }, + { + "id": 752, + "logprob": -0.20117188, + "special": false, + "text": " get" + }, + { + "id": 279, + "logprob": 0.0, + "special": false, + "text": " in" + }, + { + "id": 5402, + "logprob": 0.0, + "special": false, + "text": " touch" + }, + { + "id": 366, + "logprob": 0.0, + "special": false, + "text": " with" + } + ], + "top_tokens": null + }, + "generated_text": "Test request sent successfully.\nWe will get in touch with" +} diff --git a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json new file mode 100644 index 00000000..a4b9784a --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.1875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.5546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.8125, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.890625, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.1484375, + "special": false, + "text": " is" + }, + { + "id": 245, + "logprob": -1.5390625, + "special": false, + "text": " a" + }, + { + "id": 3102, + "logprob": -2.609375, + "special": false, + "text": " request" + }, + { + "id": 327, + "logprob": -0.75, + "special": false, + "text": " for" + }, + { + "id": 245, + "logprob": -1.1171875, + "special": false, + "text": " a" + }, + { + "id": 1727, + "logprob": -0.90625, + "special": false, + "text": " test" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is a request for a test" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.25, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.5546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.8125, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.890625, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.1484375, + "special": false, + "text": " is" + }, + { + "id": 245, + "logprob": -1.5390625, + "special": false, + "text": " a" + }, + { + "id": 3102, + "logprob": -2.609375, + "special": false, + "text": " request" + }, + { + "id": 327, + "logprob": -0.75, + "special": false, + "text": " for" + }, + { + "id": 245, + "logprob": -1.1171875, + "special": false, + "text": " a" + }, + { + "id": 1727, + "logprob": -0.90625, + "special": false, + "text": " test" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is a request for a test" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.25, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.5546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.8125, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.890625, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.1484375, + "special": false, + "text": " is" + }, + { + "id": 245, + "logprob": -1.5390625, + "special": false, + "text": " a" + }, + { + "id": 3102, + "logprob": -2.609375, + "special": false, + "text": " request" + }, + { + "id": 327, + "logprob": -0.75, + "special": false, + "text": " for" + }, + { + "id": 245, + "logprob": -1.1171875, + "special": false, + "text": " a" + }, + { + "id": 1727, + "logprob": -0.90625, + "special": false, + "text": " test" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is a request for a test" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.25, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.5546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.8125, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.890625, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.1484375, + "special": false, + "text": " is" + }, + { + "id": 245, + "logprob": -1.5390625, + "special": false, + "text": " a" + }, + { + "id": 3102, + "logprob": -2.609375, + "special": false, + "text": " request" + }, + { + "id": 327, + "logprob": -0.75, + "special": false, + "text": " for" + }, + { + "id": 245, + "logprob": -1.1171875, + "special": false, + "text": " a" + }, + { + "id": 1727, + "logprob": -0.90625, + "special": false, + "text": " test" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is a request for a test" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8.json new file mode 100644 index 00000000..85cfb91f --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 2323, + "logprob": -9.421875, + "text": "Test" + }, + { + "id": 1715, + "logprob": -10.546875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 369, + "logprob": -2.1816406, + "special": false, + "text": " for" + }, + { + "id": 279, + "logprob": -2.6992188, + "special": false, + "text": " the" + }, + { + "id": 220, + "logprob": -3.6308594, + "special": false, + "text": " " + }, + { + "id": 679, + "logprob": -1.7900391, + "special": false, + "text": "201" + }, + { + "id": 24, + "logprob": -1.3554688, + "special": false, + "text": "9" + }, + { + "id": 12, + "logprob": -2.0039062, + "special": false, + "text": "-" + }, + { + "id": 2366, + "logprob": -0.4489746, + "special": false, + "text": "202" + }, + { + "id": 15, + "logprob": -0.037109375, + "special": false, + "text": "0" + }, + { + "id": 2978, + "logprob": -0.8100586, + "special": false, + "text": " school" + }, + { + "id": 1060, + "logprob": -0.013015747, + "special": false, + "text": " year" + } + ], + "top_tokens": null + }, + "generated_text": " for the 2019-2020 school year" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json new file mode 100644 index 00000000..dcb4d063 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 2323, + "logprob": -9.421875, + "text": "Test" + }, + { + "id": 1715, + "logprob": -10.546875, + "text": " request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 25, + "logprob": -0.8535156, + "special": false, + "text": ":" + }, + { + "id": 2209, + "logprob": -2.4804688, + "special": false, + "text": " Is" + }, + { + "id": 279, + "logprob": -0.7167969, + "special": false, + "text": " the" + }, + { + "id": 734, + "logprob": -2.625, + "special": false, + "text": " function" + }, + { + "id": 330, + "logprob": -0.35131836, + "special": false, + "text": " \"" + }, + { + "id": 4110, + "logprob": -2.4101562, + "special": false, + "text": "Create" + }, + { + "id": 264, + "logprob": -0.23181152, + "special": false, + "text": " a" + }, + { + "id": 502, + "logprob": -0.25512695, + "special": false, + "text": " new" + }, + { + "id": 1052, + "logprob": -1.2792969, + "special": false, + "text": " file" + }, + { + "id": 1, + "logprob": -1.2529297, + "special": false, + "text": "\"" + } + ], + "top_tokens": null + }, + "generated_text": "Test request: Is the function \"Create a new file\"" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_load.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_load.json new file mode 100644 index 00000000..36c87c09 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 2323, + "logprob": -9.421875, + "text": "Test" + }, + { + "id": 1715, + "logprob": -10.546875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 369, + "logprob": -2.1816406, + "special": false, + "text": " for" + }, + { + "id": 279, + "logprob": -2.6992188, + "special": false, + "text": " the" + }, + { + "id": 220, + "logprob": -3.6308594, + "special": false, + "text": " " + }, + { + "id": 679, + "logprob": -1.7988281, + "special": false, + "text": "201" + }, + { + "id": 24, + "logprob": -1.3535156, + "special": false, + "text": "9" + }, + { + "id": 12, + "logprob": -2.0058594, + "special": false, + "text": "-" + }, + { + "id": 2366, + "logprob": -0.45410156, + "special": false, + "text": "202" + }, + { + "id": 15, + "logprob": -0.037109375, + "special": false, + "text": "0" + }, + { + "id": 2978, + "logprob": -0.8095703, + "special": false, + "text": " school" + }, + { + "id": 1060, + "logprob": -0.013053894, + "special": false, + "text": " year" + } + ], + "top_tokens": null + }, + "generated_text": " for the 2019-2020 school year" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 2323, + "logprob": -9.421875, + "text": "Test" + }, + { + "id": 1715, + "logprob": -10.546875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 369, + "logprob": -2.1816406, + "special": false, + "text": " for" + }, + { + "id": 279, + "logprob": -2.6992188, + "special": false, + "text": " the" + }, + { + "id": 220, + "logprob": -3.6308594, + "special": false, + "text": " " + }, + { + "id": 679, + "logprob": -1.7988281, + "special": false, + "text": "201" + }, + { + "id": 24, + "logprob": -1.3535156, + "special": false, + "text": "9" + }, + { + "id": 12, + "logprob": -2.0058594, + "special": false, + "text": "-" + }, + { + "id": 2366, + "logprob": -0.45410156, + "special": false, + "text": "202" + }, + { + "id": 15, + "logprob": -0.037109375, + "special": false, + "text": "0" + }, + { + "id": 2978, + "logprob": -0.8095703, + "special": false, + "text": " school" + }, + { + "id": 1060, + "logprob": -0.013053894, + "special": false, + "text": " year" + } + ], + "top_tokens": null + }, + "generated_text": " for the 2019-2020 school year" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 2323, + "logprob": -9.421875, + "text": "Test" + }, + { + "id": 1715, + "logprob": -10.546875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 369, + "logprob": -2.1816406, + "special": false, + "text": " for" + }, + { + "id": 279, + "logprob": -2.6992188, + "special": false, + "text": " the" + }, + { + "id": 220, + "logprob": -3.6308594, + "special": false, + "text": " " + }, + { + "id": 679, + "logprob": -1.7988281, + "special": false, + "text": "201" + }, + { + "id": 24, + "logprob": -1.3535156, + "special": false, + "text": "9" + }, + { + "id": 12, + "logprob": -2.0058594, + "special": false, + "text": "-" + }, + { + "id": 2366, + "logprob": -0.45410156, + "special": false, + "text": "202" + }, + { + "id": 15, + "logprob": -0.037109375, + "special": false, + "text": "0" + }, + { + "id": 2978, + "logprob": -0.8095703, + "special": false, + "text": " school" + }, + { + "id": 1060, + "logprob": -0.013053894, + "special": false, + "text": " year" + } + ], + "top_tokens": null + }, + "generated_text": " for the 2019-2020 school year" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 2323, + "logprob": -9.421875, + "text": "Test" + }, + { + "id": 1715, + "logprob": -10.546875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 369, + "logprob": -2.1816406, + "special": false, + "text": " for" + }, + { + "id": 279, + "logprob": -2.6992188, + "special": false, + "text": " the" + }, + { + "id": 220, + "logprob": -3.6308594, + "special": false, + "text": " " + }, + { + "id": 679, + "logprob": -1.7988281, + "special": false, + "text": "201" + }, + { + "id": 24, + "logprob": -1.3535156, + "special": false, + "text": "9" + }, + { + "id": 12, + "logprob": -2.0058594, + "special": false, + "text": "-" + }, + { + "id": 2366, + "logprob": -0.45410156, + "special": false, + "text": "202" + }, + { + "id": 15, + "logprob": -0.037109375, + "special": false, + "text": "0" + }, + { + "id": 2978, + "logprob": -0.8095703, + "special": false, + "text": " school" + }, + { + "id": 1060, + "logprob": -0.013053894, + "special": false, + "text": " year" + } + ], + "top_tokens": null + }, + "generated_text": " for the 2019-2020 school year" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin.json b/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin.json new file mode 100644 index 00000000..94883de5 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.0859375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -16.359375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 5229, + "logprob": -2.7988281, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": -0.91259766, + "special": false, + "text": ":" + }, + { + "id": 853, + "logprob": -2.8496094, + "special": false, + "text": " Un" + }, + { + "id": 23765, + "logprob": -1.1894531, + "special": false, + "text": "supported" + }, + { + "id": 4714, + "logprob": -1.5917969, + "special": false, + "text": " browser" + }, + { + "id": 29892, + "logprob": -0.34765625, + "special": false, + "text": "," + }, + { + "id": 1873, + "logprob": -1.2695312, + "special": false, + "text": " version" + }, + { + "id": 470, + "logprob": -0.25170898, + "special": false, + "text": " or" + }, + { + "id": 7481, + "logprob": -0.21411133, + "special": false, + "text": " platform" + }, + { + "id": 13, + "logprob": -1.1162109, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": " failed: Unsupported browser, version or platform\n" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_all_params.json new file mode 100644 index 00000000..58cacb80 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.0859375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -16.359375, + "text": "request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 5229, + "logprob": -0.6645508, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": 0.0, + "special": false, + "text": ":" + }, + { + "id": 6527, + "logprob": -2.2324219, + "special": false, + "text": " Could" + }, + { + "id": 451, + "logprob": 0.0, + "special": false, + "text": " not" + }, + { + "id": 6088, + "logprob": -1.6074219, + "special": false, + "text": " parse" + }, + { + "id": 1243, + "logprob": -1.6298828, + "special": false, + "text": " test" + }, + { + "id": 1206, + "logprob": -0.72558594, + "special": false, + "text": " case" + }, + { + "id": 1024, + "logprob": -0.40429688, + "special": false, + "text": " name" + }, + { + "id": 515, + "logprob": 0.0, + "special": false, + "text": " from" + }, + { + "id": 525, + "logprob": -1.2519531, + "special": false, + "text": " '" + } + ], + "top_tokens": null + }, + "generated_text": "Test request failed: Could not parse test case name from '" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_load.json b/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_load.json new file mode 100644 index 00000000..96a40fa4 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.0859375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -16.359375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 5229, + "logprob": -2.7988281, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": -0.91259766, + "special": false, + "text": ":" + }, + { + "id": 853, + "logprob": -2.8496094, + "special": false, + "text": " Un" + }, + { + "id": 23765, + "logprob": -1.1894531, + "special": false, + "text": "supported" + }, + { + "id": 4714, + "logprob": -1.5917969, + "special": false, + "text": " browser" + }, + { + "id": 29892, + "logprob": -0.34765625, + "special": false, + "text": "," + }, + { + "id": 1873, + "logprob": -1.2695312, + "special": false, + "text": " version" + }, + { + "id": 470, + "logprob": -0.25170898, + "special": false, + "text": " or" + }, + { + "id": 7481, + "logprob": -0.21411133, + "special": false, + "text": " platform" + }, + { + "id": 13, + "logprob": -1.1162109, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": " failed: Unsupported browser, version or platform\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.0859375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -16.359375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 5229, + "logprob": -2.7988281, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": -0.91259766, + "special": false, + "text": ":" + }, + { + "id": 853, + "logprob": -2.8496094, + "special": false, + "text": " Un" + }, + { + "id": 23765, + "logprob": -1.1894531, + "special": false, + "text": "supported" + }, + { + "id": 4714, + "logprob": -1.5917969, + "special": false, + "text": " browser" + }, + { + "id": 29892, + "logprob": -0.34765625, + "special": false, + "text": "," + }, + { + "id": 1873, + "logprob": -1.2695312, + "special": false, + "text": " version" + }, + { + "id": 470, + "logprob": -0.25170898, + "special": false, + "text": " or" + }, + { + "id": 7481, + "logprob": -0.21411133, + "special": false, + "text": " platform" + }, + { + "id": 13, + "logprob": -1.1162109, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": " failed: Unsupported browser, version or platform\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.0859375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -16.359375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 5229, + "logprob": -2.7988281, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": -0.91259766, + "special": false, + "text": ":" + }, + { + "id": 853, + "logprob": -2.8496094, + "special": false, + "text": " Un" + }, + { + "id": 23765, + "logprob": -1.1894531, + "special": false, + "text": "supported" + }, + { + "id": 4714, + "logprob": -1.5917969, + "special": false, + "text": " browser" + }, + { + "id": 29892, + "logprob": -0.34765625, + "special": false, + "text": "," + }, + { + "id": 1873, + "logprob": -1.2695312, + "special": false, + "text": " version" + }, + { + "id": 470, + "logprob": -0.25170898, + "special": false, + "text": " or" + }, + { + "id": 7481, + "logprob": -0.21411133, + "special": false, + "text": " platform" + }, + { + "id": 13, + "logprob": -1.1162109, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": " failed: Unsupported browser, version or platform\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.0859375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -16.359375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 5229, + "logprob": -2.7988281, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": -0.91259766, + "special": false, + "text": ":" + }, + { + "id": 853, + "logprob": -2.8496094, + "special": false, + "text": " Un" + }, + { + "id": 23765, + "logprob": -1.1894531, + "special": false, + "text": "supported" + }, + { + "id": 4714, + "logprob": -1.5917969, + "special": false, + "text": " browser" + }, + { + "id": 29892, + "logprob": -0.34765625, + "special": false, + "text": "," + }, + { + "id": 1873, + "logprob": -1.2695312, + "special": false, + "text": " version" + }, + { + "id": 470, + "logprob": -0.25170898, + "special": false, + "text": " or" + }, + { + "id": 7481, + "logprob": -0.21411133, + "special": false, + "text": " platform" + }, + { + "id": 13, + "logprob": -1.1162109, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": " failed: Unsupported browser, version or platform\n" + } +] diff --git a/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_with_customer_support_adapter.json b/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_with_customer_support_adapter.json new file mode 100644 index 00000000..dfdd2cc3 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_with_customer_support_adapter.json @@ -0,0 +1,251 @@ +{ + "details": { + "finish_reason": "length", + "generated_tokens": 40, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -0.27416992, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.17016602, + "special": false, + "text": "\n" + }, + { + "id": 28737, + "logprob": -2.7109375, + "special": false, + "text": "I" + }, + { + "id": 28809, + "logprob": -1.5, + "special": false, + "text": "’" + }, + { + "id": 28719, + "logprob": -0.34204102, + "special": false, + "text": "m" + }, + { + "id": 459, + "logprob": -1.6914062, + "special": false, + "text": " not" + }, + { + "id": 1864, + "logprob": -0.69140625, + "special": false, + "text": " sure" + }, + { + "id": 513, + "logprob": -1.6171875, + "special": false, + "text": " if" + }, + { + "id": 315, + "logprob": -1.3837891, + "special": false, + "text": " I" + }, + { + "id": 541, + "logprob": -1.2226562, + "special": false, + "text": " can" + }, + { + "id": 1567, + "logprob": -1.8652344, + "special": false, + "text": " come" + }, + { + "id": 582, + "logprob": -0.0070228577, + "special": false, + "text": " up" + }, + { + "id": 395, + "logprob": -0.0054092407, + "special": false, + "text": " with" + }, + { + "id": 28705, + "logprob": -0.62597656, + "special": false, + "text": " " + }, + { + "id": 28770, + "logprob": -0.0035572052, + "special": false, + "text": "3" + }, + { + "id": 4842, + "logprob": -0.93603516, + "special": false, + "text": " unique" + }, + { + "id": 3085, + "logprob": -0.028411865, + "special": false, + "text": " words" + }, + { + "id": 369, + "logprob": -1.0400391, + "special": false, + "text": " that" + }, + { + "id": 6685, + "logprob": -0.09710693, + "special": false, + "text": " describe" + }, + { + "id": 528, + "logprob": -0.066467285, + "special": false, + "text": " me" + }, + { + "id": 28725, + "logprob": -1.0722656, + "special": false, + "text": "," + }, + { + "id": 562, + "logprob": -0.33422852, + "special": false, + "text": " but" + }, + { + "id": 315, + "logprob": -0.5136719, + "special": false, + "text": " I" + }, + { + "id": 28809, + "logprob": -0.8989258, + "special": false, + "text": "’" + }, + { + "id": 584, + "logprob": -0.2076416, + "special": false, + "text": "ll" + }, + { + "id": 1464, + "logprob": -0.8808594, + "special": false, + "text": " try" + }, + { + "id": 28723, + "logprob": -0.88427734, + "special": false, + "text": "." + }, + { + "id": 13, + "logprob": -0.91064453, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.08105469, + "special": false, + "text": "\n" + }, + { + "id": 28740, + "logprob": -1.8486328, + "special": false, + "text": "1" + }, + { + "id": 28723, + "logprob": -0.111572266, + "special": false, + "text": "." + }, + { + "id": 23626, + "logprob": -3.15625, + "special": false, + "text": " Creative" + }, + { + "id": 13, + "logprob": -0.9194336, + "special": false, + "text": "\n" + }, + { + "id": 28750, + "logprob": -0.24841309, + "special": false, + "text": "2" + }, + { + "id": 28723, + "logprob": -9.393692e-05, + "special": false, + "text": "." + }, + { + "id": 6785, + "logprob": -3.1386719, + "special": false, + "text": " Fun" + }, + { + "id": 1780, + "logprob": -0.53564453, + "special": false, + "text": "ny" + }, + { + "id": 13, + "logprob": -0.09033203, + "special": false, + "text": "\n" + }, + { + "id": 28770, + "logprob": -0.00466156, + "special": false, + "text": "3" + }, + { + "id": 28723, + "logprob": -0.00016450882, + "special": false, + "text": "." + } + ] + }, + "generated_text": "\n\nI’m not sure if I can come up with 3 unique words that describe me, but I’ll try.\n\n1. Creative\n2. Funny\n3." +} diff --git a/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_with_dbpedia_adapter.json b/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_with_dbpedia_adapter.json new file mode 100644 index 00000000..91eb5edf --- /dev/null +++ b/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_with_dbpedia_adapter.json @@ -0,0 +1,53 @@ +{ + "details": { + "finish_reason": "eos_token", + "generated_tokens": 7, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 1, + "logprob": -0.49658203, + "special": true, + "text": "" + }, + { + "id": 28705, + "logprob": -0.0016384125, + "special": false, + "text": " " + }, + { + "id": 1, + "logprob": -1.4931641, + "special": true, + "text": "" + }, + { + "id": 28705, + "logprob": -0.00075769424, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -0.25024414, + "special": false, + "text": "1" + }, + { + "id": 28740, + "logprob": -0.2631836, + "special": false, + "text": "1" + }, + { + "id": 2, + "logprob": -0.0003285408, + "special": true, + "text": "" + } + ] + }, + "generated_text": " 11" +} diff --git a/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_without_adapter.json b/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_without_adapter.json new file mode 100644 index 00000000..13018688 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_without_adapter.json @@ -0,0 +1,251 @@ +{ + "details": { + "finish_reason": "length", + "generated_tokens": 40, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.0488281, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -1.0800781, + "special": false, + "text": "\n" + }, + { + "id": 27332, + "logprob": -2.1152344, + "special": false, + "text": "###" + }, + { + "id": 28705, + "logprob": -1.6748047, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -0.097229004, + "special": false, + "text": "1" + }, + { + "id": 28723, + "logprob": -0.16467285, + "special": false, + "text": "." + }, + { + "id": 7615, + "logprob": -2.2246094, + "special": false, + "text": " News" + }, + { + "id": 13, + "logprob": -1.0488281, + "special": false, + "text": "\n" + }, + { + "id": 27332, + "logprob": -0.69189453, + "special": false, + "text": "###" + }, + { + "id": 28705, + "logprob": -0.013343811, + "special": false, + "text": " " + }, + { + "id": 28750, + "logprob": -0.011230469, + "special": false, + "text": "2" + }, + { + "id": 28723, + "logprob": -0.00096845627, + "special": false, + "text": "." + }, + { + "id": 21095, + "logprob": -2.5605469, + "special": false, + "text": " Blog" + }, + { + "id": 13, + "logprob": -0.19458008, + "special": false, + "text": "\n" + }, + { + "id": 27332, + "logprob": -0.031280518, + "special": false, + "text": "###" + }, + { + "id": 28705, + "logprob": -0.0030708313, + "special": false, + "text": " " + }, + { + "id": 28770, + "logprob": -0.0029277802, + "special": false, + "text": "3" + }, + { + "id": 28723, + "logprob": -0.0012350082, + "special": false, + "text": "." + }, + { + "id": 20108, + "logprob": -2.1582031, + "special": false, + "text": " Article" + }, + { + "id": 13, + "logprob": -0.05810547, + "special": false, + "text": "\n" + }, + { + "id": 27332, + "logprob": -0.35083008, + "special": false, + "text": "###" + }, + { + "id": 28705, + "logprob": -0.034332275, + "special": false, + "text": " " + }, + { + "id": 28781, + "logprob": -0.009666443, + "special": false, + "text": "4" + }, + { + "id": 28723, + "logprob": -0.0013113022, + "special": false, + "text": "." + }, + { + "id": 8349, + "logprob": -2.6191406, + "special": false, + "text": " Review" + }, + { + "id": 13, + "logprob": -0.04031372, + "special": false, + "text": "\n" + }, + { + "id": 27332, + "logprob": -0.45239258, + "special": false, + "text": "###" + }, + { + "id": 28705, + "logprob": -0.045410156, + "special": false, + "text": " " + }, + { + "id": 28782, + "logprob": -0.0041236877, + "special": false, + "text": "5" + }, + { + "id": 28723, + "logprob": -0.0010223389, + "special": false, + "text": "." + }, + { + "id": 5299, + "logprob": -2.8066406, + "special": false, + "text": " Other" + }, + { + "id": 13, + "logprob": -0.12054443, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.44580078, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -1.4921875, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -1.3574219, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -1.0039062, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.5859375, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.43481445, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.2783203, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.20410156, + "special": false, + "text": "\n" + } + ] + }, + "generated_text": "\n\n### 1. News\n### 2. Blog\n### 3. Article\n### 4. Review\n### 5. Other\n\n\n\n\n\n\n\n\n" +} diff --git a/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_without_customer_support_adapter.json b/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_without_customer_support_adapter.json new file mode 100644 index 00000000..8c00dee7 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_without_customer_support_adapter.json @@ -0,0 +1,251 @@ +{ + "details": { + "finish_reason": "length", + "generated_tokens": 40, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -0.31347656, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.27441406, + "special": false, + "text": "\n" + }, + { + "id": 28737, + "logprob": -2.2285156, + "special": false, + "text": "I" + }, + { + "id": 28809, + "logprob": -1.4677734, + "special": false, + "text": "’" + }, + { + "id": 28719, + "logprob": -0.31762695, + "special": false, + "text": "m" + }, + { + "id": 264, + "logprob": -1.6865234, + "special": false, + "text": " a" + }, + { + "id": 1215, + "logprob": -3.2695312, + "special": false, + "text": " very" + }, + { + "id": 20640, + "logprob": -3.1230469, + "special": false, + "text": " passionate" + }, + { + "id": 1338, + "logprob": -0.48339844, + "special": false, + "text": " person" + }, + { + "id": 28723, + "logprob": -0.9970703, + "special": false, + "text": "." + }, + { + "id": 315, + "logprob": -0.5498047, + "special": false, + "text": " I" + }, + { + "id": 28809, + "logprob": -1.1923828, + "special": false, + "text": "’" + }, + { + "id": 28719, + "logprob": -0.080444336, + "special": false, + "text": "m" + }, + { + "id": 1215, + "logprob": -1.8271484, + "special": false, + "text": " very" + }, + { + "id": 12215, + "logprob": -2.8847656, + "special": false, + "text": " driven" + }, + { + "id": 28723, + "logprob": -1.0927734, + "special": false, + "text": "." + }, + { + "id": 315, + "logprob": -0.4584961, + "special": false, + "text": " I" + }, + { + "id": 28809, + "logprob": -0.5019531, + "special": false, + "text": "’" + }, + { + "id": 28719, + "logprob": -0.030715942, + "special": false, + "text": "m" + }, + { + "id": 1215, + "logprob": -0.96972656, + "special": false, + "text": " very" + }, + { + "id": 7798, + "logprob": -2.8847656, + "special": false, + "text": " determined" + }, + { + "id": 28723, + "logprob": -0.27319336, + "special": false, + "text": "." + }, + { + "id": 13, + "logprob": -0.56396484, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.011016846, + "special": false, + "text": "\n" + }, + { + "id": 3195, + "logprob": -0.7163086, + "special": false, + "text": "What" + }, + { + "id": 349, + "logprob": -1.1611328, + "special": false, + "text": " is" + }, + { + "id": 574, + "logprob": -0.515625, + "special": false, + "text": " your" + }, + { + "id": 6656, + "logprob": -1.0253906, + "special": false, + "text": " favorite" + }, + { + "id": 1970, + "logprob": -2.1738281, + "special": false, + "text": " thing" + }, + { + "id": 684, + "logprob": -0.48364258, + "special": false, + "text": " about" + }, + { + "id": 1250, + "logprob": -1.8876953, + "special": false, + "text": " being" + }, + { + "id": 264, + "logprob": -0.41967773, + "special": false, + "text": " a" + }, + { + "id": 8626, + "logprob": -2.9160156, + "special": false, + "text": " teacher" + }, + { + "id": 28804, + "logprob": -0.11920166, + "special": false, + "text": "?" + }, + { + "id": 13, + "logprob": -0.023727417, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.010848999, + "special": false, + "text": "\n" + }, + { + "id": 28737, + "logprob": -1.0566406, + "special": false, + "text": "I" + }, + { + "id": 2016, + "logprob": -0.7163086, + "special": false, + "text": " love" + }, + { + "id": 272, + "logprob": -1.9169922, + "special": false, + "text": " the" + }, + { + "id": 1639, + "logprob": -2.03125, + "special": false, + "text": " fact" + } + ] + }, + "generated_text": "\n\nI’m a very passionate person. I’m very driven. I’m very determined.\n\nWhat is your favorite thing about being a teacher?\n\nI love the fact" +} diff --git a/integration-tests/models/test_completion_prompts.py b/integration-tests/models/test_completion_prompts.py index 0efb6693..d787873b 100644 --- a/integration-tests/models/test_completion_prompts.py +++ b/integration-tests/models/test_completion_prompts.py @@ -100,6 +100,8 @@ async def test_flash_llama_completion_many_prompts_stream( chunk = [c.replace("data:", "") for c in chunk] # remove empty strings chunk = [c for c in chunk if c] + # remove completion marking chunk + chunk = [c for c in chunk if c != " [DONE]"] # parse json chunk = [json.loads(c) for c in chunk] diff --git a/integration-tests/models/test_flash_deepseek_v2.py b/integration-tests/models/test_flash_deepseek_v2.py new file mode 100644 index 00000000..010e08c9 --- /dev/null +++ b/integration-tests/models/test_flash_deepseek_v2.py @@ -0,0 +1,63 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_deepseek_v2_handle(launcher): + with launcher("deepseek-ai/DeepSeek-V2-Lite", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_deepseek_v2(flash_deepseek_v2_handle): + await flash_deepseek_v2_handle.health(300) + return flash_deepseek_v2_handle.client + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_deepseek_v2(flash_deepseek_v2, response_snapshot): + response = await flash_deepseek_v2.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_deepseek_v2_all_params(flash_deepseek_v2, response_snapshot): + response = await flash_deepseek_v2.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_deepseek_v2_load( + flash_deepseek_v2, generate_load, response_snapshot +): + responses = await generate_load( + flash_deepseek_v2, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_llama_fp8.py b/integration-tests/models/test_flash_llama_fp8.py new file mode 100644 index 00000000..fe5df590 --- /dev/null +++ b/integration-tests/models/test_flash_llama_fp8.py @@ -0,0 +1,62 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_llama_fp8_handle(launcher): + with launcher("meta-llama/Meta-Llama-3-8B", num_shard=2, quantize="fp8") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_fp8(flash_llama_fp8_handle): + await flash_llama_fp8_handle.health(300) + return flash_llama_fp8_handle.client + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_fp8(flash_llama_fp8, response_snapshot): + response = await flash_llama_fp8.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_fp8_all_params(flash_llama_fp8, response_snapshot): + response = await flash_llama_fp8.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_fp8_load(flash_llama_fp8, generate_load, response_snapshot): + responses = await generate_load( + flash_llama_fp8, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_llama_marlin_24.py b/integration-tests/models/test_flash_llama_marlin_24.py new file mode 100644 index 00000000..3eb94f02 --- /dev/null +++ b/integration-tests/models/test_flash_llama_marlin_24.py @@ -0,0 +1,66 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_llama_marlin24_handle(launcher): + with launcher( + "nm-testing/Llama-2-7b-pruned2.4-Marlin_24", quantize="marlin" + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_marlin(flash_llama_marlin24_handle): + await flash_llama_marlin24_handle.health(300) + return flash_llama_marlin24_handle.client + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot): + response = await flash_llama_marlin.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_marlin24_all_params(flash_llama_marlin, response_snapshot): + response = await flash_llama_marlin.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_marlin24_load( + flash_llama_marlin, generate_load, response_snapshot +): + responses = await generate_load( + flash_llama_marlin, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_lora_mistral.py b/integration-tests/models/test_lora_mistral.py new file mode 100644 index 00000000..ccdc1486 --- /dev/null +++ b/integration-tests/models/test_lora_mistral.py @@ -0,0 +1,134 @@ +import pytest +import requests + + +@pytest.fixture(scope="module") +def lora_mistral_handle(launcher): + with launcher( + "mistralai/Mistral-7B-v0.1", + lora_adapters=[ + "predibase/dbpedia", + "predibase/customer_support", + ], + cuda_graphs=[0], + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def lora_mistral(lora_mistral_handle): + await lora_mistral_handle.health(300) + return lora_mistral_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_lora_mistral(lora_mistral, response_snapshot): + response = await lora_mistral.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + assert response.details.generated_tokens == 10 + + +classification_prompt = """You are given the title and the body of an article below. Please determine the type of the article.\n### Title: Great White Whale\n\n### Body: Great White Whale is the debut album by the Canadian rock band Secret and Whisper. The album was in the works for about a year and was released on February 12 2008. A music video was shot in Pittsburgh for the album's first single XOXOXO. The album reached number 17 on iTunes's top 100 albums in its first week on sale.\n\n### Article Type:""" + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_lora_mistral_without_adapter(lora_mistral, response_snapshot): + response = requests.post( + f"{lora_mistral.base_url}/generate", + headers=lora_mistral.headers, + json={ + "inputs": classification_prompt, + "parameters": { + "max_new_tokens": 40, + "details": True, + }, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert ( + data["generated_text"] + == "\n\n### 1. News\n### 2. Blog\n### 3. Article\n### 4. Review\n### 5. Other\n\n\n\n\n\n\n\n\n" + ) + assert data == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_lora_mistral_with_dbpedia_adapter(lora_mistral, response_snapshot): + response = requests.post( + f"{lora_mistral.base_url}/generate", + headers=lora_mistral.headers, + json={ + "inputs": classification_prompt, + "parameters": { + "max_new_tokens": 40, + "adapter_id": "predibase/dbpedia", + "details": True, + }, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["generated_text"] == " 11" + assert data == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_lora_mistral_with_customer_support_adapter( + lora_mistral, response_snapshot +): + print(lora_mistral.base_url) + print(lora_mistral.headers) + response = requests.post( + f"{lora_mistral.base_url}/generate", + headers=lora_mistral.headers, + json={ + "inputs": "What are 3 unique words that describe you?", + "parameters": { + "max_new_tokens": 40, + "adapter_id": "predibase/customer_support", + "details": True, + }, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert ( + data["generated_text"] + == "\n\nI’m not sure if I can come up with 3 unique words that describe me, but I’ll try.\n\n1. Creative\n2. Funny\n3." + ) + assert data == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_lora_mistral_without_customer_support_adapter( + lora_mistral, response_snapshot +): + response = requests.post( + f"{lora_mistral.base_url}/generate", + headers=lora_mistral.headers, + json={ + "inputs": "What are 3 unique words that describe you?", + "parameters": { + "max_new_tokens": 40, + "details": True, + }, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert ( + data["generated_text"] + == "\n\nI’m a very passionate person. I’m very driven. I’m very determined.\n\nWhat is your favorite thing about being a teacher?\n\nI love the fact" + ) + assert data == response_snapshot diff --git a/launcher/src/main.rs b/launcher/src/main.rs index d2ca38e5..ef7b4712 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -457,6 +457,14 @@ struct Args { /// startup that will be available to callers via the `adapter_id` field in a request. #[clap(long, env)] lora_adapters: Option, + + /// Disable sending of all usage statistics + #[clap(default_value = "false", long, env)] + disable_usage_stats: bool, + + /// Disable sending of crash reports, but allow anonymous usage statistics + #[clap(default_value = "false", long, env)] + disable_crash_reports: bool, } #[derive(Debug)] @@ -1201,6 +1209,14 @@ fn spawn_webserver( args.model_id, ]; + // Pass usage stats flags to router + if args.disable_usage_stats { + router_args.push("--disable-usage-stats".to_string()); + } + if args.disable_crash_reports { + router_args.push("--disable-crash-reports".to_string()); + } + // Grammar support if args.disable_grammar_support { router_args.push("--disable-grammar-support".to_string()); diff --git a/router/Cargo.toml b/router/Cargo.toml index 5855ac86..0fc700a0 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -24,7 +24,7 @@ futures = "0.3.28" hf-hub = { workspace = true } itertools = "0.10" jsonschema = { version = "0.17.1", features = ["draft202012"] } -metrics = "0.21.1" +metrics = "0.23.0" metrics-exporter-prometheus = { version = "0.15.1", features = [] } nohash-hasher = "0.2.0" opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } @@ -52,6 +52,10 @@ regex = "1.10.3" once_cell = "1.19.0" image = "0.25.1" base64 = { workspace = true } +sysinfo = "0.30.13" +uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] } +csv = "1.3.0" + [build-dependencies] vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 49282eb9..db9070d4 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -7,7 +7,7 @@ pub(crate) use health::HealthCheck; use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; use crate::{ ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, - HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, + HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, ToolChoice, }; use crate::{ FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools, @@ -91,14 +91,14 @@ impl Infer { .limit_concurrent_requests .try_acquire_owned() .map_err(|err| { - metrics::increment_counter!("tgi_request_failure", "err" => "overloaded"); + metrics::counter!("tgi_request_failure", "err" => "overloaded").increment(1); tracing::error!("{err}"); err })?; // Validate request let valid_request = self.validation.validate(request).await.map_err(|err| { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); err })?; @@ -140,7 +140,7 @@ impl Infer { .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? .apply(messages, grammar_with_prompt) .map_err(|e| { - metrics::increment_counter!("tgi_request_failure", "err" => "template"); + metrics::counter!("tgi_request_failure", "err" => "template").increment(1); tracing::error!("{e}"); e }) @@ -214,7 +214,7 @@ impl Infer { }) } else { let err = InferError::IncompleteGeneration; - metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); + metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1); tracing::error!("{err}"); Err(err) } @@ -332,126 +332,131 @@ impl ChatTemplate { pub struct ToolGrammar {} impl ToolGrammar { + // find a tool by name + fn find_tool_by_name(tools: &[Tool], name: &str) -> Result { + tools + .iter() + .find(|tool| tool.function.name == name) + .cloned() + .ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name))) + } + pub fn apply( tools: Option>, - tool_choice: Option, + tool_choice: ToolChoice, ) -> Result, InferError> { - if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) { - // let tool_prompt = tool_prompt.unwrap_or_default(); - let tools_to_use = match tool_choice { - ToolType::FunctionName(name) => { - vec![req_tools - .iter() - .find(|tool| tool.function.name == *name) - .unwrap_or_else(|| panic!("Tool with name {} not found", name)) - .clone()] - } - ToolType::Function { function } => { - let tool = req_tools - .iter() - .find(|tool| tool.function.name == function.name) - .unwrap_or_else(|| panic!("Tool with name {} not found", function.name)) - .clone(); - vec![tool] - } - ToolType::OneOf => req_tools.to_owned(), - }; + // if no tools are provided, we return None + let tools = match tools { + Some(tools) if !tools.is_empty() => tools, + _ => return Ok(None), + }; - // adds the error notification function for LLM feedback if required - let mut text_response_properties = Map::new(); - text_response_properties.insert( - "error".to_string(), - serde_json::json!({ - "type": "string", - "description": "The error or issue to notify" - }), - ); - text_response_properties.insert( - "_name".to_string(), - serde_json::json!({ - "type": "string", - "const": "notify_error" - }), - ); + let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf); - let functions: HashMap = tools_to_use - .iter() - .map(|tool| { - let func = tool.function.clone(); + // if tools are provided and no tool_choice we default to the OneOf + let tools_to_use = match tool_choice { + ToolType::FunctionName(name) => { + vec![Self::find_tool_by_name(&tools, &name)?] + } + ToolType::Function { function } => { + vec![Self::find_tool_by_name(&tools, &function.name)?] + } + ToolType::OneOf => tools, + ToolType::NoTool => return Ok(None), + }; - // Clone the existing parameters, which are expected to be a JSON object - let mut params = if let Value::Object(params) = &func.arguments { - params.clone() - } else { - Map::new() - }; + // adds the error notification function for LLM feedback if required + let mut text_response_properties = Map::new(); + text_response_properties.insert( + "error".to_string(), + serde_json::json!({ + "type": "string", + "description": "The error or issue to notify" + }), + ); + text_response_properties.insert( + "_name".to_string(), + serde_json::json!({ + "type": "string", + "const": "notify_error" + }), + ); - // Insert the function's description at the top level, outside of properties - params.insert( - "description".to_string(), - Value::String(func.description.clone().unwrap_or_default()), - ); + let functions: HashMap = tools_to_use + .iter() + .map(|tool| { + let func = tool.function.clone(); - // Ensure 'properties' exists and is an object - let properties = params - .entry("properties".to_string()) - .or_insert_with(|| json!({})) - .as_object_mut() - .unwrap(); + // Clone the existing parameters, which are expected to be a JSON object + let mut params = if let Value::Object(params) = &func.arguments { + params.clone() + } else { + Map::new() + }; - // Insert the constant for the function name inside 'properties' - properties.insert( - "_name".to_string(), - json!({ - "type": "string", - "const": func.name.clone(), - // "description": "The name of the function" - }), - ); + // Insert the function's description at the top level, outside of properties + params.insert( + "description".to_string(), + Value::String(func.description.clone().unwrap_or_default()), + ); - // Check if 'required' exists, and it is an array. If not, create an empty array. - let required = params - .entry("required".to_string()) - .or_insert_with(|| json!([])) - .as_array_mut() - .unwrap(); + // Ensure 'properties' exists and is an object + let properties = params + .entry("properties".to_string()) + .or_insert_with(|| json!({})) + .as_object_mut() + .unwrap(); - // Add 'name' to the 'required' array if it is not already present - if !required.iter().any(|r| r == "_name") { - required.push(json!("_name")); - } - - (func.name, Value::Object(params)) - }) - .chain([( - "notify_error".to_string(), - serde_json::json!({ - "properties": text_response_properties, - "required": ["error", "_name"], - "type": "object" + // Insert the constant for the function name inside 'properties' + properties.insert( + "_name".to_string(), + json!({ + "type": "string", + "const": func.name.clone(), + // "description": "The name of the function" }), - )]) - .collect(); + ); - let tools = Tools { - functions_map: FunctionsMap { functions }, - properties: Properties { - function: tools_to_use - .iter() - .map(|tool| FunctionRef { - ref_path: format!("#/$functions/{}", tool.function.name.clone()), - }) - .chain(std::iter::once(FunctionRef { - ref_path: "#/$functions/notify_error".to_string(), - })) - .collect(), - }, - }; + // Check if 'required' exists, and it is an array. If not, create an empty array. + let required = params + .entry("required".to_string()) + .or_insert_with(|| json!([])) + .as_array_mut() + .unwrap(); - return Ok(Some(tools)); - } - // Err(InferError::ToolError("No tools provided".to_string())) - Ok(None) + // Add 'name' to the 'required' array if it is not already present + if !required.iter().any(|r| r == "_name") { + required.push(json!("_name")); + } + + (func.name, Value::Object(params)) + }) + .chain([( + "notify_error".to_string(), + serde_json::json!({ + "properties": text_response_properties, + "required": ["error", "_name"], + "type": "object" + }), + )]) + .collect(); + + let tools = Tools { + functions_map: FunctionsMap { functions }, + properties: Properties { + function: tools_to_use + .iter() + .map(|tool| FunctionRef { + ref_path: format!("#/$functions/{}", tool.function.name.clone()), + }) + .chain(std::iter::once(FunctionRef { + ref_path: "#/$functions/notify_error".to_string(), + })) + .collect(), + }, + }; + + Ok(Some(tools)) } } diff --git a/router/src/infer/v2/queue.rs b/router/src/infer/v2/queue.rs index 93cf9469..0b51645a 100644 --- a/router/src/infer/v2/queue.rs +++ b/router/src/infer/v2/queue.rs @@ -111,7 +111,7 @@ async fn queue_task( match cmd { QueueCommand::Append(entry, span) => { span.in_scope(|| state.append(*entry)); - metrics::increment_gauge!("tgi_queue_size", 1.0); + metrics::gauge!("tgi_queue_size").increment(1.0); } QueueCommand::NextBatch { min_size, @@ -124,7 +124,7 @@ async fn queue_task( let next_batch = state.next_batch(min_size, max_size, prefill_token_budget, token_budget); response_sender.send(next_batch).unwrap(); - metrics::gauge!("tgi_queue_size", state.entries.len() as f64); + metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); }), } } @@ -226,7 +226,7 @@ impl State { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_closed() { - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); tracing::debug!("Dropping entry"); continue; } @@ -336,7 +336,7 @@ impl State { // Increment batch id self.next_batch_id += 1; - metrics::histogram!("tgi_batch_next_size", batch.size as f64); + metrics::histogram!("tgi_batch_next_size").record(batch.size as f64); Some((batch_entries, batch, next_batch_span)) } diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index e4c3de26..97379bc5 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -148,8 +148,8 @@ pub(crate) async fn batching_task( let batch_size = batch.size; let batch_max_tokens = batch.max_tokens; let mut batches = vec![batch]; - metrics::gauge!("tgi_batch_current_size", batch_size as f64); - metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64); + metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); + metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); let min_size = if waiting_tokens >= max_waiting_tokens { // If we didn't onboard any new requests since >= max_waiting_tokens, we try @@ -170,9 +170,11 @@ pub(crate) async fn batching_task( { // Tracking metrics if min_size.is_some() { - metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure"); + metrics::counter!("tgi_batch_concat", "reason" => "backpressure") + .increment(1); } else { - metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded"); + metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") + .increment(1); } entries.iter_mut().for_each(|(_, entry)| { @@ -219,8 +221,8 @@ pub(crate) async fn batching_task( .await; waiting_tokens += 1; } - metrics::gauge!("tgi_batch_current_size", 0.0); - metrics::gauge!("tgi_batch_current_max_tokens", 0.0); + metrics::gauge!("tgi_batch_current_size").set(0.0); + metrics::gauge!("tgi_batch_current_max_tokens").set(0.0); } } } @@ -234,7 +236,7 @@ async fn prefill( ) -> Option { let start_time = Instant::now(); let batch_id = batch.id; - metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill"); + metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); match client.prefill(batch).await { Ok((generations, next_batch, timings)) => { @@ -248,11 +250,15 @@ async fn prefill( // Filter next batch and remove requests that were stopped let next_batch = filter_batch(client, next_batch, entries).await; - metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); - metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); + metrics::histogram!("tgi_batch_forward_duration","method" => "prefill") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration","method" => "prefill") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1); next_batch } // If we have an error, we discard the whole batch @@ -261,7 +267,7 @@ async fn prefill( generation_health.store(false, Ordering::SeqCst); let _ = client.clear_cache(Some(batch_id)).await; send_errors(err, entries); - metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); + metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1); None } } @@ -276,7 +282,7 @@ async fn decode( ) -> Option { let start_time = Instant::now(); let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); - metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); + metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); match client.decode(batches).await { Ok((generations, next_batch, timings)) => { @@ -291,13 +297,18 @@ async fn decode( let next_batch = filter_batch(client, next_batch, entries).await; if let Some(concat_duration) = timings.concat { - metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode"); + metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") + .record(concat_duration.as_secs_f64()); } - metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); - metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); + metrics::histogram!("tgi_batch_forward_duration", "method" => "decode") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "decode") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "decode") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration", "method" => "decode") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1); next_batch } // If we have an error, we discard the whole batch @@ -307,7 +318,7 @@ async fn decode( let _ = client.clear_cache(Some(id)).await; } send_errors(err, entries); - metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode"); + metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1); None } } @@ -365,7 +376,7 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); err }).unwrap_or(true); if stopped { @@ -381,7 +392,7 @@ fn send_responses( ) -> Result>>> { // Return directly if the channel is disconnected if entry.response_tx.is_closed() { - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); return Ok(true); } @@ -407,7 +418,7 @@ fn send_responses( // Create last Token let tokens_ = generation.tokens.expect("Non empty tokens in generation"); let n = tokens_.ids.len(); - metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64); + metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64); let mut iterator = tokens_ .ids .into_iter() @@ -472,7 +483,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { // Create and enter a span to link this function back to the entry let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); let err = InferError::GenerationError(error.to_string()); - metrics::increment_counter!("tgi_request_failure", "err" => "generation"); + metrics::counter!("tgi_request_failure", "err" => "generation").increment(1); tracing::error!("{err}"); // unwrap_or is valid here as we don't care if the receiver is gone. diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index ba65b9b6..894d9cab 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -126,7 +126,7 @@ async fn queue_task( match cmd { QueueCommand::Append(entry, span) => { span.in_scope(|| state.append(*entry)); - metrics::increment_gauge!("tgi_queue_size", 1.0); + metrics::gauge!("tgi_queue_size").increment(1.0); } QueueCommand::NextBatch { min_size, @@ -141,7 +141,7 @@ async fn queue_task( .instrument(span) .await; response_sender.send(next_batch).unwrap(); - metrics::gauge!("tgi_queue_size", state.entries.len() as f64); + metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); } } } @@ -248,7 +248,7 @@ impl State { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_closed() { - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); tracing::debug!("Dropping entry"); continue; } @@ -399,7 +399,7 @@ impl State { // Increment batch id self.next_batch_id += 1; - metrics::histogram!("tgi_batch_next_size", batch.size as f64); + metrics::histogram!("tgi_batch_next_size").record(batch.size as f64); Some((batch_entries, batch, next_batch_span)) } diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index 543ce89f..26cd9584 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -154,8 +154,8 @@ pub(crate) async fn batching_task( let batch_size = batch.size; let batch_max_tokens = batch.max_tokens; let mut batches = vec![batch]; - metrics::gauge!("tgi_batch_current_size", batch_size as f64); - metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64); + metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); + metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); let min_size = if waiting_tokens >= max_waiting_tokens { // If we didn't onboard any new requests since >= max_waiting_tokens, we try @@ -176,9 +176,11 @@ pub(crate) async fn batching_task( { // Tracking metrics if min_size.is_some() { - metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure"); + metrics::counter!("tgi_batch_concat", "reason" => "backpressure") + .increment(1); } else { - metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded"); + metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") + .increment(1); } entries.iter_mut().for_each(|(_, entry)| { @@ -225,8 +227,8 @@ pub(crate) async fn batching_task( .await; waiting_tokens += 1; } - metrics::gauge!("tgi_batch_current_size", 0.0); - metrics::gauge!("tgi_batch_current_max_tokens", 0.0); + metrics::gauge!("tgi_batch_current_size").set(0.0); + metrics::gauge!("tgi_batch_current_max_tokens").set(0.0); } } } @@ -240,7 +242,7 @@ async fn prefill( ) -> Option { let start_time = Instant::now(); let batch_id = batch.id; - metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill"); + metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); match client.prefill(batch).await { Ok((generations, next_batch, timings)) => { @@ -254,11 +256,15 @@ async fn prefill( // Filter next batch and remove requests that were stopped let next_batch = filter_batch(client, next_batch, entries).await; - metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); - metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); + metrics::histogram!("tgi_batch_forward_duration","method" => "prefill") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1); next_batch } // If we have an error, we discard the whole batch @@ -267,7 +273,7 @@ async fn prefill( generation_health.store(false, Ordering::SeqCst); let _ = client.clear_cache(Some(batch_id)).await; send_errors(err, entries); - metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); + metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1); None } } @@ -282,7 +288,7 @@ async fn decode( ) -> Option { let start_time = Instant::now(); let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); - metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); + metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); match client.decode(batches).await { Ok((generations, next_batch, timings)) => { @@ -297,13 +303,18 @@ async fn decode( let next_batch = filter_batch(client, next_batch, entries).await; if let Some(concat_duration) = timings.concat { - metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode"); + metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") + .record(concat_duration.as_secs_f64()); } - metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); - metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); + metrics::histogram!("tgi_batch_forward_duration", "method" => "decode") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "decode") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "decode") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration", "method" => "decode") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1); next_batch } // If we have an error, we discard the whole batch @@ -313,7 +324,7 @@ async fn decode( let _ = client.clear_cache(Some(id)).await; } send_errors(err, entries); - metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode"); + metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1); None } } @@ -371,7 +382,7 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); err }).unwrap_or(true); if stopped { @@ -387,7 +398,7 @@ fn send_responses( ) -> Result>>> { // Return directly if the channel is disconnected if entry.response_tx.is_closed() { - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); return Ok(true); } @@ -413,7 +424,7 @@ fn send_responses( // Create last Token let tokens_ = generation.tokens.expect("Non empty tokens in generation"); let n = tokens_.ids.len(); - metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64); + metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64); let mut iterator = tokens_ .ids .into_iter() @@ -478,7 +489,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { // Create and enter a span to link this function back to the entry let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); let err = InferError::GenerationError(error.to_string()); - metrics::increment_counter!("tgi_request_failure", "err" => "generation"); + metrics::counter!("tgi_request_failure", "err" => "generation").increment(1); tracing::error!("{err}"); // unwrap_or is valid here as we don't care if the receiver is gone. diff --git a/router/src/lib.rs b/router/src/lib.rs index 165b2ad2..b6e0d09d 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -7,6 +7,8 @@ mod validation; #[cfg(feature = "kserve")] mod kserve; +pub mod usage_stats; + use serde::{Deserialize, Serialize}; use tracing::warn; use utoipa::ToSchema; @@ -40,13 +42,13 @@ pub struct HubModelInfo { pub pipeline_tag: Option, } -#[derive(Debug, Clone, Deserialize, PartialEq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ChatTemplate { name: String, template: String, } -#[derive(Debug, Clone, Deserialize, PartialEq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(untagged)] pub enum ChatTemplateVersions { Single(String), @@ -55,7 +57,7 @@ pub enum ChatTemplateVersions { use std::path::Path; -#[derive(Debug, Clone, Deserialize, Default)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct HubTokenizerConfig { pub chat_template: Option, pub completion_template: Option, @@ -384,7 +386,7 @@ pub struct CompletionRequest { /// UNUSED #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] /// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. - pub model: String, + pub model: Option, /// The prompt to generate completions for. #[schema(example = "What is Deep Learning?")] @@ -731,7 +733,7 @@ impl ChatCompletionChunk { pub(crate) struct ChatRequest { #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] /// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. - pub model: String, + pub model: Option, /// A list of messages comprising the conversation so far. #[schema(example = "[{\"role\": \"user\", \"content\": \"What is Deep Learning?\"}]")] @@ -824,7 +826,7 @@ pub(crate) struct ChatRequest { /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. #[serde(default)] #[schema(nullable = true, example = "null")] - pub tool_choice: Option, + pub tool_choice: ToolChoice, /// Response format constraints for the generation. /// @@ -846,34 +848,34 @@ pub enum ToolType { OneOf, FunctionName(String), Function { function: FunctionName }, + NoTool, } -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)] pub struct FunctionName { pub name: String, } -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default, ToSchema)] #[serde(from = "ToolTypeDeserializer")] pub struct ToolChoice(pub Option); #[derive(Deserialize)] #[serde(untagged)] enum ToolTypeDeserializer { - None(Option), - Some(ToolType), + String(String), + ToolType(ToolType), } impl From for ToolChoice { fn from(value: ToolTypeDeserializer) -> Self { match value { - ToolTypeDeserializer::None(opt) => match opt.as_deref() { - Some("none") => ToolChoice(None), - Some("auto") => ToolChoice(Some(ToolType::OneOf)), - Some(s) => ToolChoice(Some(ToolType::FunctionName(s.to_string()))), - None => ToolChoice(Some(ToolType::OneOf)), + ToolTypeDeserializer::String(s) => match s.as_str() { + "none" => ToolChoice(Some(ToolType::NoTool)), + "auto" => ToolChoice(Some(ToolType::OneOf)), + _ => ToolChoice(Some(ToolType::FunctionName(s))), }, - ToolTypeDeserializer::Some(tool_type) => ToolChoice(Some(tool_type)), + ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)), } } } diff --git a/router/src/main.rs b/router/src/main.rs index 21cd6649..bfc77913 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -14,6 +14,7 @@ use std::io::BufReader; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::{Path, PathBuf}; use text_generation_router::config::Config; +use text_generation_router::usage_stats; use text_generation_router::{ server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig, }; @@ -87,6 +88,10 @@ struct Args { disable_grammar_support: bool, #[clap(default_value = "4", long, env)] max_client_batch_size: usize, + #[clap(long, env, default_value_t)] + disable_usage_stats: bool, + #[clap(long, env, default_value_t)] + disable_crash_reports: bool, } #[derive(Debug, Subcommand)] @@ -128,6 +133,8 @@ async fn main() -> Result<(), RouterError> { messages_api_enabled, disable_grammar_support, max_client_batch_size, + disable_usage_stats, + disable_crash_reports, command, } = args; @@ -210,7 +217,11 @@ async fn main() -> Result<(), RouterError> { } let api = if use_api { if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) { - let cache = Cache::default(); + let cache = std::env::var("HUGGINGFACE_HUB_CACHE") + .map_err(|_| ()) + .map(|cache_dir| Cache::new(cache_dir.into())) + .unwrap_or_else(|_| Cache::default()); + tracing::warn!("Offline mode active using cache defaults"); Type::Cache(cache) } else { @@ -320,6 +331,7 @@ async fn main() -> Result<(), RouterError> { tracing::warn!("Could not find tokenizer config locally and no API specified"); HubTokenizerConfig::default() }); + let tokenizer_class = tokenizer_config.tokenizer_class.clone(); let tokenizer: Option = tokenizer_filename.and_then(|filename| { let mut tokenizer = Tokenizer::from_file(filename).ok(); @@ -374,8 +386,47 @@ async fn main() -> Result<(), RouterError> { } }; + // Only send usage stats when TGI is run in container and the function returns Some + let is_container = matches!(usage_stats::is_container(), Ok(true)); + + let user_agent = if !disable_usage_stats && is_container { + let reduced_args = usage_stats::Args::new( + config.clone(), + tokenizer_class, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + revision, + validation_workers, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + disable_usage_stats, + disable_crash_reports, + ); + Some(usage_stats::UserAgent::new(reduced_args)) + } else { + None + }; + + if let Some(ref ua) = user_agent { + let start_event = + usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None); + tokio::spawn(async move { + start_event.send().await; + }); + }; + // Run server - server::run( + let result = server::run( master_shard_uds_path, model_info, compat_return_full_text, @@ -406,8 +457,41 @@ async fn main() -> Result<(), RouterError> { max_client_batch_size, print_schema_command, ) - .await?; - Ok(()) + .await; + + match result { + Ok(_) => { + if let Some(ref ua) = user_agent { + let stop_event = usage_stats::UsageStatsEvent::new( + ua.clone(), + usage_stats::EventType::Stop, + None, + ); + stop_event.send().await; + }; + Ok(()) + } + Err(e) => { + if let Some(ref ua) = user_agent { + if !disable_crash_reports { + let error_event = usage_stats::UsageStatsEvent::new( + ua.clone(), + usage_stats::EventType::Error, + Some(e.to_string()), + ); + error_event.send().await; + } else { + let unknow_error_event = usage_stats::UsageStatsEvent::new( + ua.clone(), + usage_stats::EventType::Error, + Some("unknow_error".to_string()), + ); + unknow_error_event.send().await; + } + }; + Err(RouterError::WebServer(e)) + } + } } /// Init logging using env variables LOG_LEVEL and LOG_FORMAT: diff --git a/router/src/server.rs b/router/src/server.rs index db8b16ad..c56c39a3 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -11,10 +11,11 @@ use crate::kserve::{ }; use crate::validation::ValidationError; use crate::{ - BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, - GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, - Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, - Usage, Validation, + BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, + GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, + HubTokenizerConfig, Info, Message, MessageChunk, MessageContent, OutputMessage, PrefillToken, + SimpleToken, StreamDetails, StreamResponse, TextMessage, Token, TokenizeResponse, + ToolCallDelta, ToolCallMessage, Url, Usage, Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, @@ -23,7 +24,7 @@ use crate::{ CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest, VertexResponse, }; -use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolType}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -185,7 +186,7 @@ pub(crate) async fn generate_internal( span: tracing::Span, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let start_time = Instant::now(); - metrics::increment_counter!("tgi_request_count"); + metrics::counter!("tgi_request_count").increment(1); // Do not long ultra long inputs, like image payloads. tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]); @@ -301,25 +302,15 @@ pub(crate) async fn generate_internal( ); // Metrics - metrics::increment_counter!("tgi_request_success"); - metrics::histogram!("tgi_request_duration", total_time.as_secs_f64()); - metrics::histogram!( - "tgi_request_validation_duration", - validation_time.as_secs_f64() - ); - metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64()); - metrics::histogram!( - "tgi_request_inference_duration", - inference_time.as_secs_f64() - ); - metrics::histogram!( - "tgi_request_mean_time_per_token_duration", - time_per_token.as_secs_f64() - ); - metrics::histogram!( - "tgi_request_generated_tokens", - response.generated_text.generated_tokens as f64 - ); + metrics::counter!("tgi_request_success").increment(1); + metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64()); + metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64()); + metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64()); + metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64()); + metrics::histogram!("tgi_request_mean_time_per_token_duration") + .record(time_per_token.as_secs_f64()); + metrics::histogram!("tgi_request_generated_tokens") + .record(response.generated_text.generated_tokens as f64); // Send response let mut output_text = response.generated_text.text; @@ -399,7 +390,7 @@ async fn generate_stream_internal( span: tracing::Span, ) -> (HeaderMap, impl Stream>) { let start_time = Instant::now(); - metrics::increment_counter!("tgi_request_count"); + metrics::counter!("tgi_request_count").increment(1); tracing::debug!("Input: {}", req.inputs); @@ -427,12 +418,12 @@ async fn generate_stream_internal( let best_of = req.parameters.best_of.unwrap_or(1); if best_of != 1 { let err = InferError::from(ValidationError::BestOfStream); - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); yield Ok(Event::from(err)); } else if req.parameters.decoder_input_details { let err = InferError::from(ValidationError::PrefillDetailsStream); - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); yield Ok(Event::from(err)); } else { @@ -500,13 +491,13 @@ async fn generate_stream_internal( span.record("seed", format!("{:?}", generated_text.seed)); // Metrics - metrics::increment_counter!("tgi_request_success"); - metrics::histogram!("tgi_request_duration", total_time.as_secs_f64()); - metrics::histogram!("tgi_request_validation_duration", validation_time.as_secs_f64()); - metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64()); - metrics::histogram!("tgi_request_inference_duration", inference_time.as_secs_f64()); - metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token.as_secs_f64()); - metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64); + metrics::counter!("tgi_request_success").increment(1); + metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64()); + metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64()); + metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64()); + metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64()); + metrics::histogram!("tgi_request_mean_time_per_token_duration").record(time_per_token.as_secs_f64()); + metrics::histogram!("tgi_request_generated_tokens").record(generated_text.generated_tokens as f64); // StreamResponse end_reached = true; @@ -553,7 +544,7 @@ async fn generate_stream_internal( // Skip if we already sent an error if !end_reached && !error { let err = InferError::IncompleteGeneration; - metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); + metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1); tracing::error!("{err}"); yield Ok(Event::from(err)); } @@ -572,8 +563,8 @@ request_body = CompletionRequest, responses( (status = 200, description = "Generated Chat Completion", content( -("application/json" = Completion), -("text/event-stream" = CompletionCompleteChunk), +("application/json" = CompletionFinal), +("text/event-stream" = Chunk), )), (status = 424, description = "Generation Error", body = ErrorResponse, example = json ! ({"error": "Request failed during generation"})), @@ -604,9 +595,10 @@ async fn completions( Json(req): Json, ) -> Result)> { let span = tracing::Span::current(); - metrics::increment_counter!("tgi_request_count"); + metrics::counter!("tgi_request_count").increment(1); let CompletionRequest { + model, max_tokens, seed, stop, @@ -625,7 +617,7 @@ async fn completions( // if suffix is present throw an error if req.suffix.is_some() { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); return Err(( StatusCode::UNPROCESSABLE_ENTITY, Json(ErrorResponse { @@ -637,7 +629,7 @@ async fn completions( } if req.prompt.0.len() > info.max_client_batch_size { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); return Err(( StatusCode::UNPROCESSABLE_ENTITY, Json(ErrorResponse { @@ -675,7 +667,7 @@ async fn completions( seed, top_n_tokens: None, grammar: None, - ..Default::default() + adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from), }, }) .collect(); @@ -820,6 +812,10 @@ async fn completions( } }; + let stream = stream.chain(futures::stream::once(async { + Ok(Event::default().data("[DONE]")) + })); + let sse = Sse::new(stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) } else { @@ -1009,8 +1005,9 @@ async fn chat_completions( Json(req): Json, ) -> Result)> { let span = tracing::Span::current(); - metrics::increment_counter!("tgi_request_count"); + metrics::counter!("tgi_request_count").increment(1); let ChatRequest { + model, logprobs, max_tokens, messages, @@ -1039,7 +1036,7 @@ async fn chat_completions( // response_format and tools are mutually exclusive if response_format.is_some() && tools.as_ref().is_some() { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); return Err(( StatusCode::UNPROCESSABLE_ENTITY, Json(ErrorResponse { @@ -1053,7 +1050,7 @@ async fn chat_completions( let tool_grammar = match ToolGrammar::apply(tools, tool_choice) { Ok(grammar) => grammar, Err(err) => { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); return Err(( StatusCode::UNPROCESSABLE_ENTITY, @@ -1082,7 +1079,7 @@ async fn chat_completions( let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) { Ok(inputs) => inputs, Err(err) => { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); return Err(( StatusCode::UNPROCESSABLE_ENTITY, @@ -1116,7 +1113,7 @@ async fn chat_completions( seed, top_n_tokens: req.top_logprobs, grammar, - ..Default::default() + adapter_id: model.filter(|m| *m != "tgi").map(String::from), }, }; @@ -1178,6 +1175,11 @@ async fn chat_completions( span, ) .await; + + let response_stream = response_stream.chain(futures::stream::once(async { + Ok(Event::default().data("[DONE]")) + })); + let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) } else { @@ -1190,39 +1192,33 @@ async fn chat_completions( .as_secs(); let (tool_calls, output) = if tool_grammar.is_some() { - // gen_text should be valid json - let gen_text_value: Value = - serde_json::from_str(&generation.generated_text).map_err(|e| { - ( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: e.to_string(), - error_type: "Input validation error".to_string(), - }), - ) - })?; + let gen_text_value: Value = serde_json::from_str(&generation.generated_text) + .map_err(|e| InferError::ToolError(e.to_string()))?; + + let function = gen_text_value.get("function").ok_or(InferError::ToolError( + "No function found in generated text".to_string(), + ))?; + + let name = function + .get("_name") + .and_then(Value::as_str) + .ok_or(InferError::ToolError( + "No _name found in generated text".to_string(), + ))? + .to_string(); + + let mut arguments = function.clone(); + if let Value::Object(ref mut props) = arguments { + props.remove("_name"); + } + let tool_calls = vec![ToolCall { id: "0".to_string(), r#type: "function".to_string(), function: FunctionDefinition { description: None, - name: gen_text_value - .get("function") - .and_then(|f| f.get("_name")) - .and_then(|name| name.as_str()) - .unwrap_or("default_function_name") - .to_string(), - // Serialize the JSON object obtained from "function" to an escaped JSON string - arguments: gen_text_value - .get("function") - .map(|f| { - let mut f_cloned = f.clone(); - if let Value::Object(ref mut props) = f_cloned { - props.remove("_name"); - } - f_cloned - }) - .unwrap_or_default(), + name, + arguments, }, }]; (Some(tool_calls), None) @@ -1280,7 +1276,7 @@ async fn vertex_compatibility( Json(req): Json, ) -> Result)> { let span = tracing::Span::current(); - metrics::increment_counter!("tgi_request_count"); + metrics::counter!("tgi_request_count").increment(1); // check that theres at least one instance if req.instances.is_empty() { @@ -1454,6 +1450,14 @@ pub async fn run( GrammarType, ChatRequest, Message, + MessageContent, + MessageChunk, + Url, + FunctionName, + OutputMessage, + TextMessage, + ToolCallMessage, + ToolCallDelta, ChatCompletionComplete, ChatCompletionChoice, ChatCompletionDelta, @@ -1488,6 +1492,7 @@ pub async fn run( ToolCall, Function, FunctionDefinition, + ToolChoice, ) ), tags( diff --git a/router/src/usage_stats.rs b/router/src/usage_stats.rs new file mode 100644 index 00000000..8559ae90 --- /dev/null +++ b/router/src/usage_stats.rs @@ -0,0 +1,355 @@ +use crate::config::Config; +use csv::ReaderBuilder; +use reqwest::header::HeaderMap; +use serde::Serialize; +use std::{ + fs::File, + io::{self, BufRead}, + path::Path, + process::Command, + time::Duration, +}; +use uuid::Uuid; + +const TELEMETRY_URL: &str = "https://huggingface.co/api/telemetry/tgi"; + +#[derive(Debug, Clone, Serialize)] +pub struct UserAgent { + pub uid: String, + pub args: Args, + pub env: Env, +} + +impl UserAgent { + pub fn new(reduced_args: Args) -> Self { + Self { + uid: Uuid::new_v4().to_string(), + args: reduced_args, + env: Env::new(), + } + } +} + +#[derive(Serialize, Debug)] +pub enum EventType { + Start, + Stop, + Error, +} + +#[derive(Debug, Serialize)] +pub struct UsageStatsEvent { + user_agent: UserAgent, + event_type: EventType, + #[serde(skip_serializing_if = "Option::is_none")] + error_reason: Option, +} + +impl UsageStatsEvent { + pub fn new(user_agent: UserAgent, event_type: EventType, error_reason: Option) -> Self { + Self { + user_agent, + event_type, + error_reason, + } + } + pub async fn send(&self) { + let mut headers = HeaderMap::new(); + headers.insert("Content-Type", "application/json".parse().unwrap()); + let body = serde_json::to_string(&self).unwrap(); + let client = reqwest::Client::new(); + let _ = client + .post(TELEMETRY_URL) + .headers(headers) + .body(body) + .timeout(Duration::from_secs(5)) + .send() + .await; + } +} + +#[derive(Debug, Clone, Serialize)] +pub struct Args { + model_config: Option, + tokenizer_config: Option, + max_concurrent_requests: usize, + max_best_of: usize, + max_stop_sequences: usize, + max_top_n_tokens: u32, + max_input_tokens: usize, + max_total_tokens: usize, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: Option, + max_waiting_tokens: usize, + max_batch_size: Option, + revision: Option, + validation_workers: usize, + messages_api_enabled: bool, + disable_grammar_support: bool, + max_client_batch_size: usize, + disable_usage_stats: bool, + disable_crash_reports: bool, +} + +impl Args { + #[allow(clippy::too_many_arguments)] + pub fn new( + model_config: Option, + tokenizer_config: Option, + max_concurrent_requests: usize, + max_best_of: usize, + max_stop_sequences: usize, + max_top_n_tokens: u32, + max_input_tokens: usize, + max_total_tokens: usize, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: Option, + max_waiting_tokens: usize, + max_batch_size: Option, + revision: Option, + validation_workers: usize, + messages_api_enabled: bool, + disable_grammar_support: bool, + max_client_batch_size: usize, + disable_usage_stats: bool, + disable_crash_reports: bool, + ) -> Self { + Self { + model_config, + tokenizer_config, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + revision, + validation_workers, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + disable_usage_stats, + disable_crash_reports, + } + } +} + +/// This is more or less a copy of the code from the `text-generation-launcher` crate to avoid a dependency +#[derive(Serialize, Debug, Clone)] +pub struct Env { + git_sha: &'static str, + docker_label: &'static str, + nvidia_info: Option>, + xpu_info: Option>, + system_env: SystemInfo, +} + +#[derive(Debug, Serialize, Clone)] +struct NvidiaSmiInfo { + name: String, + pci_bus_id: String, + driver_version: String, + pstate: String, + pcie_link_gen_max: String, + pcie_link_gen_current: String, + temperature_gpu: String, + utilization_gpu: String, + utilization_memory: String, + memory_total: String, + memory_free: String, + memory_used: String, + reset_status_reset_required: String, + reset_status_drain_and_reset_recommended: String, + compute_cap: String, + ecc_errors_corrected_volatile_total: String, + mig_mode_current: String, + power_draw_instant: String, + power_limit: String, +} + +impl NvidiaSmiInfo { + fn new() -> Option> { + let output = Command::new("nvidia-smi") + .args([ + "--query-gpu=name,pci.bus_id,driver_version,pstate,pcie.link.gen.max,pcie.link.gen.gpucurrent,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,reset_status.reset_required,reset_status.drain_and_reset_recommended,compute_cap,ecc.errors.corrected.volatile.total,mig.mode.current,power.draw.instant,power.limit", + "--format=csv" + ]) + .output() + .ok()?; + + if !output.status.success() { + return None; + } + + let stdout = String::from_utf8(output.stdout).ok()?; + + let mut rdr = ReaderBuilder::new() + .has_headers(true) + .from_reader(stdout.as_bytes()); + + let mut infos = Vec::new(); + + for result in rdr.records() { + let record = result.ok()?; + infos.push(NvidiaSmiInfo { + name: record[0].to_string(), + pci_bus_id: record[1].to_string(), + driver_version: record[2].to_string(), + pstate: record[3].to_string(), + pcie_link_gen_max: record[4].to_string(), + pcie_link_gen_current: record[5].to_string(), + temperature_gpu: record[6].to_string(), + utilization_gpu: record[7].to_string(), + utilization_memory: record[8].to_string(), + memory_total: record[9].to_string(), + memory_free: record[10].to_string(), + memory_used: record[11].to_string(), + reset_status_reset_required: record[12].to_string(), + reset_status_drain_and_reset_recommended: record[13].to_string(), + compute_cap: record[14].to_string(), + ecc_errors_corrected_volatile_total: record[15].to_string(), + mig_mode_current: record[16].to_string(), + power_draw_instant: record[17].to_string(), + power_limit: record[18].to_string(), + }); + } + + Some(infos) + } +} + +#[derive(Debug, Serialize, Clone)] +struct XpuSmiInfo { + device_id: usize, + gpu_utilization: f32, + gpu_power: f32, + gpu_core_temperature: f32, + gpu_memory_bandwidth_utilization: f32, +} + +impl XpuSmiInfo { + /// based on this https://github.com/intel/xpumanager/blob/master/doc/smi_user_guide.md#dump-the-device-statistics-in-csv-format + fn new() -> Option> { + let output = Command::new("xpu-smi") + .args([ + "dump", "-d", "-1", "-m", + "0,1,3,17", // Metrics IDs: GPU Utilization, GPU Power, GPU Core Temperature, GPU Memory Bandwidth Utilization + "-n", "1", "-j", + ]) + .output() + .ok()?; + + if !output.status.success() { + return None; + } + + let stdout = String::from_utf8(output.stdout).ok()?; + let mut infos = Vec::new(); + + let json_data: serde_json::Value = match serde_json::from_str(&stdout) { + Ok(data) => data, + Err(_) => return None, + }; + + if let Some(metrics_data) = json_data.as_array() { + for entry in metrics_data { + let device_id = entry["deviceId"].as_u64()? as usize; + let gpu_utilization = entry["metrics"][0].as_f64()? as f32; + let gpu_power = entry["metrics"][1].as_f64()? as f32; + let gpu_core_temperature = entry["metrics"][2].as_f64()? as f32; + let gpu_memory_bandwidth_utilization = entry["metrics"][3].as_f64()? as f32; + + infos.push(XpuSmiInfo { + device_id, + gpu_utilization, + gpu_power, + gpu_core_temperature, + gpu_memory_bandwidth_utilization, + }); + } + } + + Some(infos) + } +} + +#[derive(Serialize, Debug, Clone)] +pub struct SystemInfo { + cpu_count: usize, + cpu_type: String, + total_memory: u64, + architecture: String, + platform: String, +} + +impl SystemInfo { + fn new() -> Self { + let mut system = sysinfo::System::new_all(); + system.refresh_all(); + + let cpu_count = system.cpus().len(); + let cpu_type = system.cpus()[0].brand().to_string(); + let total_memory = system.total_memory(); + let architecture = std::env::consts::ARCH.to_string(); + let platform = format!( + "{}-{}-{}", + std::env::consts::OS, + std::env::consts::FAMILY, + std::env::consts::ARCH + ); + Self { + cpu_count, + cpu_type, + total_memory, + architecture, + platform, + } + } +} + +impl Default for Env { + fn default() -> Self { + Self::new() + } +} + +impl Env { + pub fn new() -> Self { + Self { + system_env: SystemInfo::new(), + nvidia_info: NvidiaSmiInfo::new(), + xpu_info: XpuSmiInfo::new(), + git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"), + docker_label: option_env!("DOCKER_LABEL").unwrap_or("N/A"), + } + } +} + +pub fn is_container() -> io::Result { + let path = Path::new("/proc/self/cgroup"); + let file = File::open(path)?; + let reader = io::BufReader::new(file); + + for line in reader.lines() { + let line = line?; + // Check for common container runtimes + if line.contains("/docker/") + || line.contains("/docker-") + || line.contains("/kubepods/") + || line.contains("/kubepods-") + || line.contains("containerd") + || line.contains("crio") + || line.contains("podman") + { + return Ok(true); + } + } + Ok(false) +} diff --git a/router/src/validation.rs b/router/src/validation.rs index 12cf2ab3..07ad14c9 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -157,7 +157,7 @@ impl Validation { )); } - metrics::histogram!("tgi_request_input_length", input_length as f64); + metrics::histogram!("tgi_request_input_length").record(input_length as f64); Ok((inputs, input_length, max_new_tokens)) } // Return inputs without validation @@ -384,7 +384,7 @@ impl Validation { ignore_eos_token: false, }; - metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64); + metrics::histogram!("tgi_request_max_new_tokens").record(max_new_tokens as f64); Ok(ValidGenerateRequest { inputs, diff --git a/server/Makefile b/server/Makefile index 0099c56a..209fc44e 100644 --- a/server/Makefile +++ b/server/Makefile @@ -5,6 +5,7 @@ include Makefile-awq include Makefile-eetq include Makefile-selective-scan include Makefile-lorax-punica +include Makefile-fbgemm unit-tests: pytest -s -vv -m "not private" tests @@ -21,13 +22,15 @@ gen-server: install-server: gen-server pip install pip --upgrade pip install -r requirements_cuda.txt - pip install -e ".[bnb, accelerate, quantize, peft, outlines]" + pip install -e ".[accelerate, quantize, peft, outlines]" install: install-cuda echo "Installed server" -install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention +install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm + pip install -e ".[bnb]" + pip install nvidia-nccl-cu12==2.22.3 install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm @@ -37,3 +40,4 @@ run-dev: export-requirements: poetry export -o requirements_cuda.txt --without-hashes poetry export -o requirements_rocm.txt --without-hashes + poetry export -o requirements_intel.txt --without-hashes diff --git a/server/Makefile-fbgemm b/server/Makefile-fbgemm new file mode 100644 index 00000000..38f8f31f --- /dev/null +++ b/server/Makefile-fbgemm @@ -0,0 +1,15 @@ +fbgemm_commit := 9cf0429b726931cfab72b8264730bea682f32fca + +build-fbgemm: + chmod +x fix_torch90a.sh && ./fix_torch90a.sh && \ + git clone https://github.com/pytorch/FBGEMM.git fbgemm && \ + cp fbgemm_remove_unused.patch fbgemm && \ + cd fbgemm && git fetch && git checkout $(fbgemm_commit) && git apply fbgemm_remove_unused.patch && \ + git submodule update --init --recursive && \ + cd fbgemm_gpu && \ + pip install -r requirements.txt && \ + CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai build + +install-fbgemm: build-fbgemm + cd fbgemm/fbgemm_gpu && \ + CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai install diff --git a/server/Makefile-vllm b/server/Makefile-vllm index 2f2b5ef6..f1f80529 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,14 +1,14 @@ -commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa +commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921 build-vllm-cuda: if [ ! -d 'vllm' ]; then \ pip install -U ninja packaging --no-cache-dir && \ git clone https://github.com/Narsil/vllm.git vllm; \ fi - cd vllm && git fetch && git checkout $(commit_cuda) && python setup.py build + cd vllm && git fetch origin && git checkout $(commit_cuda) && python setup.py build install-vllm-cuda: build-vllm-cuda - cd vllm && git fetch && git checkout $(commit_cuda) && pip install -e . + cd vllm && git fetch origin && git checkout $(commit_cuda) && pip install -e . build-vllm-rocm: if [ ! -d 'vllm' ]; then \ diff --git a/server/fbgemm_remove_unused.patch b/server/fbgemm_remove_unused.patch new file mode 100644 index 00000000..ad6af811 --- /dev/null +++ b/server/fbgemm_remove_unused.patch @@ -0,0 +1,306 @@ +diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt +index 2244ea6f..96265a48 100644 +--- a/fbgemm_gpu/CMakeLists.txt ++++ b/fbgemm_gpu/CMakeLists.txt +@@ -94,14 +94,14 @@ endif() + # Build Experimental Modules + ################################################################################ + +-if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM) +- # TODO: Figure out NCCL/RCCL integration with ROCm +- add_subdirectory(experimental/example) +-endif() +- +-if(NOT FBGEMM_CPU_ONLY) +- add_subdirectory(experimental/gemm) +-endif() ++# if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM) ++# # TODO: Figure out NCCL/RCCL integration with ROCm ++# add_subdirectory(experimental/example) ++# endif() ++ ++# if(NOT FBGEMM_CPU_ONLY) ++# add_subdirectory(experimental/gemm) ++# endif() + + if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM) + # CUTLASS currently doesn't build on ROCm and CK hasnt yet been added: +diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake +index c56773fe..0c0d349e 100644 +--- a/fbgemm_gpu/FbgemmGpu.cmake ++++ b/fbgemm_gpu/FbgemmGpu.cmake +@@ -446,53 +446,55 @@ set_source_files_properties(${fbgemm_sources} + ################################################################################ + + set(fbgemm_gpu_sources_static_cpu +- codegen/training/forward/embedding_forward_split_cpu.cpp +- codegen/inference/embedding_forward_quantized_host_cpu.cpp +- codegen/training/backward/embedding_backward_dense_host_cpu.cpp +- codegen/utils/embedding_bounds_check_host_cpu.cpp +- src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp +- src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp +- src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp +- src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp +- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp +- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp +- src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp +- src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp +- src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp +- src/input_combine_ops/input_combine_cpu.cpp +- src/layout_transform_ops/layout_transform_ops_cpu.cpp ++ # codegen/training/forward/embedding_forward_split_cpu.cpp ++ # codegen/inference/embedding_forward_quantized_host_cpu.cpp ++ # codegen/training/backward/embedding_backward_dense_host_cpu.cpp ++ # codegen/utils/embedding_bounds_check_host_cpu.cpp ++ # src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp ++ # src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp ++ # src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp ++ # src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp ++ # src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp ++ # src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp ++ # src/input_combine_ops/input_combine_cpu.cpp ++ # src/layout_transform_ops/layout_transform_ops_cpu.cpp + src/quantize_ops/quantize_ops_cpu.cpp + src/quantize_ops/quantize_ops_meta.cpp +- src/sparse_ops/sparse_ops_cpu.cpp +- src/sparse_ops/sparse_ops_meta.cpp +- src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp +- src/split_embeddings_cache/linearize_cache_indices.cpp +- src/split_embeddings_cache/lfu_cache_populate_byte.cpp +- src/split_embeddings_cache/lru_cache_populate_byte.cpp +- src/split_embeddings_cache/lxu_cache.cpp +- src/split_embeddings_cache/split_embeddings_cache_ops.cpp +- codegen/training/index_select/batch_index_select_dim0_ops.cpp +- codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp) ++ # src/sparse_ops/sparse_ops_cpu.cpp ++ # src/sparse_ops/sparse_ops_meta.cpp ++ # src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp ++ # src/split_embeddings_cache/linearize_cache_indices.cpp ++ # src/split_embeddings_cache/lfu_cache_populate_byte.cpp ++ # src/split_embeddings_cache/lru_cache_populate_byte.cpp ++ # src/split_embeddings_cache/lxu_cache.cpp ++ # src/split_embeddings_cache/split_embeddings_cache_ops.cpp ++ # codegen/training/index_select/batch_index_select_dim0_ops.cpp ++ # codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp) ++) + + if(NOT FBGEMM_CPU_ONLY) + list(APPEND fbgemm_gpu_sources_static_cpu +- codegen/inference/embedding_forward_quantized_host.cpp +- codegen/utils/embedding_bounds_check_host.cpp +- src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp +- src/layout_transform_ops/layout_transform_ops_gpu.cpp +- src/memory_utils/memory_utils.cpp +- src/memory_utils/memory_utils_ops.cpp +- src/memory_utils/memory_utils_ops_cpu.cpp +- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp +- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp ++ # codegen/inference/embedding_forward_quantized_host.cpp ++ # codegen/utils/embedding_bounds_check_host.cpp ++ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp ++ # src/layout_transform_ops/layout_transform_ops_gpu.cpp ++ # src/memory_utils/memory_utils.cpp ++ # src/memory_utils/memory_utils_ops.cpp ++ # src/memory_utils/memory_utils_ops_cpu.cpp ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp + src/quantize_ops/quantize_ops_gpu.cpp +- src/sparse_ops/sparse_ops_gpu.cpp +- src/split_embeddings_utils/split_embeddings_utils.cpp +- src/split_embeddings_cache/split_embeddings_cache_ops.cu +- src/metric_ops/metric_ops_host.cpp +- src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp +- src/input_combine_ops/input_combine_gpu.cpp +- codegen/training/index_select/batch_index_select_dim0_host.cpp) ++ # src/sparse_ops/sparse_ops_gpu.cpp ++ # src/split_embeddings_utils/split_embeddings_utils.cpp ++ # src/split_embeddings_cache/split_embeddings_cache_ops.cu ++ # src/metric_ops/metric_ops_host.cpp ++ # src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp ++ # src/input_combine_ops/input_combine_gpu.cpp ++ # codegen/training/index_select/batch_index_select_dim0_host.cpp) ++ ) + + if(NVML_LIB_PATH OR USE_ROCM) + message(STATUS "Adding merge_pooled_embeddings sources") +@@ -516,36 +518,36 @@ endif() + + if(NOT FBGEMM_CPU_ONLY) + set(fbgemm_gpu_sources_static_gpu +- codegen/utils/embedding_bounds_check.cu +- codegen/inference/embedding_forward_quantized_split_lookup.cu +- src/embedding_inplace_ops/embedding_inplace_update.cu +- src/histogram_binning_calibration_ops.cu +- src/input_combine_ops/input_combine.cu +- src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu +- src/memory_utils/memory_utils.cu +- src/memory_utils/memory_utils_ops.cu +- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu +- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu +- src/jagged_tensor_ops/dense_to_jagged_forward.cu +- src/jagged_tensor_ops/jagged_dense_bmm_forward.cu +- src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu +- src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu +- src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu +- src/jagged_tensor_ops/jagged_index_add_2d_forward.cu +- src/jagged_tensor_ops/jagged_index_select_2d_forward.cu +- src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu +- src/jagged_tensor_ops/jagged_softmax_backward.cu +- src/jagged_tensor_ops/jagged_softmax_forward.cu +- src/jagged_tensor_ops/jagged_tensor_ops.cu +- src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu +- src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu +- src/jagged_tensor_ops/jagged_unique_indices.cu +- src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu +- src/layout_transform_ops/layout_transform_ops.cu +- src/metric_ops/metric_ops.cu +- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu +- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu +- src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu ++ # codegen/utils/embedding_bounds_check.cu ++ # codegen/inference/embedding_forward_quantized_split_lookup.cu ++ # src/embedding_inplace_ops/embedding_inplace_update.cu ++ # src/histogram_binning_calibration_ops.cu ++ # src/input_combine_ops/input_combine.cu ++ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu ++ # src/memory_utils/memory_utils.cu ++ # src/memory_utils/memory_utils_ops.cu ++ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu ++ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu ++ # src/jagged_tensor_ops/dense_to_jagged_forward.cu ++ # src/jagged_tensor_ops/jagged_dense_bmm_forward.cu ++ # src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu ++ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu ++ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu ++ # src/jagged_tensor_ops/jagged_index_add_2d_forward.cu ++ # src/jagged_tensor_ops/jagged_index_select_2d_forward.cu ++ # src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu ++ # src/jagged_tensor_ops/jagged_softmax_backward.cu ++ # src/jagged_tensor_ops/jagged_softmax_forward.cu ++ # src/jagged_tensor_ops/jagged_tensor_ops.cu ++ # src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu ++ # src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu ++ # src/jagged_tensor_ops/jagged_unique_indices.cu ++ # src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu ++ # src/layout_transform_ops/layout_transform_ops.cu ++ # src/metric_ops/metric_ops.cu ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu ++ # src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu + src/quantize_ops/quantize_bfloat16.cu + src/quantize_ops/quantize_fp8_rowwise.cu + src/quantize_ops/quantize_fused_8bit_rowwise.cu +@@ -554,39 +556,40 @@ if(NOT FBGEMM_CPU_ONLY) + src/quantize_ops/quantize_msfp.cu + src/quantize_ops/quantize_padded_fp8_rowwise.cu + src/quantize_ops/quantize_mx.cu +- src/sparse_ops/sparse_async_cumsum.cu +- src/sparse_ops/sparse_block_bucketize_features.cu +- src/sparse_ops/sparse_bucketize_features.cu +- src/sparse_ops/sparse_batched_unary_embeddings.cu +- src/sparse_ops/sparse_compute_frequency_sequence.cu +- src/sparse_ops/sparse_expand_into_jagged_permute.cu +- src/sparse_ops/sparse_group_index.cu +- src/sparse_ops/sparse_index_add.cu +- src/sparse_ops/sparse_index_select.cu +- src/sparse_ops/sparse_invert_permute.cu +- src/sparse_ops/sparse_pack_segments_backward.cu +- src/sparse_ops/sparse_pack_segments_forward.cu +- src/sparse_ops/sparse_permute_1d.cu +- src/sparse_ops/sparse_permute_2d.cu +- src/sparse_ops/sparse_permute102.cu +- src/sparse_ops/sparse_permute_embeddings.cu +- src/sparse_ops/sparse_range.cu +- src/sparse_ops/sparse_reorder_batched_ad.cu +- src/sparse_ops/sparse_segment_sum_csr.cu +- src/sparse_ops/sparse_zipf.cu +- src/split_embeddings_cache/lfu_cache_find.cu +- src/split_embeddings_cache/lfu_cache_populate.cu +- src/split_embeddings_cache/lfu_cache_populate_byte.cu +- src/split_embeddings_cache/lru_cache_find.cu +- src/split_embeddings_cache/lru_cache_populate.cu +- src/split_embeddings_cache/lru_cache_populate_byte.cu +- src/split_embeddings_cache/lxu_cache.cu +- src/split_embeddings_cache/linearize_cache_indices.cu +- src/split_embeddings_cache/reset_weight_momentum.cu +- src/split_embeddings_utils/generate_vbe_metadata.cu +- src/split_embeddings_utils/get_infos_metadata.cu +- src/split_embeddings_utils/radix_sort_pairs.cu +- src/split_embeddings_utils/transpose_embedding_input.cu) ++ # src/sparse_ops/sparse_async_cumsum.cu ++ # src/sparse_ops/sparse_block_bucketize_features.cu ++ # src/sparse_ops/sparse_bucketize_features.cu ++ # src/sparse_ops/sparse_batched_unary_embeddings.cu ++ # src/sparse_ops/sparse_compute_frequency_sequence.cu ++ # src/sparse_ops/sparse_expand_into_jagged_permute.cu ++ # src/sparse_ops/sparse_group_index.cu ++ # src/sparse_ops/sparse_index_add.cu ++ # src/sparse_ops/sparse_index_select.cu ++ # src/sparse_ops/sparse_invert_permute.cu ++ # src/sparse_ops/sparse_pack_segments_backward.cu ++ # src/sparse_ops/sparse_pack_segments_forward.cu ++ # src/sparse_ops/sparse_permute_1d.cu ++ # src/sparse_ops/sparse_permute_2d.cu ++ # src/sparse_ops/sparse_permute102.cu ++ # src/sparse_ops/sparse_permute_embeddings.cu ++ # src/sparse_ops/sparse_range.cu ++ # src/sparse_ops/sparse_reorder_batched_ad.cu ++ # src/sparse_ops/sparse_segment_sum_csr.cu ++ # src/sparse_ops/sparse_zipf.cu ++ # src/split_embeddings_cache/lfu_cache_find.cu ++ # src/split_embeddings_cache/lfu_cache_populate.cu ++ # src/split_embeddings_cache/lfu_cache_populate_byte.cu ++ # src/split_embeddings_cache/lru_cache_find.cu ++ # src/split_embeddings_cache/lru_cache_populate.cu ++ # src/split_embeddings_cache/lru_cache_populate_byte.cu ++ # src/split_embeddings_cache/lxu_cache.cu ++ # src/split_embeddings_cache/linearize_cache_indices.cu ++ # src/split_embeddings_cache/reset_weight_momentum.cu ++ # src/split_embeddings_utils/generate_vbe_metadata.cu ++ # src/split_embeddings_utils/get_infos_metadata.cu ++ # src/split_embeddings_utils/radix_sort_pairs.cu ++ # src/split_embeddings_utils/transpose_embedding_input.cu) ++ ) + + set_source_files_properties(${fbgemm_gpu_sources_static_gpu} + PROPERTIES COMPILE_OPTIONS +diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt +index 01f1d6ab..a6b8d7a8 100644 +--- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt ++++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt +@@ -25,23 +25,24 @@ set(fbgemm_sources_include_directories + ${THIRDPARTY}/json/include + ${NCCL_INCLUDE_DIRS}) + +-set(attention_ops_sources +- src/attention/attention.cpp +- src/attention/gqa_attn_splitk.cu) ++# set(attention_ops_sources ++# src/attention/attention.cpp ++# src/attention/gqa_attn_splitk.cu) + + set(quantize_ops_sources + src/quantize/cutlass_extensions.cu + src/quantize/quantize.cu + src/quantize/quantize.cpp) + +-set(comm_ops_sources +- src/comm/car.cu +- src/comm/car.cpp) ++# set(comm_ops_sources ++# src/comm/car.cu ++# src/comm/car.cpp) + + set(experimental_gen_ai_cpp_source_files +- ${attention_ops_sources} ++ # ${attention_ops_sources} + ${quantize_ops_sources} +- ${comm_ops_sources}) ++ # ${comm_ops_sources} ++) + + set_source_files_properties(${experimental_gen_ai_cpp_source_files} + PROPERTIES INCLUDE_DIRECTORIES diff --git a/server/fix_torch90a.sh b/server/fix_torch90a.sh new file mode 100755 index 00000000..5e444828 --- /dev/null +++ b/server/fix_torch90a.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +# This script is required to patch torch < 2.4 +# It adds the 90a cuda target (H100) +# This target is required to build FBGEMM kernels + +torch_cuda_arch=$(python -c "import torch; print(torch.__file__)" | sed 's/\/__init__.py//; s|$|/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake|') + +sed -i '189s/\[0-9]\\\\\.\[0-9](/[0-9]\\\\.[0-9]a?(/' $torch_cuda_arch +sed -i '245s/\[0-9()]+\+"/[0-9()]+a?"/' $torch_cuda_arch +sed -i '246s/\[0-9]+\+"/[0-9]+a?"/' $torch_cuda_arch diff --git a/server/marlin/marlin_kernels/__init__.pyi b/server/marlin/marlin_kernels/__init__.pyi index 663984d0..53464719 100644 --- a/server/marlin/marlin_kernels/__init__.pyi +++ b/server/marlin/marlin_kernels/__init__.pyi @@ -59,3 +59,18 @@ def marlin_gemm( Matrix multiplication using Marlin kernels. """ ... + +# fp8 marlin +def fp8_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return torch.ops._C.fp8_marlin_gemm( + a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k + ) diff --git a/server/marlin/marlin_kernels/ext.cpp b/server/marlin/marlin_kernels/ext.cpp index 37eccef6..04e1530f 100644 --- a/server/marlin/marlin_kernels/ext.cpp +++ b/server/marlin/marlin_kernels/ext.cpp @@ -9,4 +9,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("gptq_marlin_repack", &gptq_marlin_repack, "Repack GPTQ parameters for Marlin"); m.def("marlin_gemm", &marlin_gemm, "Marlin gemm"); + // fp8_marlin Optimized Quantized GEMM for FP8 weight-only. + m.def("fp8_marlin_gemm", &fp8_marlin_gemm); } diff --git a/server/marlin/marlin_kernels/ext.hh b/server/marlin/marlin_kernels/ext.hh index d1caaab7..102c058e 100644 --- a/server/marlin/marlin_kernels/ext.hh +++ b/server/marlin/marlin_kernels/ext.hh @@ -27,4 +27,9 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, torch::Tensor &b_scales, torch::Tensor &workspace, int64_t size_m, int64_t size_n, int64_t size_k); +torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k); + #endif diff --git a/server/marlin/marlin_kernels/fp8_marlin.cu b/server/marlin/marlin_kernels/fp8_marlin.cu new file mode 100644 index 00000000..aaef67e5 --- /dev/null +++ b/server/marlin/marlin_kernels/fp8_marlin.cu @@ -0,0 +1,1308 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#include "./gptq_marlin.cuh" +#include "./gptq_marlin_dtypes.cuh" + +using namespace gptq_marlin; + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || \ + std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +template +inline std::string str(T x) { + return std::to_string(x); +} + +namespace fp8_marlin { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) {} + +} // namespace fp8_marlin + +torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k) { + TORCH_CHECK_NOT_IMPLEMENTED(false, + "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +template +__device__ inline void mma(const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +template +__device__ inline void ldsm4(typename ScalarType::FragA& frag_a, + const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Fast FP8ToFp16/FP8ToBf16: Efficiently dequantize 8bit fp8_e4m3 values to fp16 +// bf16 Reference: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 +template +__device__ inline typename ScalarType::FragB dequant_8bit(int q) { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); +} + +template <> +__device__ inline typename ScalarType::FragB dequant_8bit(int q) { + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; + + // Calculate MASK for extracting mantissa and exponent + constexpr int MASK1 = 0x80000000; + constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); + constexpr int MASK3 = MASK2 & 0x7fffffff; + constexpr int MASK = MASK3 | (MASK3 >> 16); + // Final MASK value: 0x7F007F00 + + // Extract and shift FP8 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = + (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + typename ScalarType::FragB frag_b; + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); + frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant_8bit(int q) { + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + + // Calculate MASK for extracting mantissa and exponent + constexpr int MASK1 = 0x80000000; + constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); + constexpr int MASK3 = MASK2 & 0x7fffffff; + constexpr int MASK = MASK3 | (MASK3 >> 16); + // Final MASK value: 0x7F007F00 + + // Extract and shift FP8 values to BF16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = + (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = + __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to bfloat162 and apply bias + typename ScalarType::FragB frag_b; + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); + frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +template +__device__ inline void scale(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = + ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +// Given 2 floats multiply by 2 scales (halves) +template +__device__ inline void scale_float(float* c, + typename ScalarType::FragS& s) { + scalar_t* s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; + + constexpr int pack_factor = 32 / num_bits; + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + int slice_k_start = tb_k * slice_row; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // We scale a `half2` tile in row-major layout for column-wise quantization. + int s_sh_rd = + 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_s = sh_g_idx + (stages * g_idx_stage); + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + + B_ptr[i] += b_gl_rd_delta_o; + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], + &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + + frag_b0 = dequant_8bit(b_quant_0); + frag_b1 = dequant_8bit(b_quant_1); + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + Dtype::num2float(reinterpret_cast(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = + c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + scalar_t2 res = + Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + ((scalar_t2*)sh)[idx] = res; + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + + thread_block_reduce(); + + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + + start_pipes(); + } + } + } +} + + #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, GROUP_BLOCKS, NUM_THREADS) \ + else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin \ + <<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, num_groups, prob_m, prob_n, prob_k, \ + locks); \ + } + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +typedef struct { + int max_m_blocks; + thread_config_t tb_cfg; +} exec_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}, + {128, 64, 128}, +}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}, + {128, 64, 128}, + +}; + +int get_scales_cache_size(thread_config_t const& th_config, int prob_m, + int prob_n, int prob_k, int num_bits, + int group_size) { + int tb_n = th_config.thread_n; + + // Get max scale groups per thread-block + // Fixed for channelwise + int tb_groups = 1; + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; +} + +bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int scales_cache_size, int max_shared_mem) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + + int b_size = (tb_k * tb_n / pack_factor) * 4; + + // Get A size + int m_blocks = div_ceil(prob_m, 16); + int tb_max_m = 16; + + while (true) { + if (m_blocks >= max_m_blocks) { + tb_max_m *= max_m_blocks; + break; + } + + max_m_blocks--; + if (max_m_blocks == 0) { + TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); + } + } + + int a_size = (tb_max_m * tb_k) * 2; + + float pipe_size = (a_size + b_size) * pipe_stages; + + TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity + + return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); +} + +bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, int max_shared_mem) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || + th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + // Determine cache for scales + int scales_cache_size = get_scales_cache_size(th_config, prob_m, prob_n, + prob_k, num_bits, group_size); + + // Check that pipeline fits into cache + if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, scales_cache_size, max_shared_mem)) { + return false; + } + + return true; +} + +exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, + int num_bits, int group_size, + int max_shared_mem) { + int max_m_blocks = 4; + while (max_m_blocks > 0) { + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } + } + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } + } + } + + max_m_blocks--; // Process less M blocks per invocation to reduce cache + // usage + } + + return exec_config_t{0, {-1, -1, -1}}; +} + + #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) + +template +void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, int prob_m, + int prob_n, int prob_k, void* workspace, int num_bits, + int num_groups, int group_size, int dev, + cudaStream_t stream, int thread_k, int thread_n, int sms, + int max_par) { + TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits); + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + ", ", prob_n, ", ", prob_k, "]"); + + int tot_m = prob_m; + int tot_m_blocks = div_ceil(tot_m, 16); + int pad = 16 * tot_m_blocks - tot_m; + + if (sms == -1) { + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + } + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + // Set thread config + exec_config_t exec_cfg; + if (thread_k != -1 && thread_n != -1) { + // User-defined config + exec_cfg = + exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}}; + } else { + // Auto config + exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits, + group_size, max_shared_mem); + } + + TORCH_CHECK( + exec_cfg.max_m_blocks > 0 && + is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m, + prob_n, prob_k, num_bits, group_size, max_shared_mem), + "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, + ", thread_k = ", exec_cfg.tb_cfg.thread_k, + ", thread_n = ", exec_cfg.tb_cfg.thread_n, + ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", prob_m, + ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, ", max_shared_mem = ", max_shared_mem); + + int num_threads = exec_cfg.tb_cfg.num_threads; + thread_k = exec_cfg.tb_cfg.thread_k; + thread_n = exec_cfg.tb_cfg.thread_n; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + int blocks = sms; + + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + + int group_blocks = -1; + + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + const int4* s_ptr = (const int4*)s; + + int* locks = (int*)workspace; + + // Main loop + for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { + int thread_m_blocks = tot_m_blocks - i; + prob_m = tot_m - 16 * i; + int par = 1; + if (thread_m_blocks > exec_cfg.max_m_blocks) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); + if (par > max_par) par = max_par; + prob_m = (16 * exec_cfg.max_m_blocks) * par; + i += exec_cfg.max_m_blocks * (par - 1); + thread_m_blocks = exec_cfg.max_m_blocks; + } + + // Define kernel configurations + if (false) { + } + CALL_IF(8, 32, 2, 256) + CALL_IF(8, 16, 4, 256) + CALL_IF(8, 8, 8, 256) + CALL_IF(8, 8, 4, 128) + CALL_IF(8, 4, 8, 128) + else { + TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + + str(prob_n) + ", " + str(prob_k) + "]" + + ", num_groups = " + str(num_groups) + + ", group_size = " + str(group_size) + + ", thread_m_blocks = " + str(thread_m_blocks) + + ", thread_n_blocks = " + str(thread_n_blocks) + + ", thread_k_blocks = " + str(thread_k_blocks)); + } + + A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; + } +} + +} // namespace fp8_marlin + +torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k) { + // Verify num_bits + TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits); + int pack_factor = 32 / num_bits; + + // Verify A + TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), + ", size_m = ", size_m); + TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), + ", size_k = ", size_k); + + // Verify B + TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k, + " is not divisible by tile_size = ", gptq_marlin::tile_size); + TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), + ", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size); + TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0, + "b_q_weight.size(1) = ", b_q_weight.size(1), + " is not divisible by tile_size = ", gptq_marlin::tile_size); + int actual_size_n = + (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor; + TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, + ", actual_size_n = ", actual_size_n); + + // Verify device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + // Alloc buffers + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + torch::Tensor c = torch::empty({size_m, size_n}, options); + + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel (can usually be left as auto -1) + int sms = -1; + + // Detect groupsize and act_order + int num_groups = -1; + int group_size = -1; + + int b_rank = b_scales.sizes().size(); + TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2"); + TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1), + " is not size_n = ", size_n); + // Channelwise only for FP8 + TORCH_CHECK(b_scales.size(0) == 1) + num_groups = b_scales.size(0); + + // Verify workspace size + TORCH_CHECK( + size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n, + ", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n); + int min_workspace_size = + (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par; + TORCH_CHECK(workspace.numel() >= min_workspace_size, + "workspace.numel = ", workspace.numel(), + " is below min_workspace_size = ", min_workspace_size); + + int dev = a.get_device(); + if (a.scalar_type() == at::ScalarType::Half) { + fp8_marlin::marlin_mm_f16i4( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), + b_scales.data_ptr(), size_m, size_n, size_k, + workspace.data_ptr(), num_bits, num_groups, group_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + gptq_marlin::max_par); + } else if (a.scalar_type() == at::ScalarType::BFloat16) { + fp8_marlin::marlin_mm_f16i4( + a.data_ptr(), b_q_weight.data_ptr(), + c.data_ptr(), b_scales.data_ptr(), size_m, + size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size, + dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + gptq_marlin::max_par); + } else { + TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16"); + } + + return c; +} + +#endif diff --git a/server/marlin/setup.py b/server/marlin/setup.py index aed84e9e..cc38bccf 100644 --- a/server/marlin/setup.py +++ b/server/marlin/setup.py @@ -9,6 +9,7 @@ setup( CUDAExtension( name="marlin_kernels", sources=[ + "marlin_kernels/fp8_marlin.cu", "marlin_kernels/gptq_marlin.cu", "marlin_kernels/gptq_marlin_repack.cu", "marlin_kernels/marlin_cuda_kernel.cu", diff --git a/server/tests/utils/test_layers.py b/server/tests/utils/test_layers.py index 9a8da0d6..1e3aaf6b 100644 --- a/server/tests/utils/test_layers.py +++ b/server/tests/utils/test_layers.py @@ -2,6 +2,7 @@ import torch from text_generation_server.layers import ( TensorParallelEmbedding, ) +from text_generation_server.utils.weights import DefaultWeightsLoader class ProcessGroup: @@ -42,7 +43,12 @@ class Weights: def test_weight_hub_files_offline_error(): vocab_size = 17 - weights = Weights(rank=0, world_size=1, vocab_size=vocab_size, hidden_dim=256) + weights = Weights( + rank=0, + world_size=1, + vocab_size=vocab_size, + hidden_dim=256, + ) embeddings = TensorParallelEmbedding("", weights) input_ids = torch.arange(vocab_size) diff --git a/server/tests/utils/test_weights.py b/server/tests/utils/test_weights.py index 8f88b1f8..d2d2b76e 100644 --- a/server/tests/utils/test_weights.py +++ b/server/tests/utils/test_weights.py @@ -1,13 +1,48 @@ import pytest import torch -from text_generation_server.utils.weights import Weights -from text_generation_server.layers.gptq import GPTQWeight -from text_generation_server.layers.exl2 import Exl2Weight -from text_generation_server.layers.marlin import MarlinWeight +from text_generation_server.utils.weights import ( + DefaultWeightsLoader, + UnquantizedWeight, + Weights, + WeightsLoader, +) +from text_generation_server.layers.gptq import GPTQWeight, GPTQWeightsLoader +from text_generation_server.layers.exl2 import Exl2Weight, Exl2WeightsLoader +from text_generation_server.layers.marlin import MarlinWeight, MarlinWeightsLoader from types import SimpleNamespace from typing import List, Optional, Dict, Union from pathlib import Path + +@pytest.fixture +def gptq_weights_loader(): + return GPTQWeightsLoader( + bits=4, + groupsize=-1, + desc_act=False, + quant_method="gptq", + quantize="gptq", + sym=True, + ) + + +@pytest.fixture +def gptq_weights_loader_awq(): + return GPTQWeightsLoader( + bits=4, + groupsize=-1, + desc_act=False, + quant_method="awq", + quantize="awq", + sym=True, + ) + + +@pytest.fixture +def marlin_weights_loader(): + return MarlinWeightsLoader(bits=4, is_marlin_24=False) + + dummy_file_system = { "test_weights": { "layer.0.weight": torch.tensor( @@ -58,7 +93,7 @@ dummy_file_system = { dtype=torch.float32, ), }, - "test_get_multi_weights_row": { + "test_get_weights_row": { "weight.weight": torch.tensor( [ [1, 2], @@ -101,7 +136,7 @@ dummy_file_system = { "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), }, - "test_get_multi_weights_row_gptq": { + "test_get_weights_row_gptq": { "weight.qweight": torch.tensor( [ [1, 2], @@ -200,7 +235,7 @@ dummy_file_system = { "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_groups": torch.tensor([4], dtype=torch.int16), }, - "test_get_multi_weights_row_exl2": { + "test_get_weights_row_exl2": { "weight.q_weight": torch.tensor( [ [1, 2], @@ -245,7 +280,7 @@ dummy_file_system = { "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_groups": torch.tensor([4], dtype=torch.int16), }, - "test_get_multi_weights_row_marlin": { + "test_get_weights_row_marlin": { "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), }, @@ -308,6 +343,7 @@ class MockWeights(Weights): dummy_fs, aliases: Optional[Dict[str, List[str]]] = None, prefix: Optional[str] = None, + weights_loader: Optional[WeightsLoader] = None, ): routing = {} self.dummy_fs = dummy_fs @@ -327,6 +363,12 @@ class MockWeights(Weights): self.dtype = dtype self.process_group = process_group self.prefix = prefix + self.weights_loader = ( + # We don't need to get linear layers, so just wrap raw tensors. + DefaultWeightsLoader(lambda x: x) + if weights_loader is None + else weights_loader + ) self._handles = {} def _get_handle(self, filename: Union[Path, str]): @@ -412,12 +454,10 @@ def test_get_weights_col_packed(): ) prefix = "weight" - quantize = None block_sizes = 1 w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -448,12 +488,10 @@ def test_get_weights_col_packed_block_size(): ) prefix = "weight" - quantize = None block_sizes = 2 w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -484,12 +522,10 @@ def test_get_weights_col_packed_block_size_arr(): ) prefix = "weight" - quantize = None block_sizes = [1, 1] w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -519,11 +555,9 @@ def test_get_multi_weights_col(): ) prefixes = ["weight", "weight"] - quantize = None w = weights.get_multi_weights_col( prefixes=prefixes, - quantize=quantize, dim=0, ) @@ -545,10 +579,10 @@ def test_get_multi_weights_col(): ) -def test_get_multi_weights_row(): +def test_get_weights_row(): weights = MockWeights( [ - "test_get_multi_weights_row", + "test_get_weights_row", ], device="cpu", dtype=torch.float32, @@ -557,11 +591,9 @@ def test_get_multi_weights_row(): ) prefix = "weight" - quantize = None - w = weights.get_multi_weights_row( + w = weights.get_weights_row( prefix=prefix, - quantize=quantize, ) assert torch.allclose( @@ -576,7 +608,7 @@ def test_get_multi_weights_row(): # test_get_weights_col -def test_get_weights_col_awq(): +def test_get_weights_col_awq(gptq_weights_loader_awq): weights = MockWeights( [ "test_get_weights_col_gptq", @@ -585,14 +617,13 @@ def test_get_weights_col_awq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader_awq, ) prefix = "weight" - quantize = "awq" w = weights.get_weights_col( prefix=prefix, - quantize=quantize, ) expected_weight = GPTQWeight( @@ -605,6 +636,7 @@ def test_get_weights_col_awq(): g_idx=None, bits=8.0, groupsize=2.0, + use_awq_kernel=True, use_exllama=False, ) @@ -614,10 +646,11 @@ def test_get_weights_col_awq(): assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_weights_col_gtpq(): +def test_get_weights_col_gtpq(gptq_weights_loader): weights = MockWeights( [ "test_get_weights_col_gptq", @@ -626,14 +659,13 @@ def test_get_weights_col_gtpq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader, ) prefix = "weight" - quantize = "gptq" w = weights.get_weights_col( prefix=prefix, - quantize=quantize, ) expected_weight = GPTQWeight( @@ -643,6 +675,7 @@ def test_get_weights_col_gtpq(): g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), bits=8.0, groupsize=2.0, + use_awq_kernel=False, use_exllama=False, ) @@ -652,6 +685,7 @@ def test_get_weights_col_gtpq(): assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -664,14 +698,13 @@ def test_get_weights_col_exl2(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=Exl2WeightsLoader(), ) prefix = "weight" - quantize = "exl2" w = weights.get_weights_col( prefix=prefix, - quantize=quantize, ) scaled_scale_max = 0.3906 * 256 @@ -692,7 +725,7 @@ def test_get_weights_col_exl2(): assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" -def test_get_weights_col_marlin(): +def test_get_weights_col_marlin(marlin_weights_loader): weights = MockWeights( [ "test_get_weights_col_marlin", @@ -701,14 +734,13 @@ def test_get_weights_col_marlin(): dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=marlin_weights_loader, ) prefix = "weight" - quantize = "marlin" w = weights.get_weights_col( prefix=prefix, - quantize=quantize, ) expected_weight = MarlinWeight( @@ -723,7 +755,7 @@ def test_get_weights_col_marlin(): # test_get_weights_col_packed -def test_get_weights_col_packed_awq(): +def test_get_weights_col_packed_awq(gptq_weights_loader_awq): weights = MockWeights( [ "test_get_weights_col_packed_gptq", @@ -732,15 +764,14 @@ def test_get_weights_col_packed_awq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader_awq, ) prefix = "weight" - quantize = "awq" block_sizes = 1 w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -751,6 +782,7 @@ def test_get_weights_col_packed_awq(): g_idx=None, bits=8.0, groupsize=2.0, + use_awq_kernel=True, use_exllama=False, ) @@ -760,6 +792,7 @@ def test_get_weights_col_packed_awq(): assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -773,15 +806,14 @@ def test_get_weights_col_packed_exl2(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=Exl2WeightsLoader(), ) prefix = "weight" - quantize = "exl2" block_sizes = 1 w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -803,7 +835,7 @@ def test_get_weights_col_packed_exl2(): assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" -def test_get_weights_col_packed_gptq(): +def test_get_weights_col_packed_gptq(gptq_weights_loader): weights = MockWeights( [ "test_get_weights_col_packed_gptq", @@ -812,14 +844,13 @@ def test_get_weights_col_packed_gptq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader, ) prefixes = ["weight"] - quantize = "gptq" w = weights.get_multi_weights_col( prefixes=prefixes, - quantize=quantize, dim=0, ) @@ -830,6 +861,7 @@ def test_get_weights_col_packed_gptq(): g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), bits=8.0, groupsize=2.0, + use_awq_kernel=False, use_exllama=False, ) @@ -839,10 +871,11 @@ def test_get_weights_col_packed_gptq(): assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_weights_col_packed_marlin(): +def test_get_weights_col_packed_marlin(marlin_weights_loader): weights = MockWeights( [ "test_get_weights_col_packed_marlin", @@ -851,14 +884,13 @@ def test_get_weights_col_packed_marlin(): dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=marlin_weights_loader, ) prefix = "weight" - quantize = "marlin" w = weights.get_multi_weights_col( prefixes=[prefix], - quantize=quantize, dim=0, ) @@ -876,7 +908,7 @@ def test_get_weights_col_packed_marlin(): # test_get_multi_weights_col -def test_get_multi_weights_col_awq(): +def test_get_multi_weights_col_awq(gptq_weights_loader_awq): weights = MockWeights( [ "test_get_multi_weights_col_gptq", @@ -885,14 +917,13 @@ def test_get_multi_weights_col_awq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader_awq, ) prefixes = ["weight"] - quantize = "awq" w = weights.get_multi_weights_col( prefixes=prefixes, - quantize=quantize, dim=0, ) @@ -903,6 +934,7 @@ def test_get_multi_weights_col_awq(): g_idx=None, bits=8.0, groupsize=2.0, + use_awq_kernel=True, use_exllama=False, ) @@ -912,6 +944,7 @@ def test_get_multi_weights_col_awq(): assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -924,22 +957,21 @@ def test_get_multi_weights_col_exl2(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=Exl2WeightsLoader(), ) prefix = "weight" - quantize = "exl2" try: w = weights.get_multi_weights_col( prefixes=[prefix], - quantize=quantize, dim=0, ) except ValueError as e: assert e.args[0] == "get_multi_weights_col is not supported for exl2" -def test_get_multi_weights_col_gptq(): +def test_get_multi_weights_col_gptq(gptq_weights_loader): weights = MockWeights( [ "test_get_multi_weights_col_gptq", @@ -948,14 +980,13 @@ def test_get_multi_weights_col_gptq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader, ) prefixes = ["weight"] - quantize = "gptq" w = weights.get_multi_weights_col( prefixes=prefixes, - quantize=quantize, dim=0, ) @@ -966,6 +997,7 @@ def test_get_multi_weights_col_gptq(): g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), bits=8.0, groupsize=2.0, + use_awq_kernel=False, use_exllama=False, ) @@ -975,10 +1007,11 @@ def test_get_multi_weights_col_gptq(): assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_multi_weights_col_marlin(): +def test_get_multi_weights_col_marlin(marlin_weights_loader): weights = MockWeights( [ "test_get_multi_weights_col_marlin", @@ -987,14 +1020,13 @@ def test_get_multi_weights_col_marlin(): dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=marlin_weights_loader, ) prefix = "weight" - quantize = "marlin" w = weights.get_multi_weights_col( prefixes=[prefix], - quantize=quantize, dim=0, ) @@ -1007,26 +1039,25 @@ def test_get_multi_weights_col_marlin(): assert torch.allclose(w.s, expected_weight.s), "s mismatch" -# test_get_multi_weights_row +# test_get_weights_row -def test_get_multi_weights_row_awq(): +def test_get_weights_row_awq(gptq_weights_loader_awq): weights = MockWeights( [ - "test_get_multi_weights_row_gptq", + "test_get_weights_row_gptq", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader_awq, ) prefix = "weight" - quantize = "awq" - w = weights.get_multi_weights_row( + w = weights.get_weights_row( prefix=prefix, - quantize=quantize, ) expected_weight = GPTQWeight( @@ -1036,6 +1067,7 @@ def test_get_multi_weights_row_awq(): g_idx=None, bits=8.0, groupsize=2.0, + use_awq_kernel=True, use_exllama=False, ) @@ -1045,26 +1077,26 @@ def test_get_multi_weights_row_awq(): assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_multi_weights_row_exl2(): +def test_get_weights_row_exl2(): weights = MockWeights( [ - "test_get_multi_weights_row_exl2", + "test_get_weights_row_exl2", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=Exl2WeightsLoader(), ) prefix = "weight" - quantize = "exl2" - w = weights.get_multi_weights_row( + w = weights.get_weights_row( prefix=prefix, - quantize=quantize, ) print(w) @@ -1086,23 +1118,22 @@ def test_get_multi_weights_row_exl2(): assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" -def test_get_multi_weights_row_gptq(): +def test_get_weights_row_gptq(gptq_weights_loader): weights = MockWeights( [ - "test_get_multi_weights_row_gptq", + "test_get_weights_row_gptq", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader, ) prefix = "weight" - quantize = "gptq" - w = weights.get_multi_weights_row( + w = weights.get_weights_row( prefix=prefix, - quantize=quantize, ) expected_weight = GPTQWeight( @@ -1112,6 +1143,7 @@ def test_get_multi_weights_row_gptq(): g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), bits=8.0, groupsize=2.0, + use_awq_kernel=False, use_exllama=False, ) @@ -1121,26 +1153,26 @@ def test_get_multi_weights_row_gptq(): assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_multi_weights_row_marlin(): +def test_get_weights_row_marlin(marlin_weights_loader): weights = MockWeights( [ - "test_get_multi_weights_row_marlin", + "test_get_weights_row_marlin", ], device="cpu", dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=marlin_weights_loader, ) prefix = "weight" - quantize = "marlin" - w = weights.get_multi_weights_row( + w = weights.get_weights_row( prefix=prefix, - quantize=quantize, ) expected_weight = MarlinWeight( diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 68ae95dd..8ec2a5ae 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -8,6 +8,7 @@ from typing import Optional from enum import Enum from huggingface_hub import hf_hub_download +from text_generation_server.utils.log import log_master app = typer.Typer() @@ -87,10 +88,21 @@ def serve( ) if len(lora_adapter_ids) > 0: - logger.warning( - f"LoRA adapters are enabled. This is an experimental feature and may not work as expected." + log_master( + logger.warning, + f"LoRA adapters are enabled. This is an experimental feature and may not work as expected.", ) + # TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled + # and warn the user + if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None: + log_master( + logger.warning, + f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.", + ) + global CUDA_GRAPHS + CUDA_GRAPHS = None + # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value dtype = None if dtype is None else dtype.value @@ -332,6 +344,7 @@ def quantize( upload_to_model_id: Optional[str] = None, percdamp: float = 0.01, act_order: bool = False, + groupsize: int = 128, ): if revision is None: revision = "main" @@ -346,13 +359,14 @@ def quantize( quantize( model_id=model_id, bits=4, - groupsize=128, + groupsize=groupsize, output_dir=output_dir, revision=revision, trust_remote_code=trust_remote_code, upload_to_model_id=upload_to_model_id, percdamp=percdamp, act_order=act_order, + sym=True, ) diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 99c490d5..54da63e8 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -3,6 +3,7 @@ import torch from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.layers.attention import Seqlen +from text_generation_server.utils.log import log_master from loguru import logger major, minor = torch.cuda.get_device_capability() @@ -136,7 +137,10 @@ if ENGINE != "triton": try: import flash_attn_2_cuda - logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.") + log_master( + logger.info, + "ROCm: using Flash Attention 2 Composable Kernel implementation.", + ) except ImportError as e: if major >= 8: architecture_suffix = f"-{SYSTEM}" diff --git a/server/text_generation_server/layers/bnb.py b/server/text_generation_server/layers/bnb.py index ca39919c..aae2bd1a 100644 --- a/server/text_generation_server/layers/bnb.py +++ b/server/text_generation_server/layers/bnb.py @@ -1,15 +1,18 @@ -import torch -from loguru import logger +from dataclasses import dataclass from functools import lru_cache + import bitsandbytes as bnb +import torch from bitsandbytes.nn import Int8Params, Params4bit +from text_generation_server.utils.weights import UnquantizedWeight -@lru_cache(1) -def warn_deprecate_bnb(): - logger.warning( - "Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce" - ) +@dataclass +class BNBWeight(UnquantizedWeight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + return Linear8bitLt(self.weight, bias, has_fp16_weights=False, threshold=6.0) class Linear8bitLt(torch.nn.Module): @@ -70,6 +73,22 @@ class Linear8bitLt(torch.nn.Module): return out +@dataclass +class BNBFP4Weight(UnquantizedWeight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + return Linear4bit(self.weight, bias, quant_type="fp4") + + +@dataclass +class BNBNF4Weight(UnquantizedWeight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + return Linear4bit(self.weight, bias, quant_type="nf4") + + class Linear4bit(torch.nn.Module): def __init__(self, weight, bias, quant_type): super().__init__() diff --git a/server/text_generation_server/layers/eetq.py b/server/text_generation_server/layers/eetq.py index fd22b5c6..b1e5235a 100644 --- a/server/text_generation_server/layers/eetq.py +++ b/server/text_generation_server/layers/eetq.py @@ -1,5 +1,23 @@ +from dataclasses import dataclass + import torch from EETQ import quant_weights, w8_a16_gemm +from text_generation_server.utils.weights import UnquantizedWeight + + +@dataclass +class EETQWeight(UnquantizedWeight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + try: + from text_generation_server.layers.eetq import EETQLinear + + return EETQLinear(self.weight, bias) + except ImportError: + raise ImportError( + "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" + ) class EETQLinear(torch.nn.Module): diff --git a/server/text_generation_server/layers/exl2.py b/server/text_generation_server/layers/exl2.py index f6cb729e..a6e07f45 100644 --- a/server/text_generation_server/layers/exl2.py +++ b/server/text_generation_server/layers/exl2.py @@ -1,9 +1,12 @@ -import torch from dataclasses import dataclass +from typing import List, Union + +import torch +from text_generation_server.utils.weights import Weight, Weights, WeightsLoader @dataclass -class Exl2Weight: +class Exl2Weight(Weight): """ Exllama2 exl2 quantized weights. """ @@ -21,3 +24,55 @@ class Exl2Weight: @property def device(self) -> torch.device: return self.q_weight.device + + def get_linear(self, bias: torch.Tensor): + from text_generation_server.layers.gptq import ExllamaQuantLinear + + return ExllamaQuantLinear(self, bias) + + +class Exl2WeightsLoader(WeightsLoader): + """Loader for exl2-quantized weights.""" + + def get_weights(self, weights: "Weights", prefix: str): + """ + Get weights at the given prefix and apply without tensor paralllism. + """ + try: + q_weight = weights.get_tensor(f"{prefix}.q_weight") + except RuntimeError: + raise RuntimeError( + "Cannot load `exl2`-quantized weight, make sure the model is already quantized." + ) + + q_scale = weights.get_tensor(f"{prefix}.q_scale") + q_invperm = weights.get_tensor(f"{prefix}.q_invperm") + q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max") + q_groups = weights.get_tensor(f"{prefix}.q_groups") + + return Exl2Weight( + q_weight=q_weight, + q_scale=q_scale, + q_invperm=q_invperm, + q_scale_max=q_scale_max, + q_groups=q_groups, + ) + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + raise RuntimeError("Column-packed weights are not supported for exl") + + def get_weights_col(self, weights: Weights, prefix: str): + # Sharding is not yet supported, so we return the weights as-is. + return self.get_weights(weights, prefix) + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + raise ValueError("get_multi_weights_col is not supported for exl2") + + def get_weights_row(self, weights: Weights, prefix: str): + # Sharding is not yet supported, so we return the weights as-is. + return self.get_weights(weights, prefix) diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index dd61d081..cdf16d6b 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -1,12 +1,58 @@ import torch +from dataclasses import dataclass +from typing import Optional, Union, List +from loguru import logger + +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.weights import ( + Weight, + WeightsLoader, + UnquantizedWeight, + Weights, +) +from text_generation_server.utils.log import log_master, log_once + +FBGEMM_MM_AVAILABLE = False +FBGEMM_DYN_AVAILABLE = False +try: + import fbgemm_gpu.experimental.gen_ai + + if SYSTEM == "cuda": + major, _ = torch.cuda.get_device_capability() + FBGEMM_MM_AVAILABLE = major == 9 + FBGEMM_DYN_AVAILABLE = major >= 8 +except (ImportError, ModuleNotFoundError): + log_master(logger.warning, "FBGEMM fp8 kernels are not installed.") + + +def get_fp8_linear() -> torch.nn.Module: + """ + Return an FP8 linear `Module` that is compatible with the current system. + """ + + if SYSTEM == "cuda": + major, minor = torch.cuda.get_device_capability() + if major == 8 and minor < 9: + from text_generation_server.layers.marlin import GPTQMarlinFP8Linear + + return GPTQMarlinFP8Linear + + # On other systems let Torch decide if the hardware supports FP8. + return Fp8Linear + + +def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn): + if FBGEMM_DYN_AVAILABLE: + qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( + weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype + ) + return qweight, scale -def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): - device = weight.device # weight, scale = quant_weights(weight, torch.int8, False) finfo = torch.finfo(qdtype) # Calculate the scale as dtype max divided by absmax - scale = finfo.max / weight.abs().max().clamp(min=1e-12) + scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound) # scale and clamp the tensor to bring it to # the representative range of float8 data type # (as default cast is unsaturated) @@ -18,19 +64,166 @@ def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): return qweight, scale +class HybridFP8UnquantLoader(WeightsLoader): + """Weight loader that loads FP8 and unquantized Torch tensors.""" + + def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool): + self.activation_scale_ub = activation_scale_ub + self.to_fp8 = to_fp8 + + def get_weights(self, weights: "Weights", prefix: str): + w = weights.get_tensor(f"{prefix}.weight") + + if w.dtype == torch.float8_e4m3fn: + # FP8 branch + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + w = weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes + ) + + if w.dtype == torch.float8_e4m3fn: + # FP8 branch + scale = weights.get_packed_sharded( + f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False + ) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): + w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] + w = torch.cat(w, dim=dim) + + # FP8 branch + if w.dtype == torch.float8_e4m3fn: + scale = [ + weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False) + for p in prefixes + ] + scale = torch.cat(scale, dim=0) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + def get_weights_row(self, weights: "Weights", prefix: str): + w = weights.get_sharded(f"{prefix}.weight", dim=1) + # FP8 branch + if w.dtype == torch.float8_e4m3fn: + scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, to_dtype=False) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + +@dataclass +class Fp8Weight(Weight): + weight: torch.Tensor + dtype: torch.dtype + weight_scale: Optional[torch.Tensor] = None + activation_scale_ub: Optional[float] = None + + def get_linear(self, bias: torch.Tensor): + if self.weight_scale is None: + return get_fp8_linear().from_unquant(self.weight, bias, self.dtype) + return get_fp8_linear().from_fp8( + self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype + ) + + class Fp8Linear(torch.nn.Module): def __init__( self, - weight, + qweight, + scale, + scale_upper_bound, bias, + dtype, ) -> None: super().__init__() - self.dtype = weight.dtype - self.qweight, self.scale = fp8_quantize(weight) + self.dtype = dtype + self.qweight = qweight + self.scale = scale + self.scale_upper_bound = ( + torch.tensor( + [scale_upper_bound], dtype=torch.float32, device=qweight.device + ) + if scale_upper_bound is not None + else None + ) self.bias = bias if bias is not None else None + @classmethod + def from_unquant(cls, weight, bias, dtype): + qweight, scale = fp8_quantize(weight) + return cls( + qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype + ) + + @classmethod + def from_fp8(cls, weight, scale, input_scale, bias, dtype): + return cls( + qweight=weight, + scale=scale, + scale_upper_bound=input_scale, + bias=bias, + dtype=dtype, + ) + def forward(self, input: torch.Tensor) -> torch.Tensor: + if FBGEMM_MM_AVAILABLE: + qinput, scale = fp8_quantize( + input, scale_upper_bound=self.scale_upper_bound + ) + + y = torch.ops.fbgemm.f8f8bf16_rowwise( + qinput, + self.qweight, + scale, + self.scale, + use_fast_accum=True, + bias=self.bias, + ) + return y.to(self.dtype) + qinput, scale = fp8_quantize(input) output, _ = torch._scaled_mm( qinput, diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 56080145..6feca275 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -1,30 +1,23 @@ -from dataclasses import dataclass import os -from typing import Optional +from dataclasses import dataclass +from typing import List, Optional, Union + import torch -from text_generation_server.utils.import_utils import ( - SYSTEM, -) +from loguru import logger +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.log import log_once +from text_generation_server.utils.weights import Weight, Weights, WeightsLoader @dataclass -class GPTQParams: - bits: int - checkpoint_format: Optional[str] - groupsize: int - desc_act: bool - quant_method: str - sym: bool - - -@dataclass -class GPTQWeight: +class GPTQWeight(Weight): qweight: torch.Tensor qzeros: torch.Tensor scales: torch.Tensor g_idx: Optional[torch.Tensor] bits: int groupsize: int + use_awq_kernel: bool use_exllama: bool def __post_init__(self): @@ -35,6 +28,50 @@ class GPTQWeight: def device(self) -> torch.device: return self.qweight.device + def get_linear(self, bias: torch.Tensor): + if self.use_awq_kernel: + if SYSTEM == "rocm": + raise NotImplementedError( + "AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead " + "to use Exllama/GPTQ kernels for AWQ inference." + ) + try: + from text_generation_server.layers.awq.quantize.qmodule import WQLinear + + return WQLinear( + w_bit=self.bits, + group_size=self.groupsize, + qweight=self.qweight, + qzeros=self.qzeros, + scales=self.scales, + bias=bias, + ) + except ImportError: + raise NotImplementedError( + "You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly" + ) + elif self.use_exllama: + try: + from text_generation_server.layers.gptq import ExllamaQuantLinear + except ImportError: + raise NotImplementedError( + f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" + ) + + return ExllamaQuantLinear(self, bias) + else: + from text_generation_server.layers.gptq.quant_linear import QuantLinear + + return QuantLinear( + self.qweight, + self.qzeros, + self.scales, + self.g_idx, + bias, + self.bits, + self.groupsize, + ) + try: major, _minor = torch.cuda.get_device_capability() @@ -51,6 +88,8 @@ elif CAN_EXLLAMA: if V2: from text_generation_server.layers.gptq.exllamav2 import ( QuantLinear as ExllamaQuantLinear, + ) + from text_generation_server.layers.gptq.exllamav2 import ( create_exllama_buffers, set_device, ) @@ -59,6 +98,8 @@ elif CAN_EXLLAMA: else: from text_generation_server.layers.gptq.exllama import ( Ex4bitLinear as ExllamaQuantLinear, + ) + from text_generation_server.layers.gptq.exllama import ( create_exllama_buffers, set_device, ) @@ -69,3 +110,457 @@ elif CAN_EXLLAMA: pass from text_generation_server.layers.gptq.quant_linear import QuantLinear + + +class GPTQWeightsLoader(WeightsLoader): + """ + Loader for GPTQ- and AWQ-quantized weights. + """ + + def __init__( + self, + *, + bits: int, + desc_act: bool, + groupsize: int, + quant_method: str, + quantize: str, + sym: bool, + ): + self.bits = bits + self.desc_act = desc_act + self.groupsize = groupsize + self.quant_method = quant_method + self.quantize = quantize + self.sym = sym + + def get_weights(self, weights: Weights, prefix: str): + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) + + self._get_gptq_params(weights) + if can_use_gptq_marlin( + bits=self.bits, + groupsize=self.groupsize, + quant_method=self.quant_method, + quantize=self.quantize, + sym=self.sym, + ): + log_once(logger.info, "Using GPTQ-Marlin kernels") + try: + qweight = weights.get_tensor(f"{prefix}.qweight") + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" + ) + + g_idx = weights.get_tensor(f"{prefix}.g_idx") + scales = weights.get_tensor(f"{prefix}.scales") + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + sym=self.sym, + sharded_infeatures=False, + ) + + use_exllama = True + if self.bits != 4: + use_exllama = False + + if self.desc_act: + log_once(logger.warning, "Disabling exllama because desc_act=True") + use_exllama = False + + try: + qweight = weights.get_tensor(f"{prefix}.qweight") + except RuntimeError: + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + if self.quantize == "gptq" and self.quant_method == "gptq": + g_idx = weights.get_tensor(f"{prefix}.g_idx") + else: + g_idx = None + + from text_generation_server.layers.gptq import ( + HAS_EXLLAMA, + CAN_EXLLAMA, + GPTQWeight, + ) + + if use_exllama: + if not HAS_EXLLAMA: + if CAN_EXLLAMA: + log_once( + logger.warning, + "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True", + ) + use_exllama = False + else: + log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") + + qzeros = weights.get_tensor(f"{prefix}.qzeros") + scales = weights.get_tensor(f"{prefix}.scales") + + if use_exllama and g_idx is not None: + g_idx = g_idx - g_idx[0] + + if self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + // self.groupsize + ).to(dtype=torch.int32) + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_exllama=use_exllama, + ) + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) + + try: + qweight = weights.get_packed_sharded( + f"{prefix}.qweight", dim=1, block_sizes=block_sizes + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized." + ) + scales = weights.get_packed_sharded( + f"{prefix}.scales", dim=1, block_sizes=block_sizes + ) + scales = scales.to(dtype=weights.dtype) + + self._get_gptq_params(weights) + if can_use_gptq_marlin( + bits=self.bits, + groupsize=self.groupsize, + quant_method=self.quant_method, + quantize=self.quantize, + sym=self.sym, + ): + g_idx = weights.get_tensor(f"{prefix}.g_idx") + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + sym=self.sym, + sharded_infeatures=False, + ) + + qzeros = weights.get_packed_sharded( + f"{prefix}.qzeros", dim=1, block_sizes=block_sizes + ) + if self.quantize == "gptq" and self.quant_method == "gptq": + g_idx = weights.get_tensor(f"{prefix}.g_idx") + elif self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + // self.groupsize + ).to(dtype=torch.int32) + else: + g_idx = None + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_awq_kernel=self.quantize == "awq", + use_exllama=False, + ) + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) + + try: + qweight = torch.cat( + [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized" + ) + + scales = torch.cat( + [weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 + ) + + self._get_gptq_params(weights) + if can_use_gptq_marlin( + bits=self.bits, + groupsize=self.groupsize, + quant_method=self.quant_method, + quantize=self.quantize, + sym=self.sym, + ): + w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + sym=self.sym, + sharded_infeatures=False, + ) + + qzeros = torch.cat( + [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + ) + + from text_generation_server.layers.gptq import HAS_EXLLAMA + + use_exllama = ( + self.bits == 4 + and HAS_EXLLAMA + and self.quantize == "gptq" + and not self.desc_act + ) + + if self.quantize == "gptq" and self.quant_method == "gptq": + w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + elif self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + // self.groupsize + ).to(dtype=torch.int32) + else: + g_idx = None + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_awq_kernel=self.quantize == "awq", + use_exllama=use_exllama, + ) + + def get_weights_row(self, weights: Weights, prefix: str): + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) + + self._get_gptq_params(weights) + if can_use_gptq_marlin( + bits=self.bits, + groupsize=self.groupsize, + quant_method=self.quant_method, + quantize=self.quantize, + sym=self.sym, + ): + log_once(logger.info, "Using GPTQ-Marlin kernels") + try: + qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" + ) + + g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) + if self.desc_act or self.groupsize == -1: + scales = weights.get_tensor(f"{prefix}.scales") + else: + scales = weights.get_sharded(f"{prefix}.scales", dim=0) + + sharded_in_features = weights.process_group.size() > 1 + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + sym=self.sym, + sharded_infeatures=sharded_in_features, + ) + + use_exllama = True + if self.bits != 4: + use_exllama = False + + if self.desc_act: + log_once(logger.warning, "Disabling exllama because desc_act=True") + use_exllama = False + + try: + qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + if self.quantize == "gptq" and self.quant_method == "gptq": + g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) + else: + g_idx = None + + if weights.process_group.size() > 1: + if g_idx is not None: + if ( + not torch.equal( + g_idx.cpu(), + torch.tensor( + [i // self.groupsize for i in range(g_idx.shape[0])], + dtype=torch.int32, + ), + ) + and not (g_idx == 0).all() + ): + # Exllama implementation does not support row tensor parallelism with act-order, as + # it would require to reorder input activations that are split unto several GPUs + use_exllama = False + + from text_generation_server.layers.gptq import ( + CAN_EXLLAMA, + HAS_EXLLAMA, + GPTQWeight, + ) + + if use_exllama: + if not HAS_EXLLAMA: + if CAN_EXLLAMA: + log_once( + logger.warning, + "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True", + ) + use_exllama = False + else: + log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") + + if use_exllama and self.groupsize != -1: + qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) + scales = weights.get_sharded(f"{prefix}.scales", dim=0) + else: + qzeros = weights.get_tensor(f"{prefix}.qzeros") + scales = weights.get_tensor(f"{prefix}.scales") + + if use_exllama and g_idx is not None: + g_idx = g_idx - g_idx[0] + + if self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + // self.groupsize + ).to(dtype=torch.int32) + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_awq_kernel=self.quantize == "awq", + use_exllama=use_exllama, + ) + + def _get_gptq_params(self, weights: Weights): + if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"): + self.bits = weights.get_tensor("gptq_bits").item() + self.groupsize = weights.get_tensor("gptq_groupsize").item() + self.desc_act = False + # `server quantize` used asymmetric quantization unconditionally + # before the `gptq_sym` setting tensor was added. + self.sym = ( + weights.get_tensor("gptq_sym").item() + if weights._has_tensor("gptq_sym") + else False + ) + self.quant_method = "gptq" diff --git a/server/text_generation_server/layers/gptq/exllamav2.py b/server/text_generation_server/layers/gptq/exllamav2.py index 4d45822b..dc3b832f 100644 --- a/server/text_generation_server/layers/gptq/exllamav2.py +++ b/server/text_generation_server/layers/gptq/exllamav2.py @@ -9,11 +9,12 @@ from loguru import logger from text_generation_server.layers.exl2 import Exl2Weight from text_generation_server.layers.gptq import GPTQWeight +from text_generation_server.utils.log import log_master try: from exllamav2_kernels import make_q_matrix, gemm_half_q_half except ImportError: - logger.error("exllamav2_kernels not installed.") + log_master(logger.warning, "exllamav2_kernels not installed.") raise # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension diff --git a/server/text_generation_server/layers/gptq/quantize.py b/server/text_generation_server/layers/gptq/quantize.py index 8d029817..0271d913 100644 --- a/server/text_generation_server/layers/gptq/quantize.py +++ b/server/text_generation_server/layers/gptq/quantize.py @@ -16,6 +16,8 @@ from text_generation_server.layers.gptq.quant_linear import QuantLinear from loguru import logger from typing import Optional +from text_generation_server.utils.weights import DefaultWeightsLoader + DEV = torch.device("cuda:0") @@ -869,6 +871,7 @@ def quantize( upload_to_model_id: Optional[str], percdamp: float, act_order: bool, + sym: bool, ): print("loading model") config = AutoConfig.from_pretrained( @@ -891,6 +894,7 @@ def quantize( dtype=torch.float16, process_group=process_group, aliases={"embed_tokens.weight": ["lm_head.weight"]}, + weights_loader=DefaultWeightsLoader(), ) hooks = [] for name, module in model.named_modules(): @@ -943,6 +947,7 @@ def quantize( percdamp=percdamp, act_order=act_order, hooks=hooks, + sym=sym, ) print(time.time() - tick) @@ -954,6 +959,7 @@ def quantize( state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} state_dict["gptq_bits"] = torch.LongTensor([bits]) state_dict["gptq_groupsize"] = torch.LongTensor([groupsize]) + state_dict["gptq_sym"] = torch.BoolTensor([sym]) max_shard_size = "10GB" shards, index = shard_checkpoint( diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index e94e5465..a97cc43a 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -1,7 +1,8 @@ from typing import Optional + import torch -from torch.nn import functional as F from text_generation_server.utils.import_utils import SYSTEM +from torch.nn import functional as F if SYSTEM == "rocm": try: @@ -90,167 +91,14 @@ class FastLinearROCm(torch.nn.Module): return F.linear(inp, self.weight, self.bias) -def get_linear(weight, bias, quantize): - if quantize is None: +def get_linear(weight, bias): + # Weights that are loaded through methods that are not + # quantization-aware are still bare tensors. We may want + # to change this in the future. + if isinstance(weight, torch.Tensor): if SYSTEM == "rocm": - linear = FastLinearROCm(weight, bias) + return FastLinearROCm(weight, bias) else: - linear = FastLinear(weight, bias) - elif quantize == "eetq": - try: - from text_generation_server.layers.eetq import EETQLinear + return FastLinear(weight, bias) - linear = EETQLinear(weight, bias) - except ImportError: - raise ImportError( - "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" - ) - elif quantize == "fp8": - from text_generation_server.layers.fp8 import Fp8Linear - - linear = Fp8Linear(weight, bias) - elif quantize == "bitsandbytes": - try: - from text_generation_server.layers.bnb import ( - warn_deprecate_bnb, - Linear8bitLt, - ) - except ImportError: - raise NotImplementedError( - f"Bitsandbytes is missing install it with `pip install bitsandbytes`." - ) - warn_deprecate_bnb() - linear = Linear8bitLt( - weight, - bias, - has_fp16_weights=False, - threshold=6.0, - ) - if bias is not None: - linear.bias = nn.Parameter(bias) - elif quantize == "bitsandbytes-fp4": - try: - from text_generation_server.layers.bnb import Linear4bit - except ImportError: - raise NotImplementedError( - f"Bitsandbytes is missing install it with `pip install bitsandbytes`." - ) - linear = Linear4bit( - weight, - bias, - quant_type="fp4", - ) - elif quantize == "bitsandbytes-nf4": - try: - from text_generation_server.layers.bnb import Linear4bit - except ImportError: - raise NotImplementedError( - f"Bitsandbytes is missing install it with `pip install bitsandbytes`." - ) - linear = Linear4bit( - weight, - bias, - quant_type="nf4", - ) - elif quantize == "exl2": - from text_generation_server.layers.exl2 import Exl2Weight - - if not isinstance(weight, Exl2Weight): - raise NotImplementedError( - f"The passed weight is not `exl2` compatible, loader needs to be updated." - ) - - from text_generation_server.layers.gptq import ExllamaQuantLinear - - linear = ExllamaQuantLinear(weight, bias) - - elif quantize == "gptq": - from text_generation_server.layers.gptq import GPTQWeight - from text_generation_server.layers.marlin import ( - GPTQMarlinLinear, - GPTQMarlinWeight, - ) - - if isinstance(weight, GPTQMarlinWeight): - linear = GPTQMarlinLinear( - weight=weight, - bias=bias, - ) - elif isinstance(weight, GPTQWeight): - if weight.use_exllama: - try: - from text_generation_server.layers.gptq import ( - ExllamaQuantLinear, - ) - except ImportError: - raise NotImplementedError( - f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" - ) - - linear = ExllamaQuantLinear(weight, bias) - else: - from text_generation_server.layers.gptq.quant_linear import QuantLinear - - linear = QuantLinear( - weight.qweight, - weight.qzeros, - weight.scales, - weight.g_idx, - bias, - weight.bits, - weight.groupsize, - ) - else: - raise NotImplementedError( - f"The passed weight is not `gptq` compatible, loader needs to be updated." - ) - - elif quantize == "awq": - from text_generation_server.layers.gptq import GPTQWeight - - if not isinstance(weight, GPTQWeight): - raise NotImplementedError( - f"The passed weight is not `awq` compatible, loader needs to be updated." - ) - if SYSTEM == "rocm": - raise NotImplementedError( - "AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead " - "to use Exllama/GPTQ kernels for AWQ inference." - ) - try: - from text_generation_server.layers.awq.quantize.qmodule import WQLinear - - linear = WQLinear( - w_bit=weight.bits, - group_size=weight.groupsize, - qweight=weight.qweight, - qzeros=weight.qzeros, - scales=weight.scales, - bias=bias, - ) - except ImportError: - raise NotImplementedError( - "You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly" - ) - elif quantize == "marlin": - from text_generation_server.layers.marlin import ( - GPTQMarlin24Linear, - GPTQMarlin24Weight, - MarlinLinear, - MarlinWeight, - ) - - if isinstance(weight, GPTQMarlin24Weight): - linear = GPTQMarlin24Linear( - weight=weight, - bias=bias, - ) - elif isinstance(weight, MarlinWeight): - linear = MarlinLinear(weight=weight, bias=bias) - else: - raise NotImplementedError( - f"The passed weight is not `marlin` compatible, loader needs to be updated." - ) - else: - raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") - return linear + return weight.get_linear(bias) diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index a1af67a3..40271c35 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -1,11 +1,13 @@ from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn - -from text_generation_server.layers.gptq import GPTQParams +from loguru import logger +from text_generation_server.layers.fp8 import fp8_quantize from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.log import log_once +from text_generation_server.utils.weights import Weight, Weights, WeightsLoader try: import marlin_kernels @@ -24,16 +26,159 @@ GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128] MARLIN_TILE_SIZE = 16 -def can_use_gptq_marlin(gptq_params: GPTQParams, quantize: str) -> bool: +class MarlinWeightsLoader(WeightsLoader): + """Loader for Marlin-quantized weights.""" + + def __init__(self, *, bits: int, is_marlin_24: bool): + self.bits = bits + self.is_marlin_24 = is_marlin_24 + + def get_weights(self, weights: "Weights", prefix: str): + """ + Get weights at the given prefix and apply without tensor paralllism. + """ + is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" + if is_marlin_24: + try: + B = weights.get_tensor(f"{prefix}.B_24") + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." + ) + + B_meta = weights.get_tensor(f"{prefix}.B_meta") + s = weights.get_tensor(f"{prefix}.s") + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + try: + B = weights.get_tensor(f"{prefix}.B") + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` weight, make sure the model is already quantized." + ) + + s = weights.get_tensor(f"{prefix}.s") + weight = MarlinWeight(B=B, s=s) + + return weight + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + if self.is_marlin_24: + B = weights.get_packed_sharded( + f"{prefix}.B_24", dim=1, block_sizes=block_sizes + ) + B_meta = weights.get_packed_sharded( + f"{prefix}.B_meta", dim=1, block_sizes=block_sizes + ) + s = weights.get_packed_sharded( + f"{prefix}.s", dim=1, block_sizes=block_sizes + ) + + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + B = weights.get_packed_sharded( + f"{prefix}.B", dim=1, block_sizes=block_sizes + ) + s = weights.get_packed_sharded( + f"{prefix}.s", dim=1, block_sizes=block_sizes + ) + weight = MarlinWeight(B=B, s=s) + + return weight + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + if self.is_marlin_24: + try: + B = torch.cat( + [weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `marlin` weight, make sure the model is already quantized" + ) + + B_meta = torch.cat( + [weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1 + ) + + s = torch.cat( + [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 + ) + + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + try: + B = torch.cat( + [weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `marlin` weight, make sure the model is already quantized" + ) + s = torch.cat( + [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 + ) + + weight = MarlinWeight(B=B, s=s) + + return weight + + def get_weights_row(self, weights: Weights, prefix: str): + if self.is_marlin_24: + try: + B = weights.get_sharded(f"{prefix}.B_24", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." + ) + + B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0) + num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] + if num_groups == 1: + # The number of groups is 1 when groupsize == -1. share + # scales between all shards in this case. + s = weights.get_tensor(f"{prefix}.s") + else: + s = weights.get_sharded(f"{prefix}.s", dim=0) + + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + try: + B = weights.get_sharded(f"{prefix}.B", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` weight, make sure the model is already quantized." + ) + + num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] + if num_groups == 1: + # The number of groups is 1 when groupsize == -1. share + # scales between all shards in this case. + s = weights.get_tensor(f"{prefix}.s") + else: + s = weights.get_sharded(f"{prefix}.s", dim=0) + weight = MarlinWeight(B=B, s=s) + + return weight + + +def can_use_gptq_marlin( + *, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool +) -> bool: return ( SYSTEM == "cuda" and marlin_kernels is not None and has_sm_8_0 and quantize == "gptq" - and gptq_params.quant_method == "gptq" - and gptq_params.bits in GPTQ_MARLIN_BITS - and gptq_params.groupsize in GPTQ_MARLIN_GROUP_SIZES - and gptq_params.sym + and quant_method == "gptq" + and bits in GPTQ_MARLIN_BITS + and groupsize in GPTQ_MARLIN_GROUP_SIZES + and sym ) @@ -83,7 +228,7 @@ def permute_scales(scales: torch.Tensor): @dataclass -class GPTQMarlinWeight: +class GPTQMarlinWeight(Weight): """ Repacked GPTQ Marlin weights. """ @@ -101,6 +246,12 @@ class GPTQMarlinWeight: assert self.g_idx.dtype == torch.int32 assert self.perm.dtype == torch.int32 + def get_linear(self, bias: torch.Tensor): + return GPTQMarlinLinear( + weight=self, + bias=bias, + ) + def repack_gptq_for_marlin( *, @@ -258,6 +409,12 @@ class GPTQMarlin24Weight: assert self.B_meta.dtype == torch.int16 assert self.s.dtype == torch.float16 + def get_linear(self, bias: torch.Tensor): + return GPTQMarlin24Linear( + weight=self, + bias=bias, + ) + class GPTQMarlin24Linear(nn.Module): def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]): @@ -339,8 +496,126 @@ class GPTQMarlin24Linear(nn.Module): return C +class GPTQMarlinFP8Linear(nn.Module): + """ + FP8 GPTQ-Marlin linear layer. + """ + + def __init__( + self, + qweight: torch.Tensor, + scale: torch.Tensor, + bias: Optional[torch.Tensor], + ) -> None: + super().__init__() + + _check_marlin_kernels() + assert marlin_kernels is not None + + log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") + + scale = scale.to(torch.float16) + qweight, scales = repack_fp8_for_marlin(qweight, scale) + + in_features = qweight.shape[0] * MARLIN_TILE_SIZE + out_features = scales.shape[1] + _check_valid_shape(in_features=in_features, out_features=out_features) + + self.qweight = qweight + self.scales = scales + self.bias = bias if bias is not None else None + + self.workspace = torch.zeros( + out_features // 64 * 16, dtype=torch.int, device=qweight.device + ) + + @classmethod + def from_unquant(cls, weight, bias, _dtype): + qweight, scale = fp8_quantize(weight) + return cls(qweight=qweight, scale=scale, bias=bias) + + @classmethod + def from_fp8(cls, weight, scale, _input_scale, bias, _dtype): + return cls(qweight=weight, scale=scale, bias=bias) + + def forward(self, A: torch.Tensor) -> torch.Tensor: + assert marlin_kernels is not None + + A_flat = A.view(-1, A.shape[-1]) + C = marlin_kernels.fp8_marlin_gemm( + A_flat, + self.qweight, + self.scales, + self.workspace, + 8, + A_flat.shape[0], + self.scales.shape[1], + A_flat.shape[1], + ) + C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) + + if self.bias is not None: + C += self.bias + + return C + + +def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements). + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + + if fp8_tensor.shape[0] % 4 != 0: + raise ValueError( + f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}" + ) + + # Reshape to prepare for packing + reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + + # Convert fp8 to uint8 (byte) representation + byte_tensor = reshaped.view(torch.uint8) + + # Pack 4 uint8 values into one int32 + packed = torch.zeros( + fp8_tensor.shape[0] // 4, + fp8_tensor.shape[1], + dtype=torch.int32, + device=fp8_tensor.device, + ) + + for i in range(4): + packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8) + + return packed + + +def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor): + """ + Repack FP8 tensor for GPTQ-Marlin. + """ + + out_features, in_features = weight.shape + + # Torch linear layers weights with shape [out_features, in_features], + # GPTQ-quantized weights use [in_feateres/pack_factor, in_features], + # so transpose before packing. + qweight = pack_fp8_as_int32(weight.t()) + + perm = torch.empty(0, dtype=torch.int, device=qweight.device) + repacked = marlin_kernels.gptq_marlin_repack( + qweight, perm, in_features, out_features, 8 + ) + + scales = scale.reshape(1, 1).repeat(1, out_features) + scales = permute_scales(scales) + + return repacked, scales + + @dataclass -class MarlinWeight: +class MarlinWeight(Weight): """ Marlin weights. @@ -356,6 +631,9 @@ class MarlinWeight: assert self.B.dtype == torch.int32 assert self.s.dtype == torch.float16 + def get_linear(self, bias: torch.Tensor): + return MarlinLinear(weight=self, bias=bias) + class MarlinLinear(nn.Module): def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]): diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 87a61e82..db78ee1c 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -1,6 +1,7 @@ import os import torch from torch import nn +from loguru import logger from text_generation_server.utils.import_utils import SYSTEM @@ -97,18 +98,22 @@ class PositionRotaryEmbedding(nn.Module): ) elif rope_scaling["type"] == "yarn": scaling_factor = rope_scaling["factor"] + mscale = rope_scaling.get("mscale", 1.0) + mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0) return YarnPositionRotaryEmbedding( dim=2 * inv_freq.shape[0], max_position_embeddings=rope_scaling[ "original_max_position_embeddings" ], - base=10000.0, + base=base, device=inv_freq.device, scaling_factor=scaling_factor, extrapolation_factor=1, attn_factor=1, beta_fast=32, beta_slow=1, + mscale=mscale, + mscale_all_dim=mscale_all_dim, ) elif rope_scaling["type"] in ["su", "longrope"]: short_factor = torch.tensor( @@ -181,6 +186,8 @@ class PositionRotaryEmbedding(nn.Module): scaling_factor=scaling_factor, ) elif rope_scaling["type"] == "yarn": + mscale = rope_scaling.get("mscale", 1.0) + mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0) return YarnPositionRotaryEmbedding( dim=2 * inv_freq.shape[0], max_position_embeddings=rope_scaling[ @@ -193,6 +200,8 @@ class PositionRotaryEmbedding(nn.Module): attn_factor=1, beta_fast=32, beta_slow=1, + mscale=mscale, + mscale_all_dim=mscale_all_dim, ) else: raise NotImplementedError( @@ -346,10 +355,10 @@ def linear_ramp_mask(min, max, dim): return ramp_func -def get_mscale(scale=1): +def get_mscale(scale: float = 1.0, mscale: float = 1.0): if scale <= 1: return 1.0 - return 0.1 * math.log(scale) + 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): @@ -365,6 +374,8 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): attn_factor, beta_fast, beta_slow, + mscale: float, + mscale_all_dim: float, ): inv_freq = _create_inv_freq(dim, base, device) super().__init__(inv_freq, scaling_factor) @@ -375,8 +386,12 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): self.attn_factor = attn_factor self.beta_fast = beta_fast self.beta_slow = beta_slow + self.mscale_all_dim = mscale_all_dim + self.scaling_factor = scaling_factor self.mscale = float( - get_mscale(self.scaling_factor) * self.attn_factor + get_mscale(self.scaling_factor, mscale) + / get_mscale(self.scaling_factor, mscale_all_dim) + * self.attn_factor ) # Get n-d magnitude scaling corrected for interpolation def _update_cos_sin_cache(self, dtype, device, seqlen): @@ -387,7 +402,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): or self._cos_cached.device != device or self._cos_cached.dtype != dtype ): - if seqlen > self.max_position_embeddings: + if seqlen > self.max_position_embeddings or True: inv_freq_extrapolation = _create_inv_freq( self.dim, self.base, self.inv_freq.device ) @@ -400,6 +415,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): self.base, self.max_position_embeddings, ) + inv_freq_mask = ( 1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device) ) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation @@ -409,9 +425,6 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): ) self.inv_freq = inv_freq - self.mscale = float( - get_mscale(self.scaling_factor) * self.attn_factor - ) # Get n-d magnitude scaling corrected for interpolation self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 038de258..9dddb8ae 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -52,7 +52,7 @@ class TensorParallelHead(SuperLayer): weight = weights.get_tensor(f"{prefix}.weight") except: # ...otherwise they are quantized. - weight = weights.get_weights_col(prefix, config.quantize) + weight = weights.get_weights_col(prefix) should_gather = weights.process_group.size() > 1 elif weights.process_group.size() > 1: try: @@ -77,7 +77,7 @@ class TensorParallelHead(SuperLayer): quantize = config.quantize return TensorParallelHead( - get_linear(weight, bias=None, quantize=quantize), + get_linear(weight, bias=None), process_group=weights.process_group, should_gather=should_gather, ) @@ -129,14 +129,12 @@ class TensorParallelColumnLinear(SuperLayer): @classmethod def load_gate_up(cls, config, prefix: str, weights, bias: bool): """Specific method when the QKV was joined after the fact""" - weight = weights.get_weights_col_packed_gate_up( - prefix, quantize=config.quantize - ) + weight = weights.get_weights_col_packed_gate_up(prefix) if bias: raise NotImplementedError("packed_gate_up only implemented without bias") else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) return cls(linear) @classmethod @@ -152,7 +150,6 @@ class TensorParallelColumnLinear(SuperLayer): """Specific method when the QKV was joined after the fact""" weight = weights.get_weights_col_packed_qkv( prefix, - quantize=config.quantize, num_heads=num_heads, num_key_value_heads=num_key_value_heads, ) @@ -160,17 +157,17 @@ class TensorParallelColumnLinear(SuperLayer): raise NotImplementedError("packed_qkv only implemented for baichuan") else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) return cls(linear) @classmethod def load(cls, config, prefix: str, weights, bias: bool): - weight = weights.get_weights_col(prefix, config.quantize) + weight = weights.get_weights_col(prefix) if bias: bias = weights.get_sharded(f"{prefix}.bias", dim=0) else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) return cls(linear) @classmethod @@ -178,20 +175,18 @@ class TensorParallelColumnLinear(SuperLayer): if config.quantize == "exl2": linears = [] for prefix in prefixes: - weight = weights.get_weights_col(prefix, config.quantize) + weight = weights.get_weights_col(prefix) b = weights.get_tensor(f"{prefix}.bias") if bias else None - linears.append(get_linear(weight, b, config.quantize)) + linears.append(get_linear(weight, b)) linear = LayerConcat(linears) else: - weight = weights.get_multi_weights_col( - prefixes, quantize=config.quantize, dim=dim - ) + weight = weights.get_multi_weights_col(prefixes, dim=dim) if bias: b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] bias = torch.cat(b, dim=dim) else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) return cls(linear) @@ -202,7 +197,7 @@ class TensorParallelRowLinear(SuperLayer): @classmethod def load(cls, config, prefix: str, weights, bias: bool): - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_weights_row(prefix) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process @@ -210,7 +205,7 @@ class TensorParallelRowLinear(SuperLayer): else: bias = None return cls( - get_linear(weight, bias, config.quantize), + get_linear(weight, bias), process_group=weights.process_group, ) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 58131a3a..a43cdfed 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -34,6 +34,7 @@ from text_generation_server.models.custom_modeling.t5_modeling import ( ) from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.log import log_master # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. @@ -47,9 +48,7 @@ torch.set_grad_enabled(False) __all__ = [ "Model", - "BLOOMSharded", "CausalLM", - "GalacticaSharded", "Seq2SeqLM", "get_model", ] @@ -61,6 +60,10 @@ FLASH_ATTENTION = True try: from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.vlm_causal_lm import VlmCausalLM + from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import ( + FlashDeepseekV2ForCausalLM, + DeepseekV2Config, + ) from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaForCausalLM, ) @@ -121,7 +124,7 @@ try: ) from text_generation_server.layers.attention import SUPPORTS_WINDOWING except ImportError as e: - logger.warning(f"Could not import Flash Attention enabled models: {e}") + log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}") SUPPORTS_WINDOWING = False FLASH_ATTENTION = False @@ -133,7 +136,7 @@ MAMBA_AVAILABLE = True try: from text_generation_server.models.mamba import Mamba except ImportError as e: - logger.warning(f"Could not import Mamba: {e}") + log_master(logger.warning, f"Could not import Mamba: {e}") MAMBA_AVAILABLE = False if MAMBA_AVAILABLE: @@ -141,6 +144,11 @@ if MAMBA_AVAILABLE: class ModelType(enum.Enum): + DEEPSEEK_V2 = { + "type": "deepseek_v2", + "name": "Deepseek V2", + "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2", + } IDEFICS2 = { "type": "idefics2", "name": "Idefics 2", @@ -302,6 +310,12 @@ def get_model( if quantize in ["awq", "exl2", "gptq", "marlin"]: # These quantizers only work with float16 params. dtype = torch.float16 + elif quantize == "fp8": + from text_generation_server.layers.fp8 import FBGEMM_MM_AVAILABLE + + if FBGEMM_MM_AVAILABLE: + # fbgemm kernels are fp8xfp8->bf16 + dtype = torch.bfloat16 else: # Keep it as default for now and let # every model resolve their own default dtype. @@ -424,7 +438,9 @@ def get_model( speculate = get_speculate() if speculate > 0: - logger.info(f"Using speculation {method} with {speculate} input ids.") + log_master( + logger.info, f"Using speculation {method} with {speculate} input ids." + ) if model_type is None: # TODO: fix how we determine model type for Mamba @@ -439,10 +455,10 @@ def get_model( if quantization_config is not None and quantize is None: method = quantization_config.get("quant_method", None) if method in {"gptq", "awq", "exl2"}: - logger.info(f"Auto selecting quantization method {method}") + log_master(logger.info, f"Auto selecting quantization method {method}") quantize = method else: - logger.info(f"Unknown quantization method {method}") + log_master(logger.warning, f"Unknown quantization method {method}") if quantize == "exl2" and sharded: raise RuntimeError( @@ -459,7 +475,40 @@ def get_model( f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." ) - if model_type == MAMBA: + if model_type == DEEPSEEK_V2: + if FLASH_ATTENTION: + head_size = max( + config_dict.get("qk_nope_dim", 128) + + config_dict.get("qk_rope_dim", 64), + config_dict.get("v_head_dim", 128), + ) + return FlashCausalLM( + model_id=model_id, + model_class=FlashDeepseekV2ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + default_dtype=torch.bfloat16, + dtype=dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=DeepseekV2Config, + head_size=head_size, + ) + elif sharded: + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2") + ) + else: + return CausalLM.fallback( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif model_type == MAMBA: return Mamba( model_id, revision, @@ -551,7 +600,7 @@ def get_model( ) except RuntimeError as e: # Lots of legacy models with various weight names. - logger.warning(f"Couldn't load flash gpt2 variant: {e}") + log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}") return CausalLM.fallback( model_id, revision, @@ -573,6 +622,10 @@ def get_model( ) elif model_type == GPT_NEOX: if FLASH_ATTENTION: + from text_generation_server.models.custom_modeling.flash_neox_modeling import ( + GPTNeoXConfig, + ) + return FlashCausalLM( model_id=model_id, model_class=FlashGPTNeoXForCausalLM, @@ -582,6 +635,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, + config_class=GPTNeoXConfig, ) elif sharded: return CausalLM( @@ -797,6 +851,10 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + aliases={ + "lm_head.weight": ["transformer.word_embeddings.weight"], + "transformer.word_embeddings.weight": ["lm_head.weight"], + }, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, config_class=RWConfig, diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 868a3cc0..87b9969a 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -20,6 +20,7 @@ from text_generation_server.utils import ( from text_generation_server.models import Model from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models.types import ( Batch, @@ -491,7 +492,7 @@ class CausalLMBatch(Batch): @dataclass -class CausalLMBatchKeysLast(Batch): +class CausalLMBatchKeysLast(CausalLMBatch): keys_head_dim_last: bool = False @@ -543,15 +544,25 @@ class CausalLM(Model): config.quantize = quantize config.speculator = speculator if tokenizer.pad_token_id is None: - tokenizer.pad_token_id = config.pad_token_id + if config.pad_token_id is not None: + tokenizer.pad_token_id = config.pad_token_id + elif config.eos_token_id is not None: + tokenizer.pad_token_id = config.eos_token_id + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id torch.distributed.barrier(group=self.process_group) + weights_loader = get_loader( + quantize=quantize, model_id=model_id, revision=revision + ) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + weights_loader=weights_loader, ) - if config.quantize in ["awq", "exl2", "gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) prefix = "" model = model_class(prefix, config, weights) diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index f993fe72..c7b29d13 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -163,7 +163,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) @@ -187,9 +186,7 @@ def _load_gqa(config, prefix: str, weights): else: bias = None - return TensorParallelColumnLinear( - get_linear(weight, bias=bias, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=bias)) class FlashCohereAttention(torch.nn.Module): @@ -260,8 +257,8 @@ class FlashCohereAttention(torch.nn.Module): cu_seqlen_prefill, kv_cache, block_tables, - input_lengths, slots, + input_lengths, max_s, ): qkv = self.query_key_value(hidden_states) diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 41aa5859..7426fc55 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -105,6 +105,12 @@ class DbrxFFNConfig(PretrainedConfig): class DbrxConfig(PretrainedConfig): + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "n_heads", + "num_hidden_layers": "n_layers", + } + def __init__( self, d_model: int = 2048, @@ -157,6 +163,12 @@ class DbrxConfig(PretrainedConfig): **kwargs, ) + @property + def num_key_value_heads(self): + # We can't use the attribute map, since this the number of KV + # heads is not top-level. + return self.attn_config.kv_n_heads + def promote_scalar(x: torch.Tensor) -> torch.Tensor: return x.view(1) if len(x.size()) == 0 else x @@ -235,10 +247,10 @@ def _load_experts_quantized(config, prefix, weights, cls): if cls == TensorParallelRowLinear: expert_slice = expert_slice.t().contiguous() - linear = get_linear(expert_slice, None, config.quantize) + linear = get_linear(expert_slice, None) experts.append(cls(linear, weights.process_group)) else: - linear = get_linear(expert_slice, None, config.quantize) + linear = get_linear(expert_slice, None) experts.append(cls(linear)) return experts diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py new file mode 100644 index 00000000..f5b2ba0e --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -0,0 +1,980 @@ +# coding=utf-8 +# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.distributed +from text_generation_server.layers import ( + FastLinear, + SpeculativeHead, + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + get_linear, +) +from text_generation_server.layers.attention import ( + attention, + paged_attention, + reshape_and_cache, +) +from text_generation_server.layers.attention.common import Seqlen +from text_generation_server.layers.layernorm import FastRMSNorm +from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.weights import Weights +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig + + +class DeepseekV2Config(PretrainedConfig): + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=11008, + moe_intermediate_size=1407, + num_hidden_layers=30, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts=2, + n_routed_experts=160, + ep_size=1, + routed_scaling_factor=1.0, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method="gready", + n_group=8, + topk_group=3, + num_experts_per_tok=6, + moe_layer_freq=1, + first_k_dense_replace=0, + norm_topk_prob=False, + scoring_func="softmax", + aux_loss_alpha=0.001, + seq_aux=True, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) + if tie_word_embeddings: + raise ValueError( + "tie_word_embeddings is not supported for Deepseek V2 models." + ) + + if ep_size != 1: + raise ValueError( + f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}" + ) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +def _load_experts(config, prefix: str, mat: str, weights: Weights): + if config.quantize is not None: + raise NotImplementedError( + "Deepseek V2 does not support weight quantization yet." + ) + + assert mat in ["gate_proj", "up_proj", "down_proj"] + + world_size = weights.process_group.size() + rank = weights.process_group.rank() + + assert ( + config.moe_intermediate_size % world_size == 0 + ), f"The chosen size {config.moe_intermediate_size} is not compatible with sharding on {world_size} shards" + + block_size = config.moe_intermediate_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + + tensor = torch.empty( + (config.n_routed_experts * block_size, config.hidden_size), + dtype=weights.dtype, + device=weights.device, + ) + + for i in range(config.n_routed_experts): + slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight") + + if mat == "down_proj": + expert_slice = slice_[:, start:stop].t().contiguous() + else: + expert_slice = slice_[start:stop] + tensor[i * block_size : (i + 1) * block_size] = expert_slice.to( + dtype=weights.dtype + ).to(device=weights.device) + return tensor + + +class DeepseekV2Attention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights: Weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.kv_lora_rank = config.kv_lora_rank + self.q_lora_rank = config.q_lora_rank + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim + self.value_head_size = config.v_head_dim + self.head_pad_size = max(self.head_size, self.value_head_size) + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.qk_rope_head_dim, + base=config.rope_theta, + device=weights.device, + ) + + mscale = get_mscale( + self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim + ) + self.softmax_scale = self.head_size**-0.5 * mscale * mscale + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + if self.q_lora_rank is None: + self.q_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_proj", + weights=weights, + bias=config.attention_bias, + ) + else: + self.q_a_proj = get_linear( + weight=weights.get_weights(f"{prefix}.q_a_proj"), + bias=( + weights.get_tensor(f"{prefix}.q_a_proj.bias") + if config.attention_bias + else None + ), + ) + self.q_a_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.q_a_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.q_b_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_b_proj", + weights=weights, + bias=config.attention_bias, + ) + + self.kv_a_proj_with_mqa = get_linear( + weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"), + bias=( + weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias") + if config.attention_bias + else None + ), + ) + + self.kv_a_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps + ) + + self.kv_b_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.kv_b_proj", + weights=weights, + bias=config.attention_bias, + ) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + cu_seqlen_prefill: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: Seqlen, + max_s: int, + ): + if self.q_lora_rank is None: + query = self.q_proj(hidden_states) + else: + query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0]) + query = query.view(-1, self.num_heads, self.head_size) + + _, query_pe = torch.split( + query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, key_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim) + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view( + -1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size + ) + + key_nope, value = torch.split( + kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1 + ) + + batch_size, heads, head_dim = query_pe.shape + query_pe = ( + query_pe.view(batch_size, heads, head_dim // 2, 2) + .transpose(2, 3) + .reshape(batch_size, heads, head_dim) + ) + batch_size, heads, head_dim = key_pe.shape + key_pe = ( + key_pe.view(batch_size, heads, head_dim // 2, 2) + .transpose(2, 3) + .reshape(batch_size, heads, head_dim) + ) + self.rotary_emb(query_pe, key_pe, cos, sin) + + query[..., self.qk_nope_head_dim :] = query_pe + key = torch.empty_like(query) + key[..., : self.qk_nope_head_dim] = key_nope + key[..., self.qk_nope_head_dim :] = key_pe + + # We need to pad the heads because Flash Attention does not support + # qk and v with different head sizes. + query = torch.nn.functional.pad( + query, (0, self.head_pad_size - self.head_size), value=0 + ) + key = torch.nn.functional.pad( + key, (0, self.head_pad_size - self.head_size), value=0 + ) + value = torch.nn.functional.pad( + value, (0, self.head_pad_size - self.value_head_size), value=0 + ) + + reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + + # Output tensor + attn_output = torch.empty_like(query) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attention( + query, + key, + value, + attn_output, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + ) + # Decode + else: + paged_attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) + + # Remove padding. + attn_output = attn_output[..., : self.value_head_size] + + return self.o_proj( + attn_output.reshape(-1, self.num_heads * self.value_head_size) + ) + + +class DeepseekV2MLP(nn.Module): + def __init__(self, prefix: str, config, weights, intermediate_size: int): + super().__init__() + self.hidden_act = config.hidden_act + if self.hidden_act != "silu": + # Bail out because MoE only supports silu. + raise NotImplementedError( + "Currently only `silu` is supported as an activation for Deepseek V2." + ) + self.act = ACT2FN[self.hidden_act] + + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + + self.intermediate_size = intermediate_size // weights.process_group.size() + + # TODO: This is a hotfix to be removed & properly refactored. + self.quantize = config.quantize + + def forward(self, hidden_states: torch.Tensor, reduce: bool = True): + if ( + SYSTEM == "rocm" + and self.hidden_act == "silu" + and hidden_states.shape[0] == 1 + and not self.quantize + ): + out = torch.empty( + hidden_states.shape[0], + self.intermediate_size, + dtype=hidden_states.dtype, + device="cuda", + ) + _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) + return self.down_proj(out, reduce=reduce) + else: + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce + ) + + +class BlockSparseMoE(nn.Module): + def __init__(self, prefix, config: DeepseekV2Config, weights): + super().__init__() + + self.hidden_dim = config.hidden_size + self.moe_intermediate_size = ( + config.moe_intermediate_size // weights.process_group.size() + ) + self.n_routed_experts = config.n_routed_experts + self.n_expert_group = config.n_group + self.topk_group = config.topk_group + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + self.routed_scaling_factor = config.routed_scaling_factor + + gate_proj = _load_experts( + config, f"{prefix}.experts", "gate_proj", weights + ).view(self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim) + + up_proj = _load_experts(config, f"{prefix}.experts", "up_proj", weights).view( + self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim + ) + + self.gate_up_proj = torch.cat([gate_proj, up_proj], dim=1) + + self.down_proj = ( + _load_experts(config, f"{prefix}.experts", "down_proj", weights) + .view(self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim) + .transpose(1, 2) + .contiguous() + ) + + # Gating + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + + if config.n_shared_experts is not None: + self.shared_experts = DeepseekV2MLP( + prefix=f"{prefix}.shared_experts", + config=config, + weights=weights, + intermediate_size=config.moe_intermediate_size + * config.n_shared_experts, + ) + else: + self.shared_experts = None + + self.process_group = weights.process_group + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.shared_experts is not None: + shared_output = self.shared_experts(x, reduce=False) + else: + shared_output = None + + router_logits = self.gate(x) + topk_weights, topk_ids = grouped_topk( + x, + router_logits, + self.top_k, + renormalize=self.norm_topk_prob, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + ) + out = ( + fused_experts( + x, + self.gate_up_proj, + self.down_proj, + topk_weights, + topk_ids, + inplace=True, + ) + * self.routed_scaling_factor + ) + + if shared_output is not None: + out = out + shared_output + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out.view(*x.shape) + + +class DenseMoE(nn.Module): + def __init__(self, prefix: str, config: DeepseekV2Config, weights: Weights): + super().__init__() + + self.hidden_dim = config.hidden_size + self.moe_intermediate_size = config.moe_intermediate_size + self.n_routed_experts = config.n_routed_experts + self.n_expert_group = config.n_group + self.topk_group = config.topk_group + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + self.routed_scaling_factor = config.routed_scaling_factor + + # Gating + # + # Seems like no one quantizes the gate. + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + + self.experts = [ + DeepseekV2MLP( + f"{prefix}.experts.{i}", config, weights, self.moe_intermediate_size + ) + for i in range(self.n_routed_experts) + ] + + if config.n_shared_experts is not None: + self.shared_experts = DeepseekV2MLP( + prefix=f"{prefix}.shared_experts", + config=config, + weights=weights, + intermediate_size=config.moe_intermediate_size + * config.n_shared_experts, + ) + else: + self.shared_experts = None + + self.process_group = weights.process_group + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + x: (sequence_length, model_dim) + gate_logits: (sequence_length, n_experts) + """ + # optional reshape + input_shape = x.shape + x = x.view(-1, input_shape[-1]) + + if self.shared_experts is not None: + shared_output = self.shared_experts(x, reduce=False) + else: + shared_output = None + + # gate_logits: (sequence_length, n_experts) + router_logits = self.gate(x) + + topk_weights, topk_ids = grouped_topk( + x, + router_logits, + self.top_k, + renormalize=self.norm_topk_prob, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + ) + + out = self.moe_infer_gpu(x, topk_ids, topk_weights) * self.routed_scaling_factor + + if shared_output is not None: + out = out + shared_output + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out + + def moe_infer_gpu( + self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor + ): + weights = torch.zeros( + topk_ids.shape[0], len(self.experts), dtype=x.dtype, device=x.device + ) + weights.scatter_(1, topk_ids, topk_weight) + + out = x.new_zeros(x.shape[0], self.hidden_dim) + for i, expert in enumerate(self.experts): + # Add expert output to out with masking + out += expert(x, reduce=False) * weights[:, i].view(-1, 1) + return out + + +class DeepseekV2Layer(nn.Module): + def __init__(self, prefix, layer_id, config, weights): + super().__init__() + prefix = f"{prefix}.layers.{layer_id}" + + self.self_attn = DeepseekV2Attention( + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + ) + + if ( + config.n_routed_experts is not None + and layer_id >= config.first_k_dense_replace + and layer_id % config.moe_layer_freq == 0 + ): + moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE + self.mlp = moe_cls(f"{prefix}.mlp", config, weights) + else: + self.mlp = DeepseekV2MLP( + prefix=f"{prefix}.mlp", + config=config, + weights=weights, + intermediate_size=config.intermediate_size, + ) + + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + cu_seqlen_prefill: torch.Tensor, + kv_cache, + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: Seqlen, + max_s: int, + ): + normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + + # faster post attention rms norm + normed_attn_res_output, residual = self.post_attention_layernorm( + attn_output, residual + ) + + output = self.mlp(normed_attn_res_output) + + return output, residual + + +class DeepseekV2Model(torch.nn.Module): + def __init__(self, prefix: str, config, weights: Weights): + super().__init__() + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens", weights=weights + ) + + self.layers = nn.ModuleList( + [ + DeepseekV2Layer( + prefix, + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = FastRMSNorm.load( + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashDeepseekV2ForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights: Weights): + super().__init__() + + self.model = DeepseekV2Model( + "model" if not prefix else f"{prefix}.model", config, weights + ) + self.lm_head = SpeculativeHead.load( + config, + prefix="lm_head" if not prefix else f"{prefix}.lm_head", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits + + +# Functions below are from vLLM: +# +# https://github.com/vllm-project/vllm/blob/f7160d946a0a07703e72d81ba9ecf3913f192605/vllm/model_executor/layers/fused_moe/fused_moe.py#L397 +# +# Remove after we have synced our version with upstream. + + +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + scores = torch.softmax(gating_output, dim=-1) + num_token = scores.shape[0] + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids + + +def get_default_config( + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: Optional[str], +) -> Dict[str, int]: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + if M <= E: + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } + return config + + +def fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +): + # Check constraints. + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] + + import triton.language as tl + from vllm import _custom_ops as ops + from vllm.model_executor.layers.fused_moe.fused_moe import ( + get_moe_configs, + invoke_fused_moe_kernel, + moe_align_block_size, + ) + + M, _ = hidden_states.shape + E, N, _ = w1.shape + + if override_config: + config = override_config + else: + # First try to load optimal config from the file + configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None) + + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Else use the default config + config = get_default_config( + M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None + ) + + intermediate_cache1 = torch.empty( + (M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache3 = torch.empty( + (M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, config["BLOCK_SIZE_M"], E + ) + compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + + invoke_fused_moe_kernel( + hidden_states, + w1, + intermediate_cache1, + a1_scale, + w1_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + invoke_fused_moe_kernel( + intermediate_cache2, + w2, + intermediate_cache3, + a2_scale, + w2_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) + + if inplace: + return torch.sum( + intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=hidden_states, + ) + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index beff08b3..8526d515 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -42,6 +42,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +from text_generation_server.utils.weights import UnquantizedWeight class Gemma2Config(PretrainedConfig): @@ -141,24 +142,21 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) - if config.quantize not in ["gptq", "awq", "marlin"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) + if isinstance(weight, UnquantizedWeight): + weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.head_dim num_heads = config.num_attention_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ + assert list(weight.weight.shape) == [ (num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=None)) class FlashGemma2Attention(torch.nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 14b62b00..dfe6510c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -42,6 +42,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +from text_generation_server.utils.weights import UnquantizedWeight class GemmaConfig(PretrainedConfig): @@ -141,24 +142,21 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) - if config.quantize not in ["gptq", "awq", "marlin"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) + if isinstance(weight, UnquantizedWeight): + weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.head_dim num_heads = config.num_attention_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ + assert list(weight.weight.shape) == [ (num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=None)) class FlashGemmaAttention(torch.nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index d5dc25cf..a55a4af3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -61,7 +61,6 @@ def _load_qkv_gptq(config, prefix: str, weights): # Weights weight = weights.get_weights_col_packed_qkv( f"{prefix}.c_attn", - config.quantize, config.num_attention_heads, config.num_attention_heads, ) @@ -83,7 +82,7 @@ def _load_qkv_gptq(config, prefix: str, weights): bias = torch.cat(tensors, dim=0) bias = bias.to(device=weights.device) - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias)) def _load_qkv(config, prefix: str, weights, head_size, num_heads): @@ -130,14 +129,14 @@ def _load_qkv(config, prefix: str, weights, head_size, num_heads): 3 * num_heads * head_size ], f"{weight.shape} != {[3 * num_heads * head_size]}" - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias)) def load_row(config, prefix: str, weights, bias: bool): """load_row, but with transposed weight matrices.""" if config.quantize == "gptq": - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_weights_row(prefix) else: weight = weights.get_sharded(f"{prefix}.weight", dim=0).T @@ -148,16 +147,14 @@ def load_row(config, prefix: str, weights, bias: bool): bias = None return TensorParallelRowLinear( - get_linear(weight, bias, config.quantize), process_group=weights.process_group + get_linear(weight, bias), process_group=weights.process_group ) def load_col(config, prefix: str, weights, bias: bool): """load_col, but with transposed weight matrices.""" if config.quantize == "gptq": - weight = weights.get_multi_weights_col( - [prefix], quantize=config.quantize, dim=1 - ) + weight = weights.get_multi_weights_col([prefix], dim=1) else: weight = weights.get_sharded(f"{prefix}.weight", dim=1).T @@ -166,7 +163,7 @@ def load_col(config, prefix: str, weights, bias: bool): else: bias = None - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias)) class FlashGPT2Attention(torch.nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 78832341..f7980d2d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager from typing import List, Optional, Tuple import torch @@ -25,7 +26,6 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN -from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( @@ -33,7 +33,6 @@ from text_generation_server.layers.attention import ( attention, reshape_and_cache, ) -from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -46,6 +45,11 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +from text_generation_server.utils.weights import ( + UnquantizedWeight, + Weights, +) +from text_generation_server.layers.fp8 import HybridFP8UnquantLoader if SYSTEM == "rocm": try: @@ -105,6 +109,19 @@ def load_attention(config, prefix: str, weights, layer_id): ) +@contextmanager +def no_fp8(weights: Weights): + """De-activate fp8 auto conversion for the duration of this context manager""" + weights_loader = weights.weights_loader + if isinstance(weights_loader, HybridFP8UnquantLoader) and weights_loader.to_fp8: + weights_loader = HybridFP8UnquantLoader( + weights_loader.activation_scale_ub, to_fp8=False + ) + + with weights.use_loader(weights_loader): + yield + + class FlashLlamaAttention(torch.nn.Module): def __init__( self, @@ -330,12 +347,15 @@ class LlamaMLP(nn.Module): class FlashLlamaLayer(nn.Module): def __init__(self, index, prefix, config, weights): super().__init__() - self.self_attn = FlashLlamaAttention( - index=index, - prefix=f"{prefix}.self_attn", - config=config, - weights=weights, - ) + + with no_fp8(weights): + self.self_attn = FlashLlamaAttention( + index=index, + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + ) + self.mlp = LlamaMLP( prefix=f"{prefix}.mlp", config=config, weights=weights, index=index ) @@ -396,7 +416,22 @@ class FlashLlamaModel(torch.nn.Module): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.layers = nn.ModuleList( + + # Skip fp8 quant for first and last layers + self.layers = nn.ModuleList() + with no_fp8(weights): + self.layers.append( + FlashLlamaLayer( + index=0, + prefix=( + "model.layers.0" if not prefix else "{prefix}.model.layers.0" + ), + config=config, + weights=weights, + ) + ) + + self.layers.extend( [ FlashLlamaLayer( index=layer_id, @@ -408,9 +443,26 @@ class FlashLlamaModel(torch.nn.Module): config=config, weights=weights, ) - for layer_id in range(config.num_hidden_layers) + # Skip first and last layers + for layer_id in range(1, config.num_hidden_layers - 1) ] ) + + with no_fp8(weights): + last_layer_id = config.num_hidden_layers - 1 + self.layers.append( + FlashLlamaLayer( + index=last_layer_id, + prefix=( + f"model.layers.{last_layer_id}" + if not prefix + else f"{prefix}.model.layers.{last_layer_id}" + ), + config=config, + weights=weights, + ) + ) + self.norm = FastRMSNorm.load( prefix="model.norm" if not prefix else f"{prefix}.model.norm", weights=weights, @@ -470,23 +522,27 @@ class FlashLlamaForCausalLM(torch.nn.Module): def __init__(self, prefix: str, config, weights): super().__init__() - self.embed_tokens = TensorParallelEmbedding( - prefix=( - "model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens" - ), - weights=weights, - ) + with no_fp8(weights): + self.embed_tokens = TensorParallelEmbedding( + prefix=( + "model.embed_tokens" + if not prefix + else f"{prefix}.model.embed_tokens" + ), + weights=weights, + ) self.model = FlashLlamaModel(prefix, config, weights) if config.tie_word_embeddings: suffix = "model.embed_tokens" else: suffix = "lm_head" - self.lm_head = SpeculativeHead.load( - config, - prefix=suffix if not prefix else f"{prefix}.{suffix}", - weights=weights, - ) + with no_fp8(weights): + self.lm_head = SpeculativeHead.load( + config, + prefix=suffix if not prefix else f"{prefix}.{suffix}", + weights=weights, + ) def forward( self, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 429793ea..a1e36fc7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -135,7 +135,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) @@ -150,9 +149,7 @@ def _load_gqa(config, prefix: str, weights): config.hidden_size, ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=None)) def _load_experts(config, prefix: str, mat, weights): diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 0eca181b..623b164c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -24,7 +24,7 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel -from transformers.models.gpt_neox import GPTNeoXConfig +from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( @@ -45,10 +45,17 @@ from text_generation_server.layers.layernorm import ( from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) +from text_generation_server.utils.weights import UnquantizedWeight + + +class GPTNeoXConfig(TransformersGPTNeoXConfig): + attribute_map = { + "num_key_value_heads": "num_attention_heads", + } def load_row(config, prefix: str, weights, bias: bool): - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_weights_row(prefix) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process @@ -56,7 +63,7 @@ def load_row(config, prefix: str, weights, bias: bool): else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) if config.use_parallel_residual: return linear else: @@ -64,11 +71,11 @@ def load_row(config, prefix: str, weights, bias: bool): def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): - weight = weights.get_multi_weights_col([prefix], quantize=config.quantize, dim=0) - if isinstance(weight, torch.Tensor): + weight = weights.get_multi_weights_col([prefix], dim=0) + if isinstance(weight, UnquantizedWeight): # Only on non quantized versions - weight = ( - weight.view( + weight.weight = ( + weight.weight.view( num_heads, 3, head_size, @@ -81,7 +88,7 @@ def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): bias = weights.get_sharded(f"{prefix}.bias", dim=0) bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1) - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) if config.use_parallel_residual: return linear else: diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 7401bc27..a1ce03b9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -85,7 +85,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) @@ -101,9 +100,7 @@ def _load_gqa(config, prefix: str, weights): ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" # this is the same as llama except for Phi uses bias=True - return TensorParallelColumnLinear( - get_linear(weight, bias=True, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=True)) class FlashPhiAttention(torch.nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index d12ed567..d7cad480 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -23,7 +23,7 @@ from text_generation_server.layers.attention import ( def load_row(config, prefix: str, weights, bias: bool): - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_weights_row(prefix) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process @@ -31,7 +31,7 @@ def load_row(config, prefix: str, weights, bias: bool): else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) if config.parallel_attn: return linear else: @@ -42,6 +42,7 @@ class RWConfig(PretrainedConfig): attribute_map = { "num_hidden_layers": "n_layer", "num_attention_heads": "n_head", + "num_key_value_heads": "n_head_kv", } def __init__( diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 21a22046..2b939a10 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -17,6 +17,7 @@ from text_generation_server.layers import ( TensorParallelEmbedding, get_linear, ) +from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.layernorm import ( FastLayerNorm, ) @@ -81,11 +82,13 @@ def _load_multi_mqa_gptq( qzeros = torch.cat([q_tensor, kv_tensor], dim=1) qzeros = qzeros.to(device=weights.device) - gptq_params = weights._get_gptq_params() - if gptq_params.quant_method == "gptq": + loader = weights.weights_loader + assert isinstance(loader, GPTQWeightsLoader) + loader._get_gptq_params(weights) + if loader.quant_method == "gptq": g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") g_idx = g_idx.to(device=weights.device) - elif gptq_params.quant_method == "awq": + elif loader.quant_method == "awq": g_idx = None from text_generation_server.layers.awq.conversion_utils import ( fast_awq_to_gptq, @@ -100,8 +103,9 @@ def _load_multi_mqa_gptq( qzeros=qzeros, scales=scales, g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, + bits=loader.bits, + groupsize=loader.groupsize, + use_awq_kernel=loader.quantize == "awq", use_exllama=HAS_EXLLAMA, ) @@ -118,7 +122,7 @@ def _load_multi_mqa_gptq( bias = torch.cat([q_tensor, kv_tensor], dim=0) bias = bias.to(device=weights.device) - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias)) else: raise NotImplementedError("Gptq loading with santacoder is not implemented") @@ -190,29 +194,27 @@ def _load_multi_mqa( assert list(bias.shape) == [ (num_heads + 2) * head_size ], f"{weight.shape} != {[(num_heads + 2) * head_size]}" - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias)) def load_col(config, prefix: str, weights, bias: bool): if config.transpose: weight = weights.get_sharded(f"{prefix}.weight", dim=1).T else: - weight = weights.get_multi_weights_col( - [prefix], quantize=config.quantize, dim=0 - ) + weight = weights.get_multi_weights_col([prefix], dim=0) if bias: bias = weights.get_sharded(f"{prefix}.bias", dim=0) else: bias = None - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias)) def load_row(config, prefix: str, weights, bias: bool): if config.transpose: weight = weights.get_sharded(f"{prefix}.weight", dim=0).T else: - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_weights_row(prefix) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process @@ -220,7 +222,7 @@ def load_row(config, prefix: str, weights, bias: bool): else: bias = None return TensorParallelRowLinear( - get_linear(weight, bias, config.quantize), process_group=weights.process_group + get_linear(weight, bias), process_group=weights.process_group ) diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 2b346283..cfa891d4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -45,6 +45,7 @@ from text_generation_server.layers.layernorm import ( from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) +from text_generation_server.utils.weights import UnquantizedWeight class Starcoder2Config(PretrainedConfig): @@ -126,20 +127,19 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) - if config.quantize not in ["gptq", "awq", "marlin"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) + if isinstance(weight, UnquantizedWeight): + weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads num_heads = config.num_attention_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ + assert list(weight.weight.shape) == [ (num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" if config.use_bias: w = [ @@ -150,9 +150,7 @@ def _load_gqa(config, prefix: str, weights): else: bias = None - return TensorParallelColumnLinear( - get_linear(weight, bias=bias, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=bias)) class Starcoder2Attention(torch.nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index a83bc1c6..735c3899 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -34,6 +34,7 @@ from text_generation_server.layers import ( TensorParallelEmbedding, TensorParallelRowLinear, ) +from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -682,7 +683,7 @@ class Idefics2Connector(nn.Module): class Idefics2ForConditionalGeneration(nn.Module): def __init__(self, prefix, config, weights): super().__init__() - config.vision_config.quantize = config.quantize + config.vision_config.quantize = None config.vision_config.speculator = config.speculator config.text_config.quantize = config.quantize config.text_config.speculator = config.speculator @@ -695,16 +696,24 @@ class Idefics2ForConditionalGeneration(nn.Module): name="text_model", ) self.dtype = weights.dtype - self.vision_model = Idefics2VisionTransformer( - prefix=f"{prefix}.model.vision_model" if prefix else "model.vision_model", - config=vision_config, - weights=weights, - ) - self.connector = Idefics2Connector( - prefix=f"{prefix}.model.connector" if prefix else "model.connector", - config=config, - weights=weights, - ) + + # The vision and connector models are not quantized. + with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)): + self.vision_model = Idefics2VisionTransformer( + prefix=( + f"{prefix}.model.vision_model" if prefix else "model.vision_model" + ), + config=vision_config, + weights=weights, + ) + + config.quantize = None + self.connector = Idefics2Connector( + prefix=f"{prefix}.model.connector" if prefix else "model.connector", + config=config, + weights=weights, + ) + self.config = config self.image_seq_len = config.perceiver_config.resampler_n_latents self.image_token_id = config.image_token_id diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py index fb09a8f1..04c547b2 100644 --- a/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -75,7 +75,7 @@ def load_col(config, prefix, weights, bias): bias = bias.to(device=weights.device) else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) return TensorParallelColumnLinear(linear) @@ -337,17 +337,17 @@ class MultiheadAttention(nn.Module): weights, ): super().__init__() - attn_impl = config.attn_config["attn_impl"] - self.attn_impl = config.attn_config["attn_impl"] - self.clip_qkv = config.attn_config["clip_qkv"] - self.qk_ln = config.attn_config["qk_ln"] + attn_impl = config.attn_config.attn_impl + self.attn_impl = config.attn_config.attn_impl + self.clip_qkv = config.attn_config.clip_qkv + self.qk_ln = config.attn_config.qk_ln self.d_model = config.d_model d_model = config.d_model self.n_heads = config.n_heads - self.softmax_scale = config.attn_config["softmax_scale"] + self.softmax_scale = config.attn_config.softmax_scale if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) - self.attn_dropout_p = config.attn_config["attn_pdrop"] + self.attn_dropout_p = config.attn_config.attn_pdrop if self.n_heads % weights.process_group.size() != 0: raise ValueError( @@ -430,17 +430,17 @@ class MultiQueryAttention(nn.Module): def __init__(self, config, prefix, weights): super().__init__() - attn_impl = config.attn_config["attn_impl"] - self.attn_impl = config.attn_config["attn_impl"] - self.clip_qkv = config.attn_config["clip_qkv"] - self.qk_ln = config.attn_config["qk_ln"] + attn_impl = config.attn_config.attn_impl + self.attn_impl = config.attn_config.attn_impl + self.clip_qkv = config.attn_config.clip_qkv + self.qk_ln = config.attn_config.qk_ln self.d_model = config.d_model d_model = config.d_model self.n_heads = config.n_heads - self.softmax_scale = config.attn_config["softmax_scale"] + self.softmax_scale = config.attn_config.softmax_scale if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.head_dim) - self.attn_dropout_p = config.attn_config["attn_pdrop"] + self.attn_dropout_p = config.attn_config.attn_pdrop # self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device) self.Wqkv = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias @@ -614,9 +614,9 @@ class MPTBlock(nn.Module): def __init__(self, config, prefix, weights): super().__init__() self.prefix = prefix - if config.attn_config["attn_type"] != "multihead_attention": + if config.attn_config.attn_type != "multihead_attention": raise NotImplementedError( - f"""Not implemented attn {config.attn_config["attn_type"]}""" + f"""Not implemented attn {config.attn_config.attn_type}""" ) resid_pdrop = config.resid_pdrop if config.no_bias: @@ -789,11 +789,11 @@ class MPTModel(MPTPreTrainedModel): self.world_size = weights.process_group.size() self.rank = weights.process_group.rank() self.n_heads = config.n_heads - self.attn_impl = config.attn_config["attn_impl"] - self.prefix_lm = config.attn_config["prefix_lm"] - self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"] - self.alibi = config.attn_config["alibi"] - self.alibi_bias_max = config.attn_config["alibi_bias_max"] + self.attn_impl = config.attn_config.attn_impl + self.prefix_lm = config.attn_config.prefix_lm + self.attn_uses_sequence_id = config.attn_config.attn_uses_sequence_id + self.alibi = config.attn_config.alibi + self.alibi_bias_max = config.attn_config.alibi_bias_max if config.init_device == "mixed": if dist.get_local_rank() == 0: config.init_device = "cpu" diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index e66011a1..cfffafa1 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -23,14 +23,13 @@ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models import Model +from text_generation_server.utils.log import log_master from text_generation_server.utils.tokens import batch_top_tokens -from text_generation_server.utils.dist import RANK from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils import ( initialize_torch_distributed, weight_files, Weights, - hub, ) from text_generation_server.models.types import ( Batch, @@ -50,6 +49,7 @@ from text_generation_server.models.globals import ( from text_generation_server.layers.attention import Seqlen from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION +from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments from text_generation_server.utils.import_utils import ( @@ -838,7 +838,9 @@ class FlashCausalLM(Model): default_dtype=torch.float16, aliases=None, # Used for Santacoder override of config - num_kv_heads=None, + num_kv_heads: Optional[int] = None, + # Deepseek V2 uses different QK and V dims. + head_size: Optional[int] = None, skip_special_tokens: bool = True, ): self.process_group, rank, world_size = initialize_torch_distributed() @@ -881,12 +883,16 @@ class FlashCausalLM(Model): torch.distributed.barrier(group=self.process_group) + weights_loader = get_loader(quantize, model_id, revision) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, device, dtype, process_group=self.process_group, aliases=aliases + filenames, + device, + dtype, + process_group=self.process_group, + aliases=aliases, + weights_loader=weights_loader, ) - if config.quantize in ["awq", "exl2", "gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) prefix = "" model = model_class(prefix, config, weights) @@ -905,15 +911,23 @@ class FlashCausalLM(Model): self.num_layers = config.num_hidden_layers # Validation is done in the model itself if num_kv_heads is None: - # Order is important here. - for attr in ["num_key_value_heads", "num_key_value_heads", "n_head"]: - num_kv_heads = getattr(config, "num_attention_heads", None) - if num_kv_heads is not None: - break + num_kv_heads = getattr(config, "num_key_value_heads", None) + # GPT-2 workaround if num_kv_heads is None: - raise ValueError("Cannot get the number of key/value heads") - self.num_kv_heads = num_kv_heads // self.process_group.size() - self.head_size = config.hidden_size // config.num_attention_heads + num_kv_heads = getattr(config, "n_head", None) + if num_kv_heads is None: + raise ValueError("Cannot get the number of key/value heads") + self.num_kv_heads = ( + num_kv_heads // self.process_group.size() + if num_kv_heads > 1 + else num_kv_heads + ) + assert self.num_kv_heads > 0 + + if head_size is None: + self.head_size = config.hidden_size // config.num_attention_heads + else: + self.head_size = head_size self.cuda_graphs = {} self.kv_cache = [] @@ -1141,31 +1155,36 @@ class FlashCausalLM(Model): f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", ) - logger.info( - f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`." + log_master( + logger.info, + f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.", ) if os.path.isfile(tunableop_filepath): - logger.info( - f"The file {tunableop_filepath} already exists and will be reused." + log_master( + logger.info, + f"The file {tunableop_filepath} already exists and will be reused.", ) torch.cuda.tunable.read_file(tunableop_filepath) os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True) for seqlen in tuning_sequences: - logger.info(f"Warming up TunableOp for seqlen={seqlen}") + log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}") self.tunableop_warmup(seqlen) torch.cuda.tunable.write_file(tunableop_filepath) torch.cuda.tunable.tuning_enable(False) else: - logger.info( - "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp." + log_master( + logger.info, + "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.", ) if CUDA_GRAPHS: try: - logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}") + log_master( + logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}" + ) # Warmup cuda graphs for bs in CUDA_GRAPHS: if self.speculate is None or self.speculate + 1 <= bs: @@ -1173,7 +1192,9 @@ class FlashCausalLM(Model): except torch.cuda.OutOfMemoryError: logger.exception(f"Decode cuda graph warmup failed") else: - logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") + log_master( + logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})." + ) return int(num_blocks * BLOCK_SIZE) @@ -1525,8 +1546,7 @@ class FlashCausalLM(Model): left = 0 if n_accepted_ids > 1: - if RANK == 0: - logger.debug(f"Speculated ids {n_accepted_ids - 1}") + log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}") current_stopped = False for j in range(index, index + n_accepted_ids): diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 06035ccd..ac42df30 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -1,15 +1,16 @@ import torch import os from loguru import logger -from typing import Dict +from typing import Dict, Optional + +from text_generation_server.utils.log import log_master MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 if FLASH_DECODING: - logger.info("Using FLASH_DECODING") - + log_master(logger.info, "Using FLASH_DECODING") cuda_graphs = os.getenv("CUDA_GRAPHS") if cuda_graphs is not None: @@ -26,11 +27,9 @@ else: if cuda_graphs is not None: cuda_graphs.sort(reverse=True) - CUDA_GRAPHS = cuda_graphs # This is overridden at model loading. -global MODEL_ID MODEL_ID = None @@ -41,8 +40,7 @@ def set_model_id(model_id: str): # NOTE: eventually we should move this into the router and pass back the # index in all cases. -global ADAPTER_TO_INDEX -ADAPTER_TO_INDEX: Dict[str, int] = None +ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None def set_adapter_to_index(adapter_to_index: Dict[str, int]): diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index f2955bd0..0deab6ce 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -23,6 +23,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.quantization import get_loader class IDEFICSSharded(IdeficsCausalLM): @@ -70,6 +71,9 @@ class IDEFICSSharded(IdeficsCausalLM): trust_remote_code=trust_remote_code, ) + weights_loader = get_loader( + quantize=quantize, model_id=model_id, revision=revision + ) torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( @@ -77,6 +81,7 @@ class IDEFICSSharded(IdeficsCausalLM): device=device, dtype=dtype, process_group=self.process_group, + weights_loader=weights_loader, ) model = IdeficsForVisionText2Text(config, weights) diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 9189b45c..4ed9722c 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -28,6 +28,7 @@ from text_generation_server.models.types import ( GeneratedText, ) from text_generation_server.utils.chunks import concat_text_chunks +from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.tokens import batch_top_tokens, Sampling from dataclasses import dataclass from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling @@ -448,8 +449,17 @@ class Mamba(Model): config.quantize = quantize config.speculator = speculator torch.distributed.barrier(group=self.process_group) + weights_loader = get_loader( + quantize=quantize, model_id=model_id, revision=revision + ) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) + weights = Weights( + filenames, + device, + dtype, + process_group=self.process_group, + weights_loader=weights_loader, + ) model = MambaModel(config, weights) torch.distributed.barrier(group=self.process_group) super(Mamba, self).__init__( diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 09130b85..e7748bb9 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -15,6 +15,7 @@ from text_generation_server.utils.adapter import ( AdapterParameters, AdapterSource, ) +from text_generation_server.utils.log import log_master from loguru import logger @@ -204,8 +205,9 @@ class Model(ABC): f"order to use the dynamic adapter loading feature." ) - logger.info( - f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}" + log_master( + logger.info, + f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}", ) weight_names = tuple([v[0] for v in self.target_to_layer.values()]) ( @@ -240,8 +242,9 @@ class Model(ABC): layer_weights.add_adapter(adapter_index, adapter_weights) if len(unused_weight_names) > 0: - logger.warning( - f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}" + log_master( + logger.warning, + f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}", ) if adapter_tokenizer is not None: diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index dbaf1253..fa8b5025 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -18,6 +18,7 @@ from text_generation_server.utils import ( Weights, ) from text_generation_server.utils.chunks import concat_text_chunks +from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models import Model from text_generation_server.models.types import ( @@ -586,6 +587,9 @@ class Seq2SeqLM(Model): ) tokenizer.bos_token_id = config.decoder_start_token_id + weights_loader = get_loader( + quantize=quantize, model_id=model_id, revision=revision + ) torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( @@ -594,6 +598,7 @@ class Seq2SeqLM(Model): dtype=dtype, process_group=self.process_group, aliases=aliases, + weights_loader=weights_loader, ) if config.quantize in ["awq", "exl2", "gptq", "marlin"]: weights._set_gptq_params(model_id, revision) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index ace48805..308d5a3d 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -1,4 +1,3 @@ -from itertools import repeat import torch from PIL import Image from io import BytesIO @@ -13,6 +12,7 @@ from text_generation_server.models.flash_causal_lm import ( FlashCausalLMBatch, FlashCausalLM, ) +from text_generation_server.utils.log import log_master from transformers import AutoProcessor tracer = trace.get_tracer(__name__) @@ -56,8 +56,9 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str num_features = get_number_of_features(height, width, config) from loguru import logger - logger.info( - f"Found {num_features} features in image of resolution {height}x{width}" + log_master( + logger.info, + f"Found {num_features} features in image of resolution {height}x{width}", ) return "" * num_features @@ -261,7 +262,12 @@ class VlmCausalLM(FlashCausalLM): **processor_kwargs, ) self.batch_class = batch_class - super().__init__(model_id=model_id, **kwargs) + super().__init__( + model_id=model_id, + revision=revision, + trust_remote_code=trust_remote_code, + **kwargs, + ) @property def batch_type(self) -> Type[VlmCausalLMBatch]: diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 36d63e86..82aeba6c 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -56,7 +56,7 @@ def initialize_torch_distributed(): backend = "nccl" options = ProcessGroupNCCL.Options() options.is_high_priority_stream = True - options._timeout = timedelta(seconds=60) + options._timeout = timedelta(seconds=120) else: backend = "gloo" options = None @@ -76,7 +76,7 @@ def initialize_torch_distributed(): backend="ccl", world_size=WORLD_SIZE, rank=RANK, - timeout=timedelta(seconds=60), + timeout=timedelta(seconds=120), pg_options=options, ) else: @@ -84,7 +84,7 @@ def initialize_torch_distributed(): backend=backend, world_size=WORLD_SIZE, rank=RANK, - timeout=timedelta(seconds=60), + timeout=timedelta(seconds=120), pg_options=options, ) else: diff --git a/server/text_generation_server/utils/log.py b/server/text_generation_server/utils/log.py index b1456f1e..4385c71e 100644 --- a/server/text_generation_server/utils/log.py +++ b/server/text_generation_server/utils/log.py @@ -1,6 +1,15 @@ from functools import lru_cache +from text_generation_server.utils.dist import RANK @lru_cache(10) -def log_once(log, msg: str): - log(msg) +def log_once(log, msg: str, master=True): + if master: + log_master(log, msg) + else: + log(msg) + + +def log_master(log, msg: str): + if RANK == 0: + log(msg) diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py new file mode 100644 index 00000000..c3c038fe --- /dev/null +++ b/server/text_generation_server/utils/quantization.py @@ -0,0 +1,173 @@ +import json +import os +from dataclasses import dataclass +from typing import Optional + +from huggingface_hub import hf_hub_download +from text_generation_server.utils.weights import ( + DefaultWeightsLoader, + UnquantizedWeight, + WeightsLoader, +) + + +# TODO: Split this config to have a single config type per quant method +@dataclass +class _QuantizerConfig: + bits: int + checkpoint_format: Optional[str] + desc_act: bool + groupsize: int + quant_method: str + sym: bool + + +@dataclass +class _FP8QuantizerConfig: + activation_scale_ub: float + + +# We should probably do this with Pytantic JSON deserialization, +# but for now we'll stay close to the old _set_gptq_params. +def _get_quantizer_config(model_id, revision): + bits = 4 + groupsize = -1 + quant_method = "gptq" + checkpoint_format = None + sym = True + desc_act = False + + filename = "config.json" + try: + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download(model_id, filename=filename, revision=revision) + with open(filename, "r") as f: + data = json.load(f) + + # FP8 config + if data["quantization_config"]["quant_method"] == "fbgemm_fp8": + return _FP8QuantizerConfig( + activation_scale_ub=data["quantization_config"]["activation_scale_ub"] + ) + + bits = data["quantization_config"]["bits"] + groupsize = data["quantization_config"]["group_size"] + # Order is important here, desc_act is missing on some real models + quant_method = data["quantization_config"]["quant_method"] + checkpoint_format = data["quantization_config"].get("checkpoint_format") + sym = data["quantization_config"]["sym"] + desc_act = data["quantization_config"]["desc_act"] + except Exception: + filename = "quantize_config.json" + try: + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download( + model_id, filename=filename, revision=revision + ) + with open(filename, "r") as f: + data = json.load(f) + bits = data["bits"] + groupsize = data["group_size"] + sym = data["sym"] + desc_act = data["desc_act"] + if "version" in data and data["version"] == "GEMM": + quant_method = "awq" + except Exception: + filename = "quant_config.json" + try: + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download( + model_id, filename=filename, revision=revision + ) + with open(filename, "r") as f: + data = json.load(f) + bits = data["w_bit"] + groupsize = data["q_group_size"] + desc_act = data["desc_act"] + if "version" in data and data["version"] == "GEMM": + quant_method = "awq" + except Exception: + pass + + return _QuantizerConfig( + bits=bits, + groupsize=groupsize, + quant_method=quant_method, + checkpoint_format=checkpoint_format, + sym=sym, + desc_act=desc_act, + ) + + +def get_loader( + quantize: Optional[str], model_id: str, revision: Optional[str] +) -> WeightsLoader: + quantizer_config = _get_quantizer_config(model_id, revision) + if quantize in {"awq", "gptq"}: + from text_generation_server.layers.gptq import GPTQWeightsLoader + + # TODO: improve check once we have one config type per quantize value + if not isinstance(quantizer_config, _QuantizerConfig): + raise ValueError( + f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." + ) + + return GPTQWeightsLoader( + bits=quantizer_config.bits, + desc_act=quantizer_config.desc_act, + groupsize=quantizer_config.groupsize, + quant_method=quantizer_config.quant_method, + quantize=quantize, + sym=quantizer_config.sym, + ) + elif quantize == "bitsandbytes": + from text_generation_server.layers.bnb import BNBWeight + + return DefaultWeightsLoader(BNBWeight) + elif quantize == "bitsandbytes-fp4": + from text_generation_server.layers.bnb import BNBFP4Weight + + return DefaultWeightsLoader(BNBFP4Weight) + elif quantize == "bitsandbytes-nf4": + from text_generation_server.layers.bnb import BNBNF4Weight + + return DefaultWeightsLoader(BNBNF4Weight) + elif quantize == "eetq": + from text_generation_server.layers.eetq import EETQWeight + + return DefaultWeightsLoader(EETQWeight) + elif quantize == "exl2": + from text_generation_server.layers.exl2 import Exl2WeightsLoader + + return Exl2WeightsLoader() + elif quantize == "marlin": + from text_generation_server.layers.marlin import MarlinWeightsLoader + + # TODO: improve check once we have one config type per quantize value + if not isinstance(quantizer_config, _QuantizerConfig): + raise ValueError( + f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." + ) + + return MarlinWeightsLoader( + bits=quantizer_config.bits, + is_marlin_24=quantizer_config.checkpoint_format == "marlin_24", + ) + elif quantize == "fp8" or quantize is None: + from text_generation_server.layers.fp8 import HybridFP8UnquantLoader + + # Since the default for the quantize config is _QuantizerConfig, + # we need to add this check to not get an attribute error + activation_scale_ub = None + if isinstance(quantizer_config, _FP8QuantizerConfig): + activation_scale_ub = quantizer_config.activation_scale_ub + + return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8") + else: + raise ValueError(f"Unknown quantization method: {quantize}") diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 3731fd24..66bb6051 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,13 +1,140 @@ -import os -from pathlib import Path -from typing import Dict, List, Optional, Union -from safetensors import safe_open, SafetensorError import torch -from loguru import logger -from huggingface_hub import hf_hub_download -import json -from text_generation_server.layers.gptq import GPTQParams -from text_generation_server.utils.log import log_once + +from abc import ABC, abstractmethod +from contextlib import contextmanager +from pathlib import Path +from typing import Dict, List, Optional, Union, Type +from safetensors import safe_open +from dataclasses import dataclass + +from text_generation_server.utils.import_utils import SYSTEM + + +class WeightsLoader(ABC): + """ + Instances of this type implement higher-level weight loading. + + At a low-level, every weight is stored in the Safetensors format. + The interpretation of weights may be different however, for instance + could be packed, quantized weights. Loaders are responsible for + interpreting the raw tensors, sharding tensors in a manner compatible + with the format, etc. + """ + + @abstractmethod + def get_weights(self, weights: "Weights", prefix: str): + """ + Get weights at the given prefix and apply without tensor paralllism. + """ + ... + + @abstractmethod + def get_weights_col_packed( + self, + weights: "Weights", + prefix: str, + block_sizes: Union[int, List[int]], + ): + """ + Get the packed weights at the given prefix with column-splitting for + tensor parallelism. This method should be used when multiple different + weights are packed into a tensor, for instance, query/key/value + weights or a gate/up projection. + + The `block_sizes` determines the proportions of the packed tensors. + The columns are split in equally sized blocks when `block_sizes` is an + `int`, or in blocks proportional given to the sizes. For instance + `[2, 1, 1]` will divide an input with dimensionality `1024` in + `[512, 256, 256]`. + """ + ... + + def get_weights_col(self, weights: "Weights", prefix: str): + """ + Get weights at the given prefix and apply column-splitting for tensor + paralllism. + """ + return weights.get_multi_weights_col([prefix], 0) + + @abstractmethod + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): + """ + Get the weights at the given prefixes, column-split them for tensor + parallelim, and then concatenate the weights along the given dimension. + """ + ... + + @abstractmethod + def get_weights_row(self, weights: "Weights", prefix: str): + """ + Get the weights at the given prefix and apply row-splitting for tensor + parallism. + """ + ... + + +class Weight(ABC): + """Instances of this type implement unquantized/quantized/to-be + quantized weights.""" + + @abstractmethod + def get_linear(self, bias: torch.Tensor): + """Create a linear layer from this weight.""" + ... + + +@dataclass +class UnquantizedWeight(Weight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + from text_generation_server.layers.linear import FastLinear, FastLinearROCm + + if SYSTEM == "rocm": + return FastLinearROCm(self.weight, bias) + else: + return FastLinear(self.weight, bias) + + +class DefaultWeightsLoader(WeightsLoader): + """Weight loader that loads (unquantized) Torch tensors.""" + + def __init__(self, weight_class: Type[UnquantizedWeight]): + """Create a loader. Weights will be wrapped using the given `weights_class`, + normally this will be `UnquantizedWeight`, but a quantizer-specific class + such as `Fp8Weight` can be used to quantize the weights during loading. + """ + self.weight_class = weight_class + + """ + Loader that uses tensors as-is with the exception of applying sharding + and/or concatenation. + """ + + def get_weights(self, weights: "Weights", prefix: str): + return weights.get_tensor(f"{prefix}.weight") + + def get_weights_col_packed( + self, + weights: "Weights", + prefix: str, + block_sizes: Union[int, List[int]], + ): + + return self.weight_class( + weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes + ), + ) + + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): + w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] + return self.weight_class(torch.cat(w, dim=dim)) + + def get_weights_row(self, weights: "Weights", prefix: str): + return self.weight_class( + weights.get_sharded(f"{prefix}.weight", dim=1), + ) class Weights: @@ -17,6 +144,7 @@ class Weights: device, dtype, process_group, + weights_loader: WeightsLoader, aliases: Optional[Dict[str, List[str]]] = None, prefix: Optional[str] = None, ): @@ -37,6 +165,7 @@ class Weights: self.dtype = dtype self.process_group = process_group self.prefix = prefix + self.weights_loader = weights_loader self._handles = {} def _get_handle(self, filename): @@ -69,23 +198,39 @@ class Weights: slice_ = f.get_slice(tensor_name) return slice_ + def _has_tensor(self, tensor_name: str): + try: + self.get_filename(tensor_name) + except Exception: + return False + return True + def get_shape(self, tensor_name: str): return self._get_slice(tensor_name).get_shape() - def get_tensor(self, tensor_name: str, to_device=True): + def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) tensor = f.get_tensor(tensor_name) # Special case for gptq which shouldn't convert # u4 which are disguised as int32. Exl2 uses int16 - # as well. - if tensor.dtype not in [torch.int16, torch.int32, torch.int64]: + # as well. FP8 uses torch.float8_e4m3fn + if ( + tensor.dtype + not in [ + torch.float8_e4m3fn, + torch.int16, + torch.int32, + torch.int64, + ] + and to_dtype + ): tensor = tensor.to(dtype=self.dtype) if to_device: tensor = tensor.to(device=self.device) return tensor - def get_partial_sharded(self, tensor_name: str, dim: int): + def get_partial_sharded(self, tensor_name: str, dim: int, to_dtype=True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -105,12 +250,16 @@ class Weights: raise NotImplementedError("Let's make that generic when needed") # Special case for gptq which shouldn't convert # u4 which are disguised as int32. exl2 uses int16. - if tensor.dtype not in (torch.int16, torch.int32): + # FP8 uses torch.float8_e4m3fn. + if ( + tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32) + and to_dtype + ): tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(device=self.device) return tensor - def get_sharded(self, tensor_name: str, dim: int): + def get_sharded(self, tensor_name: str, dim: int, to_dtype=True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -119,10 +268,14 @@ class Weights: assert ( size % world_size == 0 ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" - return self.get_partial_sharded(tensor_name, dim) + return self.get_partial_sharded(tensor_name, dim, to_dtype=to_dtype) def get_packed_sharded( - self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]] + self, + tensor_name: str, + dim: int, + block_sizes: Union[int, List[int]], + to_dtype=True, ) -> torch.Tensor: """ Get a shard from a tensor that packs multiple tensors. @@ -168,308 +321,51 @@ class Weights: tensor = tensor.to(device=self.device) # Avoid casting quantizer dtypes. - if tensor.dtype not in [torch.int16, torch.int32, torch.int64]: + if ( + tensor.dtype + not in [ + torch.float8_e4m3fn, + torch.int16, + torch.int32, + torch.int64, + ] + and to_dtype + ): tensor = tensor.to(dtype=self.dtype) return tensor + def get_weights(self, prefix: str): + return self.weights_loader.get_weights(self, prefix) + def get_weights_col_packed_qkv( self, prefix: str, - quantize: str, num_heads: int, num_key_value_heads: int, ): return self.get_weights_col_packed( - prefix, quantize, [num_heads, num_key_value_heads, num_key_value_heads] + prefix, [num_heads, num_key_value_heads, num_key_value_heads] ) - def get_weights_col_packed_gate_up(self, prefix: str, quantize: str): - return self.get_weights_col_packed(prefix, quantize, 2) + def get_weights_col_packed_gate_up(self, prefix: str): + return self.get_weights_col_packed(prefix, 2) - def get_weights_col_packed( - self, prefix: str, quantize: str, block_sizes: Union[int, List[int]] - ): + def get_weights_col_packed(self, prefix: str, block_sizes: Union[int, List[int]]): """ - Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being - already alternating Q,K,V within the main tensor. - The columns are split in equally sized blocks when blocks is an `int`, or in blocks proportional given to the sizes. For instance `[2, 1, 1]` will divide an input with dimensionality `1024` in `[512, 256, 256]`. This is convenient for e.g. splitting QKV without knowing the storage details of quantized weights. """ - if quantize in ["gptq", "awq"]: - from text_generation_server.layers.gptq import GPTQWeight - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) + return self.weights_loader.get_weights_col_packed(self, prefix, block_sizes) - try: - qweight = self.get_packed_sharded( - f"{prefix}.qweight", dim=1, block_sizes=block_sizes - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight, make sure the model is already quantized." - ) - scales = self.get_packed_sharded( - f"{prefix}.scales", dim=1, block_sizes=block_sizes - ) - scales = scales.to(dtype=self.dtype) + def get_weights_col(self, prefix: str): + return self.weights_loader.get_weights_col(self, prefix) - gptq_params = self._get_gptq_params() - if can_use_gptq_marlin(gptq_params, quantize): - g_idx = self.get_tensor(f"{prefix}.g_idx") - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - desc_act=gptq_params.desc_act, - groupsize=gptq_params.groupsize, - sym=gptq_params.sym, - sharded_infeatures=False, - ) - - qzeros = self.get_packed_sharded( - f"{prefix}.qzeros", dim=1, block_sizes=block_sizes - ) - if quantize == "gptq" and gptq_params.quant_method == "gptq": - g_idx = self.get_tensor(f"{prefix}.g_idx") - elif quantize == "gptq" and gptq_params.quant_method == "awq": - log_once( - logger.info, "Converting AWQ model to Exllama/GPTQ packing format." - ) - from text_generation_server.layers.awq.conversion_utils import ( - fast_awq_to_gptq, - ) - - qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - g_idx = ( - torch.arange( - qweight.shape[0] * (32 // gptq_params.bits), - device=qweight.device, - ) - // gptq_params.groupsize - ).to(dtype=torch.int32) - else: - g_idx = None - - weight = GPTQWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, - use_exllama=False, - ) - elif quantize == "marlin": - from text_generation_server.layers.marlin import ( - GPTQMarlin24Weight, - MarlinWeight, - repack_gptq_for_marlin, - ) - - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: - B = self.get_packed_sharded( - f"{prefix}.B_24", dim=1, block_sizes=block_sizes - ) - B_meta = self.get_packed_sharded( - f"{prefix}.B_meta", dim=1, block_sizes=block_sizes - ) - s = self.get_packed_sharded( - f"{prefix}.s", dim=1, block_sizes=block_sizes - ) - - gptq_params = self._get_gptq_params() - weight = GPTQMarlin24Weight( - B=B, B_meta=B_meta, s=s, bits=gptq_params.bits - ) - else: - B = self.get_packed_sharded( - f"{prefix}.B", dim=1, block_sizes=block_sizes - ) - s = self.get_packed_sharded( - f"{prefix}.s", dim=1, block_sizes=block_sizes - ) - weight = MarlinWeight(B=B, s=s) - else: - weight = self.get_packed_sharded( - f"{prefix}.weight", dim=0, block_sizes=block_sizes - ) - return weight - - def get_weights_col(self, prefix: str, quantize: str): - if quantize == "exl2": - from text_generation_server.layers.exl2 import Exl2Weight - - try: - q_weight = self.get_tensor(f"{prefix}.q_weight") - except RuntimeError: - raise RuntimeError( - f"Cannot load `exl2`-quantized weight, make sure the model is already quantized." - ) - - q_scale = self.get_tensor(f"{prefix}.q_scale") - q_invperm = self.get_tensor(f"{prefix}.q_invperm") - q_scale_max = self.get_tensor(f"{prefix}.q_scale_max") - q_groups = self.get_tensor(f"{prefix}.q_groups") - - return Exl2Weight( - q_weight=q_weight, - q_scale=q_scale, - q_invperm=q_invperm, - q_scale_max=q_scale_max, - q_groups=q_groups, - ) - - return self.get_multi_weights_col([prefix], quantize, 0) - - def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): - if quantize == "exl2": - raise ValueError("get_multi_weights_col is not supported for exl2") - elif quantize in ["gptq", "awq"]: - from text_generation_server.layers.gptq import GPTQWeight - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) - - try: - qweight = torch.cat( - [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight, make sure the model is already quantized" - ) - - scales = torch.cat( - [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 - ) - - gptq_params = self._get_gptq_params() - if can_use_gptq_marlin(gptq_params, quantize): - w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - desc_act=gptq_params.desc_act, - groupsize=gptq_params.groupsize, - sym=gptq_params.sym, - sharded_infeatures=False, - ) - - qzeros = torch.cat( - [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 - ) - - from text_generation_server.layers.gptq import HAS_EXLLAMA - - use_exllama = ( - gptq_params.bits == 4 - and HAS_EXLLAMA - and quantize == "gptq" - and not gptq_params.desc_act - ) - - if quantize == "gptq" and gptq_params.quant_method == "gptq": - w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - elif quantize == "gptq" and gptq_params.quant_method == "awq": - log_once( - logger.info, "Converting AWQ model to Exllama/GPTQ packing format." - ) - from text_generation_server.layers.awq.conversion_utils import ( - fast_awq_to_gptq, - ) - - qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - if use_exllama: - g_idx = None - else: - g_idx = ( - torch.arange( - qweight.shape[0] * (32 // gptq_params.bits), - device=qweight.device, - ) - // gptq_params.groupsize - ).to(dtype=torch.int32) - else: - g_idx = None - - weight = GPTQWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, - use_exllama=use_exllama, - ) - elif quantize == "marlin": - from text_generation_server.layers.gptq import GPTQWeight - from text_generation_server.layers.marlin import ( - GPTQMarlin24Weight, - MarlinWeight, - ) - - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: - try: - B = torch.cat( - [self.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight, make sure the model is already quantized" - ) - - B_meta = torch.cat( - [self.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1 - ) - - s = torch.cat( - [self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 - ) - - gptq_params = self._get_gptq_params() - weight = GPTQMarlin24Weight( - B=B, B_meta=B_meta, s=s, bits=gptq_params.bits - ) - else: - try: - B = torch.cat( - [self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight, make sure the model is already quantized" - ) - s = torch.cat( - [self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 - ) - - weight = MarlinWeight(B=B, s=s) - - else: - w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] - weight = torch.cat(w, dim=dim) - - return weight + def get_multi_weights_col(self, prefixes: List[str], dim: int): + return self.weights_loader.get_multi_weights_col(self, prefixes, dim) def get_tensor_shard(self, var, dim): world_size = self.process_group.size() @@ -487,318 +383,22 @@ class Weights: tensor = tensor.to(device=self.device) return tensor - def get_multi_weights_row(self, prefix: str, quantize: str): - if quantize == "exl2": - from text_generation_server.layers.exl2 import Exl2Weight + def get_weights_row(self, prefix: str): + return self.weights_loader.get_weights_row(self, prefix) - try: - q_weight = self.get_tensor(f"{prefix}.q_weight") - except RuntimeError: - raise RuntimeError( - f"Cannot load `exl2`-quantized weight, make sure the model is already quantized." - ) + @contextmanager + def use_loader(self, weights_loader: WeightsLoader): + """ + This method is a context manager that can be used to use `Weights` with + a different loader for the duration of the context. + """ - q_scale = self.get_tensor(f"{prefix}.q_scale") - q_invperm = self.get_tensor(f"{prefix}.q_invperm") - q_scale_max = self.get_tensor(f"{prefix}.q_scale_max") - q_groups = self.get_tensor(f"{prefix}.q_groups") - - return Exl2Weight( - q_weight=q_weight, - q_scale=q_scale, - q_invperm=q_invperm, - q_scale_max=q_scale_max, - q_groups=q_groups, - ) - - elif quantize == "gptq": - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) - - gptq_params = self._get_gptq_params() - if can_use_gptq_marlin(gptq_params, quantize): - log_once(logger.info, "Using GPTQ-Marlin kernels") - try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - if gptq_params.desc_act or gptq_params.groupsize == -1: - scales = self.get_tensor(f"{prefix}.scales") - else: - scales = self.get_sharded(f"{prefix}.scales", dim=0) - - sharded_in_features = self.process_group.size() > 1 - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - desc_act=gptq_params.desc_act, - groupsize=gptq_params.groupsize, - sym=gptq_params.sym, - sharded_infeatures=sharded_in_features, - ) - - use_exllama = True - if gptq_params.bits != 4: - use_exllama = False - - if gptq_params.desc_act: - log_once(logger.warning, "Disabling exllama because desc_act=True") - use_exllama = False - - try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" - ) - - if gptq_params.quant_method == "gptq": - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - elif gptq_params.quant_method == "awq": - g_idx = None - - if self.process_group.size() > 1: - if g_idx is not None: - if ( - not torch.equal( - g_idx.cpu(), - torch.tensor( - [ - i // gptq_params.groupsize - for i in range(g_idx.shape[0]) - ], - dtype=torch.int32, - ), - ) - and not (g_idx == 0).all() - ): - # Exllama implementation does not support row tensor parallelism with act-order, as - # it would require to reorder input activations that are split unto several GPUs - use_exllama = False - - from text_generation_server.layers.gptq import ( - HAS_EXLLAMA, - CAN_EXLLAMA, - GPTQWeight, - ) - - if use_exllama: - if not HAS_EXLLAMA: - if CAN_EXLLAMA: - log_once( - logger.warning, - "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True", - ) - use_exllama = False - else: - log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") - - if use_exllama and gptq_params.groupsize != -1: - qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) - scales = self.get_sharded(f"{prefix}.scales", dim=0) - else: - qzeros = self.get_tensor(f"{prefix}.qzeros") - scales = self.get_tensor(f"{prefix}.scales") - - if use_exllama and g_idx is not None: - g_idx = g_idx - g_idx[0] - - if gptq_params.quant_method == "awq": - log_once( - logger.info, "Converting AWQ model to Exllama/GPTQ packing format." - ) - from text_generation_server.layers.awq.conversion_utils import ( - fast_awq_to_gptq, - ) - - qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - if use_exllama: - g_idx = None - else: - g_idx = ( - torch.arange( - qweight.shape[0] * (32 // gptq_params.bits), - device=qweight.device, - ) - // gptq_params.groupsize - ).to(dtype=torch.int32) - - weight = GPTQWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, - use_exllama=use_exllama, - ) - elif quantize == "awq": - from text_generation_server.layers.gptq import GPTQWeight - - gptq_params = self._get_gptq_params() - - try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `awq` weight, make sure the model is already quantized" - ) - - qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) - scales = self.get_sharded(f"{prefix}.scales", dim=0) - g_idx = None - use_exllama = False - - weight = GPTQWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, - use_exllama=use_exllama, - ) - elif quantize == "marlin": - from text_generation_server.layers.gptq import GPTQWeight - from text_generation_server.layers.marlin import ( - GPTQMarlin24Weight, - MarlinWeight, - ) - - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: - try: - B = self.get_sharded(f"{prefix}.B_24", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." - ) - - B_meta = self.get_sharded(f"{prefix}.B_meta", dim=0) - num_groups = self._get_slice(f"{prefix}.s").get_shape()[0] - if num_groups == 1: - # The number of groups is 1 when groupsize == -1. share - # scales between all shards in this case. - s = self.get_tensor(f"{prefix}.s") - else: - s = self.get_sharded(f"{prefix}.s", dim=0) - - gptq_params = self._get_gptq_params() - weight = GPTQMarlin24Weight( - B=B, B_meta=B_meta, s=s, bits=gptq_params.bits - ) - else: - try: - B = self.get_sharded(f"{prefix}.B", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized." - ) - - num_groups = self._get_slice(f"{prefix}.s").get_shape()[0] - if num_groups == 1: - # The number of groups is 1 when groupsize == -1. share - # scales between all shards in this case. - s = self.get_tensor(f"{prefix}.s") - else: - s = self.get_sharded(f"{prefix}.s", dim=0) - weight = MarlinWeight(B=B, s=s) - else: - weight = self.get_sharded(f"{prefix}.weight", dim=1) - return weight - - def _get_gptq_params(self) -> GPTQParams: + old_loader = self.weights_loader + self.weights_loader = weights_loader try: - bits = self.get_tensor("gptq_bits").item() - groupsize = self.get_tensor("gptq_groupsize").item() - checkpoint_format = getattr(self, "gptq_checkpoint_format", None) - desc_act = False - sym = False - quant_method = "gptq" - except (SafetensorError, RuntimeError) as e: - try: - bits = self.gptq_bits - groupsize = self.gptq_groupsize - checkpoint_format = getattr(self, "gptq_checkpoint_format", None) - desc_act = getattr(self, "gptq_desc_act", False) - quant_method = getattr(self, "quant_method", "gptq") - sym = getattr(self, "sym", True) - except Exception: - raise e - - return GPTQParams( - bits=bits, - checkpoint_format=checkpoint_format, - desc_act=desc_act, - groupsize=groupsize, - quant_method=quant_method, - sym=sym, - ) - - def _set_gptq_params(self, model_id, revision): - filename = "config.json" - try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download( - model_id, filename=filename, revision=revision - ) - with open(filename, "r") as f: - data = json.load(f) - self.gptq_bits = data["quantization_config"]["bits"] - self.gptq_groupsize = data["quantization_config"]["group_size"] - # Order is important here, desc_act is missing on some real models - self.quant_method = data["quantization_config"]["quant_method"] - self.gptq_checkpoint_format = data["quantization_config"].get( - "checkpoint_format" - ) - self.gptq_sym = data["quantization_config"]["sym"] - self.gptq_desc_act = data["quantization_config"]["desc_act"] - except Exception: - filename = "quantize_config.json" - try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download( - model_id, filename=filename, revision=revision - ) - with open(filename, "r") as f: - data = json.load(f) - self.gptq_bits = data["bits"] - self.gptq_groupsize = data["group_size"] - self.gptq_sym = data["sym"] - self.gptq_desc_act = data["desc_act"] - if "version" in data and data["version"] == "GEMM": - self.quant_method = "awq" - except Exception: - filename = "quant_config.json" - try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download( - model_id, filename=filename, revision=revision - ) - with open(filename, "r") as f: - data = json.load(f) - self.gptq_bits = data["w_bit"] - self.gptq_groupsize = data["q_group_size"] - self.gptq_desc_act = data["desc_act"] - if "version" in data and data["version"] == "GEMM": - self.quant_method = "awq" - except Exception: - pass + yield + finally: + self.weights_loader = old_loader def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]: diff --git a/update_doc.py b/update_doc.py index 1ff94a2c..bfa7e4e9 100644 --- a/update_doc.py +++ b/update_doc.py @@ -155,7 +155,7 @@ def check_openapi(check: bool): filename, ], capture_output=True, - ).stdout.decode() + ).stdout.decode("utf-8") os.remove(tmp_filename) if diff: @@ -164,11 +164,27 @@ def check_openapi(check: bool): "OpenAPI documentation is not up-to-date, run `python update_doc.py` in order to update it" ) - return True else: os.rename(tmp_filename, filename) print("OpenAPI documentation updated.") - return True + errors = subprocess.run( + [ + "swagger-cli", + # allow for trailing whitespace since it's not significant + # and the precommit hook will remove it + "validate", + filename, + ], + capture_output=True, + ).stderr.decode("utf-8") + # The openapi specs fails on `exclusive_minimum` which is expected to be a boolean where + # utoipa outputs a value instead: https://github.com/juhaku/utoipa/issues/969 + if not errors.startswith("Swagger schema validation failed."): + print(errors) + raise Exception( + f"OpenAPI documentation is invalid, `swagger-cli validate` showed some error:\n {errors}" + ) + return True def main():