Merge branch 'main' into feature/usage-stats

This commit is contained in:
ErikKaumk 2024-07-15 10:08:10 +02:00
commit 81c9ad7073
109 changed files with 4328 additions and 3761 deletions

View File

@ -30,6 +30,10 @@ jobs:
id: install-router id: install-router
run: cargo install --path router/ run: cargo install --path router/
- uses: actions/setup-node@v4
with:
node-version: 22
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v2 uses: actions/setup-python@v2
with: with:
@ -37,4 +41,5 @@ jobs:
- name: Check that documentation is up-to-date - name: Check that documentation is up-to-date
run: | run: |
npm install -g swagger-cli
python update_doc.py --check python update_doc.py --check

View File

@ -11,6 +11,11 @@ on:
# - rocm # - rocm
# - intel # - intel
required: true required: true
release-tests:
description: "Run release integration tests"
required: true
default: false
type: boolean
jobs: jobs:
build-and-push: build-and-push:
@ -23,7 +28,7 @@ jobs:
group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }} group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true cancel-in-progress: true
# TODO see with @Glegendre to get CPU runner here instead # TODO see with @Glegendre to get CPU runner here instead
runs-on: [self-hosted, nvidia-gpu , multi-gpu, 4-a10, ci] runs-on: [self-hosted, intel-cpu, 32-cpu, 256-ram, ci]
permissions: permissions:
contents: write contents: write
packages: write packages: write
@ -131,8 +136,8 @@ jobs:
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }} tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }} labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
cache-from: type=registry,ref=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
cache-to: type=registry,ref=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
- name: Final - name: Final
id: final id: final
run: | run: |
@ -148,7 +153,7 @@ jobs:
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"] runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"]
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest' if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
env: env:
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main') && '--release' || '' }} PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '' }}
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4

View File

@ -20,7 +20,14 @@ on:
- "Dockerfile_amd" - "Dockerfile_amd"
- "Dockerfile_intel" - "Dockerfile_intel"
branches: branches:
- 'main' - "main"
workflow_dispatch:
inputs:
release-tests:
description: "Run release integration tests"
required: true
default: false
type: boolean
jobs: jobs:
build: build:
@ -33,4 +40,6 @@ jobs:
uses: ./.github/workflows/build.yaml # calls the one above ^ uses: ./.github/workflows/build.yaml # calls the one above ^
with: with:
hardware: ${{ matrix.hardware }} hardware: ${{ matrix.hardware }}
# https://github.com/actions/runner/issues/2206
release-tests: ${{ inputs.release-tests == true }}
secrets: inherit secrets: inherit

28
Cargo.lock generated
View File

