diff --git a/.github/workflows/autodocs.yaml b/.github/workflows/autodocs.yaml index 8af0b95d..e10b232c 100644 --- a/.github/workflows/autodocs.yaml +++ b/.github/workflows/autodocs.yaml @@ -30,6 +30,10 @@ jobs: id: install-router run: cargo install --path router/ + - uses: actions/setup-node@v4 + with: + node-version: 22 + - name: Set up Python uses: actions/setup-python@v2 with: @@ -37,4 +41,5 @@ jobs: - name: Check that documentation is up-to-date run: | + npm install -g swagger-cli python update_doc.py --check diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index b0049701..cd9f19ba 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -11,6 +11,11 @@ on: # - rocm # - intel required: true + release-tests: + description: "Run release integration tests" + required: true + default: false + type: boolean jobs: build-and-push: @@ -23,7 +28,7 @@ jobs: group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true # TODO see with @Glegendre to get CPU runner here instead - runs-on: [self-hosted, nvidia-gpu , multi-gpu, 4-a10, ci] + runs-on: [self-hosted, intel-cpu, 32-cpu, 256-ram, ci] permissions: contents: write packages: write @@ -131,8 +136,8 @@ jobs: DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }} labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }} - cache-from: type=registry,ref=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min - cache-to: type=registry,ref=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min + cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min + cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min - name: Final id: final run: | @@ -148,7 +153,7 @@ jobs: runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"] if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest' env: - PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main') && '--release' || '' }} + PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '' }} steps: - name: Checkout repository uses: actions/checkout@v4 diff --git a/.github/workflows/ci_build.yaml b/.github/workflows/ci_build.yaml index 754c4850..d62297e4 100644 --- a/.github/workflows/ci_build.yaml +++ b/.github/workflows/ci_build.yaml @@ -20,7 +20,14 @@ on: - "Dockerfile_amd" - "Dockerfile_intel" branches: - - 'main' + - "main" + workflow_dispatch: + inputs: + release-tests: + description: "Run release integration tests" + required: true + default: false + type: boolean jobs: build: @@ -33,4 +40,6 @@ jobs: uses: ./.github/workflows/build.yaml # calls the one above ^ with: hardware: ${{ matrix.hardware }} + # https://github.com/actions/runner/issues/2206 + release-tests: ${{ inputs.release-tests == true }} secrets: inherit diff --git a/Cargo.lock b/Cargo.lock index c80d864a..7660f1f1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1935,17 +1935,6 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" -[[package]] -name = "metrics" -version = "0.21.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fde3af1a009ed76a778cb84fdef9e7dbbdf5775ae3e4cc1f434a6a307f6f76c5" -dependencies = [ - "ahash", - "metrics-macros", - "portable-atomic", -] - [[package]] name = "metrics" version = "0.23.0" @@ -1969,7 +1958,7 @@ dependencies = [ "hyper-util", "indexmap 2.2.6", "ipnet", - "metrics 0.23.0", + "metrics", "metrics-util", "quanta", "thiserror", @@ -1977,17 +1966,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "metrics-macros" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38b4faf00617defe497754acde3024865bc143d44a86799b24e191ecff91354f" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.68", -] - [[package]] name = "metrics-util" version = "0.17.0" @@ -1997,7 +1975,7 @@ dependencies = [ "crossbeam-epoch", "crossbeam-utils", "hashbrown 0.14.5", - "metrics 0.23.0", + "metrics", "num_cpus", "quanta", "sketches-ddsketch", @@ -3835,7 +3813,7 @@ dependencies = [ "init-tracing-opentelemetry", "itertools 0.10.5", "jsonschema", - "metrics 0.21.1", + "metrics", "metrics-exporter-prometheus", "minijinja", "minijinja-contrib", diff --git a/Dockerfile b/Dockerfile index d4772b4a..3f2e8ef0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -40,7 +40,9 @@ RUN cargo build --profile release-opt # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS pytorch-install +# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099 ARG PYTORCH_VERSION=2.3.0 + ARG PYTHON_VERSION=3.10 # Keep in sync with `server/pyproject.toml ARG CUDA_VERSION=12.1 @@ -241,7 +243,10 @@ COPY server/Makefile server/Makefile RUN cd server && \ make gen-server && \ pip install -r requirements_cuda.txt && \ - pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir + pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir && \ + pip install nvidia-nccl-cu12==2.22.3 + +ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2 # Deps before the binaries # The binaries change on every build given we burn the SHA into them diff --git a/README.md b/README.md index 4c1c1e29..4287c119 100644 --- a/README.md +++ b/README.md @@ -20,19 +20,20 @@ to power Hugging Chat, the Inference API and Inference Endpoint. ## Table of contents -- [Get Started](#get-started) - - [API Documentation](#api-documentation) - - [Using a private or gated model](#using-a-private-or-gated-model) - - [A note on Shared Memory](#a-note-on-shared-memory-shm) - - [Distributed Tracing](#distributed-tracing) - - [Local Install](#local-install) - - [CUDA Kernels](#cuda-kernels) -- [Optimized architectures](#optimized-architectures) -- [Run Mistral](#run-a-model) - - [Run](#run) - - [Quantization](#quantization) -- [Develop](#develop) -- [Testing](#testing) + - [Get Started](#get-started) + - [Docker](#docker) + - [API documentation](#api-documentation) + - [Using a private or gated model](#using-a-private-or-gated-model) + - [A note on Shared Memory (shm)](#a-note-on-shared-memory-shm) + - [Distributed Tracing](#distributed-tracing) + - [Architecture](#architecture) + - [Local install](#local-install) + - [Optimized architectures](#optimized-architectures) + - [Run locally](#run-locally) + - [Run](#run) + - [Quantization](#quantization) + - [Develop](#develop) + - [Testing](#testing) Text Generation Inference (TGI) is a toolkit for deploying and serving Large Language Models (LLMs). TGI enables high-performance text generation for the most popular open-source LLMs, including Llama, Falcon, StarCoder, BLOOM, GPT-NeoX, and [more](https://huggingface.co/docs/text-generation-inference/supported_models). TGI implements many features, such as: diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index a56edaca..e36dd470 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -61,7 +61,7 @@ class ChoiceDeltaToolCall(BaseModel): class ChoiceDelta(BaseModel): role: str content: Optional[str] = None - tool_calls: Optional[ChoiceDeltaToolCall] + tool_calls: Optional[ChoiceDeltaToolCall] = None class Choice(BaseModel): diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index c9b4efd9..119c5662 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -11,6 +11,8 @@ title: Using TGI with Intel Gaudi - local: installation_inferentia title: Using TGI with AWS Inferentia + - local: installation_intel + title: Using TGI with Intel GPUs - local: installation title: Installation from source - local: supported_models diff --git a/docs/source/architecture.md b/docs/source/architecture.md index a8418817..28c84f62 100644 --- a/docs/source/architecture.md +++ b/docs/source/architecture.md @@ -103,6 +103,7 @@ Several variants of the model server exist that are actively supported by Huggin - By default, the model server will attempt building [a server optimized for Nvidia GPUs with CUDA](https://huggingface.co/docs/text-generation-inference/installation_nvidia). The code for this version is hosted in the [main TGI repository](https://github.com/huggingface/text-generation-inference). - A [version optimized for AMD with ROCm](https://huggingface.co/docs/text-generation-inference/installation_amd) is hosted in the main TGI repository. Some model features differ. +- A [version optimized for Intel GPUs](https://huggingface.co/docs/text-generation-inference/installation_intel) is hosted in the main TGI repository. Some model features differ. - The [version for Intel Gaudi](https://huggingface.co/docs/text-generation-inference/installation_gaudi) is maintained on a forked repository, often resynchronized with the main [TGI repository](https://github.com/huggingface/tgi-gaudi). - A [version for Neuron (AWS Inferentia2)](https://huggingface.co/docs/text-generation-inference/installation_inferentia) is maintained as part of [Optimum Neuron](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference). - A version for Google TPUs is maintained as part of [Optimum TPU](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference). diff --git a/docs/source/installation_intel.md b/docs/source/installation_intel.md new file mode 100644 index 00000000..f9fda863 --- /dev/null +++ b/docs/source/installation_intel.md @@ -0,0 +1,19 @@ +# Using TGI with Intel GPUs + +TGI optimized models are supported on Intel Data Center GPU [Max1100](https://www.intel.com/content/www/us/en/products/sku/232876/intel-data-center-gpu-max-1100/specifications.html), [Max1550](https://www.intel.com/content/www/us/en/products/sku/232873/intel-data-center-gpu-max-1550/specifications.html), the recommended usage is through Docker. + + +On a server powered by Intel GPUs, TGI can be launched with the following command: + +```bash +model=teknium/OpenHermes-2.5-Mistral-7B +volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run + +docker run --rm --privileged --cap-add=sys_nice \ + --device=/dev/dri \ + --ipc=host --shm-size 1g --net host -v $volume:/data \ + ghcr.io/huggingface/text-generation-inference:latest-intel \ + --model-id $model --cuda-graphs 0 +``` + +The launched TGI server can then be queried from clients, make sure to check out the [Consuming TGI](./basic_tutorials/consuming_tgi) guide. diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md index c546bc03..f056baad 100644 --- a/docs/source/quicktour.md +++ b/docs/source/quicktour.md @@ -17,7 +17,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ ### Supported hardware -TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on. +TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Intel GPUs](./installation_intel), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on. ## Consuming TGI diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index 1eeed39f..2bdd00de 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -10,6 +10,7 @@ Text Generation Inference enables serving optimized models on specific hardware - [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) - [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) - [Gemma](https://huggingface.co/google/gemma-7b) +- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224) - [Gemma2](https://huggingface.co/google/gemma2-9b) - [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus) - [Dbrx](https://huggingface.co/databricks/dbrx-instruct) diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json index 7797cc6c..0f99d259 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json @@ -5,85 +5,80 @@ "generated_tokens": 10, "prefill": [ { - "id": 1, + "id": 2323, "logprob": null, - "text": "" - }, - { - "id": 4321, - "logprob": -9.7890625, "text": "Test" }, { - "id": 2009, - "logprob": -9.625, - "text": "request" + "id": 1715, + "logprob": -11.34375, + "text": " request" } ], "seed": null, "tokens": [ { - "id": 13, - "logprob": -2.3359375, + "id": 198, + "logprob": -2.5742188, "special": false, "text": "\n" }, { - "id": 3057, - "logprob": -1.8779297, + "id": 262, + "logprob": -1.6230469, "special": false, - "text": "Test" + "text": " " }, { - "id": 2009, - "logprob": -1.2744141, + "id": 3270, + "logprob": -2.046875, + "special": false, + "text": " \"\"\"\n" + }, + { + "id": 262, + "logprob": -0.015281677, + "special": false, + "text": " " + }, + { + "id": 422, + "logprob": -2.1425781, + "special": false, + "text": " if" + }, + { + "id": 1715, + "logprob": -0.9238281, "special": false, "text": " request" }, { - "id": 13, - "logprob": -1.6933594, + "id": 13204, + "logprob": -0.076660156, "special": false, - "text": "\n" + "text": ".method" }, { - "id": 3057, - "logprob": -1.4648438, + "id": 624, + "logprob": -0.021987915, "special": false, - "text": "Test" + "text": " ==" }, { - "id": 2009, - "logprob": -0.15600586, + "id": 364, + "logprob": -0.39208984, "special": false, - "text": " request" + "text": " '" }, { - "id": 13, - "logprob": -0.8027344, + "id": 3019, + "logprob": -0.10821533, "special": false, - "text": "\n" - }, - { - "id": 3057, - "logprob": -0.23022461, - "special": false, - "text": "Test" - }, - { - "id": 2009, - "logprob": -0.0069885254, - "special": false, - "text": " request" - }, - { - "id": 13, - "logprob": -0.02218628, - "special": false, - "text": "\n" + "text": "POST" } ], "top_tokens": null }, - "generated_text": "\nTest request\nTest request\nTest request\n" + "generated_text": "\n \"\"\"\n if request.method == 'POST" } diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json index fa2fd4a2..4152b5b3 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json @@ -5,85 +5,80 @@ "generated_tokens": 10, "prefill": [ { - "id": 1, + "id": 2323, "logprob": null, - "text": "" - }, - { - "id": 4321, - "logprob": -9.84375, "text": "Test" }, { - "id": 2009, - "logprob": -9.6015625, - "text": "request" + "id": 1715, + "logprob": -11.34375, + "text": " request" } ], "seed": 0, "tokens": [ { - "id": 29899, - "logprob": -1.5625, + "id": 13, + "logprob": -2.2539062, "special": false, - "text": "-" + "text": "." }, { - "id": 1454, - "logprob": -0.20410156, + "id": 578, + "logprob": -0.15563965, "special": false, - "text": "for" + "text": " The" }, { - "id": 29899, + "id": 3622, + "logprob": -0.8203125, + "special": false, + "text": " server" + }, + { + "id": 706, "logprob": 0.0, "special": false, - "text": "-" + "text": " has" }, { - "id": 9342, + "id": 539, "logprob": 0.0, "special": false, - "text": "comment" + "text": " not" }, { - "id": 29901, + "id": 3686, "logprob": 0.0, "special": false, - "text": ":" + "text": " yet" }, { - "id": 396, - "logprob": -0.27685547, - "special": false, - "text": " #" - }, - { - "id": 29906, - "logprob": -0.4970703, - "special": false, - "text": "2" - }, - { - "id": 29900, - "logprob": -0.80615234, - "special": false, - "text": "0" - }, - { - "id": 29896, + "id": 3288, "logprob": 0.0, "special": false, - "text": "1" + "text": " sent" }, { - "id": 29955, - "logprob": -1.0751953, + "id": 904, + "logprob": 0.0, "special": false, - "text": "7" + "text": " any" + }, + { + "id": 828, + "logprob": 0.0, + "special": false, + "text": " data" + }, + { + "id": 382, + "logprob": -1.5517578, + "special": false, + "text": ".\n\n" } ], "top_tokens": null }, - "generated_text": "Test request-for-comment: #2017" + "generated_text": "Test request. The server has not yet sent any data.\n\n" } diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json index 594b7351..75e90303 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json @@ -6,87 +6,82 @@ "generated_tokens": 10, "prefill": [ { - "id": 1, + "id": 2323, "logprob": null, - "text": "" - }, - { - "id": 4321, - "logprob": -9.828125, "text": "Test" }, { - "id": 2009, - "logprob": -9.609375, - "text": "request" + "id": 1715, + "logprob": -11.34375, + "text": " request" } ], "seed": null, "tokens": [ { - "id": 13, - "logprob": -2.3300781, + "id": 198, + "logprob": -2.5742188, "special": false, "text": "\n" }, { - "id": 3057, - "logprob": -1.8740234, + "id": 262, + "logprob": -1.6220703, "special": false, - "text": "Test" + "text": " " }, { - "id": 2009, - "logprob": -1.2646484, + "id": 3270, + "logprob": -2.0410156, + "special": false, + "text": " \"\"\"\n" + }, + { + "id": 262, + "logprob": -0.015281677, + "special": false, + "text": " " + }, + { + "id": 422, + "logprob": -2.1445312, + "special": false, + "text": " if" + }, + { + "id": 1715, + "logprob": -0.92333984, "special": false, "text": " request" }, { - "id": 13, - "logprob": -1.7158203, + "id": 13204, + "logprob": -0.07672119, "special": false, - "text": "\n" + "text": ".method" }, { - "id": 3057, - "logprob": -1.4667969, + "id": 624, + "logprob": -0.021987915, "special": false, - "text": "Test" + "text": " ==" }, { - "id": 2009, - "logprob": -0.15344238, + "id": 364, + "logprob": -0.39208984, "special": false, - "text": " request" + "text": " '" }, { - "id": 13, - "logprob": -0.81591797, + "id": 3019, + "logprob": -0.10638428, "special": false, - "text": "\n" - }, - { - "id": 3057, - "logprob": -0.22973633, - "special": false, - "text": "Test" - }, - { - "id": 2009, - "logprob": -0.007045746, - "special": false, - "text": " request" - }, - { - "id": 13, - "logprob": -0.021957397, - "special": false, - "text": "\n" + "text": "POST" } ], "top_tokens": null }, - "generated_text": "\nTest request\nTest request\nTest request\n" + "generated_text": "\n \"\"\"\n if request.method == 'POST" }, { "details": { @@ -95,87 +90,82 @@ "generated_tokens": 10, "prefill": [ { - "id": 1, + "id": 2323, "logprob": null, - "text": "" - }, - { - "id": 4321, - "logprob": -9.84375, "text": "Test" }, { - "id": 2009, - "logprob": -9.59375, - "text": "request" + "id": 1715, + "logprob": -11.34375, + "text": " request" } ], "seed": null, "tokens": [ { - "id": 13, - "logprob": -2.3378906, + "id": 198, + "logprob": -2.5742188, "special": false, "text": "\n" }, { - "id": 3057, - "logprob": -1.8779297, + "id": 262, + "logprob": -1.6220703, "special": false, - "text": "Test" + "text": " " }, { - "id": 2009, - "logprob": -1.2636719, + "id": 3270, + "logprob": -2.0410156, + "special": false, + "text": " \"\"\"\n" + }, + { + "id": 262, + "logprob": -0.015281677, + "special": false, + "text": " " + }, + { + "id": 422, + "logprob": -2.1445312, + "special": false, + "text": " if" + }, + { + "id": 1715, + "logprob": -0.92333984, "special": false, "text": " request" }, { - "id": 13, - "logprob": -1.6992188, + "id": 13204, + "logprob": -0.07672119, "special": false, - "text": "\n" + "text": ".method" }, { - "id": 3057, - "logprob": -1.4589844, + "id": 624, + "logprob": -0.021987915, "special": false, - "text": "Test" + "text": " ==" }, { - "id": 2009, - "logprob": -0.15344238, + "id": 364, + "logprob": -0.39208984, "special": false, - "text": " request" + "text": " '" }, { - "id": 13, - "logprob": -0.79052734, + "id": 3019, + "logprob": -0.10638428, "special": false, - "text": "\n" - }, - { - "id": 3057, - "logprob": -0.22937012, - "special": false, - "text": "Test" - }, - { - "id": 2009, - "logprob": -0.007041931, - "special": false, - "text": " request" - }, - { - "id": 13, - "logprob": -0.022140503, - "special": false, - "text": "\n" + "text": "POST" } ], "top_tokens": null }, - "generated_text": "\nTest request\nTest request\nTest request\n" + "generated_text": "\n \"\"\"\n if request.method == 'POST" }, { "details": { @@ -184,87 +174,82 @@ "generated_tokens": 10, "prefill": [ { - "id": 1, + "id": 2323, "logprob": null, - "text": "" - }, - { - "id": 4321, - "logprob": -9.84375, "text": "Test" }, { - "id": 2009, - "logprob": -9.609375, - "text": "request" + "id": 1715, + "logprob": -11.34375, + "text": " request" } ], "seed": null, "tokens": [ { - "id": 13, - "logprob": -2.3261719, + "id": 198, + "logprob": -2.5742188, "special": false, "text": "\n" }, { - "id": 3057, - "logprob": -1.8730469, + "id": 262, + "logprob": -1.6220703, "special": false, - "text": "Test" + "text": " " }, { - "id": 2009, - "logprob": -1.2587891, + "id": 3270, + "logprob": -2.0410156, + "special": false, + "text": " \"\"\"\n" + }, + { + "id": 262, + "logprob": -0.015281677, + "special": false, + "text": " " + }, + { + "id": 422, + "logprob": -2.1445312, + "special": false, + "text": " if" + }, + { + "id": 1715, + "logprob": -0.92333984, "special": false, "text": " request" }, { - "id": 13, - "logprob": -1.6894531, + "id": 13204, + "logprob": -0.07672119, "special": false, - "text": "\n" + "text": ".method" }, { - "id": 3057, - "logprob": -1.46875, + "id": 624, + "logprob": -0.021987915, "special": false, - "text": "Test" + "text": " ==" }, { - "id": 2009, - "logprob": -0.1541748, + "id": 364, + "logprob": -0.39208984, "special": false, - "text": " request" + "text": " '" }, { - "id": 13, - "logprob": -0.80322266, + "id": 3019, + "logprob": -0.10638428, "special": false, - "text": "\n" - }, - { - "id": 3057, - "logprob": -0.22912598, - "special": false, - "text": "Test" - }, - { - "id": 2009, - "logprob": -0.0070495605, - "special": false, - "text": " request" - }, - { - "id": 13, - "logprob": -0.021606445, - "special": false, - "text": "\n" + "text": "POST" } ], "top_tokens": null }, - "generated_text": "\nTest request\nTest request\nTest request\n" + "generated_text": "\n \"\"\"\n if request.method == 'POST" }, { "details": { @@ -273,86 +258,81 @@ "generated_tokens": 10, "prefill": [ { - "id": 1, + "id": 2323, "logprob": null, - "text": "" - }, - { - "id": 4321, - "logprob": -9.84375, "text": "Test" }, { - "id": 2009, - "logprob": -9.6015625, - "text": "request" + "id": 1715, + "logprob": -11.34375, + "text": " request" } ], "seed": null, "tokens": [ { - "id": 13, - "logprob": -2.3320312, + "id": 198, + "logprob": -2.5742188, "special": false, "text": "\n" }, { - "id": 3057, - "logprob": -1.875, + "id": 262, + "logprob": -1.6220703, "special": false, - "text": "Test" + "text": " " }, { - "id": 2009, - "logprob": -1.2646484, + "id": 3270, + "logprob": -2.0410156, + "special": false, + "text": " \"\"\"\n" + }, + { + "id": 262, + "logprob": -0.015281677, + "special": false, + "text": " " + }, + { + "id": 422, + "logprob": -2.1445312, + "special": false, + "text": " if" + }, + { + "id": 1715, + "logprob": -0.92333984, "special": false, "text": " request" }, { - "id": 13, - "logprob": -1.6884766, + "id": 13204, + "logprob": -0.07672119, "special": false, - "text": "\n" + "text": ".method" }, { - "id": 3057, - "logprob": -1.4589844, + "id": 624, + "logprob": -0.021987915, "special": false, - "text": "Test" + "text": " ==" }, { - "id": 2009, - "logprob": -0.15185547, + "id": 364, + "logprob": -0.39208984, "special": false, - "text": " request" + "text": " '" }, { - "id": 13, - "logprob": -0.79833984, + "id": 3019, + "logprob": -0.10638428, "special": false, - "text": "\n" - }, - { - "id": 3057, - "logprob": -0.22827148, - "special": false, - "text": "Test" - }, - { - "id": 2009, - "logprob": -0.006996155, - "special": false, - "text": " request" - }, - { - "id": 13, - "logprob": -0.021560669, - "special": false, - "text": "\n" + "text": "POST" } ], "top_tokens": null }, - "generated_text": "\nTest request\nTest request\nTest request\n" + "generated_text": "\n \"\"\"\n if request.method == 'POST" } ] diff --git a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json index bf2dc5a1..44ccea71 100644 --- a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json +++ b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json @@ -1,130 +1,124 @@ { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 20, + "finish_reason": "eos_token", + "generated_tokens": 19, "prefill": [], "seed": null, "tokens": [ { "id": 415, - "logprob": -0.039886475, + "logprob": -0.03665161, "special": false, "text": " The" }, { "id": 12072, - "logprob": -0.1430664, + "logprob": -0.13549805, "special": false, "text": " cow" }, { "id": 349, - "logprob": -0.056488037, + "logprob": -0.05819702, "special": false, "text": " is" }, { "id": 6328, - "logprob": -0.6855469, + "logprob": -0.6826172, "special": false, "text": " standing" }, { "id": 356, - "logprob": -0.1685791, + "logprob": -0.1607666, "special": false, "text": " on" }, { "id": 272, - "logprob": -0.50097656, + "logprob": -0.5073242, "special": false, "text": " the" }, { "id": 10305, - "logprob": -0.017303467, + "logprob": -0.016418457, "special": false, "text": " beach" }, { "id": 304, - "logprob": -1.3564453, + "logprob": -1.3916016, "special": false, "text": " and" }, { "id": 272, - "logprob": -0.017868042, + "logprob": -0.020217896, "special": false, "text": " the" }, { "id": 13088, - "logprob": -0.0027103424, + "logprob": -0.0028133392, "special": false, "text": " chicken" }, { "id": 349, - "logprob": -0.003156662, + "logprob": -0.003145218, "special": false, "text": " is" }, { "id": 6398, - "logprob": -0.37304688, + "logprob": -0.37060547, "special": false, "text": " sitting" }, { "id": 356, - "logprob": -0.034576416, + "logprob": -0.034851074, "special": false, "text": " on" }, { "id": 264, - "logprob": -0.29418945, + "logprob": -0.2878418, "special": false, "text": " a" }, { "id": 17972, - "logprob": -0.042877197, + "logprob": -0.046051025, "special": false, "text": " pile" }, { "id": 302, - "logprob": -0.00028443336, + "logprob": -0.00028848648, "special": false, "text": " of" }, { "id": 2445, - "logprob": -0.023223877, + "logprob": -0.025772095, "special": false, "text": " money" }, { "id": 28723, - "logprob": -0.018157959, + "logprob": -0.018127441, "special": false, "text": "." }, { "id": 32002, - "logprob": -0.00018393993, + "logprob": -0.00019824505, "special": true, "text": "" - }, - { - "id": 2, - "logprob": -1.1920929e-07, - "special": true, - "text": "" } ], "top_tokens": null diff --git a/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized.json b/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized.json new file mode 100644 index 00000000..69c1f47d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.8359375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.6171875, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.3417969, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.8730469, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -1.2626953, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -1.7060547, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.4482422, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.15246582, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.796875, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.22766113, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.007045746, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.021759033, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": "\nTest request\nTest request\nTest request\n" +} diff --git a/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized_all_params.json b/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized_all_params.json new file mode 100644 index 00000000..9b5ee9ee --- /dev/null +++ b/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.7890625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.625, + "text": "request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 29899, + "logprob": -1.4980469, + "special": false, + "text": "-" + }, + { + "id": 1454, + "logprob": -0.19433594, + "special": false, + "text": "for" + }, + { + "id": 29899, + "logprob": 0.0, + "special": false, + "text": "-" + }, + { + "id": 9342, + "logprob": 0.0, + "special": false, + "text": "comment" + }, + { + "id": 29901, + "logprob": 0.0, + "special": false, + "text": ":" + }, + { + "id": 396, + "logprob": -0.27392578, + "special": false, + "text": " #" + }, + { + "id": 29906, + "logprob": -0.49389648, + "special": false, + "text": "2" + }, + { + "id": 29900, + "logprob": -0.81103516, + "special": false, + "text": "0" + }, + { + "id": 29896, + "logprob": 0.0, + "special": false, + "text": "1" + }, + { + "id": 29955, + "logprob": -1.0800781, + "special": false, + "text": "7" + } + ], + "top_tokens": null + }, + "generated_text": "Test request-for-comment: #2017" +} diff --git a/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized_load.json b/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized_load.json new file mode 100644 index 00000000..df975635 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.8828125, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.5859375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.3359375, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.8623047, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -1.2451172, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -1.6923828, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.4492188, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.15197754, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.8022461, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.22583008, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.007095337, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.021652222, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": "\nTest request\nTest request\nTest request\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.796875, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.625, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.3476562, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.8789062, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -1.2734375, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -1.703125, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.4677734, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.15454102, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.7973633, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.23278809, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.006980896, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.022033691, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": "\nTest request\nTest request\nTest request\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.9296875, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.5703125, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.3203125, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.8486328, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -1.2480469, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -1.7060547, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.4511719, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.1529541, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.81396484, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.22180176, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.007133484, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.021835327, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": "\nTest request\nTest request\nTest request\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.84375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.6171875, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.3261719, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.8691406, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -1.2597656, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -1.7070312, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.4550781, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.1538086, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.79345703, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.22924805, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.0070266724, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.021942139, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": "\nTest request\nTest request\nTest request\n" + } +] diff --git a/integration-tests/models/test_flash_llama_gptq.py b/integration-tests/models/test_flash_llama_gptq.py index 135f4b05..94a48e49 100644 --- a/integration-tests/models/test_flash_llama_gptq.py +++ b/integration-tests/models/test_flash_llama_gptq.py @@ -3,7 +3,9 @@ import pytest @pytest.fixture(scope="module") def flash_llama_gptq_handle(launcher): - with launcher("huggingface/llama-7b-gptq", num_shard=2, quantize="gptq") as handle: + with launcher( + "astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="gptq" + ) as handle: yield handle diff --git a/integration-tests/models/test_idefics2.py b/integration-tests/models/test_idefics2.py index 9aaf6d8a..c5f48da3 100644 --- a/integration-tests/models/test_idefics2.py +++ b/integration-tests/models/test_idefics2.py @@ -57,7 +57,7 @@ async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot) response.generated_text == " The cow is standing on the beach and the chicken is sitting on a pile of money." ), f"{repr(response.generated_text)}" - assert response.details.generated_tokens == 20 + assert response.details.generated_tokens == 19 assert response == response_snapshot diff --git a/router/Cargo.toml b/router/Cargo.toml index 1a7611bf..342a1c48 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -24,7 +24,7 @@ futures = "0.3.28" hf-hub = { workspace = true } itertools = "0.10" jsonschema = { version = "0.17.1", features = ["draft202012"] } -metrics = "0.21.1" +metrics = "0.23.0" metrics-exporter-prometheus = { version = "0.15.1", features = [] } nohash-hasher = "0.2.0" opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 49282eb9..f3b10450 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -91,14 +91,14 @@ impl Infer { .limit_concurrent_requests .try_acquire_owned() .map_err(|err| { - metrics::increment_counter!("tgi_request_failure", "err" => "overloaded"); + metrics::counter!("tgi_request_failure", "err" => "overloaded").increment(1); tracing::error!("{err}"); err })?; // Validate request let valid_request = self.validation.validate(request).await.map_err(|err| { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); err })?; @@ -140,7 +140,7 @@ impl Infer { .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? .apply(messages, grammar_with_prompt) .map_err(|e| { - metrics::increment_counter!("tgi_request_failure", "err" => "template"); + metrics::counter!("tgi_request_failure", "err" => "template").increment(1); tracing::error!("{e}"); e }) @@ -214,7 +214,7 @@ impl Infer { }) } else { let err = InferError::IncompleteGeneration; - metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); + metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1); tracing::error!("{err}"); Err(err) } diff --git a/router/src/infer/v2/queue.rs b/router/src/infer/v2/queue.rs index 93cf9469..0b51645a 100644 --- a/router/src/infer/v2/queue.rs +++ b/router/src/infer/v2/queue.rs @@ -111,7 +111,7 @@ async fn queue_task( match cmd { QueueCommand::Append(entry, span) => { span.in_scope(|| state.append(*entry)); - metrics::increment_gauge!("tgi_queue_size", 1.0); + metrics::gauge!("tgi_queue_size").increment(1.0); } QueueCommand::NextBatch { min_size, @@ -124,7 +124,7 @@ async fn queue_task( let next_batch = state.next_batch(min_size, max_size, prefill_token_budget, token_budget); response_sender.send(next_batch).unwrap(); - metrics::gauge!("tgi_queue_size", state.entries.len() as f64); + metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); }), } } @@ -226,7 +226,7 @@ impl State { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_closed() { - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); tracing::debug!("Dropping entry"); continue; } @@ -336,7 +336,7 @@ impl State { // Increment batch id self.next_batch_id += 1; - metrics::histogram!("tgi_batch_next_size", batch.size as f64); + metrics::histogram!("tgi_batch_next_size").record(batch.size as f64); Some((batch_entries, batch, next_batch_span)) } diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index e4c3de26..97379bc5 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -148,8 +148,8 @@ pub(crate) async fn batching_task( let batch_size = batch.size; let batch_max_tokens = batch.max_tokens; let mut batches = vec![batch]; - metrics::gauge!("tgi_batch_current_size", batch_size as f64); - metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64); + metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); + metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); let min_size = if waiting_tokens >= max_waiting_tokens { // If we didn't onboard any new requests since >= max_waiting_tokens, we try @@ -170,9 +170,11 @@ pub(crate) async fn batching_task( { // Tracking metrics if min_size.is_some() { - metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure"); + metrics::counter!("tgi_batch_concat", "reason" => "backpressure") + .increment(1); } else { - metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded"); + metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") + .increment(1); } entries.iter_mut().for_each(|(_, entry)| { @@ -219,8 +221,8 @@ pub(crate) async fn batching_task( .await; waiting_tokens += 1; } - metrics::gauge!("tgi_batch_current_size", 0.0); - metrics::gauge!("tgi_batch_current_max_tokens", 0.0); + metrics::gauge!("tgi_batch_current_size").set(0.0); + metrics::gauge!("tgi_batch_current_max_tokens").set(0.0); } } } @@ -234,7 +236,7 @@ async fn prefill( ) -> Option { let start_time = Instant::now(); let batch_id = batch.id; - metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill"); + metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); match client.prefill(batch).await { Ok((generations, next_batch, timings)) => { @@ -248,11 +250,15 @@ async fn prefill( // Filter next batch and remove requests that were stopped let next_batch = filter_batch(client, next_batch, entries).await; - metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); - metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); + metrics::histogram!("tgi_batch_forward_duration","method" => "prefill") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration","method" => "prefill") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1); next_batch } // If we have an error, we discard the whole batch @@ -261,7 +267,7 @@ async fn prefill( generation_health.store(false, Ordering::SeqCst); let _ = client.clear_cache(Some(batch_id)).await; send_errors(err, entries); - metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); + metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1); None } } @@ -276,7 +282,7 @@ async fn decode( ) -> Option { let start_time = Instant::now(); let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); - metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); + metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); match client.decode(batches).await { Ok((generations, next_batch, timings)) => { @@ -291,13 +297,18 @@ async fn decode( let next_batch = filter_batch(client, next_batch, entries).await; if let Some(concat_duration) = timings.concat { - metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode"); + metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") + .record(concat_duration.as_secs_f64()); } - metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); - metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); + metrics::histogram!("tgi_batch_forward_duration", "method" => "decode") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "decode") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "decode") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration", "method" => "decode") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1); next_batch } // If we have an error, we discard the whole batch @@ -307,7 +318,7 @@ async fn decode( let _ = client.clear_cache(Some(id)).await; } send_errors(err, entries); - metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode"); + metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1); None } } @@ -365,7 +376,7 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); err }).unwrap_or(true); if stopped { @@ -381,7 +392,7 @@ fn send_responses( ) -> Result>>> { // Return directly if the channel is disconnected if entry.response_tx.is_closed() { - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); return Ok(true); } @@ -407,7 +418,7 @@ fn send_responses( // Create last Token let tokens_ = generation.tokens.expect("Non empty tokens in generation"); let n = tokens_.ids.len(); - metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64); + metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64); let mut iterator = tokens_ .ids .into_iter() @@ -472,7 +483,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { // Create and enter a span to link this function back to the entry let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); let err = InferError::GenerationError(error.to_string()); - metrics::increment_counter!("tgi_request_failure", "err" => "generation"); + metrics::counter!("tgi_request_failure", "err" => "generation").increment(1); tracing::error!("{err}"); // unwrap_or is valid here as we don't care if the receiver is gone. diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index ba65b9b6..894d9cab 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -126,7 +126,7 @@ async fn queue_task( match cmd { QueueCommand::Append(entry, span) => { span.in_scope(|| state.append(*entry)); - metrics::increment_gauge!("tgi_queue_size", 1.0); + metrics::gauge!("tgi_queue_size").increment(1.0); } QueueCommand::NextBatch { min_size, @@ -141,7 +141,7 @@ async fn queue_task( .instrument(span) .await; response_sender.send(next_batch).unwrap(); - metrics::gauge!("tgi_queue_size", state.entries.len() as f64); + metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); } } } @@ -248,7 +248,7 @@ impl State { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_closed() { - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); tracing::debug!("Dropping entry"); continue; } @@ -399,7 +399,7 @@ impl State { // Increment batch id self.next_batch_id += 1; - metrics::histogram!("tgi_batch_next_size", batch.size as f64); + metrics::histogram!("tgi_batch_next_size").record(batch.size as f64); Some((batch_entries, batch, next_batch_span)) } diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index 543ce89f..26cd9584 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -154,8 +154,8 @@ pub(crate) async fn batching_task( let batch_size = batch.size; let batch_max_tokens = batch.max_tokens; let mut batches = vec![batch]; - metrics::gauge!("tgi_batch_current_size", batch_size as f64); - metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64); + metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); + metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); let min_size = if waiting_tokens >= max_waiting_tokens { // If we didn't onboard any new requests since >= max_waiting_tokens, we try @@ -176,9 +176,11 @@ pub(crate) async fn batching_task( { // Tracking metrics if min_size.is_some() { - metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure"); + metrics::counter!("tgi_batch_concat", "reason" => "backpressure") + .increment(1); } else { - metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded"); + metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") + .increment(1); } entries.iter_mut().for_each(|(_, entry)| { @@ -225,8 +227,8 @@ pub(crate) async fn batching_task( .await; waiting_tokens += 1; } - metrics::gauge!("tgi_batch_current_size", 0.0); - metrics::gauge!("tgi_batch_current_max_tokens", 0.0); + metrics::gauge!("tgi_batch_current_size").set(0.0); + metrics::gauge!("tgi_batch_current_max_tokens").set(0.0); } } } @@ -240,7 +242,7 @@ async fn prefill( ) -> Option { let start_time = Instant::now(); let batch_id = batch.id; - metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill"); + metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); match client.prefill(batch).await { Ok((generations, next_batch, timings)) => { @@ -254,11 +256,15 @@ async fn prefill( // Filter next batch and remove requests that were stopped let next_batch = filter_batch(client, next_batch, entries).await; - metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); - metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); + metrics::histogram!("tgi_batch_forward_duration","method" => "prefill") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1); next_batch } // If we have an error, we discard the whole batch @@ -267,7 +273,7 @@ async fn prefill( generation_health.store(false, Ordering::SeqCst); let _ = client.clear_cache(Some(batch_id)).await; send_errors(err, entries); - metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); + metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1); None } } @@ -282,7 +288,7 @@ async fn decode( ) -> Option { let start_time = Instant::now(); let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); - metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); + metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); match client.decode(batches).await { Ok((generations, next_batch, timings)) => { @@ -297,13 +303,18 @@ async fn decode( let next_batch = filter_batch(client, next_batch, entries).await; if let Some(concat_duration) = timings.concat { - metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode"); + metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") + .record(concat_duration.as_secs_f64()); } - metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); - metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); + metrics::histogram!("tgi_batch_forward_duration", "method" => "decode") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "decode") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "decode") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration", "method" => "decode") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1); next_batch } // If we have an error, we discard the whole batch @@ -313,7 +324,7 @@ async fn decode( let _ = client.clear_cache(Some(id)).await; } send_errors(err, entries); - metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode"); + metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1); None } } @@ -371,7 +382,7 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); err }).unwrap_or(true); if stopped { @@ -387,7 +398,7 @@ fn send_responses( ) -> Result>>> { // Return directly if the channel is disconnected if entry.response_tx.is_closed() { - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); return Ok(true); } @@ -413,7 +424,7 @@ fn send_responses( // Create last Token let tokens_ = generation.tokens.expect("Non empty tokens in generation"); let n = tokens_.ids.len(); - metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64); + metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64); let mut iterator = tokens_ .ids .into_iter() @@ -478,7 +489,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { // Create and enter a span to link this function back to the entry let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); let err = InferError::GenerationError(error.to_string()); - metrics::increment_counter!("tgi_request_failure", "err" => "generation"); + metrics::counter!("tgi_request_failure", "err" => "generation").increment(1); tracing::error!("{err}"); // unwrap_or is valid here as we don't care if the receiver is gone. diff --git a/router/src/lib.rs b/router/src/lib.rs index e0336190..52926c6c 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -386,7 +386,7 @@ pub struct CompletionRequest { /// UNUSED #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] /// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. - pub model: String, + pub model: Option, /// The prompt to generate completions for. #[schema(example = "What is Deep Learning?")] @@ -733,7 +733,7 @@ impl ChatCompletionChunk { pub(crate) struct ChatRequest { #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] /// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. - pub model: String, + pub model: Option, /// A list of messages comprising the conversation so far. #[schema(example = "[{\"role\": \"user\", \"content\": \"What is Deep Learning?\"}]")] @@ -850,7 +850,7 @@ pub enum ToolType { Function { function: FunctionName }, } -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)] pub struct FunctionName { pub name: String, } diff --git a/router/src/server.rs b/router/src/server.rs index db8b16ad..d3a280ca 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -11,10 +11,11 @@ use crate::kserve::{ }; use crate::validation::ValidationError; use crate::{ - BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, - GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, - Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, - Usage, Validation, + BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, + GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, + HubTokenizerConfig, Info, Message, MessageChunk, MessageContent, OutputMessage, PrefillToken, + SimpleToken, StreamDetails, StreamResponse, TextMessage, Token, TokenizeResponse, + ToolCallDelta, ToolCallMessage, Url, Usage, Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, @@ -185,7 +186,7 @@ pub(crate) async fn generate_internal( span: tracing::Span, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let start_time = Instant::now(); - metrics::increment_counter!("tgi_request_count"); + metrics::counter!("tgi_request_count").increment(1); // Do not long ultra long inputs, like image payloads. tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]); @@ -301,25 +302,15 @@ pub(crate) async fn generate_internal( ); // Metrics - metrics::increment_counter!("tgi_request_success"); - metrics::histogram!("tgi_request_duration", total_time.as_secs_f64()); - metrics::histogram!( - "tgi_request_validation_duration", - validation_time.as_secs_f64() - ); - metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64()); - metrics::histogram!( - "tgi_request_inference_duration", - inference_time.as_secs_f64() - ); - metrics::histogram!( - "tgi_request_mean_time_per_token_duration", - time_per_token.as_secs_f64() - ); - metrics::histogram!( - "tgi_request_generated_tokens", - response.generated_text.generated_tokens as f64 - ); + metrics::counter!("tgi_request_success").increment(1); + metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64()); + metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64()); + metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64()); + metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64()); + metrics::histogram!("tgi_request_mean_time_per_token_duration") + .record(time_per_token.as_secs_f64()); + metrics::histogram!("tgi_request_generated_tokens") + .record(response.generated_text.generated_tokens as f64); // Send response let mut output_text = response.generated_text.text; @@ -399,7 +390,7 @@ async fn generate_stream_internal( span: tracing::Span, ) -> (HeaderMap, impl Stream>) { let start_time = Instant::now(); - metrics::increment_counter!("tgi_request_count"); + metrics::counter!("tgi_request_count").increment(1); tracing::debug!("Input: {}", req.inputs); @@ -427,12 +418,12 @@ async fn generate_stream_internal( let best_of = req.parameters.best_of.unwrap_or(1); if best_of != 1 { let err = InferError::from(ValidationError::BestOfStream); - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); yield Ok(Event::from(err)); } else if req.parameters.decoder_input_details { let err = InferError::from(ValidationError::PrefillDetailsStream); - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); yield Ok(Event::from(err)); } else { @@ -500,13 +491,13 @@ async fn generate_stream_internal( span.record("seed", format!("{:?}", generated_text.seed)); // Metrics - metrics::increment_counter!("tgi_request_success"); - metrics::histogram!("tgi_request_duration", total_time.as_secs_f64()); - metrics::histogram!("tgi_request_validation_duration", validation_time.as_secs_f64()); - metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64()); - metrics::histogram!("tgi_request_inference_duration", inference_time.as_secs_f64()); - metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token.as_secs_f64()); - metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64); + metrics::counter!("tgi_request_success").increment(1); + metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64()); + metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64()); + metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64()); + metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64()); + metrics::histogram!("tgi_request_mean_time_per_token_duration").record(time_per_token.as_secs_f64()); + metrics::histogram!("tgi_request_generated_tokens").record(generated_text.generated_tokens as f64); // StreamResponse end_reached = true; @@ -553,7 +544,7 @@ async fn generate_stream_internal( // Skip if we already sent an error if !end_reached && !error { let err = InferError::IncompleteGeneration; - metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); + metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1); tracing::error!("{err}"); yield Ok(Event::from(err)); } @@ -572,8 +563,8 @@ request_body = CompletionRequest, responses( (status = 200, description = "Generated Chat Completion", content( -("application/json" = Completion), -("text/event-stream" = CompletionCompleteChunk), +("application/json" = CompletionFinal), +("text/event-stream" = Chunk), )), (status = 424, description = "Generation Error", body = ErrorResponse, example = json ! ({"error": "Request failed during generation"})), @@ -604,9 +595,10 @@ async fn completions( Json(req): Json, ) -> Result)> { let span = tracing::Span::current(); - metrics::increment_counter!("tgi_request_count"); + metrics::counter!("tgi_request_count").increment(1); let CompletionRequest { + model, max_tokens, seed, stop, @@ -625,7 +617,7 @@ async fn completions( // if suffix is present throw an error if req.suffix.is_some() { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); return Err(( StatusCode::UNPROCESSABLE_ENTITY, Json(ErrorResponse { @@ -637,7 +629,7 @@ async fn completions( } if req.prompt.0.len() > info.max_client_batch_size { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); return Err(( StatusCode::UNPROCESSABLE_ENTITY, Json(ErrorResponse { @@ -675,7 +667,7 @@ async fn completions( seed, top_n_tokens: None, grammar: None, - ..Default::default() + adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from), }, }) .collect(); @@ -820,6 +812,10 @@ async fn completions( } }; + let stream = stream.chain(futures::stream::once(async { + Ok(Event::default().data("[DONE]")) + })); + let sse = Sse::new(stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) } else { @@ -1009,8 +1005,9 @@ async fn chat_completions( Json(req): Json, ) -> Result)> { let span = tracing::Span::current(); - metrics::increment_counter!("tgi_request_count"); + metrics::counter!("tgi_request_count").increment(1); let ChatRequest { + model, logprobs, max_tokens, messages, @@ -1039,7 +1036,7 @@ async fn chat_completions( // response_format and tools are mutually exclusive if response_format.is_some() && tools.as_ref().is_some() { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); return Err(( StatusCode::UNPROCESSABLE_ENTITY, Json(ErrorResponse { @@ -1053,7 +1050,7 @@ async fn chat_completions( let tool_grammar = match ToolGrammar::apply(tools, tool_choice) { Ok(grammar) => grammar, Err(err) => { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); return Err(( StatusCode::UNPROCESSABLE_ENTITY, @@ -1082,7 +1079,7 @@ async fn chat_completions( let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) { Ok(inputs) => inputs, Err(err) => { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); return Err(( StatusCode::UNPROCESSABLE_ENTITY, @@ -1116,7 +1113,7 @@ async fn chat_completions( seed, top_n_tokens: req.top_logprobs, grammar, - ..Default::default() + adapter_id: model.filter(|m| *m != "tgi").map(String::from), }, }; @@ -1178,6 +1175,11 @@ async fn chat_completions( span, ) .await; + + let response_stream = response_stream.chain(futures::stream::once(async { + Ok(Event::default().data("[DONE]")) + })); + let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) } else { @@ -1280,7 +1282,7 @@ async fn vertex_compatibility( Json(req): Json, ) -> Result)> { let span = tracing::Span::current(); - metrics::increment_counter!("tgi_request_count"); + metrics::counter!("tgi_request_count").increment(1); // check that theres at least one instance if req.instances.is_empty() { @@ -1454,6 +1456,14 @@ pub async fn run( GrammarType, ChatRequest, Message, + MessageContent, + MessageChunk, + Url, + FunctionName, + OutputMessage, + TextMessage, + ToolCallMessage, + ToolCallDelta, ChatCompletionComplete, ChatCompletionChoice, ChatCompletionDelta, diff --git a/router/src/validation.rs b/router/src/validation.rs index 12cf2ab3..07ad14c9 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -157,7 +157,7 @@ impl Validation { )); } - metrics::histogram!("tgi_request_input_length", input_length as f64); + metrics::histogram!("tgi_request_input_length").record(input_length as f64); Ok((inputs, input_length, max_new_tokens)) } // Return inputs without validation @@ -384,7 +384,7 @@ impl Validation { ignore_eos_token: false, }; - metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64); + metrics::histogram!("tgi_request_max_new_tokens").record(max_new_tokens as f64); Ok(ValidGenerateRequest { inputs, diff --git a/server/Makefile b/server/Makefile index 0099c56a..d701c819 100644 --- a/server/Makefile +++ b/server/Makefile @@ -35,5 +35,5 @@ run-dev: SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded export-requirements: - poetry export -o requirements_cuda.txt --without-hashes + poetry export -o requirements_cuda.txt --without-hashes --with cuda poetry export -o requirements_rocm.txt --without-hashes diff --git a/server/marlin/marlin_kernels/__init__.pyi b/server/marlin/marlin_kernels/__init__.pyi index 663984d0..53464719 100644 --- a/server/marlin/marlin_kernels/__init__.pyi +++ b/server/marlin/marlin_kernels/__init__.pyi @@ -59,3 +59,18 @@ def marlin_gemm( Matrix multiplication using Marlin kernels. """ ... + +# fp8 marlin +def fp8_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return torch.ops._C.fp8_marlin_gemm( + a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k + ) diff --git a/server/marlin/marlin_kernels/ext.cpp b/server/marlin/marlin_kernels/ext.cpp index 37eccef6..04e1530f 100644 --- a/server/marlin/marlin_kernels/ext.cpp +++ b/server/marlin/marlin_kernels/ext.cpp @@ -9,4 +9,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("gptq_marlin_repack", &gptq_marlin_repack, "Repack GPTQ parameters for Marlin"); m.def("marlin_gemm", &marlin_gemm, "Marlin gemm"); + // fp8_marlin Optimized Quantized GEMM for FP8 weight-only. + m.def("fp8_marlin_gemm", &fp8_marlin_gemm); } diff --git a/server/marlin/marlin_kernels/ext.hh b/server/marlin/marlin_kernels/ext.hh index d1caaab7..102c058e 100644 --- a/server/marlin/marlin_kernels/ext.hh +++ b/server/marlin/marlin_kernels/ext.hh @@ -27,4 +27,9 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, torch::Tensor &b_scales, torch::Tensor &workspace, int64_t size_m, int64_t size_n, int64_t size_k); +torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k); + #endif diff --git a/server/marlin/marlin_kernels/fp8_marlin.cu b/server/marlin/marlin_kernels/fp8_marlin.cu new file mode 100644 index 00000000..aaef67e5 --- /dev/null +++ b/server/marlin/marlin_kernels/fp8_marlin.cu @@ -0,0 +1,1308 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#include "./gptq_marlin.cuh" +#include "./gptq_marlin_dtypes.cuh" + +using namespace gptq_marlin; + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || \ + std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +template +inline std::string str(T x) { + return std::to_string(x); +} + +namespace fp8_marlin { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) {} + +} // namespace fp8_marlin + +torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k) { + TORCH_CHECK_NOT_IMPLEMENTED(false, + "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +template +__device__ inline void mma(const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +template +__device__ inline void ldsm4(typename ScalarType::FragA& frag_a, + const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Fast FP8ToFp16/FP8ToBf16: Efficiently dequantize 8bit fp8_e4m3 values to fp16 +// bf16 Reference: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 +template +__device__ inline typename ScalarType::FragB dequant_8bit(int q) { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); +} + +template <> +__device__ inline typename ScalarType::FragB dequant_8bit(int q) { + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; + + // Calculate MASK for extracting mantissa and exponent + constexpr int MASK1 = 0x80000000; + constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); + constexpr int MASK3 = MASK2 & 0x7fffffff; + constexpr int MASK = MASK3 | (MASK3 >> 16); + // Final MASK value: 0x7F007F00 + + // Extract and shift FP8 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = + (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + typename ScalarType::FragB frag_b; + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); + frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant_8bit(int q) { + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + + // Calculate MASK for extracting mantissa and exponent + constexpr int MASK1 = 0x80000000; + constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); + constexpr int MASK3 = MASK2 & 0x7fffffff; + constexpr int MASK = MASK3 | (MASK3 >> 16); + // Final MASK value: 0x7F007F00 + + // Extract and shift FP8 values to BF16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = + (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = + __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to bfloat162 and apply bias + typename ScalarType::FragB frag_b; + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); + frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +template +__device__ inline void scale(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = + ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +// Given 2 floats multiply by 2 scales (halves) +template +__device__ inline void scale_float(float* c, + typename ScalarType::FragS& s) { + scalar_t* s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; + + constexpr int pack_factor = 32 / num_bits; + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + int slice_k_start = tb_k * slice_row; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // We scale a `half2` tile in row-major layout for column-wise quantization. + int s_sh_rd = + 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_s = sh_g_idx + (stages * g_idx_stage); + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + + B_ptr[i] += b_gl_rd_delta_o; + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], + &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + + frag_b0 = dequant_8bit(b_quant_0); + frag_b1 = dequant_8bit(b_quant_1); + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + Dtype::num2float(reinterpret_cast(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = + c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + scalar_t2 res = + Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + ((scalar_t2*)sh)[idx] = res; + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + + thread_block_reduce(); + + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + + start_pipes(); + } + } + } +} + + #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, GROUP_BLOCKS, NUM_THREADS) \ + else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin \ + <<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, num_groups, prob_m, prob_n, prob_k, \ + locks); \ + } + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +typedef struct { + int max_m_blocks; + thread_config_t tb_cfg; +} exec_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}, + {128, 64, 128}, +}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}, + {128, 64, 128}, + +}; + +int get_scales_cache_size(thread_config_t const& th_config, int prob_m, + int prob_n, int prob_k, int num_bits, + int group_size) { + int tb_n = th_config.thread_n; + + // Get max scale groups per thread-block + // Fixed for channelwise + int tb_groups = 1; + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; +} + +bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int scales_cache_size, int max_shared_mem) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + + int b_size = (tb_k * tb_n / pack_factor) * 4; + + // Get A size + int m_blocks = div_ceil(prob_m, 16); + int tb_max_m = 16; + + while (true) { + if (m_blocks >= max_m_blocks) { + tb_max_m *= max_m_blocks; + break; + } + + max_m_blocks--; + if (max_m_blocks == 0) { + TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); + } + } + + int a_size = (tb_max_m * tb_k) * 2; + + float pipe_size = (a_size + b_size) * pipe_stages; + + TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity + + return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); +} + +bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, int max_shared_mem) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || + th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + // Determine cache for scales + int scales_cache_size = get_scales_cache_size(th_config, prob_m, prob_n, + prob_k, num_bits, group_size); + + // Check that pipeline fits into cache + if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, scales_cache_size, max_shared_mem)) { + return false; + } + + return true; +} + +exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, + int num_bits, int group_size, + int max_shared_mem) { + int max_m_blocks = 4; + while (max_m_blocks > 0) { + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } + } + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } + } + } + + max_m_blocks--; // Process less M blocks per invocation to reduce cache + // usage + } + + return exec_config_t{0, {-1, -1, -1}}; +} + + #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) + +template +void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, int prob_m, + int prob_n, int prob_k, void* workspace, int num_bits, + int num_groups, int group_size, int dev, + cudaStream_t stream, int thread_k, int thread_n, int sms, + int max_par) { + TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits); + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + ", ", prob_n, ", ", prob_k, "]"); + + int tot_m = prob_m; + int tot_m_blocks = div_ceil(tot_m, 16); + int pad = 16 * tot_m_blocks - tot_m; + + if (sms == -1) { + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + } + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + // Set thread config + exec_config_t exec_cfg; + if (thread_k != -1 && thread_n != -1) { + // User-defined config + exec_cfg = + exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}}; + } else { + // Auto config + exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits, + group_size, max_shared_mem); + } + + TORCH_CHECK( + exec_cfg.max_m_blocks > 0 && + is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m, + prob_n, prob_k, num_bits, group_size, max_shared_mem), + "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, + ", thread_k = ", exec_cfg.tb_cfg.thread_k, + ", thread_n = ", exec_cfg.tb_cfg.thread_n, + ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", prob_m, + ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, ", max_shared_mem = ", max_shared_mem); + + int num_threads = exec_cfg.tb_cfg.num_threads; + thread_k = exec_cfg.tb_cfg.thread_k; + thread_n = exec_cfg.tb_cfg.thread_n; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + int blocks = sms; + + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + + int group_blocks = -1; + + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + const int4* s_ptr = (const int4*)s; + + int* locks = (int*)workspace; + + // Main loop + for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { + int thread_m_blocks = tot_m_blocks - i; + prob_m = tot_m - 16 * i; + int par = 1; + if (thread_m_blocks > exec_cfg.max_m_blocks) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); + if (par > max_par) par = max_par; + prob_m = (16 * exec_cfg.max_m_blocks) * par; + i += exec_cfg.max_m_blocks * (par - 1); + thread_m_blocks = exec_cfg.max_m_blocks; + } + + // Define kernel configurations + if (false) { + } + CALL_IF(8, 32, 2, 256) + CALL_IF(8, 16, 4, 256) + CALL_IF(8, 8, 8, 256) + CALL_IF(8, 8, 4, 128) + CALL_IF(8, 4, 8, 128) + else { + TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + + str(prob_n) + ", " + str(prob_k) + "]" + + ", num_groups = " + str(num_groups) + + ", group_size = " + str(group_size) + + ", thread_m_blocks = " + str(thread_m_blocks) + + ", thread_n_blocks = " + str(thread_n_blocks) + + ", thread_k_blocks = " + str(thread_k_blocks)); + } + + A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; + } +} + +} // namespace fp8_marlin + +torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k) { + // Verify num_bits + TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits); + int pack_factor = 32 / num_bits; + + // Verify A + TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), + ", size_m = ", size_m); + TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), + ", size_k = ", size_k); + + // Verify B + TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k, + " is not divisible by tile_size = ", gptq_marlin::tile_size); + TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), + ", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size); + TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0, + "b_q_weight.size(1) = ", b_q_weight.size(1), + " is not divisible by tile_size = ", gptq_marlin::tile_size); + int actual_size_n = + (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor; + TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, + ", actual_size_n = ", actual_size_n); + + // Verify device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + // Alloc buffers + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + torch::Tensor c = torch::empty({size_m, size_n}, options); + + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel (can usually be left as auto -1) + int sms = -1; + + // Detect groupsize and act_order + int num_groups = -1; + int group_size = -1; + + int b_rank = b_scales.sizes().size(); + TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2"); + TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1), + " is not size_n = ", size_n); + // Channelwise only for FP8 + TORCH_CHECK(b_scales.size(0) == 1) + num_groups = b_scales.size(0); + + // Verify workspace size + TORCH_CHECK( + size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n, + ", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n); + int min_workspace_size = + (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par; + TORCH_CHECK(workspace.numel() >= min_workspace_size, + "workspace.numel = ", workspace.numel(), + " is below min_workspace_size = ", min_workspace_size); + + int dev = a.get_device(); + if (a.scalar_type() == at::ScalarType::Half) { + fp8_marlin::marlin_mm_f16i4( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), + b_scales.data_ptr(), size_m, size_n, size_k, + workspace.data_ptr(), num_bits, num_groups, group_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + gptq_marlin::max_par); + } else if (a.scalar_type() == at::ScalarType::BFloat16) { + fp8_marlin::marlin_mm_f16i4( + a.data_ptr(), b_q_weight.data_ptr(), + c.data_ptr(), b_scales.data_ptr(), size_m, + size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size, + dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + gptq_marlin::max_par); + } else { + TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16"); + } + + return c; +} + +#endif diff --git a/server/marlin/setup.py b/server/marlin/setup.py index aed84e9e..cc38bccf 100644 --- a/server/marlin/setup.py +++ b/server/marlin/setup.py @@ -9,6 +9,7 @@ setup( CUDAExtension( name="marlin_kernels", sources=[ + "marlin_kernels/fp8_marlin.cu", "marlin_kernels/gptq_marlin.cu", "marlin_kernels/gptq_marlin_repack.cu", "marlin_kernels/marlin_cuda_kernel.cu", diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 32ee6686..08292920 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -8,6 +8,9 @@ from text_generation_server.pb import generate_pb2 from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.utils import weight_hub_files, download_weights from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded +from text_generation_server.models.custom_modeling.bloom_modeling import ( + BloomForCausalLM, +) @pytest.fixture(scope="session") @@ -16,7 +19,10 @@ def default_bloom(): revision = "main" filenames = weight_hub_files(model_id, revision, ".safetensors") download_weights(filenames, model_id, revision) - return BLOOMSharded(model_id) + return BLOOMSharded( + model_id, + model_class=BloomForCausalLM, + ) @pytest.fixture(scope="session") diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 6e6463bc..c000ef26 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -10,7 +10,7 @@ from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch @pytest.fixture(scope="session") def default_causal_lm(): - return CausalLM("gpt2") + return CausalLM.fallback("gpt2") @pytest.fixture(scope="session") diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py index cb2622d9..d5c91bff 100644 --- a/server/tests/models/test_santacoder.py +++ b/server/tests/models/test_santacoder.py @@ -1,13 +1,12 @@ import pytest from text_generation_server.pb import generate_pb2 -from text_generation_server.models.causal_lm import CausalLMBatch -from text_generation_server.models.santacoder import SantaCoder +from text_generation_server.models.causal_lm import CausalLMBatch, CausalLM @pytest.fixture(scope="session") def default_santacoder(): - return SantaCoder("bigcode/santacoder") + return CausalLM.fallback(model_id="bigcode/santacoder") @pytest.fixture diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 943c3b08..02666042 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -20,7 +20,7 @@ def mt0_small_tokenizer(): @pytest.fixture(scope="session") def default_seq2seq_lm(): - return Seq2SeqLM("bigscience/mt0-small") + return Seq2SeqLM.fallback("bigscience/mt0-small") @pytest.fixture diff --git a/server/tests/utils/test_layers.py b/server/tests/utils/test_layers.py index 9a8da0d6..1e3aaf6b 100644 --- a/server/tests/utils/test_layers.py +++ b/server/tests/utils/test_layers.py @@ -2,6 +2,7 @@ import torch from text_generation_server.layers import ( TensorParallelEmbedding, ) +from text_generation_server.utils.weights import DefaultWeightsLoader class ProcessGroup: @@ -42,7 +43,12 @@ class Weights: def test_weight_hub_files_offline_error(): vocab_size = 17 - weights = Weights(rank=0, world_size=1, vocab_size=vocab_size, hidden_dim=256) + weights = Weights( + rank=0, + world_size=1, + vocab_size=vocab_size, + hidden_dim=256, + ) embeddings = TensorParallelEmbedding("", weights) input_ids = torch.arange(vocab_size) diff --git a/server/tests/utils/test_weights.py b/server/tests/utils/test_weights.py index 8f88b1f8..36b27be8 100644 --- a/server/tests/utils/test_weights.py +++ b/server/tests/utils/test_weights.py @@ -1,13 +1,47 @@ import pytest import torch -from text_generation_server.utils.weights import Weights -from text_generation_server.layers.gptq import GPTQWeight -from text_generation_server.layers.exl2 import Exl2Weight -from text_generation_server.layers.marlin import MarlinWeight +from text_generation_server.utils.weights import ( + DefaultWeightsLoader, + Weights, + WeightsLoader, +) +from text_generation_server.layers.gptq import GPTQWeight, GPTQWeightsLoader +from text_generation_server.layers.exl2 import Exl2Weight, Exl2WeightsLoader +from text_generation_server.layers.marlin import MarlinWeight, MarlinWeightsLoader from types import SimpleNamespace from typing import List, Optional, Dict, Union from pathlib import Path + +@pytest.fixture +def gptq_weights_loader(): + return GPTQWeightsLoader( + bits=4, + groupsize=-1, + desc_act=False, + quant_method="gptq", + quantize="gptq", + sym=True, + ) + + +@pytest.fixture +def gptq_weights_loader_awq(): + return GPTQWeightsLoader( + bits=4, + groupsize=-1, + desc_act=False, + quant_method="awq", + quantize="awq", + sym=True, + ) + + +@pytest.fixture +def marlin_weights_loader(): + return MarlinWeightsLoader(bits=4, is_marlin_24=False) + + dummy_file_system = { "test_weights": { "layer.0.weight": torch.tensor( @@ -58,7 +92,7 @@ dummy_file_system = { dtype=torch.float32, ), }, - "test_get_multi_weights_row": { + "test_get_weights_row": { "weight.weight": torch.tensor( [ [1, 2], @@ -101,7 +135,7 @@ dummy_file_system = { "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), }, - "test_get_multi_weights_row_gptq": { + "test_get_weights_row_gptq": { "weight.qweight": torch.tensor( [ [1, 2], @@ -200,7 +234,7 @@ dummy_file_system = { "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_groups": torch.tensor([4], dtype=torch.int16), }, - "test_get_multi_weights_row_exl2": { + "test_get_weights_row_exl2": { "weight.q_weight": torch.tensor( [ [1, 2], @@ -245,7 +279,7 @@ dummy_file_system = { "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_groups": torch.tensor([4], dtype=torch.int16), }, - "test_get_multi_weights_row_marlin": { + "test_get_weights_row_marlin": { "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), }, @@ -308,6 +342,7 @@ class MockWeights(Weights): dummy_fs, aliases: Optional[Dict[str, List[str]]] = None, prefix: Optional[str] = None, + weights_loader: Optional[WeightsLoader] = None, ): routing = {} self.dummy_fs = dummy_fs @@ -327,6 +362,9 @@ class MockWeights(Weights): self.dtype = dtype self.process_group = process_group self.prefix = prefix + self.weights_loader = ( + DefaultWeightsLoader() if weights_loader is None else weights_loader + ) self._handles = {} def _get_handle(self, filename: Union[Path, str]): @@ -412,12 +450,10 @@ def test_get_weights_col_packed(): ) prefix = "weight" - quantize = None block_sizes = 1 w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -448,12 +484,10 @@ def test_get_weights_col_packed_block_size(): ) prefix = "weight" - quantize = None block_sizes = 2 w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -484,12 +518,10 @@ def test_get_weights_col_packed_block_size_arr(): ) prefix = "weight" - quantize = None block_sizes = [1, 1] w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -519,11 +551,9 @@ def test_get_multi_weights_col(): ) prefixes = ["weight", "weight"] - quantize = None w = weights.get_multi_weights_col( prefixes=prefixes, - quantize=quantize, dim=0, ) @@ -545,10 +575,10 @@ def test_get_multi_weights_col(): ) -def test_get_multi_weights_row(): +def test_get_weights_row(): weights = MockWeights( [ - "test_get_multi_weights_row", + "test_get_weights_row", ], device="cpu", dtype=torch.float32, @@ -557,11 +587,9 @@ def test_get_multi_weights_row(): ) prefix = "weight" - quantize = None - w = weights.get_multi_weights_row( + w = weights.get_weights_row( prefix=prefix, - quantize=quantize, ) assert torch.allclose( @@ -576,7 +604,7 @@ def test_get_multi_weights_row(): # test_get_weights_col -def test_get_weights_col_awq(): +def test_get_weights_col_awq(gptq_weights_loader_awq): weights = MockWeights( [ "test_get_weights_col_gptq", @@ -585,14 +613,13 @@ def test_get_weights_col_awq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader_awq, ) prefix = "weight" - quantize = "awq" w = weights.get_weights_col( prefix=prefix, - quantize=quantize, ) expected_weight = GPTQWeight( @@ -617,7 +644,7 @@ def test_get_weights_col_awq(): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_weights_col_gtpq(): +def test_get_weights_col_gtpq(gptq_weights_loader): weights = MockWeights( [ "test_get_weights_col_gptq", @@ -626,14 +653,13 @@ def test_get_weights_col_gtpq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader, ) prefix = "weight" - quantize = "gptq" w = weights.get_weights_col( prefix=prefix, - quantize=quantize, ) expected_weight = GPTQWeight( @@ -664,14 +690,13 @@ def test_get_weights_col_exl2(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=Exl2WeightsLoader(), ) prefix = "weight" - quantize = "exl2" w = weights.get_weights_col( prefix=prefix, - quantize=quantize, ) scaled_scale_max = 0.3906 * 256 @@ -692,7 +717,7 @@ def test_get_weights_col_exl2(): assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" -def test_get_weights_col_marlin(): +def test_get_weights_col_marlin(marlin_weights_loader): weights = MockWeights( [ "test_get_weights_col_marlin", @@ -701,14 +726,13 @@ def test_get_weights_col_marlin(): dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=marlin_weights_loader, ) prefix = "weight" - quantize = "marlin" w = weights.get_weights_col( prefix=prefix, - quantize=quantize, ) expected_weight = MarlinWeight( @@ -723,7 +747,7 @@ def test_get_weights_col_marlin(): # test_get_weights_col_packed -def test_get_weights_col_packed_awq(): +def test_get_weights_col_packed_awq(gptq_weights_loader_awq): weights = MockWeights( [ "test_get_weights_col_packed_gptq", @@ -732,15 +756,14 @@ def test_get_weights_col_packed_awq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader_awq, ) prefix = "weight" - quantize = "awq" block_sizes = 1 w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -773,15 +796,14 @@ def test_get_weights_col_packed_exl2(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=Exl2WeightsLoader(), ) prefix = "weight" - quantize = "exl2" block_sizes = 1 w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -803,7 +825,7 @@ def test_get_weights_col_packed_exl2(): assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" -def test_get_weights_col_packed_gptq(): +def test_get_weights_col_packed_gptq(gptq_weights_loader): weights = MockWeights( [ "test_get_weights_col_packed_gptq", @@ -812,14 +834,13 @@ def test_get_weights_col_packed_gptq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader, ) prefixes = ["weight"] - quantize = "gptq" w = weights.get_multi_weights_col( prefixes=prefixes, - quantize=quantize, dim=0, ) @@ -842,7 +863,7 @@ def test_get_weights_col_packed_gptq(): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_weights_col_packed_marlin(): +def test_get_weights_col_packed_marlin(marlin_weights_loader): weights = MockWeights( [ "test_get_weights_col_packed_marlin", @@ -851,14 +872,13 @@ def test_get_weights_col_packed_marlin(): dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=marlin_weights_loader, ) prefix = "weight" - quantize = "marlin" w = weights.get_multi_weights_col( prefixes=[prefix], - quantize=quantize, dim=0, ) @@ -876,7 +896,7 @@ def test_get_weights_col_packed_marlin(): # test_get_multi_weights_col -def test_get_multi_weights_col_awq(): +def test_get_multi_weights_col_awq(gptq_weights_loader_awq): weights = MockWeights( [ "test_get_multi_weights_col_gptq", @@ -885,14 +905,13 @@ def test_get_multi_weights_col_awq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader_awq, ) prefixes = ["weight"] - quantize = "awq" w = weights.get_multi_weights_col( prefixes=prefixes, - quantize=quantize, dim=0, ) @@ -924,22 +943,21 @@ def test_get_multi_weights_col_exl2(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=Exl2WeightsLoader(), ) prefix = "weight" - quantize = "exl2" try: w = weights.get_multi_weights_col( prefixes=[prefix], - quantize=quantize, dim=0, ) except ValueError as e: assert e.args[0] == "get_multi_weights_col is not supported for exl2" -def test_get_multi_weights_col_gptq(): +def test_get_multi_weights_col_gptq(gptq_weights_loader): weights = MockWeights( [ "test_get_multi_weights_col_gptq", @@ -948,14 +966,13 @@ def test_get_multi_weights_col_gptq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader, ) prefixes = ["weight"] - quantize = "gptq" w = weights.get_multi_weights_col( prefixes=prefixes, - quantize=quantize, dim=0, ) @@ -978,7 +995,7 @@ def test_get_multi_weights_col_gptq(): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_multi_weights_col_marlin(): +def test_get_multi_weights_col_marlin(marlin_weights_loader): weights = MockWeights( [ "test_get_multi_weights_col_marlin", @@ -987,14 +1004,13 @@ def test_get_multi_weights_col_marlin(): dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=marlin_weights_loader, ) prefix = "weight" - quantize = "marlin" w = weights.get_multi_weights_col( prefixes=[prefix], - quantize=quantize, dim=0, ) @@ -1007,26 +1023,25 @@ def test_get_multi_weights_col_marlin(): assert torch.allclose(w.s, expected_weight.s), "s mismatch" -# test_get_multi_weights_row +# test_get_weights_row -def test_get_multi_weights_row_awq(): +def test_get_weights_row_awq(gptq_weights_loader_awq): weights = MockWeights( [ - "test_get_multi_weights_row_gptq", + "test_get_weights_row_gptq", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader_awq, ) prefix = "weight" - quantize = "awq" - w = weights.get_multi_weights_row( + w = weights.get_weights_row( prefix=prefix, - quantize=quantize, ) expected_weight = GPTQWeight( @@ -1048,23 +1063,22 @@ def test_get_multi_weights_row_awq(): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_multi_weights_row_exl2(): +def test_get_weights_row_exl2(): weights = MockWeights( [ - "test_get_multi_weights_row_exl2", + "test_get_weights_row_exl2", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=Exl2WeightsLoader(), ) prefix = "weight" - quantize = "exl2" - w = weights.get_multi_weights_row( + w = weights.get_weights_row( prefix=prefix, - quantize=quantize, ) print(w) @@ -1086,23 +1100,22 @@ def test_get_multi_weights_row_exl2(): assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" -def test_get_multi_weights_row_gptq(): +def test_get_weights_row_gptq(gptq_weights_loader): weights = MockWeights( [ - "test_get_multi_weights_row_gptq", + "test_get_weights_row_gptq", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader, ) prefix = "weight" - quantize = "gptq" - w = weights.get_multi_weights_row( + w = weights.get_weights_row( prefix=prefix, - quantize=quantize, ) expected_weight = GPTQWeight( @@ -1124,23 +1137,22 @@ def test_get_multi_weights_row_gptq(): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_multi_weights_row_marlin(): +def test_get_weights_row_marlin(marlin_weights_loader): weights = MockWeights( [ - "test_get_multi_weights_row_marlin", + "test_get_weights_row_marlin", ], device="cpu", dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=marlin_weights_loader, ) prefix = "weight" - quantize = "marlin" - w = weights.get_multi_weights_row( + w = weights.get_weights_row( prefix=prefix, - quantize=quantize, ) expected_weight = MarlinWeight( diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 68ae95dd..71ad18f7 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -353,6 +353,7 @@ def quantize( upload_to_model_id=upload_to_model_id, percdamp=percdamp, act_order=act_order, + sym=True, ) diff --git a/server/text_generation_server/layers/exl2.py b/server/text_generation_server/layers/exl2.py index f6cb729e..55cba1cc 100644 --- a/server/text_generation_server/layers/exl2.py +++ b/server/text_generation_server/layers/exl2.py @@ -1,6 +1,9 @@ import torch +from typing import List, Union from dataclasses import dataclass +from text_generation_server.utils.weights import WeightsLoader, Weights + @dataclass class Exl2Weight: @@ -21,3 +24,60 @@ class Exl2Weight: @property def device(self) -> torch.device: return self.q_weight.device + + +class Exl2WeightsLoader(WeightsLoader): + """Loader for exl2-quantized weights.""" + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + raise RuntimeError("Column-packed weights are not supported for exl") + + def get_weights_col(self, weights: Weights, prefix: str): + try: + q_weight = weights.get_tensor(f"{prefix}.q_weight") + except RuntimeError: + raise RuntimeError( + "Cannot load `exl2`-quantized weight, make sure the model is already quantized." + ) + + q_scale = weights.get_tensor(f"{prefix}.q_scale") + q_invperm = weights.get_tensor(f"{prefix}.q_invperm") + q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max") + q_groups = weights.get_tensor(f"{prefix}.q_groups") + + return Exl2Weight( + q_weight=q_weight, + q_scale=q_scale, + q_invperm=q_invperm, + q_scale_max=q_scale_max, + q_groups=q_groups, + ) + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + raise ValueError("get_multi_weights_col is not supported for exl2") + + def get_weights_row(self, weights: Weights, prefix: str): + try: + q_weight = weights.get_tensor(f"{prefix}.q_weight") + except RuntimeError: + raise RuntimeError( + "Cannot load `exl2`-quantized weight, make sure the model is already quantized." + ) + + q_scale = weights.get_tensor(f"{prefix}.q_scale") + q_invperm = weights.get_tensor(f"{prefix}.q_invperm") + q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max") + q_groups = weights.get_tensor(f"{prefix}.q_groups") + + return Exl2Weight( + q_weight=q_weight, + q_scale=q_scale, + q_invperm=q_invperm, + q_scale_max=q_scale_max, + q_groups=q_groups, + ) diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index dd61d081..b76af8f1 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -1,4 +1,23 @@ +from enum import Enum, auto + import torch +from text_generation_server.utils.import_utils import SYSTEM + + +def get_fp8_linear() -> torch.nn.Module: + """ + Return an FP8 linear `Module` that is compatible with the current system. + """ + + if SYSTEM == "cuda": + major, minor = torch.cuda.get_device_capability() + if major == 8 and minor < 9: + from text_generation_server.layers.marlin import GPTQMarlinFP8Linear + + return GPTQMarlinFP8Linear + + # On other systems let Torch decide if the hardware supports FP8. + return Fp8Linear def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 56080145..aaa7a68a 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -1,20 +1,14 @@ from dataclasses import dataclass +from loguru import logger import os -from typing import Optional +from typing import List, Optional, Union +from safetensors import SafetensorError +from text_generation_server.utils.weights import Weights, WeightsLoader import torch from text_generation_server.utils.import_utils import ( SYSTEM, ) - - -@dataclass -class GPTQParams: - bits: int - checkpoint_format: Optional[str] - groupsize: int - desc_act: bool - quant_method: str - sym: bool +from text_generation_server.utils.log import log_once @dataclass @@ -69,3 +63,345 @@ elif CAN_EXLLAMA: pass from text_generation_server.layers.gptq.quant_linear import QuantLinear + + +class GPTQWeightsLoader(WeightsLoader): + """ + Loader for GPTQ- and AWQ-quantized weights. + """ + + def __init__( + self, + *, + bits: int, + desc_act: bool, + groupsize: int, + quant_method: str, + quantize: str, + sym: bool, + ): + self.bits = bits + self.desc_act = desc_act + self.groupsize = groupsize + self.quant_method = quant_method + self.quantize = quantize + self.sym = sym + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) + + try: + qweight = weights.get_packed_sharded( + f"{prefix}.qweight", dim=1, block_sizes=block_sizes + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized." + ) + scales = weights.get_packed_sharded( + f"{prefix}.scales", dim=1, block_sizes=block_sizes + ) + scales = scales.to(dtype=weights.dtype) + + self._get_gptq_params(weights) + if can_use_gptq_marlin( + bits=self.bits, + groupsize=self.groupsize, + quant_method=self.quant_method, + quantize=self.quantize, + sym=self.sym, + ): + g_idx = weights.get_tensor(f"{prefix}.g_idx") + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + sym=self.sym, + sharded_infeatures=False, + ) + + qzeros = weights.get_packed_sharded( + f"{prefix}.qzeros", dim=1, block_sizes=block_sizes + ) + if self.quantize == "gptq" and self.quant_method == "gptq": + g_idx = weights.get_tensor(f"{prefix}.g_idx") + elif self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + // self.groupsize + ).to(dtype=torch.int32) + else: + g_idx = None + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_exllama=False, + ) + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) + + try: + qweight = torch.cat( + [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized" + ) + + scales = torch.cat( + [weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 + ) + + self._get_gptq_params(weights) + if can_use_gptq_marlin( + bits=self.bits, + groupsize=self.groupsize, + quant_method=self.quant_method, + quantize=self.quantize, + sym=self.sym, + ): + w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + sym=self.sym, + sharded_infeatures=False, + ) + + qzeros = torch.cat( + [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + ) + + from text_generation_server.layers.gptq import HAS_EXLLAMA + + use_exllama = ( + self.bits == 4 + and HAS_EXLLAMA + and self.quantize == "gptq" + and not self.desc_act + ) + + if self.quantize == "gptq" and self.quant_method == "gptq": + w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + elif self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + // self.groupsize + ).to(dtype=torch.int32) + else: + g_idx = None + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_exllama=use_exllama, + ) + + def get_weights_row(self, weights: Weights, prefix: str): + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) + + self._get_gptq_params(weights) + if can_use_gptq_marlin( + bits=self.bits, + groupsize=self.groupsize, + quant_method=self.quant_method, + quantize=self.quantize, + sym=self.sym, + ): + log_once(logger.info, "Using GPTQ-Marlin kernels") + try: + qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" + ) + + g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) + if self.desc_act or self.groupsize == -1: + scales = weights.get_tensor(f"{prefix}.scales") + else: + scales = weights.get_sharded(f"{prefix}.scales", dim=0) + + sharded_in_features = weights.process_group.size() > 1 + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + sym=self.sym, + sharded_infeatures=sharded_in_features, + ) + + use_exllama = True + if self.bits != 4: + use_exllama = False + + if self.desc_act: + log_once(logger.warning, "Disabling exllama because desc_act=True") + use_exllama = False + + try: + qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + if self.quantize == "gptq" and self.quant_method == "gptq": + g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) + else: + g_idx = None + + if weights.process_group.size() > 1: + if g_idx is not None: + if ( + not torch.equal( + g_idx.cpu(), + torch.tensor( + [i // self.groupsize for i in range(g_idx.shape[0])], + dtype=torch.int32, + ), + ) + and not (g_idx == 0).all() + ): + # Exllama implementation does not support row tensor parallelism with act-order, as + # it would require to reorder input activations that are split unto several GPUs + use_exllama = False + + from text_generation_server.layers.gptq import ( + HAS_EXLLAMA, + CAN_EXLLAMA, + GPTQWeight, + ) + + if use_exllama: + if not HAS_EXLLAMA: + if CAN_EXLLAMA: + log_once( + logger.warning, + "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True", + ) + use_exllama = False + else: + log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") + + if use_exllama and self.groupsize != -1: + qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) + scales = weights.get_sharded(f"{prefix}.scales", dim=0) + else: + qzeros = weights.get_tensor(f"{prefix}.qzeros") + scales = weights.get_tensor(f"{prefix}.scales") + + if use_exllama and g_idx is not None: + g_idx = g_idx - g_idx[0] + + if self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + // self.groupsize + ).to(dtype=torch.int32) + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_exllama=use_exllama, + ) + + def _get_gptq_params(self, weights: Weights): + if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"): + self.bits = weights.get_tensor("gptq_bits").item() + self.groupsize = weights.get_tensor("gptq_groupsize").item() + self.desc_act = False + # `server quantize` used asymmetric quantization unconditionally + # before the `gptq_sym` setting tensor was added. + self.sym = ( + weights.get_tensor("gptq_sym").item() + if weights._has_tensor("gptq_sym") + else False + ) + self.quant_method = "gptq" diff --git a/server/text_generation_server/layers/gptq/quantize.py b/server/text_generation_server/layers/gptq/quantize.py index 8d029817..0271d913 100644 --- a/server/text_generation_server/layers/gptq/quantize.py +++ b/server/text_generation_server/layers/gptq/quantize.py @@ -16,6 +16,8 @@ from text_generation_server.layers.gptq.quant_linear import QuantLinear from loguru import logger from typing import Optional +from text_generation_server.utils.weights import DefaultWeightsLoader + DEV = torch.device("cuda:0") @@ -869,6 +871,7 @@ def quantize( upload_to_model_id: Optional[str], percdamp: float, act_order: bool, + sym: bool, ): print("loading model") config = AutoConfig.from_pretrained( @@ -891,6 +894,7 @@ def quantize( dtype=torch.float16, process_group=process_group, aliases={"embed_tokens.weight": ["lm_head.weight"]}, + weights_loader=DefaultWeightsLoader(), ) hooks = [] for name, module in model.named_modules(): @@ -943,6 +947,7 @@ def quantize( percdamp=percdamp, act_order=act_order, hooks=hooks, + sym=sym, ) print(time.time() - tick) @@ -954,6 +959,7 @@ def quantize( state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} state_dict["gptq_bits"] = torch.LongTensor([bits]) state_dict["gptq_groupsize"] = torch.LongTensor([groupsize]) + state_dict["gptq_sym"] = torch.BoolTensor([sym]) max_shard_size = "10GB" shards, index = shard_checkpoint( diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index e94e5465..babd86b0 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -106,9 +106,9 @@ def get_linear(weight, bias, quantize): "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" ) elif quantize == "fp8": - from text_generation_server.layers.fp8 import Fp8Linear + from text_generation_server.layers.fp8 import get_fp8_linear - linear = Fp8Linear(weight, bias) + linear = get_fp8_linear()(weight, bias) elif quantize == "bitsandbytes": try: from text_generation_server.layers.bnb import ( diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index a1af67a3..9777a47e 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -1,11 +1,13 @@ from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn - -from text_generation_server.layers.gptq import GPTQParams +from loguru import logger +from text_generation_server.layers.fp8 import fp8_quantize from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.log import log_once +from text_generation_server.utils.weights import Weights, WeightsLoader try: import marlin_kernels @@ -24,16 +26,132 @@ GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128] MARLIN_TILE_SIZE = 16 -def can_use_gptq_marlin(gptq_params: GPTQParams, quantize: str) -> bool: +class MarlinWeightsLoader(WeightsLoader): + """Loader for Marlin-quantized weights.""" + + def __init__(self, *, bits: int, is_marlin_24: bool): + self.bits = bits + self.is_marlin_24 = is_marlin_24 + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + if self.is_marlin_24: + B = weights.get_packed_sharded( + f"{prefix}.B_24", dim=1, block_sizes=block_sizes + ) + B_meta = weights.get_packed_sharded( + f"{prefix}.B_meta", dim=1, block_sizes=block_sizes + ) + s = weights.get_packed_sharded( + f"{prefix}.s", dim=1, block_sizes=block_sizes + ) + + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + B = weights.get_packed_sharded( + f"{prefix}.B", dim=1, block_sizes=block_sizes + ) + s = weights.get_packed_sharded( + f"{prefix}.s", dim=1, block_sizes=block_sizes + ) + weight = MarlinWeight(B=B, s=s) + + return weight + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" + if is_marlin_24: + try: + B = torch.cat( + [weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `marlin` weight, make sure the model is already quantized" + ) + + B_meta = torch.cat( + [weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1 + ) + + s = torch.cat( + [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 + ) + + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + try: + B = torch.cat( + [weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `marlin` weight, make sure the model is already quantized" + ) + s = torch.cat( + [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 + ) + + weight = MarlinWeight(B=B, s=s) + + return weight + + def get_weights_row(self, weights: Weights, prefix: str): + is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" + if is_marlin_24: + try: + B = weights.get_sharded(f"{prefix}.B_24", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." + ) + + B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0) + num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] + if num_groups == 1: + # The number of groups is 1 when groupsize == -1. share + # scales between all shards in this case. + s = weights.get_tensor(f"{prefix}.s") + else: + s = weights.get_sharded(f"{prefix}.s", dim=0) + + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + try: + B = weights.get_sharded(f"{prefix}.B", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` weight, make sure the model is already quantized." + ) + + num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] + if num_groups == 1: + # The number of groups is 1 when groupsize == -1. share + # scales between all shards in this case. + s = weights.get_tensor(f"{prefix}.s") + else: + s = weights.get_sharded(f"{prefix}.s", dim=0) + weight = MarlinWeight(B=B, s=s) + + return weight + + +def can_use_gptq_marlin( + *, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool +) -> bool: return ( SYSTEM == "cuda" and marlin_kernels is not None and has_sm_8_0 and quantize == "gptq" - and gptq_params.quant_method == "gptq" - and gptq_params.bits in GPTQ_MARLIN_BITS - and gptq_params.groupsize in GPTQ_MARLIN_GROUP_SIZES - and gptq_params.sym + and quant_method == "gptq" + and bits in GPTQ_MARLIN_BITS + and groupsize in GPTQ_MARLIN_GROUP_SIZES + and sym ) @@ -339,6 +457,115 @@ class GPTQMarlin24Linear(nn.Module): return C +class GPTQMarlinFP8Linear(nn.Module): + """ + FP8 GPTQ-Marlin linear layer. + """ + + def __init__( + self, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + ) -> None: + super().__init__() + + _check_marlin_kernels() + assert marlin_kernels is not None + + log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") + + qweight, scale = fp8_quantize(weight) + scale = scale.to(torch.float16) + qweight, scales = repack_fp8_for_marlin(qweight, scale) + + in_features = qweight.shape[0] * MARLIN_TILE_SIZE + out_features = scales.shape[1] + _check_valid_shape(in_features=in_features, out_features=out_features) + + self.qweight = qweight + self.scales = scales + self.bias = bias if bias is not None else None + + self.workspace = torch.zeros( + out_features // 64 * 16, dtype=torch.int, device=qweight.device + ) + + def forward(self, A: torch.Tensor) -> torch.Tensor: + assert marlin_kernels is not None + + A_flat = A.view(-1, A.shape[-1]) + C = marlin_kernels.fp8_marlin_gemm( + A_flat, + self.qweight, + self.scales, + self.workspace, + 8, + A_flat.shape[0], + self.scales.shape[1], + A_flat.shape[1], + ) + C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) + + if self.bias is not None: + C += self.bias + + return C + + +def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements). + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + + if fp8_tensor.shape[0] % 4 != 0: + raise ValueError( + f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}" + ) + + # Reshape to prepare for packing + reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + + # Convert fp8 to uint8 (byte) representation + byte_tensor = reshaped.view(torch.uint8) + + # Pack 4 uint8 values into one int32 + packed = torch.zeros( + fp8_tensor.shape[0] // 4, + fp8_tensor.shape[1], + dtype=torch.int32, + device=fp8_tensor.device, + ) + + for i in range(4): + packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8) + + return packed + + +def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor): + """ + Repack FP8 tensor for GPTQ-Marlin. + """ + + out_features, in_features = weight.shape + + # Torch linear layers weights with shape [out_features, in_features], + # GPTQ-quantized weights use [in_feateres/pack_factor, in_features], + # so transpose before packing. + qweight = pack_fp8_as_int32(weight.t()) + + perm = torch.empty(0, dtype=torch.int, device=qweight.device) + repacked = marlin_kernels.gptq_marlin_repack( + qweight, perm, in_features, out_features, 8 + ) + + scales = scale.reshape(1, 1).repeat(1, out_features) + scales = permute_scales(scales) + + return repacked, scales + + @dataclass class MarlinWeight: """ diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index b14005e6..8c354b82 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -102,7 +102,7 @@ class PositionRotaryEmbedding(nn.Module): max_position_embeddings=rope_scaling[ "original_max_position_embeddings" ], - base=10000.0, + base=base, device=inv_freq.device, scaling_factor=scaling_factor, extrapolation_factor=1, @@ -110,7 +110,7 @@ class PositionRotaryEmbedding(nn.Module): beta_fast=32, beta_slow=1, ) - elif rope_scaling["type"] == "su": + elif rope_scaling["type"] in ["su", "longrope"]: short_factor = torch.tensor( rope_scaling["short_factor"], dtype=torch.float32, device=device ) diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 038de258..011f105b 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -52,7 +52,7 @@ class TensorParallelHead(SuperLayer): weight = weights.get_tensor(f"{prefix}.weight") except: # ...otherwise they are quantized. - weight = weights.get_weights_col(prefix, config.quantize) + weight = weights.get_weights_col(prefix) should_gather = weights.process_group.size() > 1 elif weights.process_group.size() > 1: try: @@ -129,9 +129,7 @@ class TensorParallelColumnLinear(SuperLayer): @classmethod def load_gate_up(cls, config, prefix: str, weights, bias: bool): """Specific method when the QKV was joined after the fact""" - weight = weights.get_weights_col_packed_gate_up( - prefix, quantize=config.quantize - ) + weight = weights.get_weights_col_packed_gate_up(prefix) if bias: raise NotImplementedError("packed_gate_up only implemented without bias") else: @@ -152,7 +150,6 @@ class TensorParallelColumnLinear(SuperLayer): """Specific method when the QKV was joined after the fact""" weight = weights.get_weights_col_packed_qkv( prefix, - quantize=config.quantize, num_heads=num_heads, num_key_value_heads=num_key_value_heads, ) @@ -165,7 +162,7 @@ class TensorParallelColumnLinear(SuperLayer): @classmethod def load(cls, config, prefix: str, weights, bias: bool): - weight = weights.get_weights_col(prefix, config.quantize) + weight = weights.get_weights_col(prefix) if bias: bias = weights.get_sharded(f"{prefix}.bias", dim=0) else: @@ -178,14 +175,12 @@ class TensorParallelColumnLinear(SuperLayer): if config.quantize == "exl2": linears = [] for prefix in prefixes: - weight = weights.get_weights_col(prefix, config.quantize) + weight = weights.get_weights_col(prefix) b = weights.get_tensor(f"{prefix}.bias") if bias else None linears.append(get_linear(weight, b, config.quantize)) linear = LayerConcat(linears) else: - weight = weights.get_multi_weights_col( - prefixes, quantize=config.quantize, dim=dim - ) + weight = weights.get_multi_weights_col(prefixes, dim=dim) if bias: b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] bias = torch.cat(b, dim=dim) @@ -202,7 +197,7 @@ class TensorParallelRowLinear(SuperLayer): @classmethod def load(cls, config, prefix: str, weights, bias: bool): - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_weights_row(prefix) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 5ea43290..ba980195 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -11,17 +11,27 @@ from pathlib import Path from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model -from text_generation_server.models.causal_lm import CausalLM -from text_generation_server.models.bloom import BLOOMSharded -from text_generation_server.models.mpt import MPTSharded +from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast +from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM +from text_generation_server.models.custom_modeling.mpt_modeling import ( + MPTForCausalLM, +) +from text_generation_server.models.bloom import BloomCausalLMBatch +from text_generation_server.models.custom_modeling.bloom_modeling import ( + BloomForCausalLM, +) from text_generation_server.models.seq2seq_lm import Seq2SeqLM -from text_generation_server.models.rw import RW -from text_generation_server.models.opt import OPTSharded -from text_generation_server.models.galactica import GalacticaSharded -from text_generation_server.models.santacoder import SantaCoder -from text_generation_server.models.t5 import T5Sharded -from text_generation_server.models.gpt_neox import GPTNeoxSharded -from text_generation_server.models.phi import Phi +from text_generation_server.models.galactica import GalacticaCausalLMBatch +from text_generation_server.models.custom_modeling.neox_modeling import ( + GPTNeoxForCausalLM, +) +from text_generation_server.models.custom_modeling.phi_modeling import ( + PhiConfig, + PhiForCausalLM, +) +from text_generation_server.models.custom_modeling.t5_modeling import ( + T5ForConditionalGeneration, +) from text_generation_server.utils.import_utils import SYSTEM @@ -41,9 +51,6 @@ __all__ = [ "CausalLM", "GalacticaSharded", "Seq2SeqLM", - "SantaCoder", - "OPTSharded", - "T5Sharded", "get_model", ] @@ -53,38 +60,65 @@ FLASH_ATTENTION = True try: from text_generation_server.models.flash_causal_lm import FlashCausalLM - from text_generation_server.models.flash_rw import FlashRWSharded - from text_generation_server.models.flash_gpt2 import FlashGPT2 - from text_generation_server.models.flash_neox import FlashNeoXSharded - from text_generation_server.models.flash_llama import ( - FlashLlama, + from text_generation_server.models.vlm_causal_lm import VlmCausalLM + from text_generation_server.models.custom_modeling.flash_llama_modeling import ( + FlashLlamaForCausalLM, ) - from text_generation_server.models.flash_qwen2 import ( - FlashQwen2, + from text_generation_server.models.custom_modeling.flash_cohere_modeling import ( + FlashCohereForCausalLM, ) - from text_generation_server.models.flash_cohere import ( - FlashCohere, + from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( + FlashGemmaForCausalLM, ) - from text_generation_server.models.flash_gemma import ( - FlashGemma, + from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( + FlashGemma2ForCausalLM, ) - from text_generation_server.models.flash_gemma2 import ( - FlashGemma2, + from text_generation_server.models.custom_modeling.flash_dbrx_modeling import ( + FlashDbrxForCausalLM, + DbrxConfig, + ) + from text_generation_server.models.custom_modeling.flash_rw_modeling import ( + RWConfig, + FlashRWForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_neox_modeling import ( + FlashGPTNeoXForCausalLM, ) from text_generation_server.models.pali_gemma import ( - PaliGemma, + PaliGemmaBatch, ) - from text_generation_server.models.flash_santacoder import ( - FlashSantacoderSharded, + from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( + PaliGemmaForConditionalGeneration, + ) + from text_generation_server.models.custom_modeling.flash_phi_modeling import ( + FlashPhiForCausalLM, ) from text_generation_server.models.idefics import IDEFICSSharded - from text_generation_server.models.llava_next import LlavaNext - from text_generation_server.models.idefics2 import Idefics2 - from text_generation_server.models.flash_mistral import FlashMistral - from text_generation_server.models.flash_mixtral import FlashMixtral - from text_generation_server.models.flash_phi import FlashPhi - from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 - from text_generation_server.models.flash_dbrx import FlashDbrx + from text_generation_server.models.custom_modeling.llava_next import ( + LlavaNextForConditionalGeneration, + ) + + from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( + FlashSantacoderForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import ( + FlashStarcoder2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( + Qwen2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( + FlashMistralForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_mixtral_modeling import ( + FlashMixtralForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_gpt2_modeling import ( + FlashGPT2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.idefics2 import ( + Idefics2ForConditionalGeneration, + ) from text_generation_server.layers.attention import SUPPORTS_WINDOWING except ImportError as e: logger.warning(f"Could not import Flash Attention enabled models: {e}") @@ -93,21 +127,7 @@ except ImportError as e: if FLASH_ATTENTION: __all__.append(FlashCausalLM) - __all__.append(FlashGPT2) - __all__.append(FlashNeoXSharded) - __all__.append(FlashRWSharded) - __all__.append(FlashSantacoderSharded) - __all__.append(FlashLlama) __all__.append(IDEFICSSharded) - __all__.append(FlashMistral) - __all__.append(FlashMixtral) - __all__.append(FlashDbrx) - __all__.append(FlashPhi) - __all__.append(FlashQwen2) - __all__.append(FlashStarcoder2) - __all__.append(FlashGemma) - __all__.append(FlashGemma2) - __all__.append(FlashCohere) MAMBA_AVAILABLE = True try: @@ -148,6 +168,11 @@ class ModelType(enum.Enum): "name": "Gemma", "url": "https://huggingface.co/google/gemma-7b", } + PALIGEMMA = { + "type": "paligemma", + "name": "PaliGemma", + "url": "https://huggingface.co/google/paligemma-3b-pt-224", + } GEMMA2 = { "type": "gemma2", "name": "Gemma2", @@ -445,13 +470,16 @@ def get_model( ) if model_id.startswith("facebook/galactica"): - return GalacticaSharded( - model_id, - revision, + return CausalLM( + model_id=model_id, + # Yes galactica is just an OPT model. + model_class=OPTForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + batch_class=GalacticaCausalLMBatch, ) if ( @@ -460,22 +488,26 @@ def get_model( and model_id.startswith("bigcode/") ): if FLASH_ATTENTION: - return FlashSantacoderSharded( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashSantacoderForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + aliases={"transformer.wte.weight": ["lm_head.weight"]}, + num_kv_heads=1, ) elif sharded: raise NotImplementedError( FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") ) else: - return SantaCoder( - model_id, - revision, + return CausalLM.fallback( + model_id=model_id, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, @@ -483,38 +515,44 @@ def get_model( ) if model_type == BLOOM: - return BLOOMSharded( - model_id, - revision, + return CausalLM( + model_id=model_id, + model_class=BloomForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + batch_class=BloomCausalLMBatch, ) elif model_type == MPT: - return MPTSharded( - model_id, - revision, + return CausalLM( + model_id=model_id, + model_class=MPTForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + batch_class=CausalLMBatchKeysLast, ) elif model_type == GPT2: if FLASH_ATTENTION: try: - return FlashGPT2( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashGPT2ForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) except RuntimeError as e: # Lots of legacy models with various weight names. logger.warning(f"Couldn't load flash gpt2 variant: {e}") - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -525,7 +563,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -535,25 +573,28 @@ def get_model( ) elif model_type == GPT_NEOX: if FLASH_ATTENTION: - return FlashNeoXSharded( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashGPTNeoXForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: - return GPTNeoxSharded( - model_id, - revision, + return CausalLM( + model_id=model_id, + model_class=GPTNeoxForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -564,16 +605,18 @@ def get_model( elif model_type == PHI: if FLASH_ATTENTION: - return FlashPhi( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashPhiForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -588,9 +631,11 @@ def get_model( "Legacy phi-msft is not supported with Flash Attention" ) else: - return Phi( - model_id, - revision, + return CausalLM( + model_id=model_id, + model_class=PhiForCausalLM, + config_class=PhiConfig, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, @@ -599,9 +644,10 @@ def get_model( elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: if FLASH_ATTENTION: - return FlashLlama( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashLlamaForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, @@ -611,7 +657,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -621,18 +667,22 @@ def get_model( ) if model_type == GEMMA: if FLASH_ATTENTION: - return FlashGemma( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashGemmaForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, + # Works better for these models + default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -642,18 +692,22 @@ def get_model( ) elif model_type == GEMMA2: if FLASH_ATTENTION: - return FlashGemma2( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashGemma2ForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, + # Works better for these models + default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -664,18 +718,20 @@ def get_model( if model_type == COHERE: if FLASH_ATTENTION: - return FlashCohere( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashCohereForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -686,18 +742,23 @@ def get_model( if model_type == DBRX: if FLASH_ATTENTION: - return FlashDbrx( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashDbrxForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, + # Dbrx works better in bfloat16. + default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=DbrxConfig, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -711,27 +772,41 @@ def get_model( if FLASH_ATTENTION: if config_dict.get("alibi", False): raise NotImplementedError("sharded is not supported for this model") - return FlashRWSharded( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashRWForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, + aliases={ + "lm_head.weight": ["transformer.word_embeddings.weight"], + "transformer.word_embeddings.weight": ["lm_head.weight"], + }, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=RWConfig, ) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon")) else: if FLASH_ATTENTION and not config_dict.get("alibi", False): - return FlashRWSharded( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashRWForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, + aliases={ + "lm_head.weight": ["transformer.word_embeddings.weight"], + "transformer.word_embeddings.weight": ["lm_head.weight"], + }, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=RWConfig, ) else: - return RW( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -742,18 +817,20 @@ def get_model( if model_type == MISTRAL: if FLASH_ATTENTION: - return FlashMistral( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashMistralForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -764,18 +841,20 @@ def get_model( if model_type == MIXTRAL: if FLASH_ATTENTION: - return FlashMixtral( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashMixtralForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -786,19 +865,22 @@ def get_model( if model_type == STARCODER2: if FLASH_ATTENTION: - return FlashStarcoder2( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashStarcoder2ForCausalLM, + revision=revision, quantize=quantize, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError( FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2") ) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -809,17 +891,20 @@ def get_model( if model_type == QWEN2: if FLASH_ATTENTION: - return FlashQwen2( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=Qwen2ForCausalLM, + revision=revision, quantize=quantize, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -829,9 +914,10 @@ def get_model( ) if model_type == OPT: - return OPTSharded( - model_id, - revision, + return CausalLM( + model_id=model_id, + model_class=OPTForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, @@ -839,13 +925,20 @@ def get_model( ) if model_type == T5: - return T5Sharded( - model_id, - revision, + return Seq2SeqLM( + model_id=model_id, + model_class=T5ForConditionalGeneration, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + aliases={ + "shared.weight": [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + ] + }, ) if model_type == IDEFICS: if FLASH_ATTENTION: @@ -861,34 +954,45 @@ def get_model( raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == IDEFICS2: if FLASH_ATTENTION: - return Idefics2( - model_id, - revision, + return VlmCausalLM( + model_id=model_id, + model_class=Idefics2ForConditionalGeneration, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + # XXX: Extremely important to cap resolution in order to limit + # VRAM usage. + processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}}, ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) - if model_type == "paligemma": + if model_type == PALIGEMMA: if FLASH_ATTENTION: - return PaliGemma( - model_id, - revision, + return VlmCausalLM( + model_id=model_id, + model_class=PaliGemmaForConditionalGeneration, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, + # Works better for these models + default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + batch_class=PaliGemmaBatch, ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == LLAVA_NEXT: if FLASH_ATTENTION: - return LlavaNext( - model_id, - revision, + return VlmCausalLM( + model_class=LlavaNextForConditionalGeneration, + model_id=model_id, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, @@ -912,7 +1016,7 @@ def get_model( elif quantize == "exl2": raise NotImplementedError("exl2 quantization is not supported for AutoModel") if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -921,7 +1025,7 @@ def get_model( trust_remote_code=trust_remote_code, ) if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: - return Seq2SeqLM( + return Seq2SeqLM.fallback( model_id, revision, quantize=quantize, @@ -933,7 +1037,7 @@ def get_model( auto_map = config_dict.get("auto_map", None) if trust_remote_code and auto_map is not None: if "AutoModelForCausalLM" in auto_map.keys(): - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -942,7 +1046,7 @@ def get_model( trust_remote_code=trust_remote_code, ) if "AutoModelForSeq2SeqLM" in auto_map.keys(): - return Seq2SeqLM( + return Seq2SeqLM.fallback( model_id, revision, quantize=quantize, diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 17aa12e8..732b4c53 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -4,22 +4,12 @@ import torch.distributed from typing import Optional, Type from transformers import ( - AutoTokenizer, - AutoConfig, PreTrainedTokenizerBase, ) -from text_generation_server.models.custom_modeling.bloom_modeling import ( - BloomForCausalLM, -) from text_generation_server.models import CausalLM from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) class BloomCausalLMBatch(CausalLMBatch): @@ -37,69 +27,6 @@ class BloomCausalLMBatch(CausalLMBatch): class BLOOMSharded(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - slow_but_exact=False, - tp_parallel=True, - trust_remote_code=trust_remote_code, - ) - config.pad_token_id = 3 - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device=device, - dtype=dtype, - process_group=self.process_group, - prefix="transformer", - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = BloomForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - @property def batch_type(self) -> Type[CausalLMBatch]: return BloomCausalLMBatch diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 10c64c66..0ea82b1e 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -1,13 +1,26 @@ import torch import time +import torch.distributed from dataclasses import dataclass from opentelemetry import trace -from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase +from transformers import ( + AutoConfig, + AutoTokenizer, + AutoModelForCausalLM, + PreTrainedTokenizerBase, +) from typing import Optional, Tuple, List, Type, Dict +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) from text_generation_server.models import Model from text_generation_server.utils.chunks import concat_text_chunks +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models.types import ( Batch, @@ -478,10 +491,93 @@ class CausalLMBatch(Batch): return len(self.requests) +@dataclass +class CausalLMBatchKeysLast(Batch): + keys_head_dim_last: bool = False + + class CausalLM(Model): def __init__( self, model_id: str, + model_class, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + default_dtype=torch.float16, + trust_remote_code: bool = False, + tokenizer_class=AutoTokenizer, + config_class=AutoConfig, + batch_class=CausalLMBatch, + ): + self.batch_class = batch_class + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = default_dtype if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = default_dtype if dtype is None else dtype + else: + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + tokenizer = tokenizer_class.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + config = config_class.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + ) + config.quantize = quantize + config.speculator = speculator + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = config.pad_token_id + + torch.distributed.barrier(group=self.process_group) + weights_loader = get_loader( + quantize=quantize, model_id=model_id, revision=revision + ) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + weights_loader=weights_loader, + ) + + prefix = "" + model = model_class(prefix, config, weights) + + torch.distributed.barrier(group=self.process_group) + super().__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) + + @classmethod + def fallback( + cls, + model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, speculator: Optional[str] = None, @@ -537,7 +633,12 @@ class CausalLM(Model): else: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - super(CausalLM, self).__init__( + self = cls.__new__( + cls, + ) + self.batch_class = CausalLMBatch + super().__init__( + self, model_id=model_id, model=model, tokenizer=tokenizer, @@ -545,15 +646,11 @@ class CausalLM(Model): dtype=dtype, device=device, ) + return self @property def batch_type(self) -> Type[CausalLMBatch]: - return CausalLMBatch - - def decode(self, generated_ids: List[int]) -> str: - return self.tokenizer.decode( - generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) + return self.batch_class def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py index 0d8a1b59..77b89c5b 100644 --- a/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -816,7 +816,7 @@ class BloomModel(BloomPreTrainedModel): class BloomForCausalLM(BloomPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) self.transformer = BloomModel(config, weights) diff --git a/server/text_generation_server/models/custom_modeling/clip.py b/server/text_generation_server/models/custom_modeling/clip.py index 56618bf1..27b9ff1c 100644 --- a/server/text_generation_server/models/custom_modeling/clip.py +++ b/server/text_generation_server/models/custom_modeling/clip.py @@ -446,7 +446,7 @@ class CLIPEncoder(nn.Module): class CLIPTextTransformer(nn.Module): - def __init__(self, config: CLIPTextConfig): + def __init__(self, prefix: str, config: CLIPTextConfig): super().__init__() self.config = config embed_dim = config.hidden_size @@ -536,9 +536,9 @@ class CLIPTextModel(CLIPPreTrainedModel): _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"] - def __init__(self, config: CLIPTextConfig): + def __init__(self, prefix, config: CLIPTextConfig): super().__init__(config) - self.text_model = CLIPTextTransformer(config) + self.text_model = CLIPTextTransformer(prefix, config) # Initialize weights and apply final processing self.post_init() diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index e088f9aa..25719b99 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -163,7 +163,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) @@ -363,9 +362,9 @@ class CohereMLP(nn.Module): class FlashCohereLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"model.layers.{layer_id}" + prefix = f"{prefix}.layers.{layer_id}" self.self_attn = FlashCohereAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) @@ -416,18 +415,19 @@ class FlashCohereLayer(nn.Module): class FlashCohereModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights + prefix=f"{prefix}.embed_tokens", weights=weights ) self.layers = nn.ModuleList( [ FlashCohereLayer( + prefix, layer_id, config, weights, @@ -436,7 +436,7 @@ class FlashCohereModel(torch.nn.Module): ] ) self.norm = FastLayerNorm.load_no_bias( - prefix="model.norm", weights=weights, eps=config.layer_norm_eps + prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_eps ) self.gradient_checkpointing = False @@ -486,10 +486,15 @@ class FlashCohereModel(torch.nn.Module): class FlashCohereForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() - self.model = FlashCohereModel(config, weights) + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.model = FlashCohereModel(prefix, config, weights) try: self.lm_head = SpeculativeHead.load( config, @@ -499,7 +504,7 @@ class FlashCohereForCausalLM(torch.nn.Module): except RuntimeError: self.lm_head = SpeculativeHead.load( config, - prefix="model.embed_tokens", + prefix=f"{prefix}.embed_tokens", weights=weights, ) self.logit_scale = config.logit_scale diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index aea7f399..44411687 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -105,6 +105,12 @@ class DbrxFFNConfig(PretrainedConfig): class DbrxConfig(PretrainedConfig): + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "n_heads", + "num_hidden_layers": "n_layers", + } + def __init__( self, d_model: int = 2048, @@ -157,6 +163,12 @@ class DbrxConfig(PretrainedConfig): **kwargs, ) + @property + def num_key_value_heads(self): + # We can't use the attribute map, since this the number of KV + # heads is not top-level. + return self.attn_config.kv_n_heads + def promote_scalar(x: torch.Tensor) -> torch.Tensor: return x.view(1) if len(x.size()) == 0 else x @@ -593,9 +605,9 @@ class DenseMoE(nn.Module): class DbrxLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"transformer.blocks.{layer_id}" + prefix = f"{prefix}.blocks.{layer_id}" self.attn = DbrxNormAttentionNorm( prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights @@ -637,16 +649,17 @@ class DbrxLayer(nn.Module): class DbrxModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.embed_tokens = TensorParallelEmbedding( - prefix="transformer.wte", weights=weights + prefix=f"{prefix}.wte", weights=weights ) self.layers = nn.ModuleList( [ DbrxLayer( + prefix, layer_id, config, weights, @@ -655,7 +668,7 @@ class DbrxModel(torch.nn.Module): ] ) self.norm = FastLayerNorm.load_no_bias( - prefix="transformer.norm_f", weights=weights, eps=1e-5 + prefix=f"{prefix}.norm_f", weights=weights, eps=1e-5 ) self.head_size = self.layers[0].attn.self_attn.head_size @@ -702,10 +715,15 @@ class DbrxModel(torch.nn.Module): class FlashDbrxForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() - self.model = DbrxModel(config, weights) + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + + self.model = DbrxModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( config, prefix="lm_head", diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index cfa6b2fe..a3ce5521 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -102,7 +102,7 @@ class Gemma2Config(PretrainedConfig): class Gemma2FastRMSNorm(FastRMSNorm): @classmethod - def load(cls, prefix, weights, eps=1e-6): + def load(cls, prefix: str, weights, eps=1e-6): dtype = weights.dtype weights.dtype = torch.float32 weight = weights.get_tensor(f"{prefix}.weight") + 1 @@ -123,7 +123,7 @@ class Gemma2FastRMSNorm(FastRMSNorm): return hidden_states.to(self.dtype), residual -def load_attention(config, prefix, weights): +def load_attention(config, prefix: str, weights): if config.num_attention_heads != config.num_key_value_heads: return _load_gqa(config, prefix, weights) else: @@ -141,7 +141,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) @@ -305,7 +304,7 @@ class Gemma2MLP(nn.Module): class FlashGemma2Layer(nn.Module): - def __init__(self, prefix, config, weights, causal: bool, is_sliding: bool): + def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool): super().__init__() self.self_attn = FlashGemma2Attention( prefix=f"{prefix}.self_attn", @@ -376,7 +375,7 @@ class FlashGemma2Layer(nn.Module): class FlashGemma2Model(torch.nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix: str, config, weights, causal: bool): super().__init__() process_group = weights.process_group @@ -442,7 +441,7 @@ class FlashGemma2Model(torch.nn.Module): class FlashGemma2ForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix: str, config, weights, *, causal: bool = True): super().__init__() embed_norm = config.hidden_size**0.5 diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 842df0d4..34a7efa2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -102,7 +102,7 @@ class GemmaConfig(PretrainedConfig): class GemmaFastRMSNorm(FastRMSNorm): @classmethod - def load(cls, prefix, weights, eps=1e-6): + def load(cls, prefix: str, weights, eps=1e-6): dtype = weights.dtype weights.dtype = torch.float32 weight = weights.get_tensor(f"{prefix}.weight") + 1 @@ -123,7 +123,7 @@ class GemmaFastRMSNorm(FastRMSNorm): return hidden_states.to(self.dtype), residual -def load_attention(config, prefix, weights): +def load_attention(config, prefix: str, weights): if config.num_attention_heads != config.num_key_value_heads: return _load_gqa(config, prefix, weights) else: @@ -141,7 +141,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) @@ -261,7 +260,7 @@ class FlashGemmaAttention(torch.nn.Module): class GemmaMLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() act = config.hidden_act self.act = ( @@ -299,7 +298,7 @@ class GemmaMLP(nn.Module): class FlashGemmaLayer(nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix: str, config, weights, causal: bool): super().__init__() self.self_attn = FlashGemmaAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal @@ -354,7 +353,7 @@ class FlashGemmaLayer(nn.Module): class FlashGemmaModel(torch.nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix: str, config, weights, causal: bool): super().__init__() process_group = weights.process_group @@ -419,7 +418,7 @@ class FlashGemmaModel(torch.nn.Module): class FlashGemmaForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix: str, config, weights, *, causal: bool = True): super().__init__() embed_norm = config.hidden_size**0.5 diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 9f800146..cbfcb1b8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -61,7 +61,6 @@ def _load_qkv_gptq(config, prefix: str, weights): # Weights weight = weights.get_weights_col_packed_qkv( f"{prefix}.c_attn", - config.quantize, config.num_attention_heads, config.num_attention_heads, ) @@ -137,7 +136,7 @@ def load_row(config, prefix: str, weights, bias: bool): """load_row, but with transposed weight matrices.""" if config.quantize == "gptq": - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_weights_row(prefix) else: weight = weights.get_sharded(f"{prefix}.weight", dim=0).T @@ -155,9 +154,7 @@ def load_row(config, prefix: str, weights, bias: bool): def load_col(config, prefix: str, weights, bias: bool): """load_col, but with transposed weight matrices.""" if config.quantize == "gptq": - weight = weights.get_multi_weights_col( - [prefix], quantize=config.quantize, dim=1 - ) + weight = weights.get_multi_weights_col([prefix], dim=1) else: weight = weights.get_sharded(f"{prefix}.weight", dim=1).T @@ -261,7 +258,7 @@ class FlashGPT2Attention(torch.nn.Module): class GPT2MLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() act = config.activation_function self.act = ( @@ -298,7 +295,7 @@ class GPT2MLP(nn.Module): class FlashGPT2Layer(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.self_attn = FlashGPT2Attention( prefix=f"{prefix}.attn", config=config, weights=weights @@ -350,7 +347,7 @@ class FlashGPT2Layer(nn.Module): class FlashGPT2Model(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group @@ -414,7 +411,7 @@ class FlashGPT2Model(torch.nn.Module): class FlashGPT2ForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.embed_tokens = TensorParallelEmbedding( diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 77a7e2d5..78832341 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -54,7 +54,7 @@ if SYSTEM == "rocm": raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") -def load_attention(config, prefix, weights, layer_id): +def load_attention(config, prefix: str, weights, layer_id): # Only defined in granite. bias = getattr(config, "attention_bias", False) head_size = config.hidden_size // config.num_attention_heads @@ -467,7 +467,7 @@ class FlashLlamaModel(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.embed_tokens = TensorParallelEmbedding( diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 396969cd..8028dbe8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -248,7 +248,7 @@ class MistralAttention(torch.nn.Module): class MistralMLP(nn.Module): - def __init__(self, prefix, config, weights, layer_id): + def __init__(self, prefix: str, config, weights, layer_id): super().__init__() self.hidden_act = config.hidden_act self.act = ( @@ -328,7 +328,7 @@ class MistralMLP(nn.Module): class MistralLayer(nn.Module): - def __init__(self, prefix, config, weights, layer_id): + def __init__(self, prefix: str, config, weights, layer_id): super().__init__() self.self_attn = MistralAttention( prefix=f"{prefix}.self_attn", @@ -392,7 +392,7 @@ class MistralLayer(nn.Module): class MistralModel(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group @@ -462,7 +462,7 @@ class MistralModel(torch.nn.Module): class FlashMistralForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, name=None): + def __init__(self, prefix: str, config, weights, name=None): if name is None: name = "model" super().__init__() diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 2d6a7f97..49c0e903 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -116,7 +116,7 @@ def promote_scalar(x: torch.Tensor) -> torch.Tensor: return x.view(1) if len(x.size()) == 0 else x -def load_attention(config, prefix, weights): +def load_attention(config, prefix: str, weights): if config.num_attention_heads != config.num_key_value_heads: return _load_gqa(config, prefix, weights) else: @@ -135,7 +135,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) @@ -155,7 +154,7 @@ def _load_gqa(config, prefix: str, weights): ) -def _load_experts(config, prefix, mat, weights): +def _load_experts(config, prefix: str, mat, weights): if config.quantize is not None: raise NotImplementedError("Mixtral does not support weight quantization yet.") @@ -475,7 +474,7 @@ class DenseMoE(nn.Module): class MixtralLayer(nn.Module): - def __init__(self, prefix, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() prefix = f"{prefix}.layers.{layer_id}" @@ -536,7 +535,7 @@ class MixtralLayer(nn.Module): class MixtralModel(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.embed_tokens = TensorParallelEmbedding( @@ -610,7 +609,7 @@ class MixtralModel(torch.nn.Module): class FlashMixtralForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.model = MixtralModel(prefix, config, weights) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 33aebc2b..85dcb2a6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -48,7 +48,7 @@ from text_generation_server.layers.rotary import ( def load_row(config, prefix: str, weights, bias: bool): - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_weights_row(prefix) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process @@ -64,7 +64,7 @@ def load_row(config, prefix: str, weights, bias: bool): def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): - weight = weights.get_multi_weights_col([prefix], quantize=config.quantize, dim=0) + weight = weights.get_multi_weights_col([prefix], dim=0) if isinstance(weight, torch.Tensor): # Only on non quantized versions weight = ( @@ -305,12 +305,12 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel): class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) self.config = config self.embed_in = TensorParallelEmbedding( - prefix="gpt_neox.embed_in", weights=weights + prefix=f"{prefix}.embed_in", weights=weights ) self.layers = nn.ModuleList( @@ -320,7 +320,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): ] ) self.final_layer_norm = FastLayerNorm.load( - prefix="gpt_neox.final_layer_norm", + prefix=f"{prefix}.final_layer_norm", weights=weights, eps=config.layer_norm_eps, ) @@ -370,9 +370,15 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__(config) - self.gpt_neox = FlashGPTNeoXModel(config, weights) + + if not prefix: + prefix = "gpt_neox" + else: + prefix = f"{prefix}.gpt_neox" + + self.gpt_neox = FlashGPTNeoXModel(prefix, config, weights) self.embed_out = SpeculativeHead.load( config, prefix="embed_out", weights=weights diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index f237ea37..6c508264 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -85,7 +85,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) @@ -258,9 +257,9 @@ class PhiMLP(nn.Module): class FlashPhiLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"model.layers.{layer_id}" + prefix = f"{prefix}.layers.{layer_id}" self.self_attn = FlashPhiAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) @@ -307,18 +306,19 @@ class FlashPhiLayer(nn.Module): class FlashPhiModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights + prefix=f"{prefix}.embed_tokens", weights=weights ) self.layers = nn.ModuleList( [ FlashPhiLayer( + prefix, layer_id, config, weights, @@ -378,10 +378,15 @@ class FlashPhiModel(torch.nn.Module): class FlashPhiForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() - self.model = FlashPhiModel(config, weights) + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.model = FlashPhiModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( config, prefix="lm_head", diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 1cc6a613..a98709c5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -203,9 +203,9 @@ class Qwen2MLP(nn.Module): class Qwen2Layer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix, layer_id, config, weights): super().__init__() - prefix = f"model.layers.{layer_id}" + prefix = f"{prefix}.layers.{layer_id}" self.self_attn = Qwen2Attention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) @@ -260,17 +260,18 @@ class Qwen2Layer(nn.Module): class Qwen2Model(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights + prefix=f"{prefix}.embed_tokens", weights=weights ) self.layers = nn.ModuleList( [ Qwen2Layer( + prefix, layer_id, config, weights, @@ -279,7 +280,7 @@ class Qwen2Model(torch.nn.Module): ] ) self.norm = FastRMSNorm.load( - prefix="model.norm", weights=weights, eps=config.rms_norm_eps + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps ) self.gradient_checkpointing = False @@ -331,10 +332,15 @@ class Qwen2Model(torch.nn.Module): class Qwen2ForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() - self.model = Qwen2Model(config, weights) + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.model = Qwen2Model(prefix, config, weights) self.lm_head = SpeculativeHead.load( config, prefix="lm_head", diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index e7614232..65b40fed 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -23,7 +23,7 @@ from text_generation_server.layers.attention import ( def load_row(config, prefix: str, weights, bias: bool): - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_weights_row(prefix) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process @@ -42,6 +42,7 @@ class RWConfig(PretrainedConfig): attribute_map = { "num_hidden_layers": "n_layer", "num_attention_heads": "n_head", + "num_key_value_heads": "n_head_kv", } def __init__( @@ -127,7 +128,7 @@ class FlashRWAttention(torch.nn.Module): def __init__( self, config, - prefix, + prefix: str, weights, ): super().__init__() @@ -236,7 +237,7 @@ class FlashRWLargeAttention(torch.nn.Module): def __init__( self, config, - prefix, + prefix: str, weights, ): super().__init__() @@ -358,7 +359,7 @@ class FlashRWLargeAttention(torch.nn.Module): class FlashMLP(nn.Module): - def __init__(self, config, prefix, weights): + def __init__(self, config, prefix: str, weights): super().__init__() self.act = torch.nn.functional.gelu @@ -380,6 +381,7 @@ class FlashRWLayer(nn.Module): def __init__( self, layer_id, + prefix: str, config, weights, ): @@ -388,7 +390,7 @@ class FlashRWLayer(nn.Module): parallel_attn = config.parallel_attn self.parallel_attn = parallel_attn - prefix = f"transformer.h.{layer_id}" + prefix = f"{prefix}.h.{layer_id}" self.input_layernorm = FastLayerNorm.load( prefix=f"{prefix}.input_layernorm", @@ -479,7 +481,7 @@ class FlashRWLayer(nn.Module): class FlashRWLayerNorm(nn.Module): - def __init__(self, config, prefix, weights): + def __init__(self, config, prefix: str, weights): super().__init__() self.num_ln = config.num_ln_in_parallel_attn @@ -518,9 +520,9 @@ class FlashRWLayerNorm(nn.Module): class FlashRWLargeLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, layer_id, prefix: str, config, weights): super().__init__() - prefix = f"transformer.h.{layer_id}" + prefix = f"{prefix}.h.{layer_id}" self.ln_layer = FlashRWLayerNorm(config, prefix, weights) @@ -580,18 +582,18 @@ class FlashRWPreTrainedModel(PreTrainedModel): class FlashRWModel(FlashRWPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) self.config = config self.word_embeddings = TensorParallelEmbedding( - prefix="transformer.word_embeddings", weights=weights + prefix=f"{prefix}.word_embeddings", weights=weights ) if config.new_decoder_architecture: self.h = nn.ModuleList( [ - FlashRWLargeLayer(layer_id, config, weights) + FlashRWLargeLayer(layer_id, prefix, config, weights) for layer_id in range(config.num_hidden_layers) ] ) @@ -599,14 +601,14 @@ class FlashRWModel(FlashRWPreTrainedModel): else: self.h = nn.ModuleList( [ - FlashRWLayer(layer_id, config, weights) + FlashRWLayer(layer_id, prefix, config, weights) for layer_id in range(config.num_hidden_layers) ] ) self.cache_size = self.h[0].self_attention.num_heads_kv self.ln_f = FastLayerNorm.load( - prefix="transformer.ln_f", + prefix=f"{prefix}.ln_f", weights=weights, eps=config.layer_norm_epsilon, ) @@ -653,10 +655,15 @@ class FlashRWModel(FlashRWPreTrainedModel): class FlashRWForCausalLM(FlashRWPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) - self.transformer = FlashRWModel(config, weights) + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + + self.transformer = FlashRWModel(prefix, config, weights) self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 30989a37..77b9d49c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -17,6 +17,7 @@ from text_generation_server.layers import ( TensorParallelEmbedding, get_linear, ) +from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.layernorm import ( FastLayerNorm, ) @@ -81,11 +82,13 @@ def _load_multi_mqa_gptq( qzeros = torch.cat([q_tensor, kv_tensor], dim=1) qzeros = qzeros.to(device=weights.device) - gptq_params = weights._get_gptq_params() - if gptq_params.quant_method == "gptq": + loader = weights.weights_loader + assert isinstance(loader, GPTQWeightsLoader) + loader._get_gptq_params(weights) + if loader.quant_method == "gptq": g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") g_idx = g_idx.to(device=weights.device) - elif gptq_params.quant_method == "awq": + elif loader.quant_method == "awq": g_idx = None from text_generation_server.layers.awq.conversion_utils import ( fast_awq_to_gptq, @@ -100,8 +103,8 @@ def _load_multi_mqa_gptq( qzeros=qzeros, scales=scales, g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, + bits=loader.bits, + groupsize=loader.groupsize, use_exllama=HAS_EXLLAMA, ) @@ -197,9 +200,7 @@ def load_col(config, prefix: str, weights, bias: bool): if config.transpose: weight = weights.get_sharded(f"{prefix}.weight", dim=1).T else: - weight = weights.get_multi_weights_col( - [prefix], quantize=config.quantize, dim=0 - ) + weight = weights.get_multi_weights_col([prefix], dim=0) if bias: bias = weights.get_sharded(f"{prefix}.bias", dim=0) @@ -212,7 +213,7 @@ def load_row(config, prefix: str, weights, bias: bool): if config.transpose: weight = weights.get_sharded(f"{prefix}.weight", dim=0).T else: - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_weights_row(prefix) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process @@ -346,16 +347,16 @@ class MLP(nn.Module): class Block(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"transformer.h.{layer_id}" + prefix = f"{prefix}.h.{layer_id}" self.ln_1 = FastLayerNorm.load( prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon ) self.ln_2 = FastLayerNorm.load( prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon ) - self.attn = FlashMQAttention( + self.self_attn = FlashMQAttention( prefix=f"{prefix}.attn", config=config, weights=weights, @@ -378,7 +379,7 @@ class Block(nn.Module): max_s, ): hidden_states, residual = self.ln_1(hidden_states, residual) - hidden_states = self.attn( + hidden_states = self.self_attn( hidden_states, cu_seqlen_prefill, kv_cache, @@ -396,25 +397,26 @@ class Block(nn.Module): class FlashSantacoderModel(nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.config = config self.process_group = weights.process_group self.wte = TensorParallelEmbedding( - prefix="transformer.wte", + prefix=f"{prefix}.wte", weights=weights, reduce=False, ) self.wpe = TensorParallelEmbedding( - prefix="transformer.wpe", + prefix=f"{prefix}.wpe", weights=weights, reduce=False, ) - self.h = nn.ModuleList( + self.layers = nn.ModuleList( [ Block( + prefix, layer_id, config, weights, @@ -426,8 +428,8 @@ class FlashSantacoderModel(nn.Module): prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon ) - self.head_size = self.h[0].attn.head_size - self.num_heads = self.h[0].attn.num_heads + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads def forward( self, @@ -446,7 +448,7 @@ class FlashSantacoderModel(nn.Module): torch.distributed.all_reduce(hidden_states, group=self.process_group) residual = None - for i, layer in enumerate(self.h): + for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, residual, @@ -464,11 +466,18 @@ class FlashSantacoderModel(nn.Module): class FlashSantacoderForCausalLM(nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__() - self.transformer = FlashSantacoderModel(config, weights) + + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + + config.transpose = config.architectures[0].startswith("GPT2") + self.model = FlashSantacoderModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( - config, prefix="transformer.wte", weights=weights + config, prefix=f"{prefix}.wte", weights=weights ) def forward( @@ -485,7 +494,7 @@ class FlashSantacoderForCausalLM(nn.Module): lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.transformer( + hidden_states = self.model( input_ids, position_ids, cu_seqlen_prefill, diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index a0273c37..19556f78 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -126,7 +126,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) @@ -417,14 +416,14 @@ class Starcoder2Layer(nn.Module): class Starcoder2Model(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights + prefix=f"{prefix}.embed_tokens", weights=weights ) self.layers = nn.ModuleList( [ @@ -437,7 +436,7 @@ class Starcoder2Model(torch.nn.Module): ] ) self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load( - prefix="model.norm", weights=weights, eps=config.norm_epsilon + prefix=f"{prefix}.norm", weights=weights, eps=config.norm_epsilon ) self.gradient_checkpointing = False @@ -489,10 +488,15 @@ class Starcoder2Model(torch.nn.Module): class FlashStarcoder2ForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__() - self.model = Starcoder2Model(config, weights) + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.model = Starcoder2Model(prefix, config, weights) try: self.lm_head = SpeculativeHead.load( config, @@ -502,7 +506,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): except RuntimeError: self.lm_head = SpeculativeHead.load( config, - prefix="model.embed_tokens", + prefix=f"{prefix}.embed_tokens", weights=weights, ) diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 6d38442c..567131ef 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -136,7 +136,7 @@ class LlavaNextForConditionalGeneration(nn.Module): self.config = config config.text_config.quantize = config.quantize config.text_config.speculator = config.speculator - self.language_model = load_text_model( + self.text_model = load_text_model( prefix="language_model" if not prefix else f"{prefix}.language_model", config=config.text_config, weights=weights, @@ -180,7 +180,7 @@ class LlavaNextForConditionalGeneration(nn.Module): image_sizes: Optional[torch.LongTensor] = None, adapter_data: Optional[torch.Tensor] = None, ): - inputs_embeds = self.language_model.embed_tokens(input_ids) + inputs_embeds = self.text_model.embed_tokens(input_ids) if pixel_values is not None and len(pixel_values) > 0: # num_special_image_tokens = (input_ids == self.config.image_token_index).sum() # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" @@ -269,7 +269,7 @@ class LlavaNextForConditionalGeneration(nn.Module): input_ids, inputs_embeds, image_features ) - hidden_states = self.language_model.model( + hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, @@ -283,5 +283,5 @@ class LlavaNextForConditionalGeneration(nn.Module): ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - logits, speculative_logits = self.language_model.lm_head(hidden_states) + logits, speculative_logits = self.text_model.lm_head(hidden_states) return logits, speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py index f7981bf5..fb09a8f1 100644 --- a/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -783,7 +783,7 @@ class MPTPreTrainedModel(PreTrainedModel): class MPTModel(MPTPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): # config._validate_config() super().__init__(config) self.world_size = weights.process_group.size() @@ -809,13 +809,13 @@ class MPTModel(MPTPreTrainedModel): f"Requested norm type ({config.norm_type}) is not implemented within this repo." ) - self.wte = TensorParallelEmbedding("transformer.wte", weights) + self.wte = TensorParallelEmbedding(f"{prefix}.wte", weights) if not self.alibi: - self.wpe = TensorParallelEmbedding("transformer.wpe", weights) + self.wpe = TensorParallelEmbedding(f"{prefix}.wpe", weights) self.blocks = nn.ModuleList( [ - MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights) + MPTBlock(config, prefix=f"{prefix}.blocks.{i}", weights=weights) for i in range(config.n_layers) ] ) @@ -1085,13 +1085,19 @@ class MPTModel(MPTPreTrainedModel): class MPTForCausalLM(MPTPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) + + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + if not config.tie_word_embeddings: raise ValueError("MPTForCausalLM only supports tied word embeddings") - self.transformer = MPTModel(config, weights) + self.transformer = MPTModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( - config, prefix="transformer.wte", weights=weights + config, prefix=f"{prefix}.wte", weights=weights ) self.logit_scale = None if config.logit_scale is not None: diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index fcad32fa..8998778f 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -404,24 +404,24 @@ class GPTNeoXMLP(nn.Module): class GPTNeoXLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, layer_id, prefix: str, config, weights): super().__init__() self.use_parallel_residual = config.use_parallel_residual self.input_layernorm = nn.LayerNorm.load( - prefix=f"gpt_neox.layers.{layer_id}.input_layernorm", + prefix=f"{prefix}.layers.{layer_id}.input_layernorm", weights=weights, eps=config.layer_norm_eps, ) self.post_attention_layernorm = nn.LayerNorm.load( - prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm", + prefix=f"{prefix}.layers.{layer_id}.post_attention_layernorm", weights=weights, eps=config.layer_norm_eps, ) self.attention = GPTNeoXAttention( - config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights + config, prefix=f"{prefix}.layers.{layer_id}.attention", weights=weights ) self.mlp = GPTNeoXMLP( - config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights + config, prefix=f"{prefix}.layers.{layer_id}.mlp", weights=weights ) def forward( @@ -472,23 +472,23 @@ class GPTNeoXLayer(nn.Module): class GPTNeoXModel(GPTNeoXPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) self.config = config self.num_attention_heads = config.num_attention_heads self.embed_in = TensorParallelEmbedding( - prefix="gpt_neox.embed_in", weights=weights + prefix=f"{prefix}.embed_in", weights=weights ) self.layers = nn.ModuleList( [ - GPTNeoXLayer(layer_id, config, weights) + GPTNeoXLayer(layer_id, prefix, config, weights) for layer_id in range(config.num_hidden_layers) ] ) self.final_layer_norm = nn.LayerNorm.load( - prefix="gpt_neox.final_layer_norm", + prefix=f"{prefix}.final_layer_norm", weights=weights, eps=config.layer_norm_eps, ) @@ -640,9 +640,15 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) - self.gpt_neox = GPTNeoXModel(config, weights) + + if not prefix: + prefix = "gpt_neox" + else: + prefix = f"{prefix}.gpt_neox" + + self.gpt_neox = GPTNeoXModel(prefix, config, weights) self.embed_out = SpeculativeHead.load( config, prefix="embed_out", weights=weights ) diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index 9b2d01e0..84a1c069 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -94,11 +94,11 @@ class OPTLearnedPositionalEmbedding(nn.Module): This module learns positional embeddings up to a fixed maximum size. """ - def __init__(self, weights): + def __init__(self, prefix: str, weights): super().__init__() self.offset = 2 self.weight = nn.Parameter( - weights.get_tensor("model.decoder.embed_positions.weight") + weights.get_tensor(f"{prefix}.decoder.embed_positions.weight") ) def forward( @@ -311,11 +311,11 @@ class OPTAttention(nn.Module): class OPTDecoderLayer(nn.Module): - def __init__(self, layer_id: int, config: OPTConfig, weights): + def __init__(self, layer_id: int, prefix: str, config: OPTConfig, weights): super().__init__() self.process_group = weights.process_group self.hidden_size = config.hidden_size - prefix = f"model.decoder.layers.{layer_id}" + prefix = f"{prefix}.decoder.layers.{layer_id}" self.self_attn = OPTAttention( config, prefix=f"{prefix}.self_attn", @@ -429,7 +429,7 @@ class OPTPreTrainedModel(PreTrainedModel): class OPTDecoder(OPTPreTrainedModel): - def __init__(self, config: OPTConfig, weights): + def __init__(self, prefix: str, config: OPTConfig, weights): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.layerdrop @@ -438,20 +438,26 @@ class OPTDecoder(OPTPreTrainedModel): self.vocab_size = config.vocab_size self.embed_tokens = TensorParallelEmbedding( - prefix="model.decoder.embed_tokens", weights=weights + prefix=f"{prefix}.decoder.embed_tokens", weights=weights ) - self.embed_positions = OPTLearnedPositionalEmbedding(weights) + self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights) if config.word_embed_proj_dim != config.hidden_size: self.project_out = FastLinear.load( - config, prefix="model.decoder.project_out", weights=weights, bias=False + config, + prefix=f"{prefix}.decoder.project_out", + weights=weights, + bias=False, ) else: self.project_out = None if config.word_embed_proj_dim != config.hidden_size: self.project_in = FastLinear.load( - config, prefix="model.decoder.project_in", weights=weights, bias=False + config, + prefix=f"{prefix}.decoder.project_in", + weights=weights, + bias=False, ) else: self.project_in = None @@ -461,14 +467,14 @@ class OPTDecoder(OPTPreTrainedModel): # see https://github.com/facebookresearch/metaseq/pull/164 if config.do_layer_norm_before and not config._remove_final_layer_norm: self.final_layer_norm = nn.LayerNorm.load( - prefix="model.decoder.final_layer_norm", weights=weights, eps=EPS + prefix=f"{prefix}.decoder.final_layer_norm", weights=weights, eps=EPS ) else: self.final_layer_norm = None self.layers = nn.ModuleList( [ - OPTDecoderLayer(layer_id, config, weights) + OPTDecoderLayer(layer_id, prefix, config, weights) for layer_id in range(config.num_hidden_layers) ] ) @@ -686,9 +692,9 @@ class OPTDecoder(OPTPreTrainedModel): class OPTModel(OPTPreTrainedModel): - def __init__(self, config: OPTConfig, weights): + def __init__(self, prefix: str, config: OPTConfig, weights): super().__init__(config) - self.decoder = OPTDecoder(config, weights) + self.decoder = OPTDecoder(prefix, config, weights) # Initialize weights and apply final processing def forward( @@ -743,13 +749,18 @@ class OPTModel(OPTPreTrainedModel): class OPTForCausalLM(OPTPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__(config) - self.model = OPTModel(config, weights) + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.model = OPTModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( - config, prefix="model.decoder.embed_tokens", weights=weights + config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights ) def forward( diff --git a/server/text_generation_server/models/custom_modeling/phi_modeling.py b/server/text_generation_server/models/custom_modeling/phi_modeling.py index 04b470eb..b4d56db1 100644 --- a/server/text_generation_server/models/custom_modeling/phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/phi_modeling.py @@ -248,16 +248,16 @@ class PhiBlock(nn.Module): # PhiModel implements the embedding layer and the transformer blocks. class PhiModel(nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.tp_rank = weights.process_group.rank() self.tp_world_size = weights.process_group.size() self.embed_tokens = TensorParallelEmbedding( - prefix="transformer.embd.wte", weights=weights + prefix=f"{prefix}.embd.wte", weights=weights ) self.blocks = nn.ModuleList( [ - PhiBlock(f"transformer.h.{layer_id}", config, weights) + PhiBlock(f"{prefix}.h.{layer_id}", config, weights) for layer_id in range(config.n_layer) ] ) @@ -289,9 +289,15 @@ class PhiModel(nn.Module): # PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object. class PhiForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() - self.model = PhiModel(config, weights) + + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + + self.model = PhiModel(prefix, config, weights) self.lm_head = PhiCausalLMHead(config, weights) def forward( diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 4f276ed4..2ca9eef3 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -10,7 +10,12 @@ import numpy as np from loguru import logger from dataclasses import dataclass from opentelemetry import trace -from transformers import PreTrainedTokenizerBase +from transformers import ( + PreTrainedTokenizerBase, + AutoConfig, + AutoTokenizer, + GenerationConfig, +) from typing import Iterable, Optional, Tuple, List, Type, Dict from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata @@ -21,6 +26,12 @@ from text_generation_server.models import Model from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.dist import RANK from text_generation_server.utils.speculate import get_speculate +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, + hub, +) from text_generation_server.models.types import ( Batch, Tokens, @@ -39,6 +50,7 @@ from text_generation_server.models.globals import ( from text_generation_server.layers.attention import Seqlen from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION +from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments from text_generation_server.utils.import_utils import ( @@ -799,29 +811,123 @@ class FlashCausalLMBatch(Batch): return len(self.requests) +ADAPTER_LAYERS = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", +] +ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} + + class FlashCausalLM(Model): def __init__( self, model_id: str, - model: torch.nn.Module, - tokenizer: PreTrainedTokenizerBase, - num_layers: int, - num_kv_heads: int, - head_size: int, - dtype: torch.dtype, - device: torch.device, - rank: int = 0, - world_size: int = 1, - sliding_window: Optional[int] = None, + model_class, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + lora_adapter_ids: Optional[list] = [], + tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer, + config_class: PreTrainedTokenizerBase = AutoConfig, + default_dtype=torch.float16, + aliases=None, + # Used for Santacoder override of config + num_kv_heads=None, + skip_special_tokens: bool = True, ): - self.num_layers = num_layers - self.num_kv_heads = num_kv_heads - self.head_size = head_size + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = default_dtype if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = default_dtype if dtype is None else dtype + else: + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype + else: + raise NotImplementedError(f"{model_class} is only available on GPU") + + tokenizer = tokenizer_class.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + try: + generation_config = GenerationConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + if isinstance(generation_config.eos_token_id, (list, set)): + # TODO Huge hack + tokenizer._eos_token_ids = set(generation_config.eos_token_id) + except Exception: + pass + + config = config_class.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + config.quantize = quantize + config.speculator = speculator + + torch.distributed.barrier(group=self.process_group) + + weights_loader = get_loader(quantize, model_id, revision) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, + device, + dtype, + process_group=self.process_group, + aliases=aliases, + weights_loader=weights_loader, + ) + + prefix = "" + model = model_class(prefix, config, weights) + torch.distributed.barrier(group=self.process_group) + + # VLM models define the config we care about in their text_config + text_config = getattr(config, "text_config", None) + if text_config is not None: + config = text_config + + if getattr(config, "sliding_window", None) is not None: + set_sliding_window(config.sliding_window) + else: + config.sliding_window = None + + self.num_layers = config.num_hidden_layers + # Validation is done in the model itself + if num_kv_heads is None: + num_kv_heads = getattr(config, "num_key_value_heads", None) + # GPT-2 workaround + if num_kv_heads is None: + num_kv_heads = getattr(config, "n_head", None) + if num_kv_heads is None: + raise ValueError("Cannot get the number of key/value heads") + self.num_kv_heads = ( + num_kv_heads // self.process_group.size() + if num_kv_heads > 1 + else num_kv_heads + ) + assert self.num_kv_heads > 0 + self.head_size = config.hidden_size // config.num_attention_heads self.cuda_graphs = {} self.kv_cache = [] - super(FlashCausalLM, self).__init__( + super().__init__( model_id=model_id, model=model, tokenizer=tokenizer, @@ -830,7 +936,7 @@ class FlashCausalLM(Model): device=device, rank=rank, world_size=world_size, - sliding_window=sliding_window, + sliding_window=config.sliding_window, ) @property @@ -1578,3 +1684,72 @@ class FlashCausalLM(Model): forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns) + + @property + def supports_adapter_loading(self) -> bool: + return True + + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: + layer_weights = {} + + prefix = "model.layers" + + # This accounts for VLMs (e.g. LlavaNext, Idefics2) + # that have a language_model inside of the larger model. + if hasattr(self.model, "language_model"): + _model = self.model.language_model + elif hasattr(self.model, "text_model"): + _model = self.model.text_model + else: + _model = self.model + + for i, layer in enumerate(_model.model.layers): + layer_weights[(i, "q_proj")] = ( + f"{prefix}.{i}.self_attn.q_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, "k_proj")] = ( + f"{prefix}.{i}.self_attn.k_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, "v_proj")] = ( + f"{prefix}.{i}.self_attn.v_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, "o_proj")] = ( + f"{prefix}.{i}.self_attn.o_proj", + layer.self_attn.o_proj, + ) + + # TODO: this is a hack to avoid the gate_proj for + # FlashStarcoder2 that doesnt have these layers + if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"): + layer_weights[(i, "gate_proj")] = ( + f"{prefix}.{i}.mlp.gate_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, "up_proj")] = ( + f"{prefix}.{i}.mlp.up_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, "down_proj")] = ( + f"{prefix}.{i}.mlp.down_proj", + layer.mlp.down_proj, + ) + + layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head) + return layer_weights + + @property + def adapter_layers(self) -> List[str]: + return ADAPTER_LAYERS + + @property + def default_traced_adapter_layers(self) -> List[str]: + return ["q_proj", "v_proj"] + + def get_num_layers_for_type(self, layer_type: str) -> int: + return 1 if layer_type == "lm_head" else len(self.model.model.layers) + + def is_row_parallel(self, layer_type: str) -> bool: + return layer_type in ROW_PARALLEL diff --git a/server/text_generation_server/models/flash_cohere.py b/server/text_generation_server/models/flash_cohere.py deleted file mode 100644 index 9f8bcb3f..00000000 --- a/server/text_generation_server/models/flash_cohere.py +++ /dev/null @@ -1,75 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from typing import Optional -from transformers import AutoTokenizer, AutoConfig - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_cohere_modeling import ( - FlashCohereForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - -tracer = trace.get_tracer(__name__) - - -class FlashCohere(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - raise NotImplementedError("FlashCohere is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - use_fast=True, - from_slow=False, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashCohereForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(FlashCohere, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_dbrx.py b/server/text_generation_server/models/flash_dbrx.py deleted file mode 100644 index 2aba6a00..00000000 --- a/server/text_generation_server/models/flash_dbrx.py +++ /dev/null @@ -1,100 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from typing import Optional -from transformers import AutoTokenizer -from transformers.models.gpt2 import GPT2TokenizerFast - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_dbrx_modeling import ( - FlashDbrxForCausalLM, - DbrxConfig, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - -tracer = trace.get_tracer(__name__) - - -class FlashDbrx(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashDBRX is only available on GPU") - - try: - tokenizer = GPT2TokenizerFast.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - use_fast=True, - from_slow=False, - ) - except: - try: - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - use_fast=True, - from_slow=False, - ) - except: - # FIXME: change back to model id once the tokenizer.json is merged - tokenizer = GPT2TokenizerFast.from_pretrained( - "Xenova/dbrx-instruct-tokenizer", - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - use_fast=True, - from_slow=False, - ) - - config = DbrxConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashDbrxForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(FlashDbrx, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py deleted file mode 100644 index 7e2b8780..00000000 --- a/server/text_generation_server/models/flash_gemma.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from typing import Optional -from transformers import AutoConfig, AutoTokenizer - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( - FlashGemmaForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashGemma(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.bfloat16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashGemma is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - # TODO hardcoded - prefix = "" - model = FlashGemmaForCausalLM(prefix, config, weights, causal=True) - - torch.distributed.barrier(group=self.process_group) - super(FlashGemma, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_gemma2.py b/server/text_generation_server/models/flash_gemma2.py deleted file mode 100644 index 86cfc7e2..00000000 --- a/server/text_generation_server/models/flash_gemma2.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from typing import Optional -from transformers import PretrainedConfig, AutoTokenizer - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( - FlashGemma2ForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashGemma2(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.bfloat16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashGemma2 is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = PretrainedConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - # TODO hardcoded - prefix = "" - model = FlashGemma2ForCausalLM(prefix, config, weights, causal=True) - - torch.distributed.barrier(group=self.process_group) - super(FlashGemma2, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_gpt2.py b/server/text_generation_server/models/flash_gpt2.py deleted file mode 100644 index 323fcafa..00000000 --- a/server/text_generation_server/models/flash_gpt2.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoConfig, AutoTokenizer, GenerationConfig -from transformers.models.gpt2 import GPT2Tokenizer -from typing import Optional - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_gpt2_modeling import ( - FlashGPT2ForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashGPT2(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashGPT2 is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - prefix = "" - model = FlashGPT2ForCausalLM(prefix, config, weights) - torch.distributed.barrier(group=self.process_group) - super(FlashGPT2, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py deleted file mode 100644 index d996b9c3..00000000 --- a/server/text_generation_server/models/flash_llama.py +++ /dev/null @@ -1,171 +0,0 @@ -import os -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoConfig, AutoTokenizer, GenerationConfig -from typing import Optional, Tuple, Dict, List - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_llama_modeling import ( - FlashLlamaForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, - hub, -) - -tracer = trace.get_tracer(__name__) - -from text_generation_server.utils.import_utils import SYSTEM - -ADAPTER_LAYERS = [ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", -] -ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} - - -class FlashLlama(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - lora_adapter_ids: Optional[list] = [], - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashLlama is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - try: - generation_config = GenerationConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - if isinstance(generation_config.eos_token_id, (list, set)): - # TODO Huge hack - tokenizer._eos_token_ids = set(generation_config.eos_token_id) - except Exception: - pass - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["awq", "exl2", "gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - prefix = "" - model = FlashLlamaForCausalLM(prefix, config, weights) - torch.distributed.barrier(group=self.process_group) - super(FlashLlama, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - @property - def supports_adapter_loading(self) -> bool: - return True - - def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: - layer_weights = {} - - prefix = "model.layers" - - # This accounts for VLMs (e.g. LlavaNext, Idefics2) - # that have a language_model inside of the larger model. - if hasattr(self.model, "language_model"): - _model = self.model.language_model - elif hasattr(self.model, "text_model"): - _model = self.model.text_model - else: - _model = self.model - - for i, layer in enumerate(_model.model.layers): - layer_weights[(i, "q_proj")] = ( - f"{prefix}.{i}.self_attn.q_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "k_proj")] = ( - f"{prefix}.{i}.self_attn.k_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "v_proj")] = ( - f"{prefix}.{i}.self_attn.v_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "o_proj")] = ( - f"{prefix}.{i}.self_attn.o_proj", - layer.self_attn.o_proj, - ) - - layer_weights[(i, "gate_proj")] = ( - f"{prefix}.{i}.mlp.gate_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "up_proj")] = ( - f"{prefix}.{i}.mlp.up_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "down_proj")] = ( - f"{prefix}.{i}.mlp.down_proj", - layer.mlp.down_proj, - ) - - layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head) - return layer_weights - - @property - def adapter_layers(self) -> List[str]: - return ADAPTER_LAYERS - - @property - def default_traced_adapter_layers(self) -> List[str]: - return ["q_proj", "v_proj"] - - def get_num_layers_for_type(self, layer_type: str) -> int: - return 1 if layer_type == "lm_head" else len(self.model.model.layers) - - def is_row_parallel(self, layer_type: str) -> bool: - return layer_type in ROW_PARALLEL diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 0f5746de..2b2bd2e0 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -1,24 +1,7 @@ import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer, AutoConfig from typing import Optional, Tuple, Dict, List from text_generation_server.models import FlashCausalLM -from text_generation_server.models.flash_causal_lm import set_sliding_window -from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( - FlashMistralForCausalLM, - MistralConfig, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) ADAPTER_LAYERS = [ @@ -33,88 +16,7 @@ ADAPTER_LAYERS = [ ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} -class BaseFlashMistral(FlashCausalLM): - def __init__( - self, - model_cls, - model_id: str, - config_cls=AutoConfig, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - tokenizer_class=AutoTokenizer, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashMistral is only available on GPU") - - tokenizer = tokenizer_class.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = config_cls.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - # Set context windows - if getattr(config, "sliding_window", None) is not None: - set_sliding_window(config.sliding_window) - else: - config.sliding_window = None - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - prefix = "" - model = model_cls(prefix, config, weights) - - self.cuda_graphs = {} - - torch.distributed.barrier(group=self.process_group) - num_layers, num_kv_heads, head_size = self.get_layer_config(model) - super().__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=num_layers, - num_kv_heads=num_kv_heads, - head_size=head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - sliding_window=config.sliding_window, - ) - - def get_layer_config(self, model) -> Tuple[int, int, int]: - return ( - len(model.model.layers), - model.model.num_key_value_heads, - model.model.head_size, - ) - +class FlashMistral(FlashCausalLM): @property def supports_adapter_loading(self) -> bool: return True @@ -126,9 +28,7 @@ class BaseFlashMistral(FlashCausalLM): # This accounts for VLMs (e.g. LlavaNext, Idefics2) # that have a language_model inside of the larger model. - if hasattr(self.model, "language_model"): - _model = self.model.language_model - elif hasattr(self.model, "text_model"): + if hasattr(self.model, "text_model"): _model = self.model.text_model else: _model = self.model @@ -183,25 +83,3 @@ class BaseFlashMistral(FlashCausalLM): def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL - - -class FlashMistral(BaseFlashMistral): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - super(FlashMistral, self).__init__( - config_cls=MistralConfig, - model_cls=FlashMistralForCausalLM, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) diff --git a/server/text_generation_server/models/flash_mixtral.py b/server/text_generation_server/models/flash_mixtral.py deleted file mode 100644 index 587d423f..00000000 --- a/server/text_generation_server/models/flash_mixtral.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch - -from typing import Optional - -from text_generation_server.models.flash_mistral import BaseFlashMistral -from text_generation_server.models.custom_modeling.flash_mixtral_modeling import ( - MixtralConfig, - FlashMixtralForCausalLM, -) - - -class FlashMixtral(BaseFlashMistral): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - super(FlashMixtral, self).__init__( - config_cls=MixtralConfig, - model_cls=FlashMixtralForCausalLM, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py deleted file mode 100644 index ac1fd573..00000000 --- a/server/text_generation_server/models/flash_neox.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer, AutoConfig -from typing import Optional - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_neox_modeling import ( - FlashGPTNeoXForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashNeoXSharded(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashNeoX is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashGPTNeoXForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(FlashNeoXSharded, self).__init__( - model_id=model_id, - model=model.to(device), - tokenizer=tokenizer, - num_layers=len(model.gpt_neox.layers), - num_kv_heads=model.gpt_neox.num_heads, - head_size=model.gpt_neox.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py deleted file mode 100644 index a530d1c3..00000000 --- a/server/text_generation_server/models/flash_phi.py +++ /dev/null @@ -1,111 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoConfig, AutoTokenizer -from typing import Optional - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_phi_modeling import ( - FlashPhiForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashPhi(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashPhi is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashPhiForCausalLM(config, weights) - if speculator: - from text_generation_server.utils.medusa import MedusaModel - from huggingface_hub import hf_hub_download - import json - import os - from pathlib import Path - - is_local_model = ( - Path(speculator).exists() and Path(speculator).is_dir() - ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None - - if not is_local_model: - medusa_config = hf_hub_download( - speculator, revision=revision, filename="config.json" - ) - medusa_head = hf_hub_download( - speculator, revision=revision, filename="medusa_lm_head.pt" - ) - else: - medusa_config = str(Path(speculator) / "config.json") - medusa_head = str(Path(speculator) / "medusa_lm_head.pt") - - with open(medusa_config, "r") as f: - config = json.load(f) - medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" - weights = Weights( - [medusa_sf], device, dtype, process_group=self.process_group - ) - lm_head = model.lm_head - model.lm_head = MedusaModel(config, weights, lm_head) - - torch.distributed.barrier(group=self.process_group) - super(FlashPhi, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py deleted file mode 100644 index cd6078f1..00000000 --- a/server/text_generation_server/models/flash_qwen2.py +++ /dev/null @@ -1,93 +0,0 @@ -import math - -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer, AutoConfig -from typing import Optional - -from text_generation_server.models.flash_mistral import ( - BaseFlashMistral, - set_sliding_window, -) -from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( - Qwen2ForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashQwen2(BaseFlashMistral): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashQwen2 is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - # Set context windows - if config.sliding_window is not None: - set_sliding_window(config.sliding_window) - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = Qwen2ForCausalLM(config, weights) - - self.cuda_graphs = {} - - torch.distributed.barrier(group=self.process_group) - super(BaseFlashMistral, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - sliding_window=config.sliding_window, - ) diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py deleted file mode 100644 index b1f75adc..00000000 --- a/server/text_generation_server/models/flash_rw.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer -from typing import Optional - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_rw_modeling import ( - RWConfig, - FlashRWForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashRWSharded(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashRW is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = RWConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device, - dtype, - process_group=self.process_group, - aliases={ - "lm_head.weight": ["transformer.word_embeddings.weight"], - "transformer.word_embeddings.weight": ["lm_head.weight"], - }, - ) - - config.quantize = quantize - config.speculator = speculator - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashRWForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(FlashRWSharded, self).__init__( - model_id=model_id, - model=model.to(device), - tokenizer=tokenizer, - num_layers=len(model.transformer.h), - num_kv_heads=model.transformer.cache_size, - head_size=model.transformer.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py deleted file mode 100644 index e1a7b36e..00000000 --- a/server/text_generation_server/models/flash_santacoder.py +++ /dev/null @@ -1,99 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer, AutoConfig -from typing import Optional, List -import json -import os - -from huggingface_hub import hf_hub_download -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( - FlashSantacoderForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashSantacoderSharded(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashSantacoderSharded is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - trust_remote_code=True, - ) - config.quantize = quantize - config.speculator = speculator - config.transpose = config.architectures[0].startswith("GPT2") - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device=device, - dtype=dtype, - process_group=self.process_group, - aliases={"transformer.wte.weight": ["lm_head.weight"]}, - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashSantacoderForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(FlashSantacoderSharded, self).__init__( - model_id=model_id, - model=model.to(device), - tokenizer=tokenizer, - num_layers=len(model.transformer.h), - num_kv_heads=1, - head_size=model.transformer.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - def decode(self, generated_ids: List[int]) -> str: - # Do not skip special tokens as they are used for custom parsing rules of the generated text - return self.tokenizer.decode( - generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False - ) diff --git a/server/text_generation_server/models/flash_starcoder2.py b/server/text_generation_server/models/flash_starcoder2.py deleted file mode 100644 index 369e9e4c..00000000 --- a/server/text_generation_server/models/flash_starcoder2.py +++ /dev/null @@ -1,84 +0,0 @@ -import math - -import torch - -from typing import Optional - -from transformers.models.gpt2 import GPT2TokenizerFast - -from text_generation_server.models.flash_mistral import ( - BaseFlashMistral, - set_sliding_window, -) -from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import ( - Starcoder2Config, - FlashStarcoder2ForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - - -# Starcoder2 has the same base as Mistral -class FlashStarcoder2(BaseFlashMistral): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - raise NotImplementedError("FlashStarcoder2 is only available on GPU") - - tokenizer = GPT2TokenizerFast.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = Starcoder2Config.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - # Set context windows - if config.sliding_window is not None: - set_sliding_window(config.sliding_window) - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashStarcoder2ForCausalLM(config, weights) - - self.cuda_graphs = {} - - torch.distributed.barrier(group=self.process_group) - super(BaseFlashMistral, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - sliding_window=config.sliding_window, - ) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 30c92d90..2d43244a 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -162,83 +162,3 @@ class GalacticaCausalLMBatch(CausalLMBatch): padding_right_offset=padding_right_offset, max_tokens=max_tokens, ) - - -class GalacticaSharded(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - tp_parallel=True, - trust_remote_code=trust_remote_code, - ) - config.quantize = quantize - tokenizer.pad_token_id = config.pad_token_id - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = OPTForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - @property - def batch_type(self) -> Type[CausalLMBatch]: - return GalacticaCausalLMBatch - - def decode(self, generated_ids: List[int]) -> str: - # Do not skip special tokens as they are used for custom parsing rules of the generated text - return self.tokenizer.decode( - generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False - ) - - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ): - outputs, speculative_logits = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) - return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py deleted file mode 100644 index c37cfb7d..00000000 --- a/server/text_generation_server/models/gpt_neox.py +++ /dev/null @@ -1,89 +0,0 @@ -import torch -import torch.distributed - -from typing import Optional - -from transformers import ( - AutoTokenizer, - AutoConfig, -) -from text_generation_server.models import CausalLM -from text_generation_server.models.custom_modeling.neox_modeling import ( - GPTNeoxForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - - -class GPTNeoxSharded(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - tokenizer.pad_token = tokenizer.eos_token - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = GPTNeoxForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ): - outputs, speculative_logits = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=True, - ) - - return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index f2955bd0..0deab6ce 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -23,6 +23,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.quantization import get_loader class IDEFICSSharded(IdeficsCausalLM): @@ -70,6 +71,9 @@ class IDEFICSSharded(IdeficsCausalLM): trust_remote_code=trust_remote_code, ) + weights_loader = get_loader( + quantize=quantize, model_id=model_id, revision=revision + ) torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( @@ -77,6 +81,7 @@ class IDEFICSSharded(IdeficsCausalLM): device=device, dtype=dtype, process_group=self.process_group, + weights_loader=weights_loader, ) model = IdeficsForVisionText2Text(config, weights) diff --git a/server/text_generation_server/models/idefics2.py b/server/text_generation_server/models/idefics2.py deleted file mode 100644 index 314c0500..00000000 --- a/server/text_generation_server/models/idefics2.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch - -from typing import Optional, Tuple - -from transformers import ( - AutoProcessor, -) -from text_generation_server.models.custom_modeling.idefics2 import ( - Idefics2ForConditionalGeneration, -) - -from text_generation_server.models.vlm_causal_lm import VlmCausalLM - - -class Idefics2(VlmCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.processor = AutoProcessor.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - # XXX: Extremely important to cap resolution in order to limit - # VRAM usage. - size={"longest_edge": 448, "shortest_edge": 378}, - ) - super().__init__( - model_cls=Idefics2ForConditionalGeneration, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - def get_layer_config(self, model) -> Tuple[int, int, int]: - return ( - len(model.text_model.model.layers), - model.text_model.model.num_key_value_heads, - model.text_model.model.head_size, - ) - - def max_past(self) -> Optional[int]: - return getattr(self.model.text_model, "max_past", None) diff --git a/server/text_generation_server/models/llava_next.py b/server/text_generation_server/models/llava_next.py deleted file mode 100644 index effe8b91..00000000 --- a/server/text_generation_server/models/llava_next.py +++ /dev/null @@ -1,46 +0,0 @@ -import torch - -from typing import Optional, Tuple - -from transformers import ( - AutoProcessor, -) -from text_generation_server.models.custom_modeling.llava_next import ( - LlavaNextForConditionalGeneration, -) - -from text_generation_server.models.vlm_causal_lm import VlmCausalLM - - -class LlavaNext(VlmCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.processor = AutoProcessor.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - super().__init__( - model_cls=LlavaNextForConditionalGeneration, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - def get_layer_config(self, model) -> Tuple[int, int, int]: - return ( - len(model.language_model.model.layers), - model.language_model.model.num_key_value_heads, - model.language_model.model.head_size, - ) - - def max_past(self) -> Optional[int]: - return getattr(self.model.language_model, "max_past", None) diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 9189b45c..4ed9722c 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -28,6 +28,7 @@ from text_generation_server.models.types import ( GeneratedText, ) from text_generation_server.utils.chunks import concat_text_chunks +from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.tokens import batch_top_tokens, Sampling from dataclasses import dataclass from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling @@ -448,8 +449,17 @@ class Mamba(Model): config.quantize = quantize config.speculator = speculator torch.distributed.barrier(group=self.process_group) + weights_loader = get_loader( + quantize=quantize, model_id=model_id, revision=revision + ) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) + weights = Weights( + filenames, + device, + dtype, + process_group=self.process_group, + weights_loader=weights_loader, + ) model = MambaModel(config, weights) torch.distributed.barrier(group=self.process_group) super(Mamba, self).__init__( diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index c90fd38a..09130b85 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -60,7 +60,7 @@ class Model(ABC): self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict( LayerAdapterWeights ) - self.target_to_layer = self.adapter_target_to_layer() + self.target_to_layer = None self.loaded_adapters = set() self.static_adapter_id = adapter_id @@ -187,6 +187,8 @@ class Model(ABC): into model. Otherwise, the adapter weights are applied during the forward pass and stored separately from the base model parameters. """ + if self.target_to_layer is None: + self.target_to_layer = self.adapter_target_to_layer() if adapter_index in self.loaded_adapters: # Adapter already loaded return diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py deleted file mode 100644 index 1e79b25f..00000000 --- a/server/text_generation_server/models/mpt.py +++ /dev/null @@ -1,105 +0,0 @@ -import torch -import torch.distributed - -from pathlib import Path -from typing import Optional, Type -from opentelemetry import trace -from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase -from huggingface_hub import hf_hub_download -import json - -from text_generation_server.models import CausalLM -from text_generation_server.models.causal_lm import CausalLMBatch -from text_generation_server.pb import generate_pb2 -from text_generation_server.models.custom_modeling.mpt_modeling import ( - MPTForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - -tracer = trace.get_tracer(__name__) - - -class MPTCausalLMBatch(CausalLMBatch): - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "CausalLMBatch": - batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device) - batch.keys_head_dim_last = False - return batch - - -class MPTSharded(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - tokenizer.pad_token = tokenizer.eos_token - - # If model_id is a local path, load the file directly - local_path = Path(model_id, "config.json") - if local_path.exists(): - filename = str(local_path.resolve()) - else: - filename = hf_hub_download( - model_id, revision=revision, filename="config.json" - ) - with open(filename, "r") as f: - config = json.load(f) - config = PretrainedConfig(**config) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - config.quantize = quantize - model = MPTForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=False, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - @property - def batch_type(self) -> Type[CausalLMBatch]: - return MPTCausalLMBatch diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py deleted file mode 100644 index 6d7d07f5..00000000 --- a/server/text_generation_server/models/opt.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -import torch.distributed - -from typing import Optional - -from transformers import ( - AutoTokenizer, - AutoConfig, -) -from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM -from text_generation_server.models import CausalLM -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - - -class OPTSharded(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) - config.quantize = quantize - config.speculator = speculator - tokenizer.pad_token_id = config.pad_token_id - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = OPTForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ): - outputs, speculative_logits = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) - - return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/pali_gemma.py b/server/text_generation_server/models/pali_gemma.py index a167e467..3994ac70 100644 --- a/server/text_generation_server/models/pali_gemma.py +++ b/server/text_generation_server/models/pali_gemma.py @@ -74,45 +74,3 @@ class PaliGemmaBatch(VlmCausalLMBatch): else: image_inputs = None return batch_tokenized_inputs, image_inputs - - -class PaliGemma(VlmCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.processor = AutoProcessor.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) - - super().__init__( - config_cls=AutoConfig, - model_cls=PaliGemmaForConditionalGeneration, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - @property - def batch_type(self): - return PaliGemmaBatch - - def get_layer_config(self, model) -> Tuple[int, int, int]: - return ( - len(model.text_model.model.layers), - model.text_model.model.num_key_value_heads, - model.text_model.model.head_size, - ) - - def max_past(self) -> Optional[int]: - return getattr(self.model.text_model, "max_past", None) diff --git a/server/text_generation_server/models/phi.py b/server/text_generation_server/models/phi.py deleted file mode 100644 index 93d42b2b..00000000 --- a/server/text_generation_server/models/phi.py +++ /dev/null @@ -1,69 +0,0 @@ -import torch -import torch.distributed - -from transformers import AutoConfig, AutoTokenizer -from typing import Optional, List, Tuple - -from text_generation_server.models import CausalLM -from text_generation_server.models.custom_modeling.phi_modeling import ( - PhiConfig, - PhiForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - - -class Phi(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, _rank, _world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.float16 if dtype is None else dtype - else: - if quantize: - raise ValueError("quantization is not available on CPU") - - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - config = PhiConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - - tokenizer.bos_token_id = config.bos_token_id - tokenizer.eos_token_id = config.eos_token_id - tokenizer.pad_token = tokenizer.eos_token - - config.quantize = quantize - config.speculator = speculator - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - model = PhiForCausalLM(config, weights) - torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - ) diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py deleted file mode 100644 index 37ca277b..00000000 --- a/server/text_generation_server/models/rw.py +++ /dev/null @@ -1,84 +0,0 @@ -import torch - -from transformers import AutoTokenizer, AutoModelForCausalLM -from typing import List, Optional, Tuple - -from text_generation_server.models import CausalLM - - -class RW(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - if speculator: - raise RuntimeError("Medusa decoding is not enabled for AutoModel") - - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.float16 if dtype is None else dtype - else: - if quantize: - raise ValueError("quantization is not available on CPU") - - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - model = AutoModelForCausalLM.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - device_map=( - "auto" - if torch.cuda.is_available() and torch.cuda.device_count() > 1 - else None - ), - load_in_8bit=quantize == "bitsandbytes", - trust_remote_code=trust_remote_code, - ) - if torch.cuda.is_available() and torch.cuda.device_count() == 1: - model = model.cuda() - - if tokenizer.pad_token_id is None: - if model.config.pad_token_id is not None: - tokenizer.pad_token_id = model.config.pad_token_id - elif model.config.eos_token_id is not None: - tokenizer.pad_token_id = model.config.eos_token_id - elif tokenizer.eos_token_id is not None: - tokenizer.pad_token_id = tokenizer.eos_token_id - else: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - ) - - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ): - # Model Forward - outputs, speculative_logits = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) - - return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py deleted file mode 100644 index caddbe19..00000000 --- a/server/text_generation_server/models/santacoder.py +++ /dev/null @@ -1,77 +0,0 @@ -import torch -import torch.distributed - -from typing import Optional, List -from transformers import AutoTokenizer, AutoModelForCausalLM - -from text_generation_server.models import CausalLM - -FIM_PREFIX = "" -FIM_MIDDLE = "" -FIM_SUFFIX = "" -FIM_PAD = "" -EOD = "<|endoftext|>" - - -class SantaCoder(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.float16 if dtype is None else dtype - else: - if quantize: - raise ValueError("quantization is not available on CPU") - - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - tokenizer.add_special_tokens( - { - "additional_special_tokens": [ - EOD, - FIM_PREFIX, - FIM_MIDDLE, - FIM_SUFFIX, - FIM_PAD, - ], - "pad_token": EOD, - } - ) - with device: - model = AutoModelForCausalLM.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - load_in_8bit=quantize == "bitsandbytes", - trust_remote_code=trust_remote_code, - ) - - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - ) - - def decode(self, generated_ids: List[int]) -> str: - # Do not skip special tokens as they are used for custom parsing rules of the generated text - return self.tokenizer.decode( - generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False - ) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index d454d804..fa8b5025 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -1,12 +1,24 @@ import torch +import torch.distributed import time from dataclasses import dataclass from opentelemetry import trace -from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase +from transformers import ( + AutoTokenizer, + AutoModelForSeq2SeqLM, + PreTrainedTokenizerBase, + AutoConfig, +) from typing import Optional, Tuple, List, Type, Dict +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) from text_generation_server.utils.chunks import concat_text_chunks +from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models import Model from text_generation_server.models.types import ( @@ -531,6 +543,84 @@ class Seq2SeqLM(Model): def __init__( self, model_id: str, + model_class, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + default_dtype=torch.float16, + trust_remote_code: bool = False, + config_class=AutoConfig, + tokenizer_class=AutoTokenizer, + aliases=None, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = default_dtype if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = default_dtype if dtype is None else dtype + else: + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + config = config_class.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + ) + config.quantize = quantize + config.speculator = speculator + + tokenizer = tokenizer_class.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + tokenizer.bos_token_id = config.decoder_start_token_id + + weights_loader = get_loader( + quantize=quantize, model_id=model_id, revision=revision + ) + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + aliases=aliases, + weights_loader=weights_loader, + ) + if config.quantize in ["awq", "exl2", "gptq", "marlin"]: + weights._set_gptq_params(model_id, revision) + + model = model_class(config, weights) + + torch.distributed.barrier(group=self.process_group) + super().__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) + + @classmethod + def fallback( + cls, + model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, speculator: Optional[str] = None, @@ -574,7 +664,11 @@ class Seq2SeqLM(Model): ) tokenizer.bos_token_id = model.config.decoder_start_token_id - super(Seq2SeqLM, self).__init__( + self = cls.__new__( + cls, + ) + super().__init__( + self, model_id=model_id, model=model, tokenizer=tokenizer, @@ -582,16 +676,12 @@ class Seq2SeqLM(Model): dtype=dtype, device=device, ) + return self @property def batch_type(self) -> Type[Seq2SeqLMBatch]: return Seq2SeqLMBatch - def decode(self, decoder_ids: List[int]) -> str: - return self.tokenizer.decode( - decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) - def forward( self, input_ids, diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py deleted file mode 100644 index adef664c..00000000 --- a/server/text_generation_server/models/t5.py +++ /dev/null @@ -1,115 +0,0 @@ -import torch -import torch.distributed - -from typing import List, Optional, Tuple - -from transformers import ( - AutoTokenizer, - AutoConfig, -) - -from text_generation_server.models import Seq2SeqLM -from text_generation_server.models.custom_modeling.t5_modeling import ( - T5ForConditionalGeneration, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - - -class T5Sharded(Seq2SeqLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) - config.quantize = quantize - config.speculator = speculator - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - tokenizer.bos_token_id = config.decoder_start_token_id - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device=device, - dtype=dtype, - process_group=self.process_group, - aliases={ - "shared.weight": [ - "encoder.embed_tokens.weight", - "decoder.embed_tokens.weight", - ] - }, - ) - - model = T5ForConditionalGeneration(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(Seq2SeqLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - def forward( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask: Optional, - encoder_last_hidden_state: Optional, - past_key_values: Optional = None, - ) -> Tuple[ - torch.Tensor, - torch.Tensor, - List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], - ]: - # Model Forward - outputs, speculative_logits = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - encoder_outputs=encoder_last_hidden_state, - past_key_values=past_key_values, - use_cache=True, - ) - - return ( - outputs.logits, - speculative_logits, - outputs.encoder_last_hidden_state, - outputs.past_key_values, - ) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 1cdf37ea..ace48805 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -9,10 +9,11 @@ from typing import Iterable, Optional, Tuple, List, Type, Dict from transformers import PreTrainedTokenizerBase from transformers.image_processing_utils import select_best_resolution from text_generation_server.pb import generate_pb2 -from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch -from text_generation_server.models.flash_mistral import ( - BaseFlashMistral, +from text_generation_server.models.flash_causal_lm import ( + FlashCausalLMBatch, + FlashCausalLM, ) +from transformers import AutoProcessor tracer = trace.get_tracer(__name__) @@ -239,10 +240,35 @@ class VlmCausalLMBatch(FlashCausalLMBatch): return batch -class VlmCausalLM(BaseFlashMistral): +class VlmCausalLM(FlashCausalLM): + def __init__( + self, + model_id: str, + *, + processor_class=AutoProcessor, + processor_kwargs=None, + batch_class=VlmCausalLMBatch, + revision, + trust_remote_code: bool, + **kwargs, + ): + if processor_kwargs is None: + processor_kwargs = {} + self.processor = processor_class.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + **processor_kwargs, + ) + self.batch_class = batch_class + super().__init__(model_id=model_id, **kwargs) + @property def batch_type(self) -> Type[VlmCausalLMBatch]: - return VlmCausalLMBatch + return self.batch_class + + def max_past(self) -> Optional[int]: + return getattr(self.model.text_model, "max_past", None) def forward( self, diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py new file mode 100644 index 00000000..07975bea --- /dev/null +++ b/server/text_generation_server/utils/quantization.py @@ -0,0 +1,119 @@ +from typing import Optional +import os +import json +from dataclasses import dataclass + +from huggingface_hub import hf_hub_download + +from text_generation_server.utils.weights import DefaultWeightsLoader, WeightsLoader + + +@dataclass +class _QuantizerConfig: + bits: int + checkpoint_format: Optional[str] + desc_act: bool + groupsize: int + quant_method: str + sym: bool + + +# We should probably do this with Pytantic JSON deserialization, +# but for now we'll stay close to the old _set_gptq_params. +def _get_quantizer_config(model_id, revision): + bits = 4 + groupsize = -1 + quant_method = "gptq" + checkpoint_format = None + sym = True + desc_act = False + + filename = "config.json" + try: + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download(model_id, filename=filename, revision=revision) + with open(filename, "r") as f: + data = json.load(f) + bits = data["quantization_config"]["bits"] + groupsize = data["quantization_config"]["group_size"] + # Order is important here, desc_act is missing on some real models + quant_method = data["quantization_config"]["quant_method"] + checkpoint_format = data["quantization_config"].get("checkpoint_format") + sym = data["quantization_config"]["sym"] + desc_act = data["quantization_config"]["desc_act"] + except Exception: + filename = "quantize_config.json" + try: + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download( + model_id, filename=filename, revision=revision + ) + with open(filename, "r") as f: + data = json.load(f) + bits = data["bits"] + groupsize = data["group_size"] + sym = data["sym"] + desc_act = data["desc_act"] + if "version" in data and data["version"] == "GEMM": + quant_method = "awq" + except Exception: + filename = "quant_config.json" + try: + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download( + model_id, filename=filename, revision=revision + ) + with open(filename, "r") as f: + data = json.load(f) + bits = data["w_bit"] + groupsize = data["q_group_size"] + desc_act = data["desc_act"] + if "version" in data and data["version"] == "GEMM": + quant_method = "awq" + except Exception: + pass + + return _QuantizerConfig( + bits=bits, + groupsize=groupsize, + quant_method=quant_method, + checkpoint_format=checkpoint_format, + sym=sym, + desc_act=desc_act, + ) + + +def get_loader( + quantize: Optional[str], model_id: str, revision: Optional[str] +) -> WeightsLoader: + quantizer_config = _get_quantizer_config(model_id, revision) + if quantize in {"awq", "gptq"}: + from text_generation_server.layers.gptq import GPTQWeightsLoader + + return GPTQWeightsLoader( + bits=quantizer_config.bits, + desc_act=quantizer_config.desc_act, + groupsize=quantizer_config.groupsize, + quant_method=quantizer_config.quant_method, + quantize=quantize, + sym=quantizer_config.sym, + ) + elif quantize == "exl2": + from text_generation_server.layers.exl2 import Exl2WeightsLoader + + return Exl2WeightsLoader() + elif quantize == "marlin": + from text_generation_server.layers.marlin import MarlinWeightsLoader + + return MarlinWeightsLoader( + bits=quantizer_config.bits, + is_marlin_24=quantizer_config.checkpoint_format == "marlin_24", + ) + else: + return DefaultWeightsLoader() diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 3731fd24..50a9167a 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,13 +1,88 @@ -import os +from abc import ABC, abstractmethod from pathlib import Path from typing import Dict, List, Optional, Union -from safetensors import safe_open, SafetensorError +from safetensors import safe_open import torch -from loguru import logger -from huggingface_hub import hf_hub_download -import json -from text_generation_server.layers.gptq import GPTQParams -from text_generation_server.utils.log import log_once + + +class WeightsLoader(ABC): + """ + Instances of this type implement higher-level weight loading. + + At a low-level, every weight is stored in the Safetensors format. + The interpretation of weights may be different however, for instance + could be packed, quantized weights. Loaders are responsible for + interpreting the raw tensors, sharding tensors in a manner compatible + with the format, etc. + """ + + @abstractmethod + def get_weights_col_packed( + self, + weights: "Weights", + prefix: str, + block_sizes: Union[int, List[int]], + ): + """ + Get the packed weights at the given prefix with column-splitting for + tensor parallelism. This method should be used when multiple different + weights are packed into a tensor, for instance, query/key/value + weights or a gate/up projection. + + The `block_sizes` determines the proportions of the packed tensors. + The columns are split in equally sized blocks when `block_sizes` is an + `int`, or in blocks proportional given to the sizes. For instance + `[2, 1, 1]` will divide an input with dimensionality `1024` in + `[512, 256, 256]`. + """ + ... + + def get_weights_col(self, weights: "Weights", prefix: str): + """ + Get weights at the given prefix and apply column-splitting for tensor + paralllism. + """ + return weights.get_multi_weights_col([prefix], 0) + + @abstractmethod + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): + """ + Get the weights at the given prefixes, column-split them for tensor + parallelim, and then concatenate the weights along the given dimension. + """ + ... + + @abstractmethod + def get_weights_row(self, weights: "Weights", prefix: str): + """ + Get the weights at the given prefix and apply row-splitting for tensor + parallism. + """ + ... + + +class DefaultWeightsLoader(WeightsLoader): + """ + Loader that uses tensors as-is with the exception of applying sharding + and/or concatenation. + """ + + def get_weights_col_packed( + self, + weights: "Weights", + prefix: str, + block_sizes: Union[int, List[int]], + ): + return weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes + ) + + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): + w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] + return torch.cat(w, dim=dim) + + def get_weights_row(self, weights: "Weights", prefix: str): + return weights.get_sharded(f"{prefix}.weight", dim=1) class Weights: @@ -17,6 +92,7 @@ class Weights: device, dtype, process_group, + weights_loader: WeightsLoader, aliases: Optional[Dict[str, List[str]]] = None, prefix: Optional[str] = None, ): @@ -37,6 +113,7 @@ class Weights: self.dtype = dtype self.process_group = process_group self.prefix = prefix + self.weights_loader = weights_loader self._handles = {} def _get_handle(self, filename): @@ -69,6 +146,13 @@ class Weights: slice_ = f.get_slice(tensor_name) return slice_ + def _has_tensor(self, tensor_name: str): + try: + self.get_filename(tensor_name) + except Exception: + return False + return True + def get_shape(self, tensor_name: str): return self._get_slice(tensor_name).get_shape() @@ -181,295 +265,27 @@ class Weights: num_key_value_heads: int, ): return self.get_weights_col_packed( - prefix, quantize, [num_heads, num_key_value_heads, num_key_value_heads] + prefix, [num_heads, num_key_value_heads, num_key_value_heads] ) - def get_weights_col_packed_gate_up(self, prefix: str, quantize: str): - return self.get_weights_col_packed(prefix, quantize, 2) + def get_weights_col_packed_gate_up(self, prefix: str): + return self.get_weights_col_packed(prefix, 2) - def get_weights_col_packed( - self, prefix: str, quantize: str, block_sizes: Union[int, List[int]] - ): + def get_weights_col_packed(self, prefix: str, block_sizes: Union[int, List[int]]): """ - Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being - already alternating Q,K,V within the main tensor. - The columns are split in equally sized blocks when blocks is an `int`, or in blocks proportional given to the sizes. For instance `[2, 1, 1]` will divide an input with dimensionality `1024` in `[512, 256, 256]`. This is convenient for e.g. splitting QKV without knowing the storage details of quantized weights. """ - if quantize in ["gptq", "awq"]: - from text_generation_server.layers.gptq import GPTQWeight - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) + return self.weights_loader.get_weights_col_packed(self, prefix, block_sizes) - try: - qweight = self.get_packed_sharded( - f"{prefix}.qweight", dim=1, block_sizes=block_sizes - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight, make sure the model is already quantized." - ) - scales = self.get_packed_sharded( - f"{prefix}.scales", dim=1, block_sizes=block_sizes - ) - scales = scales.to(dtype=self.dtype) + def get_weights_col(self, prefix: str): + return self.weights_loader.get_weights_col(self, prefix) - gptq_params = self._get_gptq_params() - if can_use_gptq_marlin(gptq_params, quantize): - g_idx = self.get_tensor(f"{prefix}.g_idx") - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - desc_act=gptq_params.desc_act, - groupsize=gptq_params.groupsize, - sym=gptq_params.sym, - sharded_infeatures=False, - ) - - qzeros = self.get_packed_sharded( - f"{prefix}.qzeros", dim=1, block_sizes=block_sizes - ) - if quantize == "gptq" and gptq_params.quant_method == "gptq": - g_idx = self.get_tensor(f"{prefix}.g_idx") - elif quantize == "gptq" and gptq_params.quant_method == "awq": - log_once( - logger.info, "Converting AWQ model to Exllama/GPTQ packing format." - ) - from text_generation_server.layers.awq.conversion_utils import ( - fast_awq_to_gptq, - ) - - qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - g_idx = ( - torch.arange( - qweight.shape[0] * (32 // gptq_params.bits), - device=qweight.device, - ) - // gptq_params.groupsize - ).to(dtype=torch.int32) - else: - g_idx = None - - weight = GPTQWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, - use_exllama=False, - ) - elif quantize == "marlin": - from text_generation_server.layers.marlin import ( - GPTQMarlin24Weight, - MarlinWeight, - repack_gptq_for_marlin, - ) - - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: - B = self.get_packed_sharded( - f"{prefix}.B_24", dim=1, block_sizes=block_sizes - ) - B_meta = self.get_packed_sharded( - f"{prefix}.B_meta", dim=1, block_sizes=block_sizes - ) - s = self.get_packed_sharded( - f"{prefix}.s", dim=1, block_sizes=block_sizes - ) - - gptq_params = self._get_gptq_params() - weight = GPTQMarlin24Weight( - B=B, B_meta=B_meta, s=s, bits=gptq_params.bits - ) - else: - B = self.get_packed_sharded( - f"{prefix}.B", dim=1, block_sizes=block_sizes - ) - s = self.get_packed_sharded( - f"{prefix}.s", dim=1, block_sizes=block_sizes - ) - weight = MarlinWeight(B=B, s=s) - else: - weight = self.get_packed_sharded( - f"{prefix}.weight", dim=0, block_sizes=block_sizes - ) - return weight - - def get_weights_col(self, prefix: str, quantize: str): - if quantize == "exl2": - from text_generation_server.layers.exl2 import Exl2Weight - - try: - q_weight = self.get_tensor(f"{prefix}.q_weight") - except RuntimeError: - raise RuntimeError( - f"Cannot load `exl2`-quantized weight, make sure the model is already quantized." - ) - - q_scale = self.get_tensor(f"{prefix}.q_scale") - q_invperm = self.get_tensor(f"{prefix}.q_invperm") - q_scale_max = self.get_tensor(f"{prefix}.q_scale_max") - q_groups = self.get_tensor(f"{prefix}.q_groups") - - return Exl2Weight( - q_weight=q_weight, - q_scale=q_scale, - q_invperm=q_invperm, - q_scale_max=q_scale_max, - q_groups=q_groups, - ) - - return self.get_multi_weights_col([prefix], quantize, 0) - - def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): - if quantize == "exl2": - raise ValueError("get_multi_weights_col is not supported for exl2") - elif quantize in ["gptq", "awq"]: - from text_generation_server.layers.gptq import GPTQWeight - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) - - try: - qweight = torch.cat( - [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight, make sure the model is already quantized" - ) - - scales = torch.cat( - [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 - ) - - gptq_params = self._get_gptq_params() - if can_use_gptq_marlin(gptq_params, quantize): - w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - desc_act=gptq_params.desc_act, - groupsize=gptq_params.groupsize, - sym=gptq_params.sym, - sharded_infeatures=False, - ) - - qzeros = torch.cat( - [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 - ) - - from text_generation_server.layers.gptq import HAS_EXLLAMA - - use_exllama = ( - gptq_params.bits == 4 - and HAS_EXLLAMA - and quantize == "gptq" - and not gptq_params.desc_act - ) - - if quantize == "gptq" and gptq_params.quant_method == "gptq": - w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - elif quantize == "gptq" and gptq_params.quant_method == "awq": - log_once( - logger.info, "Converting AWQ model to Exllama/GPTQ packing format." - ) - from text_generation_server.layers.awq.conversion_utils import ( - fast_awq_to_gptq, - ) - - qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - if use_exllama: - g_idx = None - else: - g_idx = ( - torch.arange( - qweight.shape[0] * (32 // gptq_params.bits), - device=qweight.device, - ) - // gptq_params.groupsize - ).to(dtype=torch.int32) - else: - g_idx = None - - weight = GPTQWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, - use_exllama=use_exllama, - ) - elif quantize == "marlin": - from text_generation_server.layers.gptq import GPTQWeight - from text_generation_server.layers.marlin import ( - GPTQMarlin24Weight, - MarlinWeight, - ) - - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: - try: - B = torch.cat( - [self.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight, make sure the model is already quantized" - ) - - B_meta = torch.cat( - [self.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1 - ) - - s = torch.cat( - [self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 - ) - - gptq_params = self._get_gptq_params() - weight = GPTQMarlin24Weight( - B=B, B_meta=B_meta, s=s, bits=gptq_params.bits - ) - else: - try: - B = torch.cat( - [self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight, make sure the model is already quantized" - ) - s = torch.cat( - [self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 - ) - - weight = MarlinWeight(B=B, s=s) - - else: - w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] - weight = torch.cat(w, dim=dim) - - return weight + def get_multi_weights_col(self, prefixes: List[str], dim: int): + return self.weights_loader.get_multi_weights_col(self, prefixes, dim) def get_tensor_shard(self, var, dim): world_size = self.process_group.size() @@ -487,318 +303,8 @@ class Weights: tensor = tensor.to(device=self.device) return tensor - def get_multi_weights_row(self, prefix: str, quantize: str): - if quantize == "exl2": - from text_generation_server.layers.exl2 import Exl2Weight - - try: - q_weight = self.get_tensor(f"{prefix}.q_weight") - except RuntimeError: - raise RuntimeError( - f"Cannot load `exl2`-quantized weight, make sure the model is already quantized." - ) - - q_scale = self.get_tensor(f"{prefix}.q_scale") - q_invperm = self.get_tensor(f"{prefix}.q_invperm") - q_scale_max = self.get_tensor(f"{prefix}.q_scale_max") - q_groups = self.get_tensor(f"{prefix}.q_groups") - - return Exl2Weight( - q_weight=q_weight, - q_scale=q_scale, - q_invperm=q_invperm, - q_scale_max=q_scale_max, - q_groups=q_groups, - ) - - elif quantize == "gptq": - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) - - gptq_params = self._get_gptq_params() - if can_use_gptq_marlin(gptq_params, quantize): - log_once(logger.info, "Using GPTQ-Marlin kernels") - try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - if gptq_params.desc_act or gptq_params.groupsize == -1: - scales = self.get_tensor(f"{prefix}.scales") - else: - scales = self.get_sharded(f"{prefix}.scales", dim=0) - - sharded_in_features = self.process_group.size() > 1 - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - desc_act=gptq_params.desc_act, - groupsize=gptq_params.groupsize, - sym=gptq_params.sym, - sharded_infeatures=sharded_in_features, - ) - - use_exllama = True - if gptq_params.bits != 4: - use_exllama = False - - if gptq_params.desc_act: - log_once(logger.warning, "Disabling exllama because desc_act=True") - use_exllama = False - - try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" - ) - - if gptq_params.quant_method == "gptq": - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - elif gptq_params.quant_method == "awq": - g_idx = None - - if self.process_group.size() > 1: - if g_idx is not None: - if ( - not torch.equal( - g_idx.cpu(), - torch.tensor( - [ - i // gptq_params.groupsize - for i in range(g_idx.shape[0]) - ], - dtype=torch.int32, - ), - ) - and not (g_idx == 0).all() - ): - # Exllama implementation does not support row tensor parallelism with act-order, as - # it would require to reorder input activations that are split unto several GPUs - use_exllama = False - - from text_generation_server.layers.gptq import ( - HAS_EXLLAMA, - CAN_EXLLAMA, - GPTQWeight, - ) - - if use_exllama: - if not HAS_EXLLAMA: - if CAN_EXLLAMA: - log_once( - logger.warning, - "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True", - ) - use_exllama = False - else: - log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") - - if use_exllama and gptq_params.groupsize != -1: - qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) - scales = self.get_sharded(f"{prefix}.scales", dim=0) - else: - qzeros = self.get_tensor(f"{prefix}.qzeros") - scales = self.get_tensor(f"{prefix}.scales") - - if use_exllama and g_idx is not None: - g_idx = g_idx - g_idx[0] - - if gptq_params.quant_method == "awq": - log_once( - logger.info, "Converting AWQ model to Exllama/GPTQ packing format." - ) - from text_generation_server.layers.awq.conversion_utils import ( - fast_awq_to_gptq, - ) - - qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - if use_exllama: - g_idx = None - else: - g_idx = ( - torch.arange( - qweight.shape[0] * (32 // gptq_params.bits), - device=qweight.device, - ) - // gptq_params.groupsize - ).to(dtype=torch.int32) - - weight = GPTQWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, - use_exllama=use_exllama, - ) - elif quantize == "awq": - from text_generation_server.layers.gptq import GPTQWeight - - gptq_params = self._get_gptq_params() - - try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `awq` weight, make sure the model is already quantized" - ) - - qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) - scales = self.get_sharded(f"{prefix}.scales", dim=0) - g_idx = None - use_exllama = False - - weight = GPTQWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, - use_exllama=use_exllama, - ) - elif quantize == "marlin": - from text_generation_server.layers.gptq import GPTQWeight - from text_generation_server.layers.marlin import ( - GPTQMarlin24Weight, - MarlinWeight, - ) - - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: - try: - B = self.get_sharded(f"{prefix}.B_24", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." - ) - - B_meta = self.get_sharded(f"{prefix}.B_meta", dim=0) - num_groups = self._get_slice(f"{prefix}.s").get_shape()[0] - if num_groups == 1: - # The number of groups is 1 when groupsize == -1. share - # scales between all shards in this case. - s = self.get_tensor(f"{prefix}.s") - else: - s = self.get_sharded(f"{prefix}.s", dim=0) - - gptq_params = self._get_gptq_params() - weight = GPTQMarlin24Weight( - B=B, B_meta=B_meta, s=s, bits=gptq_params.bits - ) - else: - try: - B = self.get_sharded(f"{prefix}.B", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized." - ) - - num_groups = self._get_slice(f"{prefix}.s").get_shape()[0] - if num_groups == 1: - # The number of groups is 1 when groupsize == -1. share - # scales between all shards in this case. - s = self.get_tensor(f"{prefix}.s") - else: - s = self.get_sharded(f"{prefix}.s", dim=0) - weight = MarlinWeight(B=B, s=s) - else: - weight = self.get_sharded(f"{prefix}.weight", dim=1) - return weight - - def _get_gptq_params(self) -> GPTQParams: - try: - bits = self.get_tensor("gptq_bits").item() - groupsize = self.get_tensor("gptq_groupsize").item() - checkpoint_format = getattr(self, "gptq_checkpoint_format", None) - desc_act = False - sym = False - quant_method = "gptq" - except (SafetensorError, RuntimeError) as e: - try: - bits = self.gptq_bits - groupsize = self.gptq_groupsize - checkpoint_format = getattr(self, "gptq_checkpoint_format", None) - desc_act = getattr(self, "gptq_desc_act", False) - quant_method = getattr(self, "quant_method", "gptq") - sym = getattr(self, "sym", True) - except Exception: - raise e - - return GPTQParams( - bits=bits, - checkpoint_format=checkpoint_format, - desc_act=desc_act, - groupsize=groupsize, - quant_method=quant_method, - sym=sym, - ) - - def _set_gptq_params(self, model_id, revision): - filename = "config.json" - try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download( - model_id, filename=filename, revision=revision - ) - with open(filename, "r") as f: - data = json.load(f) - self.gptq_bits = data["quantization_config"]["bits"] - self.gptq_groupsize = data["quantization_config"]["group_size"] - # Order is important here, desc_act is missing on some real models - self.quant_method = data["quantization_config"]["quant_method"] - self.gptq_checkpoint_format = data["quantization_config"].get( - "checkpoint_format" - ) - self.gptq_sym = data["quantization_config"]["sym"] - self.gptq_desc_act = data["quantization_config"]["desc_act"] - except Exception: - filename = "quantize_config.json" - try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download( - model_id, filename=filename, revision=revision - ) - with open(filename, "r") as f: - data = json.load(f) - self.gptq_bits = data["bits"] - self.gptq_groupsize = data["group_size"] - self.gptq_sym = data["sym"] - self.gptq_desc_act = data["desc_act"] - if "version" in data and data["version"] == "GEMM": - self.quant_method = "awq" - except Exception: - filename = "quant_config.json" - try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download( - model_id, filename=filename, revision=revision - ) - with open(filename, "r") as f: - data = json.load(f) - self.gptq_bits = data["w_bit"] - self.gptq_groupsize = data["q_group_size"] - self.gptq_desc_act = data["desc_act"] - if "version" in data and data["version"] == "GEMM": - self.quant_method = "awq" - except Exception: - pass + def get_weights_row(self, prefix: str): + return self.weights_loader.get_weights_row(self, prefix) def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]: diff --git a/update_doc.py b/update_doc.py index 1ff94a2c..bfa7e4e9 100644 --- a/update_doc.py +++ b/update_doc.py @@ -155,7 +155,7 @@ def check_openapi(check: bool): filename, ], capture_output=True, - ).stdout.decode() + ).stdout.decode("utf-8") os.remove(tmp_filename) if diff: @@ -164,11 +164,27 @@ def check_openapi(check: bool): "OpenAPI documentation is not up-to-date, run `python update_doc.py` in order to update it" ) - return True else: os.rename(tmp_filename, filename) print("OpenAPI documentation updated.") - return True + errors = subprocess.run( + [ + "swagger-cli", + # allow for trailing whitespace since it's not significant + # and the precommit hook will remove it + "validate", + filename, + ], + capture_output=True, + ).stderr.decode("utf-8") + # The openapi specs fails on `exclusive_minimum` which is expected to be a boolean where + # utoipa outputs a value instead: https://github.com/juhaku/utoipa/issues/969 + if not errors.startswith("Swagger schema validation failed."): + print(errors) + raise Exception( + f"OpenAPI documentation is invalid, `swagger-cli validate` showed some error:\n {errors}" + ) + return True def main():