@ -1935,17 +1935,6 @@ version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
[[package]]
name = "metrics"
version = "0.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fde3af1a009ed76a778cb84fdef9e7dbbdf5775ae3e4cc1f434a6a307f6f76c5"
dependencies = [
"ahash",
"metrics-macros",
"portable-atomic",
]
[[package]] [[package]]
name = "metrics" name = "metrics"
version = "0.23.0" version = "0.23.0"
@ -1969,7 +1958,7 @@ dependencies = [
"hyper-util", "hyper-util",
"indexmap 2.2.6", "indexmap 2.2.6",
"ipnet", "ipnet",
"metrics 0.23.0", "metrics",
"metrics-util", "metrics-util",
"quanta", "quanta",
"thiserror", "thiserror",
@ -1977,17 +1966,6 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "metrics-macros"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38b4faf00617defe497754acde3024865bc143d44a86799b24e191ecff91354f"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.68",
]
[[package]] [[package]]
name = "metrics-util" name = "metrics-util"
version = "0.17.0" version = "0.17.0"
@ -1997,7 +1975,7 @@ dependencies = [
"crossbeam-epoch", "crossbeam-epoch",
"crossbeam-utils", "crossbeam-utils",
"hashbrown 0.14.5", "hashbrown 0.14.5",
"metrics 0.23.0", "metrics",
"num_cpus", "num_cpus",
"quanta", "quanta",
"sketches-ddsketch", "sketches-ddsketch",
@ -3835,7 +3813,7 @@ dependencies = [
"init-tracing-opentelemetry", "init-tracing-opentelemetry",
"itertools 0.10.5", "itertools 0.10.5",
"jsonschema", "jsonschema",
"metrics 0.21.1", "metrics",
"metrics-exporter-prometheus", "metrics-exporter-prometheus",
"minijinja", "minijinja",
"minijinja-contrib", "minijinja-contrib",

View File

@ -40,7 +40,9 @@ RUN cargo build --profile release-opt
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS pytorch-install FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS pytorch-install
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
ARG PYTORCH_VERSION=2.3.0 ARG PYTORCH_VERSION=2.3.0
ARG PYTHON_VERSION=3.10 ARG PYTHON_VERSION=3.10
# Keep in sync with `server/pyproject.toml # Keep in sync with `server/pyproject.toml
ARG CUDA_VERSION=12.1 ARG CUDA_VERSION=12.1
@ -241,7 +243,10 @@ COPY server/Makefile server/Makefile
RUN cd server && \ RUN cd server && \
make gen-server && \ make gen-server && \
pip install -r requirements_cuda.txt && \ pip install -r requirements_cuda.txt && \
pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir && \
pip install nvidia-nccl-cu12==2.22.3
ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2
# Deps before the binaries # Deps before the binaries
# The binaries change on every build given we burn the SHA into them # The binaries change on every build given we burn the SHA into them

View File

@ -20,19 +20,20 @@ to power Hugging Chat, the Inference API and Inference Endpoint.
## Table of contents ## Table of contents
- [Get Started](#get-started) - [Get Started](#get-started)
- [API Documentation](#api-documentation) - [Docker](#docker)
- [API documentation](#api-documentation)
- [Using a private or gated model](#using-a-private-or-gated-model) - [Using a private or gated model](#using-a-private-or-gated-model)
- [A note on Shared Memory](#a-note-on-shared-memory-shm) - [A note on Shared Memory (shm)](#a-note-on-shared-memory-shm)
- [Distributed Tracing](#distributed-tracing) - [Distributed Tracing](#distributed-tracing)
- [Local Install](#local-install) - [Architecture](#architecture)
- [CUDA Kernels](#cuda-kernels) - [Local install](#local-install)
- [Optimized architectures](#optimized-architectures) - [Optimized architectures](#optimized-architectures)
- [Run Mistral](#run-a-model) - [Run locally](#run-locally)
- [Run](#run) - [Run](#run)
- [Quantization](#quantization) - [Quantization](#quantization)
- [Develop](#develop) - [Develop](#develop)
- [Testing](#testing) - [Testing](#testing)
Text Generation Inference (TGI) is a toolkit for deploying and serving Large Language Models (LLMs). TGI enables high-performance text generation for the most popular open-source LLMs, including Llama, Falcon, StarCoder, BLOOM, GPT-NeoX, and [more](https://huggingface.co/docs/text-generation-inference/supported_models). TGI implements many features, such as: Text Generation Inference (TGI) is a toolkit for deploying and serving Large Language Models (LLMs). TGI enables high-performance text generation for the most popular open-source LLMs, including Llama, Falcon, StarCoder, BLOOM, GPT-NeoX, and [more](https://huggingface.co/docs/text-generation-inference/supported_models). TGI implements many features, such as:

View File

@ -61,7 +61,7 @@ class ChoiceDeltaToolCall(BaseModel):
class ChoiceDelta(BaseModel): class ChoiceDelta(BaseModel):
role: str role: str
content: Optional[str] = None content: Optional[str] = None
tool_calls: Optional[ChoiceDeltaToolCall] tool_calls: Optional[ChoiceDeltaToolCall] = None
class Choice(BaseModel): class Choice(BaseModel):

View File

@ -11,6 +11,8 @@
title: Using TGI with Intel Gaudi title: Using TGI with Intel Gaudi
- local: installation_inferentia - local: installation_inferentia
title: Using TGI with AWS Inferentia title: Using TGI with AWS Inferentia
- local: installation_intel
title: Using TGI with Intel GPUs
- local: installation - local: installation
title: Installation from source title: Installation from source
- local: supported_models - local: supported_models

View File

@ -103,6 +103,7 @@ Several variants of the model server exist that are actively supported by Huggin
- By default, the model server will attempt building [a server optimized for Nvidia GPUs with CUDA](https://huggingface.co/docs/text-generation-inference/installation_nvidia). The code for this version is hosted in the [main TGI repository](https://github.com/huggingface/text-generation-inference). - By default, the model server will attempt building [a server optimized for Nvidia GPUs with CUDA](https://huggingface.co/docs/text-generation-inference/installation_nvidia). The code for this version is hosted in the [main TGI repository](https://github.com/huggingface/text-generation-inference).
- A [version optimized for AMD with ROCm](https://huggingface.co/docs/text-generation-inference/installation_amd) is hosted in the main TGI repository. Some model features differ. - A [version optimized for AMD with ROCm](https://huggingface.co/docs/text-generation-inference/installation_amd) is hosted in the main TGI repository. Some model features differ.
- A [version optimized for Intel GPUs](https://huggingface.co/docs/text-generation-inference/installation_intel) is hosted in the main TGI repository. Some model features differ.
- The [version for Intel Gaudi](https://huggingface.co/docs/text-generation-inference/installation_gaudi) is maintained on a forked repository, often resynchronized with the main [TGI repository](https://github.com/huggingface/tgi-gaudi). - The [version for Intel Gaudi](https://huggingface.co/docs/text-generation-inference/installation_gaudi) is maintained on a forked repository, often resynchronized with the main [TGI repository](https://github.com/huggingface/tgi-gaudi).
- A [version for Neuron (AWS Inferentia2)](https://huggingface.co/docs/text-generation-inference/installation_inferentia) is maintained as part of [Optimum Neuron](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference). - A [version for Neuron (AWS Inferentia2)](https://huggingface.co/docs/text-generation-inference/installation_inferentia) is maintained as part of [Optimum Neuron](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference).
- A version for Google TPUs is maintained as part of [Optimum TPU](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference). - A version for Google TPUs is maintained as part of [Optimum TPU](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference).

View File

@ -0,0 +1,19 @@
# Using TGI with Intel GPUs
TGI optimized models are supported on Intel Data Center GPU [Max1100](https://www.intel.com/content/www/us/en/products/sku/232876/intel-data-center-gpu-max-1100/specifications.html), [Max1550](https://www.intel.com/content/www/us/en/products/sku/232873/intel-data-center-gpu-max-1550/specifications.html), the recommended usage is through Docker.
On a server powered by Intel GPUs, TGI can be launched with the following command:
```bash
model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:latest-intel \
--model-id $model --cuda-graphs 0
```
The launched TGI server can then be queried from clients, make sure to check out the [Consuming TGI](./basic_tutorials/consuming_tgi) guide.

View File

@ -17,7 +17,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
### Supported hardware ### Supported hardware
TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on. TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Intel GPUs](./installation_intel), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on.
## Consuming TGI ## Consuming TGI

View File

@ -10,6 +10,7 @@ Text Generation Inference enables serving optimized models on specific hardware
- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) - [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) - [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
- [Gemma](https://huggingface.co/google/gemma-7b) - [Gemma](https://huggingface.co/google/gemma-7b)
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
- [Gemma2](https://huggingface.co/google/gemma2-9b) - [Gemma2](https://huggingface.co/google/gemma2-9b)
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus) - [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct) - [Dbrx](https://huggingface.co/databricks/dbrx-instruct)

View File

@ -5,85 +5,80 @@
"generated_tokens": 10, "generated_tokens": 10,
"prefill": [ "prefill": [
{ {
"id": 1, "id": 2323,
"logprob": null, "logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.7890625,
"text": "Test" "text": "Test"
}, },
{ {
"id": 2009, "id": 1715,
"logprob": -9.625, "logprob": -11.34375,
"text": "request" "text": " request"
} }
], ],
"seed": null, "seed": null,
"tokens": [ "tokens": [
{ {
"id": 13, "id": 198,
"logprob": -2.3359375, "logprob": -2.5742188,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 3057, "id": 262,
"logprob": -1.8779297, "logprob": -1.6230469,
"special": false, "special": false,
"text": "Test" "text": " "
}, },
{ {
"id": 2009, "id": 3270,
"logprob": -1.2744141, "logprob": -2.046875,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1425781,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.9238281,
"special": false, "special": false,
"text": " request" "text": " request"
}, },
{ {
"id": 13, "id": 13204,
"logprob": -1.6933594, "logprob": -0.076660156,
"special": false, "special": false,
"text": "\n" "text": ".method"
}, },
{ {
"id": 3057, "id": 624,
"logprob": -1.4648438, "logprob": -0.021987915,
"special": false, "special": false,
"text": "Test" "text": " =="
}, },
{ {
"id": 2009, "id": 364,
"logprob": -0.15600586, "logprob": -0.39208984,
"special": false, "special": false,
"text": " request" "text": " '"
}, },
{ {
"id": 13, "id": 3019,
"logprob": -0.8027344, "logprob": -0.10821533,
"special": false, "special": false,
"text": "\n" "text": "POST"
},
{
"id": 3057,
"logprob": -0.23022461,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.0069885254,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.02218628,
"special": false,
"text": "\n"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\nTest request\nTest request\nTest request\n" "generated_text": "\n \"\"\"\n if request.method == 'POST"
} }

View File

@ -5,85 +5,80 @@
"generated_tokens": 10, "generated_tokens": 10,
"prefill": [ "prefill": [
{ {
"id": 1, "id": 2323,
"logprob": null, "logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.84375,
"text": "Test" "text": "Test"
}, },
{ {
"id": 2009, "id": 1715,
"logprob": -9.6015625, "logprob": -11.34375,
"text": "request" "text": " request"
} }
], ],
"seed": 0, "seed": 0,
"tokens": [ "tokens": [
{ {
"id": 29899, "id": 13,
"logprob": -1.5625, "logprob": -2.2539062,
"special": false, "special": false,
"text": "-" "text": "."
}, },
{ {
"id": 1454, "id": 578,
"logprob": -0.20410156, "logprob": -0.15563965,
"special": false, "special": false,
"text": "for" "text": " The"
}, },
{ {
"id": 29899, "id": 3622,
"logprob": -0.8203125,
"special": false,
"text": " server"
},
{
"id": 706,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": "-" "text": " has"
}, },
{ {
"id": 9342, "id": 539,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": "comment" "text": " not"
}, },
{ {
"id": 29901, "id": 3686,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": ":" "text": " yet"
}, },
{ {
"id": 396, "id": 3288,
"logprob": -0.27685547,
"special": false,
"text": " #"
},
{
"id": 29906,
"logprob": -0.4970703,
"special": false,
"text": "2"
},
{
"id": 29900,
"logprob": -0.80615234,
"special": false,
"text": "0"
},
{
"id": 29896,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": "1" "text": " sent"
}, },
{ {
"id": 29955, "id": 904,
"logprob": -1.0751953, "logprob": 0.0,
"special": false, "special": false,
"text": "7" "text": " any"
},
{
"id": 828,
"logprob": 0.0,
"special": false,
"text": " data"
},
{
"id": 382,
"logprob": -1.5517578,
"special": false,
"text": ".\n\n"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "Test request-for-comment: #2017" "generated_text": "Test request. The server has not yet sent any data.\n\n"
} }

View File

@ -6,87 +6,82 @@
"generated_tokens": 10, "generated_tokens": 10,
"prefill": [ "prefill": [
{ {
"id": 1, "id": 2323,
"logprob": null, "logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.828125,
"text": "Test" "text": "Test"
}, },
{ {
"id": 2009, "id": 1715,
"logprob": -9.609375, "logprob": -11.34375,
"text": "request" "text": " request"
} }
], ],
"seed": null, "seed": null,
"tokens": [ "tokens": [
{ {
"id": 13, "id": 198,
"logprob": -2.3300781, "logprob": -2.5742188,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 3057, "id": 262,
"logprob": -1.8740234, "logprob": -1.6220703,
"special": false, "special": false,
"text": "Test" "text": " "
}, },
{ {
"id": 2009, "id": 3270,
"logprob": -1.2646484, "logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false, "special": false,
"text": " request" "text": " request"
}, },
{ {
"id": 13, "id": 13204,
"logprob": -1.7158203, "logprob": -0.07672119,
"special": false, "special": false,
"text": "\n" "text": ".method"
}, },
{ {
"id": 3057, "id": 624,
"logprob": -1.4667969, "logprob": -0.021987915,
"special": false, "special": false,
"text": "Test" "text": " =="
}, },
{ {
"id": 2009, "id": 364,
"logprob": -0.15344238, "logprob": -0.39208984,
"special": false, "special": false,
"text": " request" "text": " '"
}, },
{ {
"id": 13, "id": 3019,
"logprob": -0.81591797, "logprob": -0.10638428,
"special": false, "special": false,
"text": "\n" "text": "POST"
},
{
"id": 3057,
"logprob": -0.22973633,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.007045746,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.021957397,
"special": false,
"text": "\n"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\nTest request\nTest request\nTest request\n" "generated_text": "\n \"\"\"\n if request.method == 'POST"
}, },
{ {
"details": { "details": {
@ -95,87 +90,82 @@
"generated_tokens": 10, "generated_tokens": 10,
"prefill": [ "prefill": [
{ {
"id": 1, "id": 2323,
"logprob": null, "logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.84375,
"text": "Test" "text": "Test"
}, },
{ {
"id": 2009, "id": 1715,
"logprob": -9.59375, "logprob": -11.34375,
"text": "request" "text": " request"
} }
], ],
"seed": null, "seed": null,
"tokens": [ "tokens": [
{ {
"id": 13, "id": 198,
"logprob": -2.3378906, "logprob": -2.5742188,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 3057, "id": 262,
"logprob": -1.8779297, "logprob": -1.6220703,
"special": false, "special": false,
"text": "Test" "text": " "
}, },
{ {
"id": 2009, "id": 3270,
"logprob": -1.2636719, "logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false, "special": false,
"text": " request" "text": " request"
}, },
{ {
"id": 13, "id": 13204,
"logprob": -1.6992188, "logprob": -0.07672119,
"special": false, "special": false,
"text": "\n" "text": ".method"
}, },
{ {
"id": 3057, "id": 624,
"logprob": -1.4589844, "logprob": -0.021987915,
"special": false, "special": false,
"text": "Test" "text": " =="
}, },
{ {
"id": 2009, "id": 364,
"logprob": -0.15344238, "logprob": -0.39208984,
"special": false, "special": false,
"text": " request" "text": " '"
}, },
{ {
"id": 13, "id": 3019,
"logprob": -0.79052734, "logprob": -0.10638428,
"special": false, "special": false,
"text": "\n" "text": "POST"
},
{
"id": 3057,
"logprob": -0.22937012,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.007041931,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.022140503,
"special": false,
"text": "\n"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\nTest request\nTest request\nTest request\n" "generated_text": "\n \"\"\"\n if request.method == 'POST"
}, },
{ {
"details": { "details": {
@ -184,87 +174,82 @@
"generated_tokens": 10, "generated_tokens": 10,
"prefill": [ "prefill": [
{ {
"id": 1, "id": 2323,
"logprob": null, "logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.84375,
"text": "Test" "text": "Test"
}, },
{ {
"id": 2009, "id": 1715,
"logprob": -9.609375, "logprob": -11.34375,
"text": "request" "text": " request"
} }
], ],
"seed": null, "seed": null,
"tokens": [ "tokens": [
{ {
"id": 13, "id": 198,
"logprob": -2.3261719, "logprob": -2.5742188,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 3057, "id": 262,
"logprob": -1.8730469, "logprob": -1.6220703,
"special": false, "special": false,
"text": "Test" "text": " "
}, },
{ {
"id": 2009, "id": 3270,
"logprob": -1.2587891, "logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false, "special": false,
"text": " request" "text": " request"
}, },
{ {
"id": 13, "id": 13204,
"logprob": -1.6894531, "logprob": -0.07672119,
"special": false, "special": false,
"text": "\n" "text": ".method"
}, },
{ {
"id": 3057, "id": 624,
"logprob": -1.46875, "logprob": -0.021987915,
"special": false, "special": false,
"text": "Test" "text": " =="
}, },
{ {
"id": 2009, "id": 364,
"logprob": -0.1541748, "logprob": -0.39208984,
"special": false, "special": false,
"text": " request" "text": " '"
}, },
{ {
"id": 13, "id": 3019,
"logprob": -0.80322266, "logprob": -0.10638428,
"special": false, "special": false,
"text": "\n" "text": "POST"
},
{
"id": 3057,
"logprob": -0.22912598,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.0070495605,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.021606445,
"special": false,
"text": "\n"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\nTest request\nTest request\nTest request\n" "generated_text": "\n \"\"\"\n if request.method == 'POST"
}, },
{ {
"details": { "details": {
@ -273,86 +258,81 @@
"generated_tokens": 10, "generated_tokens": 10,
"prefill": [ "prefill": [
{ {
"id": 1, "id": 2323,
"logprob": null, "logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.84375,
"text": "Test" "text": "Test"
}, },
{ {
"id": 2009, "id": 1715,
"logprob": -9.6015625, "logprob": -11.34375,
"text": "request" "text": " request"
} }
], ],
"seed": null, "seed": null,
"tokens": [ "tokens": [
{ {
"id": 13, "id": 198,
"logprob": -2.3320312, "logprob": -2.5742188,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 3057, "id": 262,
"logprob": -1.875, "logprob": -1.6220703,
"special": false, "special": false,
"text": "Test" "text": " "
}, },
{ {
"id": 2009, "id": 3270,
"logprob": -1.2646484, "logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false, "special": false,
"text": " request" "text": " request"
}, },
{ {
"id": 13, "id": 13204,
"logprob": -1.6884766, "logprob": -0.07672119,
"special": false, "special": false,
"text": "\n" "text": ".method"
}, },
{ {
"id": 3057, "id": 624,
"logprob": -1.4589844, "logprob": -0.021987915,
"special": false, "special": false,
"text": "Test" "text": " =="
}, },
{ {
"id": 2009, "id": 364,
"logprob": -0.15185547, "logprob": -0.39208984,
"special": false, "special": false,
"text": " request" "text": " '"
}, },
{ {
"id": 13, "id": 3019,
"logprob": -0.79833984, "logprob": -0.10638428,
"special": false, "special": false,
"text": "\n" "text": "POST"
},
{
"id": 3057,
"logprob": -0.22827148,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.006996155,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.021560669,
"special": false,
"text": "\n"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\nTest request\nTest request\nTest request\n" "generated_text": "\n \"\"\"\n if request.method == 'POST"
} }
] ]

View File

@ -1,130 +1,124 @@
{ {
"details": { "details": {
"best_of_sequences": null, "best_of_sequences": null,
"finish_reason": "length", "finish_reason": "eos_token",
"generated_tokens": 20, "generated_tokens": 19,
"prefill": [], "prefill": [],
"seed": null, "seed": null,
"tokens": [ "tokens": [
{ {
"id": 415, "id": 415,
"logprob": -0.039886475, "logprob": -0.03665161,
"special": false, "special": false,
"text": " The" "text": " The"
}, },
{ {
"id": 12072, "id": 12072,
"logprob": -0.1430664, "logprob": -0.13549805,
"special": false, "special": false,
"text": " cow" "text": " cow"
}, },
{ {
"id": 349, "id": 349,
"logprob": -0.056488037, "logprob": -0.05819702,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 6328, "id": 6328,
"logprob": -0.6855469, "logprob": -0.6826172,
"special": false, "special": false,
"text": " standing" "text": " standing"
}, },
{ {
"id": 356, "id": 356,
"logprob": -0.1685791, "logprob": -0.1607666,
"special": false, "special": false,
"text": " on" "text": " on"
}, },
{ {
"id": 272, "id": 272,
"logprob": -0.50097656, "logprob": -0.5073242,
"special": false, "special": false,
"text": " the" "text": " the"
}, },
{ {
"id": 10305, "id": 10305,
"logprob": -0.017303467, "logprob": -0.016418457,
"special": false, "special": false,
"text": " beach" "text": " beach"
}, },
{ {
"id": 304, "id": 304,
"logprob": -1.3564453, "logprob": -1.3916016,
"special": false, "special": false,
"text": " and" "text": " and"
}, },
{ {
"id": 272, "id": 272,
"logprob": -0.017868042, "logprob": -0.020217896,
"special": false, "special": false,
"text": " the" "text": " the"
}, },
{ {
"id": 13088, "id": 13088,
"logprob": -0.0027103424, "logprob": -0.0028133392,
"special": false, "special": false,
"text": " chicken" "text": " chicken"
}, },
{ {
"id": 349, "id": 349,
"logprob": -0.003156662, "logprob": -0.003145218,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 6398, "id": 6398,
"logprob": -0.37304688, "logprob": -0.37060547,
"special": false, "special": false,
"text": " sitting" "text": " sitting"
}, },
{ {
"id": 356, "id": 356,
"logprob": -0.034576416, "logprob": -0.034851074,
"special": false, "special": false,
"text": " on" "text": " on"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.29418945, "logprob": -0.2878418,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 17972, "id": 17972,
"logprob": -0.042877197, "logprob": -0.046051025,
"special": false, "special": false,
"text": " pile" "text": " pile"
}, },
{ {
"id": 302, "id": 302,
"logprob": -0.00028443336, "logprob": -0.00028848648,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 2445, "id": 2445,
"logprob": -0.023223877, "logprob": -0.025772095,
"special": false, "special": false,
"text": " money" "text": " money"
}, },
{ {
"id": 28723, "id": 28723,
"logprob": -0.018157959, "logprob": -0.018127441,
"special": false, "special": false,
"text": "." "text": "."
}, },
{ {
"id": 32002, "id": 32002,
"logprob": -0.00018393993, "logprob": -0.00019824505,
"special": true, "special": true,
"text": "<end_of_utterance>" "text": "<end_of_utterance>"
},
{
"id": 2,
"logprob": -1.1920929e-07,
"special": true,
"text": "</s>"
} }
], ],
"top_tokens": null "top_tokens": null

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.8359375,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.6171875,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.3417969,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.8730469,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -1.2626953,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -1.7060547,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.4482422,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.15246582,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.796875,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.22766113,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.007045746,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.021759033,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": "\nTest request\nTest request\nTest request\n"
}

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.7890625,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.625,
"text": "request"
}
],
"seed": 0,
"tokens": [
{
"id": 29899,
"logprob": -1.4980469,
"special": false,
"text": "-"
},
{
"id": 1454,
"logprob": -0.19433594,
"special": false,
"text": "for"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 9342,
"logprob": 0.0,
"special": false,
"text": "comment"
},
{
"id": 29901,
"logprob": 0.0,
"special": false,
"text": ":"
},
{
"id": 396,
"logprob": -0.27392578,
"special": false,
"text": " #"
},
{
"id": 29906,
"logprob": -0.49389648,
"special": false,
"text": "2"
},
{
"id": 29900,
"logprob": -0.81103516,
"special": false,
"text": "0"
},
{
"id": 29896,
"logprob": 0.0,
"special": false,
"text": "1"
},
{
"id": 29955,
"logprob": -1.0800781,
"special": false,
"text": "7"
}
],
"top_tokens": null
},
"generated_text": "Test request-for-comment: #2017"
}

View File

@ -0,0 +1,358 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.8828125,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.5859375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.3359375,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.8623047,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -1.2451172,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -1.6923828,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.4492188,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.15197754,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.8022461,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.22583008,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.007095337,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.021652222,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": "\nTest request\nTest request\nTest request\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.796875,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.625,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.3476562,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.8789062,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -1.2734375,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -1.703125,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.4677734,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.15454102,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.7973633,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.23278809,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.006980896,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.022033691,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": "\nTest request\nTest request\nTest request\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.9296875,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.5703125,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.3203125,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.8486328,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -1.2480469,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -1.7060547,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.4511719,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.1529541,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.81396484,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.22180176,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.007133484,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.021835327,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": "\nTest request\nTest request\nTest request\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.84375,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.6171875,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.3261719,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.8691406,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -1.2597656,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -1.7070312,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.4550781,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.1538086,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.79345703,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.22924805,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.0070266724,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.021942139,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": "\nTest request\nTest request\nTest request\n"
}
]

View File

@ -3,7 +3,9 @@ import pytest
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_llama_gptq_handle(launcher): def flash_llama_gptq_handle(launcher):
with launcher("huggingface/llama-7b-gptq", num_shard=2, quantize="gptq") as handle: with launcher(
"astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="gptq"
) as handle:
yield handle yield handle

View File

@ -57,7 +57,7 @@ async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot)
response.generated_text response.generated_text
== " The cow is standing on the beach and the chicken is sitting on a pile of money." == " The cow is standing on the beach and the chicken is sitting on a pile of money."
), f"{repr(response.generated_text)}" ), f"{repr(response.generated_text)}"
assert response.details.generated_tokens == 20 assert response.details.generated_tokens == 19
assert response == response_snapshot assert response == response_snapshot

View File

@ -24,7 +24,7 @@ futures = "0.3.28"
hf-hub = { workspace = true } hf-hub = { workspace = true }
itertools = "0.10" itertools = "0.10"
jsonschema = { version = "0.17.1", features = ["draft202012"] } jsonschema = { version = "0.17.1", features = ["draft202012"] }
metrics = "0.21.1" metrics = "0.23.0"
metrics-exporter-prometheus = { version = "0.15.1", features = [] } metrics-exporter-prometheus = { version = "0.15.1", features = [] }
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }

View File

@ -91,14 +91,14 @@ impl Infer {
.limit_concurrent_requests .limit_concurrent_requests
.try_acquire_owned() .try_acquire_owned()
.map_err(|err| { .map_err(|err| {
metrics::increment_counter!("tgi_request_failure", "err" => "overloaded"); metrics::counter!("tgi_request_failure", "err" => "overloaded").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
err err
})?; })?;
// Validate request // Validate request
let valid_request = self.validation.validate(request).await.map_err(|err| { let valid_request = self.validation.validate(request).await.map_err(|err| {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
err err
})?; })?;
@ -140,7 +140,7 @@ impl Infer {
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.apply(messages, grammar_with_prompt) .apply(messages, grammar_with_prompt)
.map_err(|e| { .map_err(|e| {
metrics::increment_counter!("tgi_request_failure", "err" => "template"); metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
tracing::error!("{e}"); tracing::error!("{e}");
e e
}) })
@ -214,7 +214,7 @@ impl Infer {
}) })
} else { } else {
let err = InferError::IncompleteGeneration; let err = InferError::IncompleteGeneration;
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
Err(err) Err(err)
} }

View File

@ -111,7 +111,7 @@ async fn queue_task(
match cmd { match cmd {
QueueCommand::Append(entry, span) => { QueueCommand::Append(entry, span) => {
span.in_scope(|| state.append(*entry)); span.in_scope(|| state.append(*entry));
metrics::increment_gauge!("tgi_queue_size", 1.0); metrics::gauge!("tgi_queue_size").increment(1.0);
} }
QueueCommand::NextBatch { QueueCommand::NextBatch {
min_size, min_size,
@ -124,7 +124,7 @@ async fn queue_task(
let next_batch = let next_batch =
state.next_batch(min_size, max_size, prefill_token_budget, token_budget); state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
response_sender.send(next_batch).unwrap(); response_sender.send(next_batch).unwrap();
metrics::gauge!("tgi_queue_size", state.entries.len() as f64); metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
}), }),
} }
} }
@ -226,7 +226,7 @@ impl State {
// Filter entries where the response receiver was dropped (== entries where the request // Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client) // was dropped by the client)
if entry.response_tx.is_closed() { if entry.response_tx.is_closed() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
tracing::debug!("Dropping entry"); tracing::debug!("Dropping entry");
continue; continue;
} }
@ -336,7 +336,7 @@ impl State {
// Increment batch id // Increment batch id
self.next_batch_id += 1; self.next_batch_id += 1;
metrics::histogram!("tgi_batch_next_size", batch.size as f64); metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
Some((batch_entries, batch, next_batch_span)) Some((batch_entries, batch, next_batch_span))
} }

View File

@ -148,8 +148,8 @@ pub(crate) async fn batching_task(
let batch_size = batch.size; let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens; let batch_max_tokens = batch.max_tokens;
let mut batches = vec![batch]; let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size", batch_size as f64); metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64); metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
let min_size = if waiting_tokens >= max_waiting_tokens { let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try // If we didn't onboard any new requests since >= max_waiting_tokens, we try
@ -170,9 +170,11 @@ pub(crate) async fn batching_task(
{ {
// Tracking metrics // Tracking metrics
if min_size.is_some() { if min_size.is_some() {
metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure"); metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
.increment(1);
} else { } else {
metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded"); metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
.increment(1);
} }
entries.iter_mut().for_each(|(_, entry)| { entries.iter_mut().for_each(|(_, entry)| {
@ -219,8 +221,8 @@ pub(crate) async fn batching_task(
.await; .await;
waiting_tokens += 1; waiting_tokens += 1;
} }
metrics::gauge!("tgi_batch_current_size", 0.0); metrics::gauge!("tgi_batch_current_size").set(0.0);
metrics::gauge!("tgi_batch_current_max_tokens", 0.0); metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
} }
} }
} }
@ -234,7 +236,7 @@ async fn prefill(
) -> Option<CachedBatch> { ) -> Option<CachedBatch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_id = batch.id; let batch_id = batch.id;
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill"); metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
match client.prefill(batch).await { match client.prefill(batch).await {
Ok((generations, next_batch, timings)) => { Ok((generations, next_batch, timings)) => {
@ -248,11 +250,15 @@ async fn prefill(
// Filter next batch and remove requests that were stopped // Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await; let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill"); metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill"); .record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill"); metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); .record(timings.decode.as_secs_f64());
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration","method" => "prefill")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
next_batch next_batch
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
@ -261,7 +267,7 @@ async fn prefill(
generation_health.store(false, Ordering::SeqCst); generation_health.store(false, Ordering::SeqCst);
let _ = client.clear_cache(Some(batch_id)).await; let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries); send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
None None
} }
} }
@ -276,7 +282,7 @@ async fn decode(
) -> Option<CachedBatch> { ) -> Option<CachedBatch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect(); let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
match client.decode(batches).await { match client.decode(batches).await {
Ok((generations, next_batch, timings)) => { Ok((generations, next_batch, timings)) => {
@ -291,13 +297,18 @@ async fn decode(
let next_batch = filter_batch(client, next_batch, entries).await; let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat { if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode"); metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
.record(concat_duration.as_secs_f64());
} }
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode"); metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode"); .record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode"); metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); .record(timings.decode.as_secs_f64());
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
next_batch next_batch
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
@ -307,7 +318,7 @@ async fn decode(
let _ = client.clear_cache(Some(id)).await; let _ = client.clear_cache(Some(id)).await;
} }
send_errors(err, entries); send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode"); metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
None None
} }
} }
@ -365,7 +376,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
// request and we need to stop generating hence why we unwrap_or(true) // request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| { let stopped = send_responses(generation, entry).map_err(|err| {
tracing::error!("Entry response channel error."); tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
err err
}).unwrap_or(true); }).unwrap_or(true);
if stopped { if stopped {
@ -381,7 +392,7 @@ fn send_responses(
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> { ) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected // Return directly if the channel is disconnected
if entry.response_tx.is_closed() { if entry.response_tx.is_closed() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
return Ok(true); return Ok(true);
} }
@ -407,7 +418,7 @@ fn send_responses(
// Create last Token // Create last Token
let tokens_ = generation.tokens.expect("Non empty tokens in generation"); let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len(); let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64); metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
let mut iterator = tokens_ let mut iterator = tokens_
.ids .ids
.into_iter() .into_iter()
@ -472,7 +483,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
// Create and enter a span to link this function back to the entry // Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string()); let err = InferError::GenerationError(error.to_string());
metrics::increment_counter!("tgi_request_failure", "err" => "generation"); metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone. // unwrap_or is valid here as we don't care if the receiver is gone.

View File

@ -126,7 +126,7 @@ async fn queue_task(
match cmd { match cmd {
QueueCommand::Append(entry, span) => { QueueCommand::Append(entry, span) => {
span.in_scope(|| state.append(*entry)); span.in_scope(|| state.append(*entry));
metrics::increment_gauge!("tgi_queue_size", 1.0); metrics::gauge!("tgi_queue_size").increment(1.0);
} }
QueueCommand::NextBatch { QueueCommand::NextBatch {
min_size, min_size,
@ -141,7 +141,7 @@ async fn queue_task(
.instrument(span) .instrument(span)
.await; .await;
response_sender.send(next_batch).unwrap(); response_sender.send(next_batch).unwrap();
metrics::gauge!("tgi_queue_size", state.entries.len() as f64); metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
} }
} }
} }
@ -248,7 +248,7 @@ impl State {
// Filter entries where the response receiver was dropped (== entries where the request // Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client) // was dropped by the client)
if entry.response_tx.is_closed() { if entry.response_tx.is_closed() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
tracing::debug!("Dropping entry"); tracing::debug!("Dropping entry");
continue; continue;
} }
@ -399,7 +399,7 @@ impl State {
// Increment batch id // Increment batch id
self.next_batch_id += 1; self.next_batch_id += 1;
metrics::histogram!("tgi_batch_next_size", batch.size as f64); metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
Some((batch_entries, batch, next_batch_span)) Some((batch_entries, batch, next_batch_span))
} }

View File

@ -154,8 +154,8 @@ pub(crate) async fn batching_task(
let batch_size = batch.size; let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens; let batch_max_tokens = batch.max_tokens;
let mut batches = vec![batch]; let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size", batch_size as f64); metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64); metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
let min_size = if waiting_tokens >= max_waiting_tokens { let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try // If we didn't onboard any new requests since >= max_waiting_tokens, we try
@ -176,9 +176,11 @@ pub(crate) async fn batching_task(
{ {
// Tracking metrics // Tracking metrics
if min_size.is_some() { if min_size.is_some() {
metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure"); metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
.increment(1);
} else { } else {
metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded"); metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
.increment(1);
} }
entries.iter_mut().for_each(|(_, entry)| { entries.iter_mut().for_each(|(_, entry)| {
@ -225,8 +227,8 @@ pub(crate) async fn batching_task(
.await; .await;
waiting_tokens += 1; waiting_tokens += 1;
} }
metrics::gauge!("tgi_batch_current_size", 0.0); metrics::gauge!("tgi_batch_current_size").set(0.0);
metrics::gauge!("tgi_batch_current_max_tokens", 0.0); metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
} }
} }
} }
@ -240,7 +242,7 @@ async fn prefill(
) -> Option<CachedBatch> { ) -> Option<CachedBatch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_id = batch.id; let batch_id = batch.id;
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill"); metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
match client.prefill(batch).await { match client.prefill(batch).await {
Ok((generations, next_batch, timings)) => { Ok((generations, next_batch, timings)) => {
@ -254,11 +256,15 @@ async fn prefill(
// Filter next batch and remove requests that were stopped // Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await; let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill"); metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill"); .record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill"); metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); .record(timings.decode.as_secs_f64());
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
next_batch next_batch
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
@ -267,7 +273,7 @@ async fn prefill(
generation_health.store(false, Ordering::SeqCst); generation_health.store(false, Ordering::SeqCst);
let _ = client.clear_cache(Some(batch_id)).await; let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries); send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
None None
} }
} }
@ -282,7 +288,7 @@ async fn decode(
) -> Option<CachedBatch> { ) -> Option<CachedBatch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect(); let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
match client.decode(batches).await { match client.decode(batches).await {
Ok((generations, next_batch, timings)) => { Ok((generations, next_batch, timings)) => {
@ -297,13 +303,18 @@ async fn decode(
let next_batch = filter_batch(client, next_batch, entries).await; let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat { if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode"); metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
.record(concat_duration.as_secs_f64());
} }
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode"); metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode"); .record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode"); metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); .record(timings.decode.as_secs_f64());
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
next_batch next_batch
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
@ -313,7 +324,7 @@ async fn decode(
let _ = client.clear_cache(Some(id)).await; let _ = client.clear_cache(Some(id)).await;
} }
send_errors(err, entries); send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode"); metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
None None
} }
} }
@ -371,7 +382,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
// request and we need to stop generating hence why we unwrap_or(true) // request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| { let stopped = send_responses(generation, entry).map_err(|err| {
tracing::error!("Entry response channel error."); tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
err err
}).unwrap_or(true); }).unwrap_or(true);
if stopped { if stopped {
@ -387,7 +398,7 @@ fn send_responses(
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> { ) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected // Return directly if the channel is disconnected
if entry.response_tx.is_closed() { if entry.response_tx.is_closed() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
return Ok(true); return Ok(true);
} }
@ -413,7 +424,7 @@ fn send_responses(
// Create last Token // Create last Token
let tokens_ = generation.tokens.expect("Non empty tokens in generation"); let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len(); let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64); metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
let mut iterator = tokens_ let mut iterator = tokens_
.ids .ids
.into_iter() .into_iter()
@ -478,7 +489,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
// Create and enter a span to link this function back to the entry // Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string()); let err = InferError::GenerationError(error.to_string());
metrics::increment_counter!("tgi_request_failure", "err" => "generation"); metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone. // unwrap_or is valid here as we don't care if the receiver is gone.

View File

@ -386,7 +386,7 @@ pub struct CompletionRequest {
/// UNUSED /// UNUSED
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. /// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String, pub model: Option<String>,
/// The prompt to generate completions for. /// The prompt to generate completions for.
#[schema(example = "What is Deep Learning?")] #[schema(example = "What is Deep Learning?")]
@ -733,7 +733,7 @@ impl ChatCompletionChunk {
pub(crate) struct ChatRequest { pub(crate) struct ChatRequest {
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. /// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String, pub model: Option<String>,
/// A list of messages comprising the conversation so far. /// A list of messages comprising the conversation so far.
#[schema(example = "[{\"role\": \"user\", \"content\": \"What is Deep Learning?\"}]")] #[schema(example = "[{\"role\": \"user\", \"content\": \"What is Deep Learning?\"}]")]
@ -850,7 +850,7 @@ pub enum ToolType {
Function { function: FunctionName }, Function { function: FunctionName },
} }
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)]
pub struct FunctionName { pub struct FunctionName {
pub name: String, pub name: String,
} }

View File

@ -11,10 +11,11 @@ use crate::kserve::{
}; };
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters,
GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig,
Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent, OutputMessage, PrefillToken,
Usage, Validation, SimpleToken, StreamDetails, StreamResponse, TextMessage, Token, TokenizeResponse,
ToolCallDelta, ToolCallMessage, Url, Usage, Validation,
}; };
use crate::{ use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
@ -185,7 +186,7 @@ pub(crate) async fn generate_internal(
span: tracing::Span, span: tracing::Span,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> { ) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let start_time = Instant::now(); let start_time = Instant::now();
metrics::increment_counter!("tgi_request_count"); metrics::counter!("tgi_request_count").increment(1);
// Do not long ultra long inputs, like image payloads. // Do not long ultra long inputs, like image payloads.
tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]); tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]);
@ -301,25 +302,15 @@ pub(crate) async fn generate_internal(
); );
// Metrics // Metrics
metrics::increment_counter!("tgi_request_success"); metrics::counter!("tgi_request_success").increment(1);
metrics::histogram!("tgi_request_duration", total_time.as_secs_f64()); metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64());
metrics::histogram!( metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64());
"tgi_request_validation_duration", metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64());
validation_time.as_secs_f64() metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64());
); metrics::histogram!("tgi_request_mean_time_per_token_duration")
metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64()); .record(time_per_token.as_secs_f64());
metrics::histogram!( metrics::histogram!("tgi_request_generated_tokens")
"tgi_request_inference_duration", .record(response.generated_text.generated_tokens as f64);
inference_time.as_secs_f64()
);
metrics::histogram!(
"tgi_request_mean_time_per_token_duration",
time_per_token.as_secs_f64()
);
metrics::histogram!(
"tgi_request_generated_tokens",
response.generated_text.generated_tokens as f64
);
// Send response // Send response
let mut output_text = response.generated_text.text; let mut output_text = response.generated_text.text;
@ -399,7 +390,7 @@ async fn generate_stream_internal(
span: tracing::Span, span: tracing::Span,
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) { ) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
let start_time = Instant::now(); let start_time = Instant::now();
metrics::increment_counter!("tgi_request_count"); metrics::counter!("tgi_request_count").increment(1);
tracing::debug!("Input: {}", req.inputs); tracing::debug!("Input: {}", req.inputs);
@ -427,12 +418,12 @@ async fn generate_stream_internal(
let best_of = req.parameters.best_of.unwrap_or(1); let best_of = req.parameters.best_of.unwrap_or(1);
if best_of != 1 { if best_of != 1 {
let err = InferError::from(ValidationError::BestOfStream); let err = InferError::from(ValidationError::BestOfStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)); yield Ok(Event::from(err));
} else if req.parameters.decoder_input_details { } else if req.parameters.decoder_input_details {
let err = InferError::from(ValidationError::PrefillDetailsStream); let err = InferError::from(ValidationError::PrefillDetailsStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)); yield Ok(Event::from(err));
} else { } else {
@ -500,13 +491,13 @@ async fn generate_stream_internal(
span.record("seed", format!("{:?}", generated_text.seed)); span.record("seed", format!("{:?}", generated_text.seed));
// Metrics // Metrics
metrics::increment_counter!("tgi_request_success"); metrics::counter!("tgi_request_success").increment(1);
metrics::histogram!("tgi_request_duration", total_time.as_secs_f64()); metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64());
metrics::histogram!("tgi_request_validation_duration", validation_time.as_secs_f64()); metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64());
metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64()); metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64());
metrics::histogram!("tgi_request_inference_duration", inference_time.as_secs_f64()); metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64());
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token.as_secs_f64()); metrics::histogram!("tgi_request_mean_time_per_token_duration").record(time_per_token.as_secs_f64());
metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64); metrics::histogram!("tgi_request_generated_tokens").record(generated_text.generated_tokens as f64);
// StreamResponse // StreamResponse
end_reached = true; end_reached = true;
@ -553,7 +544,7 @@ async fn generate_stream_internal(
// Skip if we already sent an error // Skip if we already sent an error
if !end_reached && !error { if !end_reached && !error {
let err = InferError::IncompleteGeneration; let err = InferError::IncompleteGeneration;
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)); yield Ok(Event::from(err));
} }
@ -572,8 +563,8 @@ request_body = CompletionRequest,
responses( responses(
(status = 200, description = "Generated Chat Completion", (status = 200, description = "Generated Chat Completion",
content( content(
("application/json" = Completion), ("application/json" = CompletionFinal),
("text/event-stream" = CompletionCompleteChunk), ("text/event-stream" = Chunk),
)), )),
(status = 424, description = "Generation Error", body = ErrorResponse, (status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})), example = json ! ({"error": "Request failed during generation"})),
@ -604,9 +595,10 @@ async fn completions(
Json(req): Json<CompletionRequest>, Json(req): Json<CompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count"); metrics::counter!("tgi_request_count").increment(1);
let CompletionRequest { let CompletionRequest {
model,
max_tokens, max_tokens,
seed, seed,
stop, stop,
@ -625,7 +617,7 @@ async fn completions(
// if suffix is present throw an error // if suffix is present throw an error
if req.suffix.is_some() { if req.suffix.is_some() {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
return Err(( return Err((
StatusCode::UNPROCESSABLE_ENTITY, StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse { Json(ErrorResponse {
@ -637,7 +629,7 @@ async fn completions(
} }
if req.prompt.0.len() > info.max_client_batch_size { if req.prompt.0.len() > info.max_client_batch_size {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
return Err(( return Err((
StatusCode::UNPROCESSABLE_ENTITY, StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse { Json(ErrorResponse {
@ -675,7 +667,7 @@ async fn completions(
seed, seed,
top_n_tokens: None, top_n_tokens: None,
grammar: None, grammar: None,
..Default::default() adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
}, },
}) })
.collect(); .collect();
@ -820,6 +812,10 @@ async fn completions(
} }
}; };
let stream = stream.chain(futures::stream::once(async {
Ok(Event::default().data("[DONE]"))
}));
let sse = Sse::new(stream).keep_alive(KeepAlive::default()); let sse = Sse::new(stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response()) Ok((headers, sse).into_response())
} else { } else {
@ -1009,8 +1005,9 @@ async fn chat_completions(
Json(req): Json<ChatRequest>, Json(req): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count"); metrics::counter!("tgi_request_count").increment(1);
let ChatRequest { let ChatRequest {
model,
logprobs, logprobs,
max_tokens, max_tokens,
messages, messages,
@ -1039,7 +1036,7 @@ async fn chat_completions(
// response_format and tools are mutually exclusive // response_format and tools are mutually exclusive
if response_format.is_some() && tools.as_ref().is_some() { if response_format.is_some() && tools.as_ref().is_some() {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
return Err(( return Err((
StatusCode::UNPROCESSABLE_ENTITY, StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse { Json(ErrorResponse {
@ -1053,7 +1050,7 @@ async fn chat_completions(
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) { let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
Ok(grammar) => grammar, Ok(grammar) => grammar,
Err(err) => { Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
return Err(( return Err((
StatusCode::UNPROCESSABLE_ENTITY, StatusCode::UNPROCESSABLE_ENTITY,
@ -1082,7 +1079,7 @@ async fn chat_completions(
let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) { let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) {
Ok(inputs) => inputs, Ok(inputs) => inputs,
Err(err) => { Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
return Err(( return Err((
StatusCode::UNPROCESSABLE_ENTITY, StatusCode::UNPROCESSABLE_ENTITY,
@ -1116,7 +1113,7 @@ async fn chat_completions(
seed, seed,
top_n_tokens: req.top_logprobs, top_n_tokens: req.top_logprobs,
grammar, grammar,
..Default::default() adapter_id: model.filter(|m| *m != "tgi").map(String::from),
}, },
}; };
@ -1178,6 +1175,11 @@ async fn chat_completions(
span, span,
) )
.await; .await;
let response_stream = response_stream.chain(futures::stream::once(async {
Ok(Event::default().data("[DONE]"))
}));
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response()) Ok((headers, sse).into_response())
} else { } else {
@ -1280,7 +1282,7 @@ async fn vertex_compatibility(
Json(req): Json<VertexRequest>, Json(req): Json<VertexRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count"); metrics::counter!("tgi_request_count").increment(1);
// check that theres at least one instance // check that theres at least one instance
if req.instances.is_empty() { if req.instances.is_empty() {
@ -1454,6 +1456,14 @@ pub async fn run(
GrammarType, GrammarType,
ChatRequest, ChatRequest,
Message, Message,
MessageContent,
MessageChunk,
Url,
FunctionName,
OutputMessage,
TextMessage,
ToolCallMessage,
ToolCallDelta,
ChatCompletionComplete, ChatCompletionComplete,
ChatCompletionChoice, ChatCompletionChoice,
ChatCompletionDelta, ChatCompletionDelta,

View File

@ -157,7 +157,7 @@ impl Validation {
)); ));
} }
metrics::histogram!("tgi_request_input_length", input_length as f64); metrics::histogram!("tgi_request_input_length").record(input_length as f64);
Ok((inputs, input_length, max_new_tokens)) Ok((inputs, input_length, max_new_tokens))
} }
// Return inputs without validation // Return inputs without validation
@ -384,7 +384,7 @@ impl Validation {
ignore_eos_token: false, ignore_eos_token: false,
}; };
metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64); metrics::histogram!("tgi_request_max_new_tokens").record(max_new_tokens as f64);
Ok(ValidGenerateRequest { Ok(ValidGenerateRequest {
inputs, inputs,

View File

@ -35,5 +35,5 @@ run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
export-requirements: export-requirements:
poetry export -o requirements_cuda.txt --without-hashes poetry export -o requirements_cuda.txt --without-hashes --with cuda
poetry export -o requirements_rocm.txt --without-hashes poetry export -o requirements_rocm.txt --without-hashes

View File

@ -59,3 +59,18 @@ def marlin_gemm(
Matrix multiplication using Marlin kernels. Matrix multiplication using Marlin kernels.
""" """
... ...
# fp8 marlin
def fp8_marlin_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
size_m: int,
size_n: int,
size_k: int,
) -> torch.Tensor:
return torch.ops._C.fp8_marlin_gemm(
a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
)

View File

@ -9,4 +9,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gptq_marlin_repack", &gptq_marlin_repack, m.def("gptq_marlin_repack", &gptq_marlin_repack,
"Repack GPTQ parameters for Marlin"); "Repack GPTQ parameters for Marlin");
m.def("marlin_gemm", &marlin_gemm, "Marlin gemm"); m.def("marlin_gemm", &marlin_gemm, "Marlin gemm");
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
m.def("fp8_marlin_gemm", &fp8_marlin_gemm);
} }

View File

@ -27,4 +27,9 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_scales, torch::Tensor &workspace, torch::Tensor &b_scales, torch::Tensor &workspace,
int64_t size_m, int64_t size_n, int64_t size_k); int64_t size_m, int64_t size_n, int64_t size_k);
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k);
#endif #endif

File diff suppressed because it is too large Load Diff

View File

@ -9,6 +9,7 @@ setup(
CUDAExtension( CUDAExtension(
name="marlin_kernels", name="marlin_kernels",
sources=[ sources=[
"marlin_kernels/fp8_marlin.cu",
"marlin_kernels/gptq_marlin.cu", "marlin_kernels/gptq_marlin.cu",
"marlin_kernels/gptq_marlin_repack.cu", "marlin_kernels/gptq_marlin_repack.cu",
"marlin_kernels/marlin_cuda_kernel.cu", "marlin_kernels/marlin_cuda_kernel.cu",

View File

@ -8,6 +8,9 @@ from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.utils import weight_hub_files, download_weights from text_generation_server.utils import weight_hub_files, download_weights
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -16,7 +19,10 @@ def default_bloom():
revision = "main" revision = "main"
filenames = weight_hub_files(model_id, revision, ".safetensors") filenames = weight_hub_files(model_id, revision, ".safetensors")
download_weights(filenames, model_id, revision) download_weights(filenames, model_id, revision)
return BLOOMSharded(model_id) return BLOOMSharded(
model_id,
model_class=BloomForCausalLM,
)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View File

@ -10,7 +10,7 @@ from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def default_causal_lm(): def default_causal_lm():
return CausalLM("gpt2") return CausalLM.fallback("gpt2")
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View File

@ -1,13 +1,12 @@
import pytest import pytest
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch, CausalLM
from text_generation_server.models.santacoder import SantaCoder
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def default_santacoder(): def default_santacoder():
return SantaCoder("bigcode/santacoder") return CausalLM.fallback(model_id="bigcode/santacoder")
@pytest.fixture @pytest.fixture

View File

@ -20,7 +20,7 @@ def mt0_small_tokenizer():
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def default_seq2seq_lm(): def default_seq2seq_lm():
return Seq2SeqLM("bigscience/mt0-small") return Seq2SeqLM.fallback("bigscience/mt0-small")
@pytest.fixture @pytest.fixture

View File

@ -2,6 +2,7 @@ import torch
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
) )
from text_generation_server.utils.weights import DefaultWeightsLoader
class ProcessGroup: class ProcessGroup:
@ -42,7 +43,12 @@ class Weights:
def test_weight_hub_files_offline_error(): def test_weight_hub_files_offline_error():
vocab_size = 17 vocab_size = 17
weights = Weights(rank=0, world_size=1, vocab_size=vocab_size, hidden_dim=256) weights = Weights(
rank=0,
world_size=1,
vocab_size=vocab_size,
hidden_dim=256,
)
embeddings = TensorParallelEmbedding("", weights) embeddings = TensorParallelEmbedding("", weights)
input_ids = torch.arange(vocab_size) input_ids = torch.arange(vocab_size)

View File

@ -1,13 +1,47 @@
import pytest import pytest
import torch import torch
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import (
from text_generation_server.layers.gptq import GPTQWeight DefaultWeightsLoader,
from text_generation_server.layers.exl2 import Exl2Weight Weights,
from text_generation_server.layers.marlin import MarlinWeight WeightsLoader,
)
from text_generation_server.layers.gptq import GPTQWeight, GPTQWeightsLoader
from text_generation_server.layers.exl2 import Exl2Weight, Exl2WeightsLoader
from text_generation_server.layers.marlin import MarlinWeight, MarlinWeightsLoader
from types import SimpleNamespace from types import SimpleNamespace
from typing import List, Optional, Dict, Union from typing import List, Optional, Dict, Union
from pathlib import Path from pathlib import Path
@pytest.fixture
def gptq_weights_loader():
return GPTQWeightsLoader(
bits=4,
groupsize=-1,
desc_act=False,
quant_method="gptq",
quantize="gptq",
sym=True,
)
@pytest.fixture
def gptq_weights_loader_awq():
return GPTQWeightsLoader(
bits=4,
groupsize=-1,
desc_act=False,
quant_method="awq",
quantize="awq",
sym=True,
)
@pytest.fixture
def marlin_weights_loader():
return MarlinWeightsLoader(bits=4, is_marlin_24=False)
dummy_file_system = { dummy_file_system = {
"test_weights": { "test_weights": {
"layer.0.weight": torch.tensor( "layer.0.weight": torch.tensor(
@ -58,7 +92,7 @@ dummy_file_system = {
dtype=torch.float32, dtype=torch.float32,
), ),
}, },
"test_get_multi_weights_row": { "test_get_weights_row": {
"weight.weight": torch.tensor( "weight.weight": torch.tensor(
[ [
[1, 2], [1, 2],
@ -101,7 +135,7 @@ dummy_file_system = {
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), "weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
}, },
"test_get_multi_weights_row_gptq": { "test_get_weights_row_gptq": {
"weight.qweight": torch.tensor( "weight.qweight": torch.tensor(
[ [
[1, 2], [1, 2],
@ -200,7 +234,7 @@ dummy_file_system = {
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
"weight.q_groups": torch.tensor([4], dtype=torch.int16), "weight.q_groups": torch.tensor([4], dtype=torch.int16),
}, },
"test_get_multi_weights_row_exl2": { "test_get_weights_row_exl2": {
"weight.q_weight": torch.tensor( "weight.q_weight": torch.tensor(
[ [
[1, 2], [1, 2],
@ -245,7 +279,7 @@ dummy_file_system = {
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
"weight.q_groups": torch.tensor([4], dtype=torch.int16), "weight.q_groups": torch.tensor([4], dtype=torch.int16),
}, },
"test_get_multi_weights_row_marlin": { "test_get_weights_row_marlin": {
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
}, },
@ -308,6 +342,7 @@ class MockWeights(Weights):
dummy_fs, dummy_fs,
aliases: Optional[Dict[str, List[str]]] = None, aliases: Optional[Dict[str, List[str]]] = None,
prefix: Optional[str] = None, prefix: Optional[str] = None,
weights_loader: Optional[WeightsLoader] = None,
): ):
routing = {} routing = {}
self.dummy_fs = dummy_fs self.dummy_fs = dummy_fs
@ -327,6 +362,9 @@ class MockWeights(Weights):
self.dtype = dtype self.dtype = dtype
self.process_group = process_group self.process_group = process_group
self.prefix = prefix self.prefix = prefix
self.weights_loader = (
DefaultWeightsLoader() if weights_loader is None else weights_loader
)
self._handles = {} self._handles = {}
def _get_handle(self, filename: Union[Path, str]): def _get_handle(self, filename: Union[Path, str]):
@ -412,12 +450,10 @@ def test_get_weights_col_packed():
) )
prefix = "weight" prefix = "weight"
quantize = None
block_sizes = 1 block_sizes = 1
w = weights.get_weights_col_packed( w = weights.get_weights_col_packed(
prefix=prefix, prefix=prefix,
quantize=quantize,
block_sizes=block_sizes, block_sizes=block_sizes,
) )
@ -448,12 +484,10 @@ def test_get_weights_col_packed_block_size():
) )
prefix = "weight" prefix = "weight"
quantize = None
block_sizes = 2 block_sizes = 2
w = weights.get_weights_col_packed( w = weights.get_weights_col_packed(
prefix=prefix, prefix=prefix,
quantize=quantize,
block_sizes=block_sizes, block_sizes=block_sizes,
) )
@ -484,12 +518,10 @@ def test_get_weights_col_packed_block_size_arr():
) )
prefix = "weight" prefix = "weight"
quantize = None
block_sizes = [1, 1] block_sizes = [1, 1]
w = weights.get_weights_col_packed( w = weights.get_weights_col_packed(
prefix=prefix, prefix=prefix,
quantize=quantize,
block_sizes=block_sizes, block_sizes=block_sizes,
) )
@ -519,11 +551,9 @@ def test_get_multi_weights_col():
) )
prefixes = ["weight", "weight"] prefixes = ["weight", "weight"]
quantize = None
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=prefixes, prefixes=prefixes,
quantize=quantize,
dim=0, dim=0,
) )
@ -545,10 +575,10 @@ def test_get_multi_weights_col():
) )
def test_get_multi_weights_row(): def test_get_weights_row():
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_row", "test_get_weights_row",
], ],
device="cpu", device="cpu",
dtype=torch.float32, dtype=torch.float32,
@ -557,11 +587,9 @@ def test_get_multi_weights_row():
) )
prefix = "weight" prefix = "weight"
quantize = None
w = weights.get_multi_weights_row( w = weights.get_weights_row(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
assert torch.allclose( assert torch.allclose(
@ -576,7 +604,7 @@ def test_get_multi_weights_row():
# test_get_weights_col # test_get_weights_col
def test_get_weights_col_awq(): def test_get_weights_col_awq(gptq_weights_loader_awq):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_weights_col_gptq", "test_get_weights_col_gptq",
@ -585,14 +613,13 @@ def test_get_weights_col_awq():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader_awq,
) )
prefix = "weight" prefix = "weight"
quantize = "awq"
w = weights.get_weights_col( w = weights.get_weights_col(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
expected_weight = GPTQWeight( expected_weight = GPTQWeight(
@ -617,7 +644,7 @@ def test_get_weights_col_awq():
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_weights_col_gtpq(): def test_get_weights_col_gtpq(gptq_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_weights_col_gptq", "test_get_weights_col_gptq",
@ -626,14 +653,13 @@ def test_get_weights_col_gtpq():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader,
) )
prefix = "weight" prefix = "weight"
quantize = "gptq"
w = weights.get_weights_col( w = weights.get_weights_col(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
expected_weight = GPTQWeight( expected_weight = GPTQWeight(
@ -664,14 +690,13 @@ def test_get_weights_col_exl2():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=Exl2WeightsLoader(),
) )
prefix = "weight" prefix = "weight"
quantize = "exl2"
w = weights.get_weights_col( w = weights.get_weights_col(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
scaled_scale_max = 0.3906 * 256 scaled_scale_max = 0.3906 * 256
@ -692,7 +717,7 @@ def test_get_weights_col_exl2():
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_weights_col_marlin(): def test_get_weights_col_marlin(marlin_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_weights_col_marlin", "test_get_weights_col_marlin",
@ -701,14 +726,13 @@ def test_get_weights_col_marlin():
dtype=torch.float16, dtype=torch.float16,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=marlin_weights_loader,
) )
prefix = "weight" prefix = "weight"
quantize = "marlin"
w = weights.get_weights_col( w = weights.get_weights_col(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
expected_weight = MarlinWeight( expected_weight = MarlinWeight(
@ -723,7 +747,7 @@ def test_get_weights_col_marlin():
# test_get_weights_col_packed # test_get_weights_col_packed
def test_get_weights_col_packed_awq(): def test_get_weights_col_packed_awq(gptq_weights_loader_awq):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_weights_col_packed_gptq", "test_get_weights_col_packed_gptq",
@ -732,15 +756,14 @@ def test_get_weights_col_packed_awq():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader_awq,
) )
prefix = "weight" prefix = "weight"
quantize = "awq"
block_sizes = 1 block_sizes = 1
w = weights.get_weights_col_packed( w = weights.get_weights_col_packed(
prefix=prefix, prefix=prefix,
quantize=quantize,
block_sizes=block_sizes, block_sizes=block_sizes,
) )
@ -773,15 +796,14 @@ def test_get_weights_col_packed_exl2():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=Exl2WeightsLoader(),
) )
prefix = "weight" prefix = "weight"
quantize = "exl2"
block_sizes = 1 block_sizes = 1
w = weights.get_weights_col_packed( w = weights.get_weights_col_packed(
prefix=prefix, prefix=prefix,
quantize=quantize,
block_sizes=block_sizes, block_sizes=block_sizes,
) )
@ -803,7 +825,7 @@ def test_get_weights_col_packed_exl2():
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_weights_col_packed_gptq(): def test_get_weights_col_packed_gptq(gptq_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_weights_col_packed_gptq", "test_get_weights_col_packed_gptq",
@ -812,14 +834,13 @@ def test_get_weights_col_packed_gptq():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader,
) )
prefixes = ["weight"] prefixes = ["weight"]
quantize = "gptq"
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=prefixes, prefixes=prefixes,
quantize=quantize,
dim=0, dim=0,
) )
@ -842,7 +863,7 @@ def test_get_weights_col_packed_gptq():
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_weights_col_packed_marlin(): def test_get_weights_col_packed_marlin(marlin_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_weights_col_packed_marlin", "test_get_weights_col_packed_marlin",
@ -851,14 +872,13 @@ def test_get_weights_col_packed_marlin():
dtype=torch.float16, dtype=torch.float16,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=marlin_weights_loader,
) )
prefix = "weight" prefix = "weight"
quantize = "marlin"
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=[prefix], prefixes=[prefix],
quantize=quantize,
dim=0, dim=0,
) )
@ -876,7 +896,7 @@ def test_get_weights_col_packed_marlin():
# test_get_multi_weights_col # test_get_multi_weights_col
def test_get_multi_weights_col_awq(): def test_get_multi_weights_col_awq(gptq_weights_loader_awq):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_col_gptq", "test_get_multi_weights_col_gptq",
@ -885,14 +905,13 @@ def test_get_multi_weights_col_awq():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader_awq,
) )
prefixes = ["weight"] prefixes = ["weight"]
quantize = "awq"
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=prefixes, prefixes=prefixes,
quantize=quantize,
dim=0, dim=0,
) )
@ -924,22 +943,21 @@ def test_get_multi_weights_col_exl2():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=Exl2WeightsLoader(),
) )
prefix = "weight" prefix = "weight"
quantize = "exl2"
try: try:
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=[prefix], prefixes=[prefix],
quantize=quantize,
dim=0, dim=0,
) )
except ValueError as e: except ValueError as e:
assert e.args[0] == "get_multi_weights_col is not supported for exl2" assert e.args[0] == "get_multi_weights_col is not supported for exl2"
def test_get_multi_weights_col_gptq(): def test_get_multi_weights_col_gptq(gptq_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_col_gptq", "test_get_multi_weights_col_gptq",
@ -948,14 +966,13 @@ def test_get_multi_weights_col_gptq():
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader,
) )
prefixes = ["weight"] prefixes = ["weight"]
quantize = "gptq"
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=prefixes, prefixes=prefixes,
quantize=quantize,
dim=0, dim=0,
) )
@ -978,7 +995,7 @@ def test_get_multi_weights_col_gptq():
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_col_marlin(): def test_get_multi_weights_col_marlin(marlin_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_col_marlin", "test_get_multi_weights_col_marlin",
@ -987,14 +1004,13 @@ def test_get_multi_weights_col_marlin():
dtype=torch.float16, dtype=torch.float16,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=marlin_weights_loader,
) )
prefix = "weight" prefix = "weight"
quantize = "marlin"
w = weights.get_multi_weights_col( w = weights.get_multi_weights_col(
prefixes=[prefix], prefixes=[prefix],
quantize=quantize,
dim=0, dim=0,
) )
@ -1007,26 +1023,25 @@ def test_get_multi_weights_col_marlin():
assert torch.allclose(w.s, expected_weight.s), "s mismatch" assert torch.allclose(w.s, expected_weight.s), "s mismatch"
# test_get_multi_weights_row # test_get_weights_row
def test_get_multi_weights_row_awq(): def test_get_weights_row_awq(gptq_weights_loader_awq):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_row_gptq", "test_get_weights_row_gptq",
], ],
device="cpu", device="cpu",
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader_awq,
) )
prefix = "weight" prefix = "weight"
quantize = "awq"
w = weights.get_multi_weights_row( w = weights.get_weights_row(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
expected_weight = GPTQWeight( expected_weight = GPTQWeight(
@ -1048,23 +1063,22 @@ def test_get_multi_weights_row_awq():
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_row_exl2(): def test_get_weights_row_exl2():
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_row_exl2", "test_get_weights_row_exl2",
], ],
device="cpu", device="cpu",
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=Exl2WeightsLoader(),
) )
prefix = "weight" prefix = "weight"
quantize = "exl2"
w = weights.get_multi_weights_row( w = weights.get_weights_row(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
print(w) print(w)
@ -1086,23 +1100,22 @@ def test_get_multi_weights_row_exl2():
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_multi_weights_row_gptq(): def test_get_weights_row_gptq(gptq_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_row_gptq", "test_get_weights_row_gptq",
], ],
device="cpu", device="cpu",
dtype=torch.float32, dtype=torch.float32,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader,
) )
prefix = "weight" prefix = "weight"
quantize = "gptq"
w = weights.get_multi_weights_row( w = weights.get_weights_row(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
expected_weight = GPTQWeight( expected_weight = GPTQWeight(
@ -1124,23 +1137,22 @@ def test_get_multi_weights_row_gptq():
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_row_marlin(): def test_get_weights_row_marlin(marlin_weights_loader):
weights = MockWeights( weights = MockWeights(
[ [
"test_get_multi_weights_row_marlin", "test_get_weights_row_marlin",
], ],
device="cpu", device="cpu",
dtype=torch.float16, dtype=torch.float16,
process_group=dummy_process_group, process_group=dummy_process_group,
dummy_fs=dummy_file_system, dummy_fs=dummy_file_system,
weights_loader=marlin_weights_loader,
) )
prefix = "weight" prefix = "weight"
quantize = "marlin"
w = weights.get_multi_weights_row( w = weights.get_weights_row(
prefix=prefix, prefix=prefix,
quantize=quantize,
) )
expected_weight = MarlinWeight( expected_weight = MarlinWeight(

View File

@ -353,6 +353,7 @@ def quantize(
upload_to_model_id=upload_to_model_id, upload_to_model_id=upload_to_model_id,
percdamp=percdamp, percdamp=percdamp,
act_order=act_order, act_order=act_order,
sym=True,
) )

View File

@ -1,6 +1,9 @@
import torch import torch
from typing import List, Union
from dataclasses import dataclass from dataclasses import dataclass
from text_generation_server.utils.weights import WeightsLoader, Weights
@dataclass @dataclass
class Exl2Weight: class Exl2Weight:
@ -21,3 +24,60 @@ class Exl2Weight:
@property @property
def device(self) -> torch.device: def device(self) -> torch.device:
return self.q_weight.device return self.q_weight.device
class Exl2WeightsLoader(WeightsLoader):
"""Loader for exl2-quantized weights."""
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
raise RuntimeError("Column-packed weights are not supported for exl")
def get_weights_col(self, weights: Weights, prefix: str):
try:
q_weight = weights.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = weights.get_tensor(f"{prefix}.q_scale")
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
q_groups = weights.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
raise ValueError("get_multi_weights_col is not supported for exl2")
def get_weights_row(self, weights: Weights, prefix: str):
try:
q_weight = weights.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = weights.get_tensor(f"{prefix}.q_scale")
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
q_groups = weights.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)

View File

@ -1,4 +1,23 @@
from enum import Enum, auto
import torch import torch
from text_generation_server.utils.import_utils import SYSTEM
def get_fp8_linear() -> torch.nn.Module:
"""
Return an FP8 linear `Module` that is compatible with the current system.
"""
if SYSTEM == "cuda":
major, minor = torch.cuda.get_device_capability()
if major == 8 and minor < 9:
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
return GPTQMarlinFP8Linear
# On other systems let Torch decide if the hardware supports FP8.
return Fp8Linear
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):

View File

@ -1,20 +1,14 @@
from dataclasses import dataclass from dataclasses import dataclass
from loguru import logger
import os import os
from typing import Optional from typing import List, Optional, Union
from safetensors import SafetensorError
from text_generation_server.utils.weights import Weights, WeightsLoader
import torch import torch
from text_generation_server.utils.import_utils import ( from text_generation_server.utils.import_utils import (
SYSTEM, SYSTEM,
) )
from text_generation_server.utils.log import log_once
@dataclass
class GPTQParams:
bits: int
checkpoint_format: Optional[str]
groupsize: int
desc_act: bool
quant_method: str
sym: bool
@dataclass @dataclass
@ -69,3 +63,345 @@ elif CAN_EXLLAMA:
pass pass
from text_generation_server.layers.gptq.quant_linear import QuantLinear from text_generation_server.layers.gptq.quant_linear import QuantLinear
class GPTQWeightsLoader(WeightsLoader):
"""
Loader for GPTQ- and AWQ-quantized weights.
"""
def __init__(
self,
*,
bits: int,
desc_act: bool,
groupsize: int,
quant_method: str,
quantize: str,
sym: bool,
):
self.bits = bits
self.desc_act = desc_act
self.groupsize = groupsize
self.quant_method = quant_method
self.quantize = quantize
self.sym = sym
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
try:
qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
)
scales = weights.get_packed_sharded(
f"{prefix}.scales", dim=1, block_sizes=block_sizes
)
scales = scales.to(dtype=weights.dtype)
self._get_gptq_params(weights)
if can_use_gptq_marlin(
bits=self.bits,
groupsize=self.groupsize,
quant_method=self.quant_method,
quantize=self.quantize,
sym=self.sym,
):
g_idx = weights.get_tensor(f"{prefix}.g_idx")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
sym=self.sym,
sharded_infeatures=False,
)
qzeros = weights.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
if self.quantize == "gptq" and self.quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.g_idx")
elif self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
else:
g_idx = None
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_exllama=False,
)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
try:
qweight = torch.cat(
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
)
scales = torch.cat(
[weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
self._get_gptq_params(weights)
if can_use_gptq_marlin(
bits=self.bits,
groupsize=self.groupsize,
quant_method=self.quant_method,
quantize=self.quantize,
sym=self.sym,
):
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
sym=self.sym,
sharded_infeatures=False,
)
qzeros = torch.cat(
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
from text_generation_server.layers.gptq import HAS_EXLLAMA
use_exllama = (
self.bits == 4
and HAS_EXLLAMA
and self.quantize == "gptq"
and not self.desc_act
)
if self.quantize == "gptq" and self.quant_method == "gptq":
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
elif self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
else:
g_idx = None
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_exllama=use_exllama,
)
def get_weights_row(self, weights: Weights, prefix: str):
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
self._get_gptq_params(weights)
if can_use_gptq_marlin(
bits=self.bits,
groupsize=self.groupsize,
quant_method=self.quant_method,
quantize=self.quantize,
sym=self.sym,
):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
if self.desc_act or self.groupsize == -1:
scales = weights.get_tensor(f"{prefix}.scales")
else:
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
sharded_in_features = weights.process_group.size() > 1
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
sym=self.sym,
sharded_infeatures=sharded_in_features,
)
use_exllama = True
if self.bits != 4:
use_exllama = False
if self.desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if self.quantize == "gptq" and self.quant_method == "gptq":
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
else:
g_idx = None
if weights.process_group.size() > 1:
if g_idx is not None:
if (
not torch.equal(
g_idx.cpu(),
torch.tensor(
[i // self.groupsize for i in range(g_idx.shape[0])],
dtype=torch.int32,
),
)
and not (g_idx == 0).all()
):
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
use_exllama = False
from text_generation_server.layers.gptq import (
HAS_EXLLAMA,
CAN_EXLLAMA,
GPTQWeight,
)
if use_exllama:
if not HAS_EXLLAMA:
if CAN_EXLLAMA:
log_once(
logger.warning,
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
)
use_exllama = False
else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
if use_exllama and self.groupsize != -1:
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
else:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
scales = weights.get_tensor(f"{prefix}.scales")
if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0]
if self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_exllama=use_exllama,
)
def _get_gptq_params(self, weights: Weights):
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
self.bits = weights.get_tensor("gptq_bits").item()
self.groupsize = weights.get_tensor("gptq_groupsize").item()
self.desc_act = False
# `server quantize` used asymmetric quantization unconditionally
# before the `gptq_sym` setting tensor was added.
self.sym = (
weights.get_tensor("gptq_sym").item()
if weights._has_tensor("gptq_sym")
else False
)
self.quant_method = "gptq"

View File

@ -16,6 +16,8 @@ from text_generation_server.layers.gptq.quant_linear import QuantLinear
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional
from text_generation_server.utils.weights import DefaultWeightsLoader
DEV = torch.device("cuda:0") DEV = torch.device("cuda:0")
@ -869,6 +871,7 @@ def quantize(
upload_to_model_id: Optional[str], upload_to_model_id: Optional[str],
percdamp: float, percdamp: float,
act_order: bool, act_order: bool,
sym: bool,
): ):
print("loading model") print("loading model")
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
@ -891,6 +894,7 @@ def quantize(
dtype=torch.float16, dtype=torch.float16,
process_group=process_group, process_group=process_group,
aliases={"embed_tokens.weight": ["lm_head.weight"]}, aliases={"embed_tokens.weight": ["lm_head.weight"]},
weights_loader=DefaultWeightsLoader(),
) )
hooks = [] hooks = []
for name, module in model.named_modules(): for name, module in model.named_modules():
@ -943,6 +947,7 @@ def quantize(
percdamp=percdamp, percdamp=percdamp,
act_order=act_order, act_order=act_order,
hooks=hooks, hooks=hooks,
sym=sym,
) )
print(time.time() - tick) print(time.time() - tick)
@ -954,6 +959,7 @@ def quantize(
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
state_dict["gptq_bits"] = torch.LongTensor([bits]) state_dict["gptq_bits"] = torch.LongTensor([bits])
state_dict["gptq_groupsize"] = torch.LongTensor([groupsize]) state_dict["gptq_groupsize"] = torch.LongTensor([groupsize])
state_dict["gptq_sym"] = torch.BoolTensor([sym])
max_shard_size = "10GB" max_shard_size = "10GB"
shards, index = shard_checkpoint( shards, index = shard_checkpoint(

View File

@ -106,9 +106,9 @@ def get_linear(weight, bias, quantize):
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ" "Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
) )
elif quantize == "fp8": elif quantize == "fp8":
from text_generation_server.layers.fp8 import Fp8Linear from text_generation_server.layers.fp8 import get_fp8_linear
linear = Fp8Linear(weight, bias) linear = get_fp8_linear()(weight, bias)
elif quantize == "bitsandbytes": elif quantize == "bitsandbytes":
try: try:
from text_generation_server.layers.bnb import ( from text_generation_server.layers.bnb import (

View File

@ -1,11 +1,13 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from loguru import logger
from text_generation_server.layers.gptq import GPTQParams from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weights, WeightsLoader
try: try:
import marlin_kernels import marlin_kernels
@ -24,16 +26,132 @@ GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
MARLIN_TILE_SIZE = 16 MARLIN_TILE_SIZE = 16
def can_use_gptq_marlin(gptq_params: GPTQParams, quantize: str) -> bool: class MarlinWeightsLoader(WeightsLoader):
"""Loader for Marlin-quantized weights."""
def __init__(self, *, bits: int, is_marlin_24: bool):
self.bits = bits
self.is_marlin_24 = is_marlin_24
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
if self.is_marlin_24:
B = weights.get_packed_sharded(
f"{prefix}.B_24", dim=1, block_sizes=block_sizes
)
B_meta = weights.get_packed_sharded(
f"{prefix}.B_meta", dim=1, block_sizes=block_sizes
)
s = weights.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
B = weights.get_packed_sharded(
f"{prefix}.B", dim=1, block_sizes=block_sizes
)
s = weights.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
weight = MarlinWeight(B=B, s=s)
return weight
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
B = torch.cat(
[weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `marlin` weight, make sure the model is already quantized"
)
B_meta = torch.cat(
[weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1
)
s = torch.cat(
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
try:
B = torch.cat(
[weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `marlin` weight, make sure the model is already quantized"
)
s = torch.cat(
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
weight = MarlinWeight(B=B, s=s)
return weight
def get_weights_row(self, weights: Weights, prefix: str):
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
B = weights.get_sharded(f"{prefix}.B_24", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
)
B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0)
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = weights.get_tensor(f"{prefix}.s")
else:
s = weights.get_sharded(f"{prefix}.s", dim=0)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
try:
B = weights.get_sharded(f"{prefix}.B", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized."
)
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = weights.get_tensor(f"{prefix}.s")
else:
s = weights.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s)
return weight
def can_use_gptq_marlin(
*, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool
) -> bool:
return ( return (
SYSTEM == "cuda" SYSTEM == "cuda"
and marlin_kernels is not None and marlin_kernels is not None
and has_sm_8_0 and has_sm_8_0
and quantize == "gptq" and quantize == "gptq"
and gptq_params.quant_method == "gptq" and quant_method == "gptq"
and gptq_params.bits in GPTQ_MARLIN_BITS and bits in GPTQ_MARLIN_BITS
and gptq_params.groupsize in GPTQ_MARLIN_GROUP_SIZES and groupsize in GPTQ_MARLIN_GROUP_SIZES
and gptq_params.sym and sym
) )
@ -339,6 +457,115 @@ class GPTQMarlin24Linear(nn.Module):
return C return C
class GPTQMarlinFP8Linear(nn.Module):
"""
FP8 GPTQ-Marlin linear layer.
"""
def __init__(
self,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
) -> None:
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
qweight, scale = fp8_quantize(weight)
scale = scale.to(torch.float16)
qweight, scales = repack_fp8_for_marlin(qweight, scale)
in_features = qweight.shape[0] * MARLIN_TILE_SIZE
out_features = scales.shape[1]
_check_valid_shape(in_features=in_features, out_features=out_features)
self.qweight = qweight
self.scales = scales
self.bias = bias if bias is not None else None
self.workspace = torch.zeros(
out_features // 64 * 16, dtype=torch.int, device=qweight.device
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
A_flat = A.view(-1, A.shape[-1])
C = marlin_kernels.fp8_marlin_gemm(
A_flat,
self.qweight,
self.scales,
self.workspace,
8,
A_flat.shape[0],
self.scales.shape[1],
A_flat.shape[1],
)
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
if self.bias is not None:
C += self.bias
return C
def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
"""
Repack FP8 weights to gptq format (packed int32 elements).
"""
assert fp8_tensor.dtype == torch.float8_e4m3fn
if fp8_tensor.shape[0] % 4 != 0:
raise ValueError(
f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}"
)
# Reshape to prepare for packing
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
# Convert fp8 to uint8 (byte) representation
byte_tensor = reshaped.view(torch.uint8)
# Pack 4 uint8 values into one int32
packed = torch.zeros(
fp8_tensor.shape[0] // 4,
fp8_tensor.shape[1],
dtype=torch.int32,
device=fp8_tensor.device,
)
for i in range(4):
packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8)
return packed
def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor):
"""
Repack FP8 tensor for GPTQ-Marlin.
"""
out_features, in_features = weight.shape
# Torch linear layers weights with shape [out_features, in_features],
# GPTQ-quantized weights use [in_feateres/pack_factor, in_features],
# so transpose before packing.
qweight = pack_fp8_as_int32(weight.t())
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
repacked = marlin_kernels.gptq_marlin_repack(
qweight, perm, in_features, out_features, 8
)
scales = scale.reshape(1, 1).repeat(1, out_features)
scales = permute_scales(scales)
return repacked, scales
@dataclass @dataclass
class MarlinWeight: class MarlinWeight:
""" """

View File

@ -102,7 +102,7 @@ class PositionRotaryEmbedding(nn.Module):
max_position_embeddings=rope_scaling[ max_position_embeddings=rope_scaling[
"original_max_position_embeddings" "original_max_position_embeddings"
], ],
base=10000.0, base=base,
device=inv_freq.device, device=inv_freq.device,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
extrapolation_factor=1, extrapolation_factor=1,
@ -110,7 +110,7 @@ class PositionRotaryEmbedding(nn.Module):
beta_fast=32, beta_fast=32,
beta_slow=1, beta_slow=1,
) )
elif rope_scaling["type"] == "su": elif rope_scaling["type"] in ["su", "longrope"]:
short_factor = torch.tensor( short_factor = torch.tensor(
rope_scaling["short_factor"], dtype=torch.float32, device=device rope_scaling["short_factor"], dtype=torch.float32, device=device
) )

View File

@ -52,7 +52,7 @@ class TensorParallelHead(SuperLayer):
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
except: except:
# ...otherwise they are quantized. # ...otherwise they are quantized.
weight = weights.get_weights_col(prefix, config.quantize) weight = weights.get_weights_col(prefix)
should_gather = weights.process_group.size() > 1 should_gather = weights.process_group.size() > 1
elif weights.process_group.size() > 1: elif weights.process_group.size() > 1:
try: try:
@ -129,9 +129,7 @@ class TensorParallelColumnLinear(SuperLayer):
@classmethod @classmethod
def load_gate_up(cls, config, prefix: str, weights, bias: bool): def load_gate_up(cls, config, prefix: str, weights, bias: bool):
"""Specific method when the QKV was joined after the fact""" """Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_gate_up( weight = weights.get_weights_col_packed_gate_up(prefix)
prefix, quantize=config.quantize
)
if bias: if bias:
raise NotImplementedError("packed_gate_up only implemented without bias") raise NotImplementedError("packed_gate_up only implemented without bias")
else: else:
@ -152,7 +150,6 @@ class TensorParallelColumnLinear(SuperLayer):
"""Specific method when the QKV was joined after the fact""" """Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_qkv( weight = weights.get_weights_col_packed_qkv(
prefix, prefix,
quantize=config.quantize,
num_heads=num_heads, num_heads=num_heads,
num_key_value_heads=num_key_value_heads, num_key_value_heads=num_key_value_heads,
) )
@ -165,7 +162,7 @@ class TensorParallelColumnLinear(SuperLayer):
@classmethod @classmethod
def load(cls, config, prefix: str, weights, bias: bool): def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_weights_col(prefix, config.quantize) weight = weights.get_weights_col(prefix)
if bias: if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0) bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else: else:
@ -178,14 +175,12 @@ class TensorParallelColumnLinear(SuperLayer):
if config.quantize == "exl2": if config.quantize == "exl2":
linears = [] linears = []
for prefix in prefixes: for prefix in prefixes:
weight = weights.get_weights_col(prefix, config.quantize) weight = weights.get_weights_col(prefix)
b = weights.get_tensor(f"{prefix}.bias") if bias else None b = weights.get_tensor(f"{prefix}.bias") if bias else None
linears.append(get_linear(weight, b, config.quantize)) linears.append(get_linear(weight, b, config.quantize))
linear = LayerConcat(linears) linear = LayerConcat(linears)
else: else:
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(prefixes, dim=dim)
prefixes, quantize=config.quantize, dim=dim
)
if bias: if bias:
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
bias = torch.cat(b, dim=dim) bias = torch.cat(b, dim=dim)
@ -202,7 +197,7 @@ class TensorParallelRowLinear(SuperLayer):
@classmethod @classmethod
def load(cls, config, prefix: str, weights, bias: bool): def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0: if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process # Rank is only on the first rank process

View File

@ -11,17 +11,27 @@ from pathlib import Path
from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.bloom import BLOOMSharded from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.mpt import MPTSharded from text_generation_server.models.custom_modeling.mpt_modeling import (
MPTForCausalLM,
)
from text_generation_server.models.bloom import BloomCausalLMBatch
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.rw import RW from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.opt import OPTSharded from text_generation_server.models.custom_modeling.neox_modeling import (
from text_generation_server.models.galactica import GalacticaSharded GPTNeoxForCausalLM,
from text_generation_server.models.santacoder import SantaCoder )
from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.custom_modeling.phi_modeling import (
from text_generation_server.models.gpt_neox import GPTNeoxSharded PhiConfig,
from text_generation_server.models.phi import Phi PhiForCausalLM,
)
from text_generation_server.models.custom_modeling.t5_modeling import (
T5ForConditionalGeneration,
)
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
@ -41,9 +51,6 @@ __all__ = [
"CausalLM", "CausalLM",
"GalacticaSharded", "GalacticaSharded",
"Seq2SeqLM", "Seq2SeqLM",
"SantaCoder",
"OPTSharded",
"T5Sharded",
"get_model", "get_model",
] ]
@ -53,38 +60,65 @@ FLASH_ATTENTION = True
try: try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.flash_gpt2 import FlashGPT2 from text_generation_server.models.custom_modeling.flash_llama_modeling import (
from text_generation_server.models.flash_neox import FlashNeoXSharded FlashLlamaForCausalLM,
from text_generation_server.models.flash_llama import (
FlashLlama,
) )
from text_generation_server.models.flash_qwen2 import ( from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
FlashQwen2, FlashCohereForCausalLM,
) )
from text_generation_server.models.flash_cohere import ( from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashCohere, FlashGemmaForCausalLM,
) )
from text_generation_server.models.flash_gemma import ( from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma, FlashGemma2ForCausalLM,
) )
from text_generation_server.models.flash_gemma2 import ( from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
FlashGemma2, FlashDbrxForCausalLM,
DbrxConfig,
)
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
RWConfig,
FlashRWForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
FlashGPTNeoXForCausalLM,
) )
from text_generation_server.models.pali_gemma import ( from text_generation_server.models.pali_gemma import (
PaliGemma, PaliGemmaBatch,
) )
from text_generation_server.models.flash_santacoder import ( from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
FlashSantacoderSharded, PaliGemmaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM,
) )
from text_generation_server.models.idefics import IDEFICSSharded from text_generation_server.models.idefics import IDEFICSSharded
from text_generation_server.models.llava_next import LlavaNext from text_generation_server.models.custom_modeling.llava_next import (
from text_generation_server.models.idefics2 import Idefics2 LlavaNextForConditionalGeneration,
from text_generation_server.models.flash_mistral import FlashMistral )
from text_generation_server.models.flash_mixtral import FlashMixtral
from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 FlashSantacoderForCausalLM,
from text_generation_server.models.flash_dbrx import FlashDbrx )
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
FlashStarcoder2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
FlashMixtralForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
FlashGPT2ForCausalLM,
)
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.layers.attention import SUPPORTS_WINDOWING from text_generation_server.layers.attention import SUPPORTS_WINDOWING
except ImportError as e: except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}") logger.warning(f"Could not import Flash Attention enabled models: {e}")
@ -93,21 +127,7 @@ except ImportError as e:
if FLASH_ATTENTION: if FLASH_ATTENTION:
__all__.append(FlashCausalLM) __all__.append(FlashCausalLM)
__all__.append(FlashGPT2)
__all__.append(FlashNeoXSharded)
__all__.append(FlashRWSharded)
__all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama)
__all__.append(IDEFICSSharded) __all__.append(IDEFICSSharded)
__all__.append(FlashMistral)
__all__.append(FlashMixtral)
__all__.append(FlashDbrx)
__all__.append(FlashPhi)
__all__.append(FlashQwen2)
__all__.append(FlashStarcoder2)
__all__.append(FlashGemma)
__all__.append(FlashGemma2)
__all__.append(FlashCohere)
MAMBA_AVAILABLE = True MAMBA_AVAILABLE = True
try: try:
@ -148,6 +168,11 @@ class ModelType(enum.Enum):
"name": "Gemma", "name": "Gemma",
"url": "https://huggingface.co/google/gemma-7b", "url": "https://huggingface.co/google/gemma-7b",
} }
PALIGEMMA = {
"type": "paligemma",
"name": "PaliGemma",
"url": "https://huggingface.co/google/paligemma-3b-pt-224",
}
GEMMA2 = { GEMMA2 = {
"type": "gemma2", "type": "gemma2",
"name": "Gemma2", "name": "Gemma2",
@ -445,13 +470,16 @@ def get_model(
) )
if model_id.startswith("facebook/galactica"): if model_id.startswith("facebook/galactica"):
return GalacticaSharded( return CausalLM(
model_id, model_id=model_id,
revision, # Yes galactica is just an OPT model.
model_class=OPTForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
batch_class=GalacticaCausalLMBatch,
) )
if ( if (
@ -460,22 +488,26 @@ def get_model(
and model_id.startswith("bigcode/") and model_id.startswith("bigcode/")
): ):
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashSantacoderSharded( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashSantacoderForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
num_kv_heads=1,
) )
elif sharded: elif sharded:
raise NotImplementedError( raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
) )
else: else:
return SantaCoder( return CausalLM.fallback(
model_id, model_id=model_id,
revision, revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
@ -483,38 +515,44 @@ def get_model(
) )
if model_type == BLOOM: if model_type == BLOOM:
return BLOOMSharded( return CausalLM(
model_id, model_id=model_id,
revision, model_class=BloomForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
batch_class=BloomCausalLMBatch,
) )
elif model_type == MPT: elif model_type == MPT:
return MPTSharded( return CausalLM(
model_id, model_id=model_id,
revision, model_class=MPTForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
batch_class=CausalLMBatchKeysLast,
) )
elif model_type == GPT2: elif model_type == GPT2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
try: try:
return FlashGPT2( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashGPT2ForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
except RuntimeError as e: except RuntimeError as e:
# Lots of legacy models with various weight names. # Lots of legacy models with various weight names.
logger.warning(f"Couldn't load flash gpt2 variant: {e}") logger.warning(f"Couldn't load flash gpt2 variant: {e}")
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -525,7 +563,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -535,25 +573,28 @@ def get_model(
) )
elif model_type == GPT_NEOX: elif model_type == GPT_NEOX:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashNeoXSharded( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashGPTNeoXForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
return GPTNeoxSharded( return CausalLM(
model_id, model_id=model_id,
revision, model_class=GPTNeoxForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -564,16 +605,18 @@ def get_model(
elif model_type == PHI: elif model_type == PHI:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashPhi( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashPhiForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -588,9 +631,11 @@ def get_model(
"Legacy phi-msft is not supported with Flash Attention" "Legacy phi-msft is not supported with Flash Attention"
) )
else: else:
return Phi( return CausalLM(
model_id, model_id=model_id,
revision, model_class=PhiForCausalLM,
config_class=PhiConfig,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
@ -599,9 +644,10 @@ def get_model(
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashLlama( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashLlamaForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
@ -611,7 +657,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -621,18 +667,22 @@ def get_model(
) )
if model_type == GEMMA: if model_type == GEMMA:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashGemma( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashGemmaForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -642,18 +692,22 @@ def get_model(
) )
elif model_type == GEMMA2: elif model_type == GEMMA2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashGemma2( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashGemma2ForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -664,18 +718,20 @@ def get_model(
if model_type == COHERE: if model_type == COHERE:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCohere( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashCohereForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -686,18 +742,23 @@ def get_model(
if model_type == DBRX: if model_type == DBRX:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashDbrx( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashDbrxForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
# Dbrx works better in bfloat16.
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DbrxConfig,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -711,27 +772,41 @@ def get_model(
if FLASH_ATTENTION: if FLASH_ATTENTION:
if config_dict.get("alibi", False): if config_dict.get("alibi", False):
raise NotImplementedError("sharded is not supported for this model") raise NotImplementedError("sharded is not supported for this model")
return FlashRWSharded( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashRWForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
aliases={
"lm_head.weight": ["transformer.word_embeddings.weight"],
"transformer.word_embeddings.weight": ["lm_head.weight"],
},
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=RWConfig,
) )
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
else: else:
if FLASH_ATTENTION and not config_dict.get("alibi", False): if FLASH_ATTENTION and not config_dict.get("alibi", False):
return FlashRWSharded( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashRWForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
aliases={
"lm_head.weight": ["transformer.word_embeddings.weight"],
"transformer.word_embeddings.weight": ["lm_head.weight"],
},
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=RWConfig,
) )
else: else:
return RW( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -742,18 +817,20 @@ def get_model(
if model_type == MISTRAL: if model_type == MISTRAL:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashMistral( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashMistralForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -764,18 +841,20 @@ def get_model(
if model_type == MIXTRAL: if model_type == MIXTRAL:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashMixtral( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashMixtralForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -786,19 +865,22 @@ def get_model(
if model_type == STARCODER2: if model_type == STARCODER2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashStarcoder2( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashStarcoder2ForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError( raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2") FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
) )
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -809,17 +891,20 @@ def get_model(
if model_type == QWEN2: if model_type == QWEN2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashQwen2( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=Qwen2ForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -829,9 +914,10 @@ def get_model(
) )
if model_type == OPT: if model_type == OPT:
return OPTSharded( return CausalLM(
model_id, model_id=model_id,
revision, model_class=OPTForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
@ -839,13 +925,20 @@ def get_model(
) )
if model_type == T5: if model_type == T5:
return T5Sharded( return Seq2SeqLM(
model_id, model_id=model_id,
revision, model_class=T5ForConditionalGeneration,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
aliases={
"shared.weight": [
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
},
) )
if model_type == IDEFICS: if model_type == IDEFICS:
if FLASH_ATTENTION: if FLASH_ATTENTION:
@ -861,34 +954,45 @@ def get_model(
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == IDEFICS2: if model_type == IDEFICS2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return Idefics2( return VlmCausalLM(
model_id, model_id=model_id,
revision, model_class=Idefics2ForConditionalGeneration,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
) )
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "paligemma": if model_type == PALIGEMMA:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return PaliGemma( return VlmCausalLM(
model_id, model_id=model_id,
revision, model_class=PaliGemmaForConditionalGeneration,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
batch_class=PaliGemmaBatch,
) )
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == LLAVA_NEXT: if model_type == LLAVA_NEXT:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return LlavaNext( return VlmCausalLM(
model_id, model_class=LlavaNextForConditionalGeneration,
revision, model_id=model_id,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
@ -912,7 +1016,7 @@ def get_model(
elif quantize == "exl2": elif quantize == "exl2":
raise NotImplementedError("exl2 quantization is not supported for AutoModel") raise NotImplementedError("exl2 quantization is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -921,7 +1025,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
return Seq2SeqLM( return Seq2SeqLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -933,7 +1037,7 @@ def get_model(
auto_map = config_dict.get("auto_map", None) auto_map = config_dict.get("auto_map", None)
if trust_remote_code and auto_map is not None: if trust_remote_code and auto_map is not None:
if "AutoModelForCausalLM" in auto_map.keys(): if "AutoModelForCausalLM" in auto_map.keys():
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -942,7 +1046,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if "AutoModelForSeq2SeqLM" in auto_map.keys(): if "AutoModelForSeq2SeqLM" in auto_map.keys():
return Seq2SeqLM( return Seq2SeqLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,

View File

@ -4,22 +4,12 @@ import torch.distributed
from typing import Optional, Type from typing import Optional, Type
from transformers import ( from transformers import (
AutoTokenizer,
AutoConfig,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
from text_generation_server.models import CausalLM from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class BloomCausalLMBatch(CausalLMBatch): class BloomCausalLMBatch(CausalLMBatch):
@ -37,69 +27,6 @@ class BloomCausalLMBatch(CausalLMBatch):
class BLOOMSharded(CausalLM): class BLOOMSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
slow_but_exact=False,
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
config.pad_token_id = 3
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
prefix="transformer",
)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = BloomForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@property @property
def batch_type(self) -> Type[CausalLMBatch]: def batch_type(self) -> Type[CausalLMBatch]:
return BloomCausalLMBatch return BloomCausalLMBatch

View File

@ -1,13 +1,26 @@
import torch import torch
import time import time
import torch.distributed
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase from transformers import (
AutoConfig,
AutoTokenizer,
AutoModelForCausalLM,
PreTrainedTokenizerBase,
)
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
@ -478,10 +491,93 @@ class CausalLMBatch(Batch):
return len(self.requests) return len(self.requests)
@dataclass
class CausalLMBatchKeysLast(Batch):
keys_head_dim_last: bool = False
class CausalLM(Model): class CausalLM(Model):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
model_class,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
default_dtype=torch.float16,
trust_remote_code: bool = False,
tokenizer_class=AutoTokenizer,
config_class=AutoConfig,
batch_class=CausalLMBatch,
):
self.batch_class = batch_class
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = config_class.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = config.pad_token_id
torch.distributed.barrier(group=self.process_group)
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
weights_loader=weights_loader,
)
prefix = ""
model = model_class(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@classmethod
def fallback(
cls,
model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
speculator: Optional[str] = None, speculator: Optional[str] = None,
@ -537,7 +633,12 @@ class CausalLM(Model):
else: else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super(CausalLM, self).__init__( self = cls.__new__(
cls,
)
self.batch_class = CausalLMBatch
super().__init__(
self,
model_id=model_id, model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -545,15 +646,11 @@ class CausalLM(Model):
dtype=dtype, dtype=dtype,
device=device, device=device,
) )
return self
@property @property
def batch_type(self) -> Type[CausalLMBatch]: def batch_type(self) -> Type[CausalLMBatch]:
return CausalLMBatch return self.batch_class
def decode(self, generated_ids: List[int]) -> str:
return self.tokenizer.decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None

View File

@ -816,7 +816,7 @@ class BloomModel(BloomPreTrainedModel):
class BloomForCausalLM(BloomPreTrainedModel): class BloomForCausalLM(BloomPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__(config) super().__init__(config)
self.transformer = BloomModel(config, weights) self.transformer = BloomModel(config, weights)

View File

@ -446,7 +446,7 @@ class CLIPEncoder(nn.Module):
class CLIPTextTransformer(nn.Module): class CLIPTextTransformer(nn.Module):
def __init__(self, config: CLIPTextConfig): def __init__(self, prefix: str, config: CLIPTextConfig):
super().__init__() super().__init__()
self.config = config self.config = config
embed_dim = config.hidden_size embed_dim = config.hidden_size
@ -536,9 +536,9 @@ class CLIPTextModel(CLIPPreTrainedModel):
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"] _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
def __init__(self, config: CLIPTextConfig): def __init__(self, prefix, config: CLIPTextConfig):
super().__init__(config) super().__init__(config)
self.text_model = CLIPTextTransformer(config) self.text_model = CLIPTextTransformer(prefix, config)
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()

View File

@ -163,7 +163,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0, dim=0,
) )
@ -363,9 +362,9 @@ class CohereMLP(nn.Module):
class FlashCohereLayer(nn.Module): class FlashCohereLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix: str, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}" prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = FlashCohereAttention( self.self_attn = FlashCohereAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
@ -416,18 +415,19 @@ class FlashCohereLayer(nn.Module):
class FlashCohereModel(torch.nn.Module): class FlashCohereModel(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights prefix=f"{prefix}.embed_tokens", weights=weights
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
FlashCohereLayer( FlashCohereLayer(
prefix,
layer_id, layer_id,
config, config,
weights, weights,
@ -436,7 +436,7 @@ class FlashCohereModel(torch.nn.Module):
] ]
) )
self.norm = FastLayerNorm.load_no_bias( self.norm = FastLayerNorm.load_no_bias(
prefix="model.norm", weights=weights, eps=config.layer_norm_eps prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_eps
) )
self.gradient_checkpointing = False self.gradient_checkpointing = False
@ -486,10 +486,15 @@ class FlashCohereModel(torch.nn.Module):
class FlashCohereForCausalLM(torch.nn.Module): class FlashCohereForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.model = FlashCohereModel(config, weights) if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = FlashCohereModel(prefix, config, weights)
try: try:
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
@ -499,7 +504,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
except RuntimeError: except RuntimeError:
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="model.embed_tokens", prefix=f"{prefix}.embed_tokens",
weights=weights, weights=weights,
) )
self.logit_scale = config.logit_scale self.logit_scale = config.logit_scale

View File

@ -105,6 +105,12 @@ class DbrxFFNConfig(PretrainedConfig):
class DbrxConfig(PretrainedConfig): class DbrxConfig(PretrainedConfig):
attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "n_heads",
"num_hidden_layers": "n_layers",
}
def __init__( def __init__(
self, self,
d_model: int = 2048, d_model: int = 2048,
@ -157,6 +163,12 @@ class DbrxConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
@property
def num_key_value_heads(self):
# We can't use the attribute map, since this the number of KV
# heads is not top-level.
return self.attn_config.kv_n_heads
def promote_scalar(x: torch.Tensor) -> torch.Tensor: def promote_scalar(x: torch.Tensor) -> torch.Tensor:
return x.view(1) if len(x.size()) == 0 else x return x.view(1) if len(x.size()) == 0 else x
@ -593,9 +605,9 @@ class DenseMoE(nn.Module):
class DbrxLayer(nn.Module): class DbrxLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix: str, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"transformer.blocks.{layer_id}" prefix = f"{prefix}.blocks.{layer_id}"
self.attn = DbrxNormAttentionNorm( self.attn = DbrxNormAttentionNorm(
prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights
@ -637,16 +649,17 @@ class DbrxLayer(nn.Module):
class DbrxModel(torch.nn.Module): class DbrxModel(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="transformer.wte", weights=weights prefix=f"{prefix}.wte", weights=weights
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
DbrxLayer( DbrxLayer(
prefix,
layer_id, layer_id,
config, config,
weights, weights,
@ -655,7 +668,7 @@ class DbrxModel(torch.nn.Module):
] ]
) )
self.norm = FastLayerNorm.load_no_bias( self.norm = FastLayerNorm.load_no_bias(
prefix="transformer.norm_f", weights=weights, eps=1e-5 prefix=f"{prefix}.norm_f", weights=weights, eps=1e-5
) )
self.head_size = self.layers[0].attn.self_attn.head_size self.head_size = self.layers[0].attn.self_attn.head_size
@ -702,10 +715,15 @@ class DbrxModel(torch.nn.Module):
class FlashDbrxForCausalLM(torch.nn.Module): class FlashDbrxForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.model = DbrxModel(config, weights) if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
self.model = DbrxModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head",

View File

@ -102,7 +102,7 @@ class Gemma2Config(PretrainedConfig):
class Gemma2FastRMSNorm(FastRMSNorm): class Gemma2FastRMSNorm(FastRMSNorm):
@classmethod @classmethod
def load(cls, prefix, weights, eps=1e-6): def load(cls, prefix: str, weights, eps=1e-6):
dtype = weights.dtype dtype = weights.dtype
weights.dtype = torch.float32 weights.dtype = torch.float32
weight = weights.get_tensor(f"{prefix}.weight") + 1 weight = weights.get_tensor(f"{prefix}.weight") + 1
@ -123,7 +123,7 @@ class Gemma2FastRMSNorm(FastRMSNorm):
return hidden_states.to(self.dtype), residual return hidden_states.to(self.dtype), residual
def load_attention(config, prefix, weights): def load_attention(config, prefix: str, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) return _load_gqa(config, prefix, weights)
else: else:
@ -141,7 +141,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0, dim=0,
) )
@ -305,7 +304,7 @@ class Gemma2MLP(nn.Module):
class FlashGemma2Layer(nn.Module): class FlashGemma2Layer(nn.Module):
def __init__(self, prefix, config, weights, causal: bool, is_sliding: bool): def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool):
super().__init__() super().__init__()
self.self_attn = FlashGemma2Attention( self.self_attn = FlashGemma2Attention(
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
@ -376,7 +375,7 @@ class FlashGemma2Layer(nn.Module):
class FlashGemma2Model(torch.nn.Module): class FlashGemma2Model(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool): def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
@ -442,7 +441,7 @@ class FlashGemma2Model(torch.nn.Module):
class FlashGemma2ForCausalLM(torch.nn.Module): class FlashGemma2ForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool): def __init__(self, prefix: str, config, weights, *, causal: bool = True):
super().__init__() super().__init__()
embed_norm = config.hidden_size**0.5 embed_norm = config.hidden_size**0.5

View File

@ -102,7 +102,7 @@ class GemmaConfig(PretrainedConfig):
class GemmaFastRMSNorm(FastRMSNorm): class GemmaFastRMSNorm(FastRMSNorm):
@classmethod @classmethod
def load(cls, prefix, weights, eps=1e-6): def load(cls, prefix: str, weights, eps=1e-6):
dtype = weights.dtype dtype = weights.dtype
weights.dtype = torch.float32 weights.dtype = torch.float32
weight = weights.get_tensor(f"{prefix}.weight") + 1 weight = weights.get_tensor(f"{prefix}.weight") + 1
@ -123,7 +123,7 @@ class GemmaFastRMSNorm(FastRMSNorm):
return hidden_states.to(self.dtype), residual return hidden_states.to(self.dtype), residual
def load_attention(config, prefix, weights): def load_attention(config, prefix: str, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) return _load_gqa(config, prefix, weights)
else: else:
@ -141,7 +141,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0, dim=0,
) )
@ -261,7 +260,7 @@ class FlashGemmaAttention(torch.nn.Module):
class GemmaMLP(nn.Module): class GemmaMLP(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
act = config.hidden_act act = config.hidden_act
self.act = ( self.act = (
@ -299,7 +298,7 @@ class GemmaMLP(nn.Module):
class FlashGemmaLayer(nn.Module): class FlashGemmaLayer(nn.Module):
def __init__(self, prefix, config, weights, causal: bool): def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__() super().__init__()
self.self_attn = FlashGemmaAttention( self.self_attn = FlashGemmaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal
@ -354,7 +353,7 @@ class FlashGemmaLayer(nn.Module):
class FlashGemmaModel(torch.nn.Module): class FlashGemmaModel(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool): def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
@ -419,7 +418,7 @@ class FlashGemmaModel(torch.nn.Module):
class FlashGemmaForCausalLM(torch.nn.Module): class FlashGemmaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool): def __init__(self, prefix: str, config, weights, *, causal: bool = True):
super().__init__() super().__init__()
embed_norm = config.hidden_size**0.5 embed_norm = config.hidden_size**0.5

View File

@ -61,7 +61,6 @@ def _load_qkv_gptq(config, prefix: str, weights):
# Weights # Weights
weight = weights.get_weights_col_packed_qkv( weight = weights.get_weights_col_packed_qkv(
f"{prefix}.c_attn", f"{prefix}.c_attn",
config.quantize,
config.num_attention_heads, config.num_attention_heads,
config.num_attention_heads, config.num_attention_heads,
) )
@ -137,7 +136,7 @@ def load_row(config, prefix: str, weights, bias: bool):
"""load_row, but with transposed weight matrices.""" """load_row, but with transposed weight matrices."""
if config.quantize == "gptq": if config.quantize == "gptq":
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) weight = weights.get_weights_row(prefix)
else: else:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
@ -155,9 +154,7 @@ def load_row(config, prefix: str, weights, bias: bool):
def load_col(config, prefix: str, weights, bias: bool): def load_col(config, prefix: str, weights, bias: bool):
"""load_col, but with transposed weight matrices.""" """load_col, but with transposed weight matrices."""
if config.quantize == "gptq": if config.quantize == "gptq":
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col([prefix], dim=1)
[prefix], quantize=config.quantize, dim=1
)
else: else:
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
@ -261,7 +258,7 @@ class FlashGPT2Attention(torch.nn.Module):
class GPT2MLP(nn.Module): class GPT2MLP(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
act = config.activation_function act = config.activation_function
self.act = ( self.act = (
@ -298,7 +295,7 @@ class GPT2MLP(nn.Module):
class FlashGPT2Layer(nn.Module): class FlashGPT2Layer(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.self_attn = FlashGPT2Attention( self.self_attn = FlashGPT2Attention(
prefix=f"{prefix}.attn", config=config, weights=weights prefix=f"{prefix}.attn", config=config, weights=weights
@ -350,7 +347,7 @@ class FlashGPT2Layer(nn.Module):
class FlashGPT2Model(torch.nn.Module): class FlashGPT2Model(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
@ -414,7 +411,7 @@ class FlashGPT2Model(torch.nn.Module):
class FlashGPT2ForCausalLM(torch.nn.Module): class FlashGPT2ForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(

View File

@ -54,7 +54,7 @@ if SYSTEM == "rocm":
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
def load_attention(config, prefix, weights, layer_id): def load_attention(config, prefix: str, weights, layer_id):
# Only defined in granite. # Only defined in granite.
bias = getattr(config, "attention_bias", False) bias = getattr(config, "attention_bias", False)
head_size = config.hidden_size // config.num_attention_heads head_size = config.hidden_size // config.num_attention_heads
@ -467,7 +467,7 @@ class FlashLlamaModel(torch.nn.Module):
class FlashLlamaForCausalLM(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(

View File

@ -248,7 +248,7 @@ class MistralAttention(torch.nn.Module):
class MistralMLP(nn.Module): class MistralMLP(nn.Module):
def __init__(self, prefix, config, weights, layer_id): def __init__(self, prefix: str, config, weights, layer_id):
super().__init__() super().__init__()
self.hidden_act = config.hidden_act self.hidden_act = config.hidden_act
self.act = ( self.act = (
@ -328,7 +328,7 @@ class MistralMLP(nn.Module):
class MistralLayer(nn.Module): class MistralLayer(nn.Module):
def __init__(self, prefix, config, weights, layer_id): def __init__(self, prefix: str, config, weights, layer_id):
super().__init__() super().__init__()
self.self_attn = MistralAttention( self.self_attn = MistralAttention(
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
@ -392,7 +392,7 @@ class MistralLayer(nn.Module):
class MistralModel(torch.nn.Module): class MistralModel(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
@ -462,7 +462,7 @@ class MistralModel(torch.nn.Module):
class FlashMistralForCausalLM(torch.nn.Module): class FlashMistralForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, name=None): def __init__(self, prefix: str, config, weights, name=None):
if name is None: if name is None:
name = "model" name = "model"
super().__init__() super().__init__()

View File

@ -116,7 +116,7 @@ def promote_scalar(x: torch.Tensor) -> torch.Tensor:
return x.view(1) if len(x.size()) == 0 else x return x.view(1) if len(x.size()) == 0 else x
def load_attention(config, prefix, weights): def load_attention(config, prefix: str, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) return _load_gqa(config, prefix, weights)
else: else:
@ -135,7 +135,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0, dim=0,
) )
@ -155,7 +154,7 @@ def _load_gqa(config, prefix: str, weights):
) )
def _load_experts(config, prefix, mat, weights): def _load_experts(config, prefix: str, mat, weights):
if config.quantize is not None: if config.quantize is not None:
raise NotImplementedError("Mixtral does not support weight quantization yet.") raise NotImplementedError("Mixtral does not support weight quantization yet.")
@ -475,7 +474,7 @@ class DenseMoE(nn.Module):
class MixtralLayer(nn.Module): class MixtralLayer(nn.Module):
def __init__(self, prefix, layer_id, config, weights): def __init__(self, prefix: str, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"{prefix}.layers.{layer_id}" prefix = f"{prefix}.layers.{layer_id}"
@ -536,7 +535,7 @@ class MixtralLayer(nn.Module):
class MixtralModel(torch.nn.Module): class MixtralModel(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
@ -610,7 +609,7 @@ class MixtralModel(torch.nn.Module):
class FlashMixtralForCausalLM(torch.nn.Module): class FlashMixtralForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.model = MixtralModel(prefix, config, weights) self.model = MixtralModel(prefix, config, weights)

View File

@ -48,7 +48,7 @@ from text_generation_server.layers.rotary import (
def load_row(config, prefix: str, weights, bias: bool): def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0: if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process # Rank is only on the first rank process
@ -64,7 +64,7 @@ def load_row(config, prefix: str, weights, bias: bool):
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
weight = weights.get_multi_weights_col([prefix], quantize=config.quantize, dim=0) weight = weights.get_multi_weights_col([prefix], dim=0)
if isinstance(weight, torch.Tensor): if isinstance(weight, torch.Tensor):
# Only on non quantized versions # Only on non quantized versions
weight = ( weight = (
@ -305,12 +305,12 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel):
class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.embed_in = TensorParallelEmbedding( self.embed_in = TensorParallelEmbedding(
prefix="gpt_neox.embed_in", weights=weights prefix=f"{prefix}.embed_in", weights=weights
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
@ -320,7 +320,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
] ]
) )
self.final_layer_norm = FastLayerNorm.load( self.final_layer_norm = FastLayerNorm.load(
prefix="gpt_neox.final_layer_norm", prefix=f"{prefix}.final_layer_norm",
weights=weights, weights=weights,
eps=config.layer_norm_eps, eps=config.layer_norm_eps,
) )
@ -370,9 +370,15 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__(config) super().__init__(config)
self.gpt_neox = FlashGPTNeoXModel(config, weights)
if not prefix:
prefix = "gpt_neox"
else:
prefix = f"{prefix}.gpt_neox"
self.gpt_neox = FlashGPTNeoXModel(prefix, config, weights)
self.embed_out = SpeculativeHead.load( self.embed_out = SpeculativeHead.load(
config, prefix="embed_out", weights=weights config, prefix="embed_out", weights=weights

View File

@ -85,7 +85,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0, dim=0,
) )
@ -258,9 +257,9 @@ class PhiMLP(nn.Module):
class FlashPhiLayer(nn.Module): class FlashPhiLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix: str, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}" prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = FlashPhiAttention( self.self_attn = FlashPhiAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
@ -307,18 +306,19 @@ class FlashPhiLayer(nn.Module):
class FlashPhiModel(torch.nn.Module): class FlashPhiModel(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights prefix=f"{prefix}.embed_tokens", weights=weights
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
FlashPhiLayer( FlashPhiLayer(
prefix,
layer_id, layer_id,
config, config,
weights, weights,
@ -378,10 +378,15 @@ class FlashPhiModel(torch.nn.Module):
class FlashPhiForCausalLM(torch.nn.Module): class FlashPhiForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.model = FlashPhiModel(config, weights) if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = FlashPhiModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head",

View File

@ -203,9 +203,9 @@ class Qwen2MLP(nn.Module):
class Qwen2Layer(nn.Module): class Qwen2Layer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}" prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = Qwen2Attention( self.self_attn = Qwen2Attention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
@ -260,17 +260,18 @@ class Qwen2Layer(nn.Module):
class Qwen2Model(torch.nn.Module): class Qwen2Model(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights prefix=f"{prefix}.embed_tokens", weights=weights
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
Qwen2Layer( Qwen2Layer(
prefix,
layer_id, layer_id,
config, config,
weights, weights,
@ -279,7 +280,7 @@ class Qwen2Model(torch.nn.Module):
] ]
) )
self.norm = FastRMSNorm.load( self.norm = FastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
) )
self.gradient_checkpointing = False self.gradient_checkpointing = False
@ -331,10 +332,15 @@ class Qwen2Model(torch.nn.Module):
class Qwen2ForCausalLM(torch.nn.Module): class Qwen2ForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.model = Qwen2Model(config, weights) if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = Qwen2Model(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head",

View File

@ -23,7 +23,7 @@ from text_generation_server.layers.attention import (
def load_row(config, prefix: str, weights, bias: bool): def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0: if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process # Rank is only on the first rank process
@ -42,6 +42,7 @@ class RWConfig(PretrainedConfig):
attribute_map = { attribute_map = {
"num_hidden_layers": "n_layer", "num_hidden_layers": "n_layer",
"num_attention_heads": "n_head", "num_attention_heads": "n_head",
"num_key_value_heads": "n_head_kv",
} }
def __init__( def __init__(
@ -127,7 +128,7 @@ class FlashRWAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
config, config,
prefix, prefix: str,
weights, weights,
): ):
super().__init__() super().__init__()
@ -236,7 +237,7 @@ class FlashRWLargeAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
config, config,
prefix, prefix: str,
weights, weights,
): ):
super().__init__() super().__init__()
@ -358,7 +359,7 @@ class FlashRWLargeAttention(torch.nn.Module):
class FlashMLP(nn.Module): class FlashMLP(nn.Module):
def __init__(self, config, prefix, weights): def __init__(self, config, prefix: str, weights):
super().__init__() super().__init__()
self.act = torch.nn.functional.gelu self.act = torch.nn.functional.gelu
@ -380,6 +381,7 @@ class FlashRWLayer(nn.Module):
def __init__( def __init__(
self, self,
layer_id, layer_id,
prefix: str,
config, config,
weights, weights,
): ):
@ -388,7 +390,7 @@ class FlashRWLayer(nn.Module):
parallel_attn = config.parallel_attn parallel_attn = config.parallel_attn
self.parallel_attn = parallel_attn self.parallel_attn = parallel_attn
prefix = f"transformer.h.{layer_id}" prefix = f"{prefix}.h.{layer_id}"
self.input_layernorm = FastLayerNorm.load( self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm", prefix=f"{prefix}.input_layernorm",
@ -479,7 +481,7 @@ class FlashRWLayer(nn.Module):
class FlashRWLayerNorm(nn.Module): class FlashRWLayerNorm(nn.Module):
def __init__(self, config, prefix, weights): def __init__(self, config, prefix: str, weights):
super().__init__() super().__init__()
self.num_ln = config.num_ln_in_parallel_attn self.num_ln = config.num_ln_in_parallel_attn
@ -518,9 +520,9 @@ class FlashRWLayerNorm(nn.Module):
class FlashRWLargeLayer(nn.Module): class FlashRWLargeLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, layer_id, prefix: str, config, weights):
super().__init__() super().__init__()
prefix = f"transformer.h.{layer_id}" prefix = f"{prefix}.h.{layer_id}"
self.ln_layer = FlashRWLayerNorm(config, prefix, weights) self.ln_layer = FlashRWLayerNorm(config, prefix, weights)
@ -580,18 +582,18 @@ class FlashRWPreTrainedModel(PreTrainedModel):
class FlashRWModel(FlashRWPreTrainedModel): class FlashRWModel(FlashRWPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.word_embeddings = TensorParallelEmbedding( self.word_embeddings = TensorParallelEmbedding(
prefix="transformer.word_embeddings", weights=weights prefix=f"{prefix}.word_embeddings", weights=weights
) )
if config.new_decoder_architecture: if config.new_decoder_architecture:
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
FlashRWLargeLayer(layer_id, config, weights) FlashRWLargeLayer(layer_id, prefix, config, weights)
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
@ -599,14 +601,14 @@ class FlashRWModel(FlashRWPreTrainedModel):
else: else:
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
FlashRWLayer(layer_id, config, weights) FlashRWLayer(layer_id, prefix, config, weights)
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.cache_size = self.h[0].self_attention.num_heads_kv self.cache_size = self.h[0].self_attention.num_heads_kv
self.ln_f = FastLayerNorm.load( self.ln_f = FastLayerNorm.load(
prefix="transformer.ln_f", prefix=f"{prefix}.ln_f",
weights=weights, weights=weights,
eps=config.layer_norm_epsilon, eps=config.layer_norm_epsilon,
) )
@ -653,10 +655,15 @@ class FlashRWModel(FlashRWPreTrainedModel):
class FlashRWForCausalLM(FlashRWPreTrainedModel): class FlashRWForCausalLM(FlashRWPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__(config) super().__init__(config)
self.transformer = FlashRWModel(config, weights) if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
self.transformer = FlashRWModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights) self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights)

View File

@ -17,6 +17,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
get_linear, get_linear,
) )
from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
@ -81,11 +82,13 @@ def _load_multi_mqa_gptq(
qzeros = torch.cat([q_tensor, kv_tensor], dim=1) qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
qzeros = qzeros.to(device=weights.device) qzeros = qzeros.to(device=weights.device)
gptq_params = weights._get_gptq_params() loader = weights.weights_loader
if gptq_params.quant_method == "gptq": assert isinstance(loader, GPTQWeightsLoader)
loader._get_gptq_params(weights)
if loader.quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device) g_idx = g_idx.to(device=weights.device)
elif gptq_params.quant_method == "awq": elif loader.quant_method == "awq":
g_idx = None g_idx = None
from text_generation_server.layers.awq.conversion_utils import ( from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq, fast_awq_to_gptq,
@ -100,8 +103,8 @@ def _load_multi_mqa_gptq(
qzeros=qzeros, qzeros=qzeros,
scales=scales, scales=scales,
g_idx=g_idx, g_idx=g_idx,
bits=gptq_params.bits, bits=loader.bits,
groupsize=gptq_params.groupsize, groupsize=loader.groupsize,
use_exllama=HAS_EXLLAMA, use_exllama=HAS_EXLLAMA,
) )
@ -197,9 +200,7 @@ def load_col(config, prefix: str, weights, bias: bool):
if config.transpose: if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
else: else:
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col([prefix], dim=0)
[prefix], quantize=config.quantize, dim=0
)
if bias: if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0) bias = weights.get_sharded(f"{prefix}.bias", dim=0)
@ -212,7 +213,7 @@ def load_row(config, prefix: str, weights, bias: bool):
if config.transpose: if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
else: else:
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0: if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process # Rank is only on the first rank process
@ -346,16 +347,16 @@ class MLP(nn.Module):
class Block(nn.Module): class Block(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix: str, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"transformer.h.{layer_id}" prefix = f"{prefix}.h.{layer_id}"
self.ln_1 = FastLayerNorm.load( self.ln_1 = FastLayerNorm.load(
prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
) )
self.ln_2 = FastLayerNorm.load( self.ln_2 = FastLayerNorm.load(
prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon
) )
self.attn = FlashMQAttention( self.self_attn = FlashMQAttention(
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
config=config, config=config,
weights=weights, weights=weights,
@ -378,7 +379,7 @@ class Block(nn.Module):
max_s, max_s,
): ):
hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn( hidden_states = self.self_attn(
hidden_states, hidden_states,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
@ -396,25 +397,26 @@ class Block(nn.Module):
class FlashSantacoderModel(nn.Module): class FlashSantacoderModel(nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.config = config self.config = config
self.process_group = weights.process_group self.process_group = weights.process_group
self.wte = TensorParallelEmbedding( self.wte = TensorParallelEmbedding(
prefix="transformer.wte", prefix=f"{prefix}.wte",
weights=weights, weights=weights,
reduce=False, reduce=False,
) )
self.wpe = TensorParallelEmbedding( self.wpe = TensorParallelEmbedding(
prefix="transformer.wpe", prefix=f"{prefix}.wpe",
weights=weights, weights=weights,
reduce=False, reduce=False,
) )
self.h = nn.ModuleList( self.layers = nn.ModuleList(
[ [
Block( Block(
prefix,
layer_id, layer_id,
config, config,
weights, weights,
@ -426,8 +428,8 @@ class FlashSantacoderModel(nn.Module):
prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon
) )
self.head_size = self.h[0].attn.head_size self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.h[0].attn.num_heads self.num_heads = self.layers[0].self_attn.num_heads
def forward( def forward(
self, self,
@ -446,7 +448,7 @@ class FlashSantacoderModel(nn.Module):
torch.distributed.all_reduce(hidden_states, group=self.process_group) torch.distributed.all_reduce(hidden_states, group=self.process_group)
residual = None residual = None
for i, layer in enumerate(self.h): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
residual, residual,
@ -464,11 +466,18 @@ class FlashSantacoderModel(nn.Module):
class FlashSantacoderForCausalLM(nn.Module): class FlashSantacoderForCausalLM(nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.transformer = FlashSantacoderModel(config, weights)
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
config.transpose = config.architectures[0].startswith("GPT2")
self.model = FlashSantacoderModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="transformer.wte", weights=weights config, prefix=f"{prefix}.wte", weights=weights
) )
def forward( def forward(
@ -485,7 +494,7 @@ class FlashSantacoderForCausalLM(nn.Module):
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer( hidden_states = self.model(
input_ids, input_ids,
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,

View File

@ -126,7 +126,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0, dim=0,
) )
@ -417,14 +416,14 @@ class Starcoder2Layer(nn.Module):
class Starcoder2Model(torch.nn.Module): class Starcoder2Model(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights prefix=f"{prefix}.embed_tokens", weights=weights
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
@ -437,7 +436,7 @@ class Starcoder2Model(torch.nn.Module):
] ]
) )
self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load( self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
prefix="model.norm", weights=weights, eps=config.norm_epsilon prefix=f"{prefix}.norm", weights=weights, eps=config.norm_epsilon
) )
self.gradient_checkpointing = False self.gradient_checkpointing = False
@ -489,10 +488,15 @@ class Starcoder2Model(torch.nn.Module):
class FlashStarcoder2ForCausalLM(torch.nn.Module): class FlashStarcoder2ForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.model = Starcoder2Model(config, weights) if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = Starcoder2Model(prefix, config, weights)
try: try:
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
@ -502,7 +506,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
except RuntimeError: except RuntimeError:
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="model.embed_tokens", prefix=f"{prefix}.embed_tokens",
weights=weights, weights=weights,
) )

View File

@ -136,7 +136,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
self.config = config self.config = config
config.text_config.quantize = config.quantize config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator config.text_config.speculator = config.speculator
self.language_model = load_text_model( self.text_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model", prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config, config=config.text_config,
weights=weights, weights=weights,
@ -180,7 +180,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
image_sizes: Optional[torch.LongTensor] = None, image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
): ):
inputs_embeds = self.language_model.embed_tokens(input_ids) inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0: if pixel_values is not None and len(pixel_values) > 0:
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum() # num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
@ -269,7 +269,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
input_ids, inputs_embeds, image_features input_ids, inputs_embeds, image_features
) )
hidden_states = self.language_model.model( hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
@ -283,5 +283,5 @@ class LlavaNextForConditionalGeneration(nn.Module):
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.language_model.lm_head(hidden_states) logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits return logits, speculative_logits

View File

@ -783,7 +783,7 @@ class MPTPreTrainedModel(PreTrainedModel):
class MPTModel(MPTPreTrainedModel): class MPTModel(MPTPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
# config._validate_config() # config._validate_config()
super().__init__(config) super().__init__(config)
self.world_size = weights.process_group.size() self.world_size = weights.process_group.size()
@ -809,13 +809,13 @@ class MPTModel(MPTPreTrainedModel):
f"Requested norm type ({config.norm_type}) is not implemented within this repo." f"Requested norm type ({config.norm_type}) is not implemented within this repo."
) )
self.wte = TensorParallelEmbedding("transformer.wte", weights) self.wte = TensorParallelEmbedding(f"{prefix}.wte", weights)
if not self.alibi: if not self.alibi:
self.wpe = TensorParallelEmbedding("transformer.wpe", weights) self.wpe = TensorParallelEmbedding(f"{prefix}.wpe", weights)
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [
MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights) MPTBlock(config, prefix=f"{prefix}.blocks.{i}", weights=weights)
for i in range(config.n_layers) for i in range(config.n_layers)
] ]
) )
@ -1085,13 +1085,19 @@ class MPTModel(MPTPreTrainedModel):
class MPTForCausalLM(MPTPreTrainedModel): class MPTForCausalLM(MPTPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__(config) super().__init__(config)
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
if not config.tie_word_embeddings: if not config.tie_word_embeddings:
raise ValueError("MPTForCausalLM only supports tied word embeddings") raise ValueError("MPTForCausalLM only supports tied word embeddings")
self.transformer = MPTModel(config, weights) self.transformer = MPTModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="transformer.wte", weights=weights config, prefix=f"{prefix}.wte", weights=weights
) )
self.logit_scale = None self.logit_scale = None
if config.logit_scale is not None: if config.logit_scale is not None:

View File

@ -404,24 +404,24 @@ class GPTNeoXMLP(nn.Module):
class GPTNeoXLayer(nn.Module): class GPTNeoXLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, layer_id, prefix: str, config, weights):
super().__init__() super().__init__()
self.use_parallel_residual = config.use_parallel_residual self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm.load( self.input_layernorm = nn.LayerNorm.load(
prefix=f"gpt_neox.layers.{layer_id}.input_layernorm", prefix=f"{prefix}.layers.{layer_id}.input_layernorm",
weights=weights, weights=weights,
eps=config.layer_norm_eps, eps=config.layer_norm_eps,
) )
self.post_attention_layernorm = nn.LayerNorm.load( self.post_attention_layernorm = nn.LayerNorm.load(
prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm", prefix=f"{prefix}.layers.{layer_id}.post_attention_layernorm",
weights=weights, weights=weights,
eps=config.layer_norm_eps, eps=config.layer_norm_eps,
) )
self.attention = GPTNeoXAttention( self.attention = GPTNeoXAttention(
config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights config, prefix=f"{prefix}.layers.{layer_id}.attention", weights=weights
) )
self.mlp = GPTNeoXMLP( self.mlp = GPTNeoXMLP(
config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights config, prefix=f"{prefix}.layers.{layer_id}.mlp", weights=weights
) )
def forward( def forward(
@ -472,23 +472,23 @@ class GPTNeoXLayer(nn.Module):
class GPTNeoXModel(GPTNeoXPreTrainedModel): class GPTNeoXModel(GPTNeoXPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.embed_in = TensorParallelEmbedding( self.embed_in = TensorParallelEmbedding(
prefix="gpt_neox.embed_in", weights=weights prefix=f"{prefix}.embed_in", weights=weights
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
GPTNeoXLayer(layer_id, config, weights) GPTNeoXLayer(layer_id, prefix, config, weights)
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.final_layer_norm = nn.LayerNorm.load( self.final_layer_norm = nn.LayerNorm.load(
prefix="gpt_neox.final_layer_norm", prefix=f"{prefix}.final_layer_norm",
weights=weights, weights=weights,
eps=config.layer_norm_eps, eps=config.layer_norm_eps,
) )
@ -640,9 +640,15 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__(config) super().__init__(config)
self.gpt_neox = GPTNeoXModel(config, weights)
if not prefix:
prefix = "gpt_neox"
else:
prefix = f"{prefix}.gpt_neox"
self.gpt_neox = GPTNeoXModel(prefix, config, weights)
self.embed_out = SpeculativeHead.load( self.embed_out = SpeculativeHead.load(
config, prefix="embed_out", weights=weights config, prefix="embed_out", weights=weights
) )

View File

@ -94,11 +94,11 @@ class OPTLearnedPositionalEmbedding(nn.Module):
This module learns positional embeddings up to a fixed maximum size. This module learns positional embeddings up to a fixed maximum size.
""" """
def __init__(self, weights): def __init__(self, prefix: str, weights):
super().__init__() super().__init__()
self.offset = 2 self.offset = 2
self.weight = nn.Parameter( self.weight = nn.Parameter(
weights.get_tensor("model.decoder.embed_positions.weight") weights.get_tensor(f"{prefix}.decoder.embed_positions.weight")
) )
def forward( def forward(
@ -311,11 +311,11 @@ class OPTAttention(nn.Module):
class OPTDecoderLayer(nn.Module): class OPTDecoderLayer(nn.Module):
def __init__(self, layer_id: int, config: OPTConfig, weights): def __init__(self, layer_id: int, prefix: str, config: OPTConfig, weights):
super().__init__() super().__init__()
self.process_group = weights.process_group self.process_group = weights.process_group
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
prefix = f"model.decoder.layers.{layer_id}" prefix = f"{prefix}.decoder.layers.{layer_id}"
self.self_attn = OPTAttention( self.self_attn = OPTAttention(
config, config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
@ -429,7 +429,7 @@ class OPTPreTrainedModel(PreTrainedModel):
class OPTDecoder(OPTPreTrainedModel): class OPTDecoder(OPTPreTrainedModel):
def __init__(self, config: OPTConfig, weights): def __init__(self, prefix: str, config: OPTConfig, weights):
super().__init__(config) super().__init__(config)
self.dropout = config.dropout self.dropout = config.dropout
self.layerdrop = config.layerdrop self.layerdrop = config.layerdrop
@ -438,20 +438,26 @@ class OPTDecoder(OPTPreTrainedModel):
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="model.decoder.embed_tokens", weights=weights prefix=f"{prefix}.decoder.embed_tokens", weights=weights
) )
self.embed_positions = OPTLearnedPositionalEmbedding(weights) self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights)
if config.word_embed_proj_dim != config.hidden_size: if config.word_embed_proj_dim != config.hidden_size:
self.project_out = FastLinear.load( self.project_out = FastLinear.load(
config, prefix="model.decoder.project_out", weights=weights, bias=False config,
prefix=f"{prefix}.decoder.project_out",
weights=weights,
bias=False,
) )
else: else:
self.project_out = None self.project_out = None
if config.word_embed_proj_dim != config.hidden_size: if config.word_embed_proj_dim != config.hidden_size:
self.project_in = FastLinear.load( self.project_in = FastLinear.load(
config, prefix="model.decoder.project_in", weights=weights, bias=False config,
prefix=f"{prefix}.decoder.project_in",
weights=weights,
bias=False,
) )
else: else:
self.project_in = None self.project_in = None
@ -461,14 +467,14 @@ class OPTDecoder(OPTPreTrainedModel):
# see https://github.com/facebookresearch/metaseq/pull/164 # see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm: if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm.load( self.final_layer_norm = nn.LayerNorm.load(
prefix="model.decoder.final_layer_norm", weights=weights, eps=EPS prefix=f"{prefix}.decoder.final_layer_norm", weights=weights, eps=EPS
) )
else: else:
self.final_layer_norm = None self.final_layer_norm = None
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
OPTDecoderLayer(layer_id, config, weights) OPTDecoderLayer(layer_id, prefix, config, weights)
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
@ -686,9 +692,9 @@ class OPTDecoder(OPTPreTrainedModel):
class OPTModel(OPTPreTrainedModel): class OPTModel(OPTPreTrainedModel):
def __init__(self, config: OPTConfig, weights): def __init__(self, prefix: str, config: OPTConfig, weights):
super().__init__(config) super().__init__(config)
self.decoder = OPTDecoder(config, weights) self.decoder = OPTDecoder(prefix, config, weights)
# Initialize weights and apply final processing # Initialize weights and apply final processing
def forward( def forward(
@ -743,13 +749,18 @@ class OPTModel(OPTPreTrainedModel):
class OPTForCausalLM(OPTPreTrainedModel): class OPTForCausalLM(OPTPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__(config) super().__init__(config)
self.model = OPTModel(config, weights) if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = OPTModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="model.decoder.embed_tokens", weights=weights config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights
) )
def forward( def forward(

View File

@ -248,16 +248,16 @@ class PhiBlock(nn.Module):
# PhiModel implements the embedding layer and the transformer blocks. # PhiModel implements the embedding layer and the transformer blocks.
class PhiModel(nn.Module): class PhiModel(nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.tp_rank = weights.process_group.rank() self.tp_rank = weights.process_group.rank()
self.tp_world_size = weights.process_group.size() self.tp_world_size = weights.process_group.size()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="transformer.embd.wte", weights=weights prefix=f"{prefix}.embd.wte", weights=weights
) )
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [
PhiBlock(f"transformer.h.{layer_id}", config, weights) PhiBlock(f"{prefix}.h.{layer_id}", config, weights)
for layer_id in range(config.n_layer) for layer_id in range(config.n_layer)
] ]
) )
@ -289,9 +289,15 @@ class PhiModel(nn.Module):
# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object. # PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.
class PhiForCausalLM(torch.nn.Module): class PhiForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.model = PhiModel(config, weights)
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
self.model = PhiModel(prefix, config, weights)
self.lm_head = PhiCausalLMHead(config, weights) self.lm_head = PhiCausalLMHead(config, weights)
def forward( def forward(

View File

@ -10,7 +10,12 @@ import numpy as np
from loguru import logger from loguru import logger
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import (
PreTrainedTokenizerBase,
AutoConfig,
AutoTokenizer,
GenerationConfig,
)
from typing import Iterable, Optional, Tuple, List, Type, Dict from typing import Iterable, Optional, Tuple, List, Type, Dict
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
@ -21,6 +26,12 @@ from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.dist import RANK from text_generation_server.utils.dist import RANK
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
hub,
)
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
Tokens, Tokens,
@ -39,6 +50,7 @@ from text_generation_server.models.globals import (
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.dist import MEMORY_FRACTION
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
from text_generation_server.utils.import_utils import ( from text_generation_server.utils.import_utils import (
@ -799,29 +811,123 @@ class FlashCausalLMBatch(Batch):
return len(self.requests) return len(self.requests)
ADAPTER_LAYERS = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class FlashCausalLM(Model): class FlashCausalLM(Model):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
model: torch.nn.Module, model_class,
tokenizer: PreTrainedTokenizerBase, revision: Optional[str] = None,
num_layers: int, quantize: Optional[str] = None,
num_kv_heads: int, speculator: Optional[str] = None,
head_size: int, dtype: Optional[torch.dtype] = None,
dtype: torch.dtype, trust_remote_code: bool = False,
device: torch.device, lora_adapter_ids: Optional[list] = [],
rank: int = 0, tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
world_size: int = 1, config_class: PreTrainedTokenizerBase = AutoConfig,
sliding_window: Optional[int] = None, default_dtype=torch.float16,
aliases=None,
# Used for Santacoder override of config
num_kv_heads=None,
skip_special_tokens: bool = True,
): ):
self.num_layers = num_layers self.process_group, rank, world_size = initialize_torch_distributed()
self.num_kv_heads = num_kv_heads if torch.cuda.is_available():
self.head_size = head_size device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError(f"{model_class} is only available on GPU")
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
try:
generation_config = GenerationConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
if isinstance(generation_config.eos_token_id, (list, set)):
# TODO Huge hack
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
except Exception:
pass
config = config_class.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
weights_loader = get_loader(quantize, model_id, revision)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
aliases=aliases,
weights_loader=weights_loader,
)
prefix = ""
model = model_class(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
# VLM models define the config we care about in their text_config
text_config = getattr(config, "text_config", None)
if text_config is not None:
config = text_config
if getattr(config, "sliding_window", None) is not None:
set_sliding_window(config.sliding_window)
else:
config.sliding_window = None
self.num_layers = config.num_hidden_layers
# Validation is done in the model itself
if num_kv_heads is None:
num_kv_heads = getattr(config, "num_key_value_heads", None)
# GPT-2 workaround
if num_kv_heads is None:
num_kv_heads = getattr(config, "n_head", None)
if num_kv_heads is None:
raise ValueError("Cannot get the number of key/value heads")
self.num_kv_heads = (
num_kv_heads // self.process_group.size()
if num_kv_heads > 1
else num_kv_heads
)
assert self.num_kv_heads > 0
self.head_size = config.hidden_size // config.num_attention_heads
self.cuda_graphs = {} self.cuda_graphs = {}
self.kv_cache = [] self.kv_cache = []
super(FlashCausalLM, self).__init__( super().__init__(
model_id=model_id, model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -830,7 +936,7 @@ class FlashCausalLM(Model):
device=device, device=device,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
sliding_window=sliding_window, sliding_window=config.sliding_window,
) )
@property @property
@ -1578,3 +1684,72 @@ class FlashCausalLM(Model):
forward_ns = start_decode - start forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns) return generations, batch, (forward_ns, decode_ns)
@property
def supports_adapter_loading(self) -> bool:
return True
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}
prefix = "model.layers"
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model.
if hasattr(self.model, "language_model"):
_model = self.model.language_model
elif hasattr(self.model, "text_model"):
_model = self.model.text_model
else:
_model = self.model
for i, layer in enumerate(_model.model.layers):
layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "k_proj")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "v_proj")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "o_proj")] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
# TODO: this is a hack to avoid the gate_proj for
# FlashStarcoder2 that doesnt have these layers
if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
return layer_weights
@property
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS
@property
def default_traced_adapter_layers(self) -> List[str]:
return ["q_proj", "v_proj"]
def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL

View File

@ -1,75 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional
from transformers import AutoTokenizer, AutoConfig
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
FlashCohereForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashCohere(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashCohere is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashCohereForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashCohere, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,100 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional
from transformers import AutoTokenizer
from transformers.models.gpt2 import GPT2TokenizerFast
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
FlashDbrxForCausalLM,
DbrxConfig,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashDbrx(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashDBRX is only available on GPU")
try:
tokenizer = GPT2TokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
)
except:
try:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
)
except:
# FIXME: change back to model id once the tokenizer.json is merged
tokenizer = GPT2TokenizerFast.from_pretrained(
"Xenova/dbrx-instruct-tokenizer",
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
)
config = DbrxConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashDbrxForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashDbrx, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,83 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional
from transformers import AutoConfig, AutoTokenizer
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashGemma(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashGemma is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
# TODO hardcoded
prefix = ""
model = FlashGemmaForCausalLM(prefix, config, weights, causal=True)
torch.distributed.barrier(group=self.process_group)
super(FlashGemma, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,83 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional
from transformers import PretrainedConfig, AutoTokenizer
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma2ForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashGemma2(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashGemma2 is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = PretrainedConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
# TODO hardcoded
prefix = ""
model = FlashGemma2ForCausalLM(prefix, config, weights, causal=True)
torch.distributed.barrier(group=self.process_group)
super(FlashGemma2, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,82 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
from transformers.models.gpt2 import GPT2Tokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
FlashGPT2ForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashGPT2(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashGPT2 is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
model = FlashGPT2ForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashGPT2, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,171 +0,0 @@
import os
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
from typing import Optional, Tuple, Dict, List
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
hub,
)
tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import SYSTEM
ADAPTER_LAYERS = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class FlashLlama(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
lora_adapter_ids: Optional[list] = [],
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashLlama is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
try:
generation_config = GenerationConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
if isinstance(generation_config.eos_token_id, (list, set)):
# TODO Huge hack
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
except Exception:
pass
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
model = FlashLlamaForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@property
def supports_adapter_loading(self) -> bool:
return True
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}
prefix = "model.layers"
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model.
if hasattr(self.model, "language_model"):
_model = self.model.language_model
elif hasattr(self.model, "text_model"):
_model = self.model.text_model
else:
_model = self.model
for i, layer in enumerate(_model.model.layers):
layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "k_proj")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "v_proj")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "o_proj")] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
return layer_weights
@property
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS
@property
def default_traced_adapter_layers(self) -> List[str]:
return ["q_proj", "v_proj"]
def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL

View File

@ -1,24 +1,7 @@
import torch import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, Tuple, Dict, List from typing import Optional, Tuple, Dict, List
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.flash_causal_lm import set_sliding_window
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
MistralConfig,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
ADAPTER_LAYERS = [ ADAPTER_LAYERS = [
@ -33,88 +16,7 @@ ADAPTER_LAYERS = [
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class BaseFlashMistral(FlashCausalLM): class FlashMistral(FlashCausalLM):
def __init__(
self,
model_cls,
model_id: str,
config_cls=AutoConfig,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
tokenizer_class=AutoTokenizer,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashMistral is only available on GPU")
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = config_cls.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
# Set context windows
if getattr(config, "sliding_window", None) is not None:
set_sliding_window(config.sliding_window)
else:
config.sliding_window = None
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
model = model_cls(prefix, config, weights)
self.cuda_graphs = {}
torch.distributed.barrier(group=self.process_group)
num_layers, num_kv_heads, head_size = self.get_layer_config(model)
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
sliding_window=config.sliding_window,
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.model.layers),
model.model.num_key_value_heads,
model.model.head_size,
)
@property @property
def supports_adapter_loading(self) -> bool: def supports_adapter_loading(self) -> bool:
return True return True
@ -126,9 +28,7 @@ class BaseFlashMistral(FlashCausalLM):
# This accounts for VLMs (e.g. LlavaNext, Idefics2) # This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model. # that have a language_model inside of the larger model.
if hasattr(self.model, "language_model"): if hasattr(self.model, "text_model"):
_model = self.model.language_model
elif hasattr(self.model, "text_model"):
_model = self.model.text_model _model = self.model.text_model
else: else:
_model = self.model _model = self.model
@ -183,25 +83,3 @@ class BaseFlashMistral(FlashCausalLM):
def is_row_parallel(self, layer_type: str) -> bool: def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL return layer_type in ROW_PARALLEL
class FlashMistral(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(FlashMistral, self).__init__(
config_cls=MistralConfig,
model_cls=FlashMistralForCausalLM,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

View File

@ -1,31 +0,0 @@
import torch
from typing import Optional
from text_generation_server.models.flash_mistral import BaseFlashMistral
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
MixtralConfig,
FlashMixtralForCausalLM,
)
class FlashMixtral(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(FlashMixtral, self).__init__(
config_cls=MixtralConfig,
model_cls=FlashMixtralForCausalLM,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

View File

@ -1,82 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
FlashGPTNeoXForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashNeoXSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashNeoX is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashGPTNeoXForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashNeoXSharded, self).__init__(
model_id=model_id,
model=model.to(device),
tokenizer=tokenizer,
num_layers=len(model.gpt_neox.layers),
num_kv_heads=model.gpt_neox.num_heads,
head_size=model.gpt_neox.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,111 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashPhi(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashPhi is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashPhiForCausalLM(config, weights)
if speculator:
from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download
import json
import os
from pathlib import Path
is_local_model = (
Path(speculator).exists() and Path(speculator).is_dir()
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model:
medusa_config = hf_hub_download(
speculator, revision=revision, filename="config.json"
)
medusa_head = hf_hub_download(
speculator, revision=revision, filename="medusa_lm_head.pt"
)
else:
medusa_config = str(Path(speculator) / "config.json")
medusa_head = str(Path(speculator) / "medusa_lm_head.pt")
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
weights = Weights(
[medusa_sf], device, dtype, process_group=self.process_group
)
lm_head = model.lm_head
model.lm_head = MedusaModel(config, weights, lm_head)
torch.distributed.barrier(group=self.process_group)
super(FlashPhi, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,93 +0,0 @@
import math
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig
from typing import Optional
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
set_sliding_window,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2ForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashQwen2(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashQwen2 is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
# Set context windows
if config.sliding_window is not None:
set_sliding_window(config.sliding_window)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = Qwen2ForCausalLM(config, weights)
self.cuda_graphs = {}
torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
sliding_window=config.sliding_window,
)

View File

@ -1,91 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
RWConfig,
FlashRWForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashRWSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashRW is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = RWConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
aliases={
"lm_head.weight": ["transformer.word_embeddings.weight"],
"transformer.word_embeddings.weight": ["lm_head.weight"],
},
)
config.quantize = quantize
config.speculator = speculator
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashRWForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashRWSharded, self).__init__(
model_id=model_id,
model=model.to(device),
tokenizer=tokenizer,
num_layers=len(model.transformer.h),
num_kv_heads=model.transformer.cache_size,
head_size=model.transformer.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,99 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, List
import json
import os
from huggingface_hub import hf_hub_download
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
FlashSantacoderForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashSantacoderSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=True,
)
config.quantize = quantize
config.speculator = speculator
config.transpose = config.architectures[0].startswith("GPT2")
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashSantacoderForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashSantacoderSharded, self).__init__(
model_id=model_id,
model=model.to(device),
tokenizer=tokenizer,
num_layers=len(model.transformer.h),
num_kv_heads=1,
head_size=model.transformer.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)

View File

@ -1,84 +0,0 @@
import math
import torch
from typing import Optional
from transformers.models.gpt2 import GPT2TokenizerFast
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
set_sliding_window,
)
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
Starcoder2Config,
FlashStarcoder2ForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
# Starcoder2 has the same base as Mistral
class FlashStarcoder2(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashStarcoder2 is only available on GPU")
tokenizer = GPT2TokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = Starcoder2Config.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
# Set context windows
if config.sliding_window is not None:
set_sliding_window(config.sliding_window)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashStarcoder2ForCausalLM(config, weights)
self.cuda_graphs = {}
torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
sliding_window=config.sliding_window,
)

View File

@ -162,83 +162,3 @@ class GalacticaCausalLMBatch(CausalLMBatch):
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
max_tokens=max_tokens, max_tokens=max_tokens,
) )
class GalacticaSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
tokenizer.pad_token_id = config.pad_token_id
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = OPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return GalacticaCausalLMBatch
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, speculative_logits, outputs.past_key_values

View File

@ -1,89 +0,0 @@
import torch
import torch.distributed
from typing import Optional
from transformers import (
AutoTokenizer,
AutoConfig,
)
from text_generation_server.models import CausalLM
from text_generation_server.models.custom_modeling.neox_modeling import (
GPTNeoxForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class GPTNeoxSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.pad_token = tokenizer.eos_token
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = GPTNeoxForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, speculative_logits, outputs.past_key_values

View File

@ -23,6 +23,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.quantization import get_loader
class IDEFICSSharded(IdeficsCausalLM): class IDEFICSSharded(IdeficsCausalLM):
@ -70,6 +71,9 @@ class IDEFICSSharded(IdeficsCausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights( weights = Weights(
@ -77,6 +81,7 @@ class IDEFICSSharded(IdeficsCausalLM):
device=device, device=device,
dtype=dtype, dtype=dtype,
process_group=self.process_group, process_group=self.process_group,
weights_loader=weights_loader,
) )
model = IdeficsForVisionText2Text(config, weights) model = IdeficsForVisionText2Text(config, weights)

View File

@ -1,51 +0,0 @@
import torch
from typing import Optional, Tuple
from transformers import (
AutoProcessor,
)
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
class Idefics2(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
size={"longest_edge": 448, "shortest_edge": 378},
)
super().__init__(
model_cls=Idefics2ForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.text_model.model.layers),
model.text_model.model.num_key_value_heads,
model.text_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)

View File

@ -1,46 +0,0 @@
import torch
from typing import Optional, Tuple
from transformers import (
AutoProcessor,
)
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
class LlavaNext(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
super().__init__(
model_cls=LlavaNextForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.language_model.model.layers),
model.language_model.model.num_key_value_heads,
model.language_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.language_model, "max_past", None)

View File

@ -28,6 +28,7 @@ from text_generation_server.models.types import (
GeneratedText, GeneratedText,
) )
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.tokens import batch_top_tokens, Sampling from text_generation_server.utils.tokens import batch_top_tokens, Sampling
from dataclasses import dataclass from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -448,8 +449,17 @@ class Mamba(Model):
config.quantize = quantize config.quantize = quantize
config.speculator = speculator config.speculator = speculator
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
weights_loader=weights_loader,
)
model = MambaModel(config, weights) model = MambaModel(config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(Mamba, self).__init__( super(Mamba, self).__init__(

View File

@ -60,7 +60,7 @@ class Model(ABC):
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict( self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
LayerAdapterWeights LayerAdapterWeights
) )
self.target_to_layer = self.adapter_target_to_layer() self.target_to_layer = None
self.loaded_adapters = set() self.loaded_adapters = set()
self.static_adapter_id = adapter_id self.static_adapter_id = adapter_id
@ -187,6 +187,8 @@ class Model(ABC):
into model. Otherwise, the adapter weights are applied during the forward into model. Otherwise, the adapter weights are applied during the forward
pass and stored separately from the base model parameters. pass and stored separately from the base model parameters.
""" """
if self.target_to_layer is None:
self.target_to_layer = self.adapter_target_to_layer()
if adapter_index in self.loaded_adapters: if adapter_index in self.loaded_adapters:
# Adapter already loaded # Adapter already loaded
return return

View File

@ -1,105 +0,0 @@
import torch
import torch.distributed
from pathlib import Path
from typing import Optional, Type
from opentelemetry import trace
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
from huggingface_hub import hf_hub_download
import json
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.models.custom_modeling.mpt_modeling import (
MPTForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class MPTCausalLMBatch(CausalLMBatch):
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
batch.keys_head_dim_last = False
return batch
class MPTSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.pad_token = tokenizer.eos_token
# If model_id is a local path, load the file directly
local_path = Path(model_id, "config.json")
if local_path.exists():
filename = str(local_path.resolve())
else:
filename = hf_hub_download(
model_id, revision=revision, filename="config.json"
)
with open(filename, "r") as f:
config = json.load(f)
config = PretrainedConfig(**config)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
config.quantize = quantize
model = MPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return MPTCausalLMBatch

View File

@ -1,86 +0,0 @@
import torch
import torch.distributed
from typing import Optional
from transformers import (
AutoTokenizer,
AutoConfig,
)
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models import CausalLM
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class OPTSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
tokenizer.pad_token_id = config.pad_token_id
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = OPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, speculative_logits, outputs.past_key_values

View File

@ -74,45 +74,3 @@ class PaliGemmaBatch(VlmCausalLMBatch):
else: else:
image_inputs = None image_inputs = None
return batch_tokenized_inputs, image_inputs return batch_tokenized_inputs, image_inputs
class PaliGemma(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
super().__init__(
config_cls=AutoConfig,
model_cls=PaliGemmaForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@property
def batch_type(self):
return PaliGemmaBatch
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.text_model.model.layers),
model.text_model.model.num_key_value_heads,
model.text_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)

Some files were not shown because too many files have changed in this diff Show More