mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Merge branch 'main' into feature/usage-stats
This commit is contained in:
commit
81c9ad7073
5
.github/workflows/autodocs.yaml
vendored
5
.github/workflows/autodocs.yaml
vendored
@ -30,6 +30,10 @@ jobs:
|
|||||||
id: install-router
|
id: install-router
|
||||||
run: cargo install --path router/
|
run: cargo install --path router/
|
||||||
|
|
||||||
|
- uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: 22
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v2
|
uses: actions/setup-python@v2
|
||||||
with:
|
with:
|
||||||
@ -37,4 +41,5 @@ jobs:
|
|||||||
|
|
||||||
- name: Check that documentation is up-to-date
|
- name: Check that documentation is up-to-date
|
||||||
run: |
|
run: |
|
||||||
|
npm install -g swagger-cli
|
||||||
python update_doc.py --check
|
python update_doc.py --check
|
||||||
|
13
.github/workflows/build.yaml
vendored
13
.github/workflows/build.yaml
vendored
@ -11,6 +11,11 @@ on:
|
|||||||
# - rocm
|
# - rocm
|
||||||
# - intel
|
# - intel
|
||||||
required: true
|
required: true
|
||||||
|
release-tests:
|
||||||
|
description: "Run release integration tests"
|
||||||
|
required: true
|
||||||
|
default: false
|
||||||
|
type: boolean
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build-and-push:
|
build-and-push:
|
||||||
@ -23,7 +28,7 @@ jobs:
|
|||||||
group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}
|
group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
# TODO see with @Glegendre to get CPU runner here instead
|
# TODO see with @Glegendre to get CPU runner here instead
|
||||||
runs-on: [self-hosted, nvidia-gpu , multi-gpu, 4-a10, ci]
|
runs-on: [self-hosted, intel-cpu, 32-cpu, 256-ram, ci]
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
packages: write
|
packages: write
|
||||||
@ -131,8 +136,8 @@ jobs:
|
|||||||
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
|
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
|
||||||
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
|
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
|
||||||
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
|
||||||
cache-from: type=registry,ref=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min
|
cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
|
||||||
cache-to: type=registry,ref=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min
|
cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
|
||||||
- name: Final
|
- name: Final
|
||||||
id: final
|
id: final
|
||||||
run: |
|
run: |
|
||||||
@ -148,7 +153,7 @@ jobs:
|
|||||||
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"]
|
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"]
|
||||||
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
|
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
|
||||||
env:
|
env:
|
||||||
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main') && '--release' || '' }}
|
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '' }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
11
.github/workflows/ci_build.yaml
vendored
11
.github/workflows/ci_build.yaml
vendored
@ -20,7 +20,14 @@ on:
|
|||||||
- "Dockerfile_amd"
|
- "Dockerfile_amd"
|
||||||
- "Dockerfile_intel"
|
- "Dockerfile_intel"
|
||||||
branches:
|
branches:
|
||||||
- 'main'
|
- "main"
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
release-tests:
|
||||||
|
description: "Run release integration tests"
|
||||||
|
required: true
|
||||||
|
default: false
|
||||||
|
type: boolean
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
@ -33,4 +40,6 @@ jobs:
|
|||||||
uses: ./.github/workflows/build.yaml # calls the one above ^
|
uses: ./.github/workflows/build.yaml # calls the one above ^
|
||||||
with:
|
with:
|
||||||
hardware: ${{ matrix.hardware }}
|
hardware: ${{ matrix.hardware }}
|
||||||
|
# https://github.com/actions/runner/issues/2206
|
||||||
|
release-tests: ${{ inputs.release-tests == true }}
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
28
Cargo.lock
generated
28
Cargo.lock
generated
@ -1935,17 +1935,6 @@ version = "2.7.4"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
|
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "metrics"
|
|
||||||
version = "0.21.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "fde3af1a009ed76a778cb84fdef9e7dbbdf5775ae3e4cc1f434a6a307f6f76c5"
|
|
||||||
dependencies = [
|
|
||||||
"ahash",
|
|
||||||
"metrics-macros",
|
|
||||||
"portable-atomic",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "metrics"
|
name = "metrics"
|
||||||
version = "0.23.0"
|
version = "0.23.0"
|
||||||
@ -1969,7 +1958,7 @@ dependencies = [
|
|||||||
"hyper-util",
|
"hyper-util",
|
||||||
"indexmap 2.2.6",
|
"indexmap 2.2.6",
|
||||||
"ipnet",
|
"ipnet",
|
||||||
"metrics 0.23.0",
|
"metrics",
|
||||||
"metrics-util",
|
"metrics-util",
|
||||||
"quanta",
|
"quanta",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
@ -1977,17 +1966,6 @@ dependencies = [
|
|||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "metrics-macros"
|
|
||||||
version = "0.7.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "38b4faf00617defe497754acde3024865bc143d44a86799b24e191ecff91354f"
|
|
||||||
dependencies = [
|
|
||||||
"proc-macro2",
|
|
||||||
"quote",
|
|
||||||
"syn 2.0.68",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "metrics-util"
|
name = "metrics-util"
|
||||||
version = "0.17.0"
|
version = "0.17.0"
|
||||||
@ -1997,7 +1975,7 @@ dependencies = [
|
|||||||
"crossbeam-epoch",
|
"crossbeam-epoch",
|
||||||
"crossbeam-utils",
|
"crossbeam-utils",
|
||||||
"hashbrown 0.14.5",
|
"hashbrown 0.14.5",
|
||||||
"metrics 0.23.0",
|
"metrics",
|
||||||
"num_cpus",
|
"num_cpus",
|
||||||
"quanta",
|
"quanta",
|
||||||
"sketches-ddsketch",
|
"sketches-ddsketch",
|
||||||
@ -3835,7 +3813,7 @@ dependencies = [
|
|||||||
"init-tracing-opentelemetry",
|
"init-tracing-opentelemetry",
|
||||||
"itertools 0.10.5",
|
"itertools 0.10.5",
|
||||||
"jsonschema",
|
"jsonschema",
|
||||||
"metrics 0.21.1",
|
"metrics",
|
||||||
"metrics-exporter-prometheus",
|
"metrics-exporter-prometheus",
|
||||||
"minijinja",
|
"minijinja",
|
||||||
"minijinja-contrib",
|
"minijinja-contrib",
|
||||||
|
@ -40,7 +40,9 @@ RUN cargo build --profile release-opt
|
|||||||
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
||||||
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS pytorch-install
|
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS pytorch-install
|
||||||
|
|
||||||
|
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
|
||||||
ARG PYTORCH_VERSION=2.3.0
|
ARG PYTORCH_VERSION=2.3.0
|
||||||
|
|
||||||
ARG PYTHON_VERSION=3.10
|
ARG PYTHON_VERSION=3.10
|
||||||
# Keep in sync with `server/pyproject.toml
|
# Keep in sync with `server/pyproject.toml
|
||||||
ARG CUDA_VERSION=12.1
|
ARG CUDA_VERSION=12.1
|
||||||
@ -241,7 +243,10 @@ COPY server/Makefile server/Makefile
|
|||||||
RUN cd server && \
|
RUN cd server && \
|
||||||
make gen-server && \
|
make gen-server && \
|
||||||
pip install -r requirements_cuda.txt && \
|
pip install -r requirements_cuda.txt && \
|
||||||
pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir
|
pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir && \
|
||||||
|
pip install nvidia-nccl-cu12==2.22.3
|
||||||
|
|
||||||
|
ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2
|
||||||
|
|
||||||
# Deps before the binaries
|
# Deps before the binaries
|
||||||
# The binaries change on every build given we burn the SHA into them
|
# The binaries change on every build given we burn the SHA into them
|
||||||
|
27
README.md
27
README.md
@ -20,19 +20,20 @@ to power Hugging Chat, the Inference API and Inference Endpoint.
|
|||||||
|
|
||||||
## Table of contents
|
## Table of contents
|
||||||
|
|
||||||
- [Get Started](#get-started)
|
- [Get Started](#get-started)
|
||||||
- [API Documentation](#api-documentation)
|
- [Docker](#docker)
|
||||||
- [Using a private or gated model](#using-a-private-or-gated-model)
|
- [API documentation](#api-documentation)
|
||||||
- [A note on Shared Memory](#a-note-on-shared-memory-shm)
|
- [Using a private or gated model](#using-a-private-or-gated-model)
|
||||||
- [Distributed Tracing](#distributed-tracing)
|
- [A note on Shared Memory (shm)](#a-note-on-shared-memory-shm)
|
||||||
- [Local Install](#local-install)
|
- [Distributed Tracing](#distributed-tracing)
|
||||||
- [CUDA Kernels](#cuda-kernels)
|
- [Architecture](#architecture)
|
||||||
- [Optimized architectures](#optimized-architectures)
|
- [Local install](#local-install)
|
||||||
- [Run Mistral](#run-a-model)
|
- [Optimized architectures](#optimized-architectures)
|
||||||
- [Run](#run)
|
- [Run locally](#run-locally)
|
||||||
- [Quantization](#quantization)
|
- [Run](#run)
|
||||||
- [Develop](#develop)
|
- [Quantization](#quantization)
|
||||||
- [Testing](#testing)
|
- [Develop](#develop)
|
||||||
|
- [Testing](#testing)
|
||||||
|
|
||||||
Text Generation Inference (TGI) is a toolkit for deploying and serving Large Language Models (LLMs). TGI enables high-performance text generation for the most popular open-source LLMs, including Llama, Falcon, StarCoder, BLOOM, GPT-NeoX, and [more](https://huggingface.co/docs/text-generation-inference/supported_models). TGI implements many features, such as:
|
Text Generation Inference (TGI) is a toolkit for deploying and serving Large Language Models (LLMs). TGI enables high-performance text generation for the most popular open-source LLMs, including Llama, Falcon, StarCoder, BLOOM, GPT-NeoX, and [more](https://huggingface.co/docs/text-generation-inference/supported_models). TGI implements many features, such as:
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ class ChoiceDeltaToolCall(BaseModel):
|
|||||||
class ChoiceDelta(BaseModel):
|
class ChoiceDelta(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
tool_calls: Optional[ChoiceDeltaToolCall]
|
tool_calls: Optional[ChoiceDeltaToolCall] = None
|
||||||
|
|
||||||
|
|
||||||
class Choice(BaseModel):
|
class Choice(BaseModel):
|
||||||
|
@ -11,6 +11,8 @@
|
|||||||
title: Using TGI with Intel Gaudi
|
title: Using TGI with Intel Gaudi
|
||||||
- local: installation_inferentia
|
- local: installation_inferentia
|
||||||
title: Using TGI with AWS Inferentia
|
title: Using TGI with AWS Inferentia
|
||||||
|
- local: installation_intel
|
||||||
|
title: Using TGI with Intel GPUs
|
||||||
- local: installation
|
- local: installation
|
||||||
title: Installation from source
|
title: Installation from source
|
||||||
- local: supported_models
|
- local: supported_models
|
||||||
|
@ -103,6 +103,7 @@ Several variants of the model server exist that are actively supported by Huggin
|
|||||||
|
|
||||||
- By default, the model server will attempt building [a server optimized for Nvidia GPUs with CUDA](https://huggingface.co/docs/text-generation-inference/installation_nvidia). The code for this version is hosted in the [main TGI repository](https://github.com/huggingface/text-generation-inference).
|
- By default, the model server will attempt building [a server optimized for Nvidia GPUs with CUDA](https://huggingface.co/docs/text-generation-inference/installation_nvidia). The code for this version is hosted in the [main TGI repository](https://github.com/huggingface/text-generation-inference).
|
||||||
- A [version optimized for AMD with ROCm](https://huggingface.co/docs/text-generation-inference/installation_amd) is hosted in the main TGI repository. Some model features differ.
|
- A [version optimized for AMD with ROCm](https://huggingface.co/docs/text-generation-inference/installation_amd) is hosted in the main TGI repository. Some model features differ.
|
||||||
|
- A [version optimized for Intel GPUs](https://huggingface.co/docs/text-generation-inference/installation_intel) is hosted in the main TGI repository. Some model features differ.
|
||||||
- The [version for Intel Gaudi](https://huggingface.co/docs/text-generation-inference/installation_gaudi) is maintained on a forked repository, often resynchronized with the main [TGI repository](https://github.com/huggingface/tgi-gaudi).
|
- The [version for Intel Gaudi](https://huggingface.co/docs/text-generation-inference/installation_gaudi) is maintained on a forked repository, often resynchronized with the main [TGI repository](https://github.com/huggingface/tgi-gaudi).
|
||||||
- A [version for Neuron (AWS Inferentia2)](https://huggingface.co/docs/text-generation-inference/installation_inferentia) is maintained as part of [Optimum Neuron](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference).
|
- A [version for Neuron (AWS Inferentia2)](https://huggingface.co/docs/text-generation-inference/installation_inferentia) is maintained as part of [Optimum Neuron](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference).
|
||||||
- A version for Google TPUs is maintained as part of [Optimum TPU](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference).
|
- A version for Google TPUs is maintained as part of [Optimum TPU](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference).
|
||||||
|
19
docs/source/installation_intel.md
Normal file
19
docs/source/installation_intel.md
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# Using TGI with Intel GPUs
|
||||||
|
|
||||||
|
TGI optimized models are supported on Intel Data Center GPU [Max1100](https://www.intel.com/content/www/us/en/products/sku/232876/intel-data-center-gpu-max-1100/specifications.html), [Max1550](https://www.intel.com/content/www/us/en/products/sku/232873/intel-data-center-gpu-max-1550/specifications.html), the recommended usage is through Docker.
|
||||||
|
|
||||||
|
|
||||||
|
On a server powered by Intel GPUs, TGI can be launched with the following command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=teknium/OpenHermes-2.5-Mistral-7B
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run --rm --privileged --cap-add=sys_nice \
|
||||||
|
--device=/dev/dri \
|
||||||
|
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:latest-intel \
|
||||||
|
--model-id $model --cuda-graphs 0
|
||||||
|
```
|
||||||
|
|
||||||
|
The launched TGI server can then be queried from clients, make sure to check out the [Consuming TGI](./basic_tutorials/consuming_tgi) guide.
|
@ -17,7 +17,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
|||||||
|
|
||||||
### Supported hardware
|
### Supported hardware
|
||||||
|
|
||||||
TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on.
|
TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Intel GPUs](./installation_intel), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on.
|
||||||
|
|
||||||
## Consuming TGI
|
## Consuming TGI
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ Text Generation Inference enables serving optimized models on specific hardware
|
|||||||
- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
|
- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
|
||||||
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
|
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
|
||||||
- [Gemma](https://huggingface.co/google/gemma-7b)
|
- [Gemma](https://huggingface.co/google/gemma-7b)
|
||||||
|
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
|
||||||
- [Gemma2](https://huggingface.co/google/gemma2-9b)
|
- [Gemma2](https://huggingface.co/google/gemma2-9b)
|
||||||
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
|
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
|
||||||
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct)
|
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct)
|
||||||
|
@ -5,85 +5,80 @@
|
|||||||
"generated_tokens": 10,
|
"generated_tokens": 10,
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 1,
|
"id": 2323,
|
||||||
"logprob": null,
|
"logprob": null,
|
||||||
"text": "<s>"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 4321,
|
|
||||||
"logprob": -9.7890625,
|
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 1715,
|
||||||
"logprob": -9.625,
|
"logprob": -11.34375,
|
||||||
"text": "request"
|
"text": " request"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 198,
|
||||||
"logprob": -2.3359375,
|
"logprob": -2.5742188,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3057,
|
"id": 262,
|
||||||
"logprob": -1.8779297,
|
"logprob": -1.6230469,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Test"
|
"text": " "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 3270,
|
||||||
"logprob": -1.2744141,
|
"logprob": -2.046875,
|
||||||
|
"special": false,
|
||||||
|
"text": " \"\"\"\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 262,
|
||||||
|
"logprob": -0.015281677,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 422,
|
||||||
|
"logprob": -2.1425781,
|
||||||
|
"special": false,
|
||||||
|
"text": " if"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -0.9238281,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13204,
|
||||||
"logprob": -1.6933594,
|
"logprob": -0.076660156,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": ".method"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3057,
|
"id": 624,
|
||||||
"logprob": -1.4648438,
|
"logprob": -0.021987915,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Test"
|
"text": " =="
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 364,
|
||||||
"logprob": -0.15600586,
|
"logprob": -0.39208984,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " '"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 3019,
|
||||||
"logprob": -0.8027344,
|
"logprob": -0.10821533,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "POST"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3057,
|
|
||||||
"logprob": -0.23022461,
|
|
||||||
"special": false,
|
|
||||||
"text": "Test"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2009,
|
|
||||||
"logprob": -0.0069885254,
|
|
||||||
"special": false,
|
|
||||||
"text": " request"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 13,
|
|
||||||
"logprob": -0.02218628,
|
|
||||||
"special": false,
|
|
||||||
"text": "\n"
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\nTest request\nTest request\nTest request\n"
|
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||||
}
|
}
|
||||||
|
@ -5,85 +5,80 @@
|
|||||||
"generated_tokens": 10,
|
"generated_tokens": 10,
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 1,
|
"id": 2323,
|
||||||
"logprob": null,
|
"logprob": null,
|
||||||
"text": "<s>"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 4321,
|
|
||||||
"logprob": -9.84375,
|
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 1715,
|
||||||
"logprob": -9.6015625,
|
"logprob": -11.34375,
|
||||||
"text": "request"
|
"text": " request"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": 0,
|
"seed": 0,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 29899,
|
"id": 13,
|
||||||
"logprob": -1.5625,
|
"logprob": -2.2539062,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "-"
|
"text": "."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1454,
|
"id": 578,
|
||||||
"logprob": -0.20410156,
|
"logprob": -0.15563965,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "for"
|
"text": " The"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29899,
|
"id": 3622,
|
||||||
|
"logprob": -0.8203125,
|
||||||
|
"special": false,
|
||||||
|
"text": " server"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 706,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "-"
|
"text": " has"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 9342,
|
"id": 539,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "comment"
|
"text": " not"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29901,
|
"id": 3686,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": ":"
|
"text": " yet"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 396,
|
"id": 3288,
|
||||||
"logprob": -0.27685547,
|
|
||||||
"special": false,
|
|
||||||
"text": " #"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29906,
|
|
||||||
"logprob": -0.4970703,
|
|
||||||
"special": false,
|
|
||||||
"text": "2"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29900,
|
|
||||||
"logprob": -0.80615234,
|
|
||||||
"special": false,
|
|
||||||
"text": "0"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29896,
|
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "1"
|
"text": " sent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29955,
|
"id": 904,
|
||||||
"logprob": -1.0751953,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "7"
|
"text": " any"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 828,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 382,
|
||||||
|
"logprob": -1.5517578,
|
||||||
|
"special": false,
|
||||||
|
"text": ".\n\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "Test request-for-comment: #2017"
|
"generated_text": "Test request. The server has not yet sent any data.\n\n"
|
||||||
}
|
}
|
||||||
|
@ -6,87 +6,82 @@
|
|||||||
"generated_tokens": 10,
|
"generated_tokens": 10,
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 1,
|
"id": 2323,
|
||||||
"logprob": null,
|
"logprob": null,
|
||||||
"text": "<s>"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 4321,
|
|
||||||
"logprob": -9.828125,
|
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 1715,
|
||||||
"logprob": -9.609375,
|
"logprob": -11.34375,
|
||||||
"text": "request"
|
"text": " request"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 198,
|
||||||
"logprob": -2.3300781,
|
"logprob": -2.5742188,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3057,
|
"id": 262,
|
||||||
"logprob": -1.8740234,
|
"logprob": -1.6220703,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Test"
|
"text": " "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 3270,
|
||||||
"logprob": -1.2646484,
|
"logprob": -2.0410156,
|
||||||
|
"special": false,
|
||||||
|
"text": " \"\"\"\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 262,
|
||||||
|
"logprob": -0.015281677,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 422,
|
||||||
|
"logprob": -2.1445312,
|
||||||
|
"special": false,
|
||||||
|
"text": " if"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -0.92333984,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13204,
|
||||||
"logprob": -1.7158203,
|
"logprob": -0.07672119,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": ".method"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3057,
|
"id": 624,
|
||||||
"logprob": -1.4667969,
|
"logprob": -0.021987915,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Test"
|
"text": " =="
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 364,
|
||||||
"logprob": -0.15344238,
|
"logprob": -0.39208984,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " '"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 3019,
|
||||||
"logprob": -0.81591797,
|
"logprob": -0.10638428,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "POST"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3057,
|
|
||||||
"logprob": -0.22973633,
|
|
||||||
"special": false,
|
|
||||||
"text": "Test"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2009,
|
|
||||||
"logprob": -0.007045746,
|
|
||||||
"special": false,
|
|
||||||
"text": " request"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 13,
|
|
||||||
"logprob": -0.021957397,
|
|
||||||
"special": false,
|
|
||||||
"text": "\n"
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\nTest request\nTest request\nTest request\n"
|
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
@ -95,87 +90,82 @@
|
|||||||
"generated_tokens": 10,
|
"generated_tokens": 10,
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 1,
|
"id": 2323,
|
||||||
"logprob": null,
|
"logprob": null,
|
||||||
"text": "<s>"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 4321,
|
|
||||||
"logprob": -9.84375,
|
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 1715,
|
||||||
"logprob": -9.59375,
|
"logprob": -11.34375,
|
||||||
"text": "request"
|
"text": " request"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 198,
|
||||||
"logprob": -2.3378906,
|
"logprob": -2.5742188,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3057,
|
"id": 262,
|
||||||
"logprob": -1.8779297,
|
"logprob": -1.6220703,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Test"
|
"text": " "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 3270,
|
||||||
"logprob": -1.2636719,
|
"logprob": -2.0410156,
|
||||||
|
"special": false,
|
||||||
|
"text": " \"\"\"\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 262,
|
||||||
|
"logprob": -0.015281677,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 422,
|
||||||
|
"logprob": -2.1445312,
|
||||||
|
"special": false,
|
||||||
|
"text": " if"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -0.92333984,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13204,
|
||||||
"logprob": -1.6992188,
|
"logprob": -0.07672119,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": ".method"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3057,
|
"id": 624,
|
||||||
"logprob": -1.4589844,
|
"logprob": -0.021987915,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Test"
|
"text": " =="
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 364,
|
||||||
"logprob": -0.15344238,
|
"logprob": -0.39208984,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " '"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 3019,
|
||||||
"logprob": -0.79052734,
|
"logprob": -0.10638428,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "POST"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3057,
|
|
||||||
"logprob": -0.22937012,
|
|
||||||
"special": false,
|
|
||||||
"text": "Test"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2009,
|
|
||||||
"logprob": -0.007041931,
|
|
||||||
"special": false,
|
|
||||||
"text": " request"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 13,
|
|
||||||
"logprob": -0.022140503,
|
|
||||||
"special": false,
|
|
||||||
"text": "\n"
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\nTest request\nTest request\nTest request\n"
|
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
@ -184,87 +174,82 @@
|
|||||||
"generated_tokens": 10,
|
"generated_tokens": 10,
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 1,
|
"id": 2323,
|
||||||
"logprob": null,
|
"logprob": null,
|
||||||
"text": "<s>"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 4321,
|
|
||||||
"logprob": -9.84375,
|
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 1715,
|
||||||
"logprob": -9.609375,
|
"logprob": -11.34375,
|
||||||
"text": "request"
|
"text": " request"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 198,
|
||||||
"logprob": -2.3261719,
|
"logprob": -2.5742188,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3057,
|
"id": 262,
|
||||||
"logprob": -1.8730469,
|
"logprob": -1.6220703,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Test"
|
"text": " "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 3270,
|
||||||
"logprob": -1.2587891,
|
"logprob": -2.0410156,
|
||||||
|
"special": false,
|
||||||
|
"text": " \"\"\"\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 262,
|
||||||
|
"logprob": -0.015281677,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 422,
|
||||||
|
"logprob": -2.1445312,
|
||||||
|
"special": false,
|
||||||
|
"text": " if"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -0.92333984,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13204,
|
||||||
"logprob": -1.6894531,
|
"logprob": -0.07672119,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": ".method"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3057,
|
"id": 624,
|
||||||
"logprob": -1.46875,
|
"logprob": -0.021987915,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Test"
|
"text": " =="
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 364,
|
||||||
"logprob": -0.1541748,
|
"logprob": -0.39208984,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " '"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 3019,
|
||||||
"logprob": -0.80322266,
|
"logprob": -0.10638428,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "POST"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3057,
|
|
||||||
"logprob": -0.22912598,
|
|
||||||
"special": false,
|
|
||||||
"text": "Test"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2009,
|
|
||||||
"logprob": -0.0070495605,
|
|
||||||
"special": false,
|
|
||||||
"text": " request"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 13,
|
|
||||||
"logprob": -0.021606445,
|
|
||||||
"special": false,
|
|
||||||
"text": "\n"
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\nTest request\nTest request\nTest request\n"
|
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
@ -273,86 +258,81 @@
|
|||||||
"generated_tokens": 10,
|
"generated_tokens": 10,
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 1,
|
"id": 2323,
|
||||||
"logprob": null,
|
"logprob": null,
|
||||||
"text": "<s>"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 4321,
|
|
||||||
"logprob": -9.84375,
|
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 1715,
|
||||||
"logprob": -9.6015625,
|
"logprob": -11.34375,
|
||||||
"text": "request"
|
"text": " request"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 198,
|
||||||
"logprob": -2.3320312,
|
"logprob": -2.5742188,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3057,
|
"id": 262,
|
||||||
"logprob": -1.875,
|
"logprob": -1.6220703,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Test"
|
"text": " "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 3270,
|
||||||
"logprob": -1.2646484,
|
"logprob": -2.0410156,
|
||||||
|
"special": false,
|
||||||
|
"text": " \"\"\"\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 262,
|
||||||
|
"logprob": -0.015281677,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 422,
|
||||||
|
"logprob": -2.1445312,
|
||||||
|
"special": false,
|
||||||
|
"text": " if"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1715,
|
||||||
|
"logprob": -0.92333984,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13204,
|
||||||
"logprob": -1.6884766,
|
"logprob": -0.07672119,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": ".method"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3057,
|
"id": 624,
|
||||||
"logprob": -1.4589844,
|
"logprob": -0.021987915,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Test"
|
"text": " =="
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 364,
|
||||||
"logprob": -0.15185547,
|
"logprob": -0.39208984,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " '"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 3019,
|
||||||
"logprob": -0.79833984,
|
"logprob": -0.10638428,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "POST"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3057,
|
|
||||||
"logprob": -0.22827148,
|
|
||||||
"special": false,
|
|
||||||
"text": "Test"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2009,
|
|
||||||
"logprob": -0.006996155,
|
|
||||||
"special": false,
|
|
||||||
"text": " request"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 13,
|
|
||||||
"logprob": -0.021560669,
|
|
||||||
"special": false,
|
|
||||||
"text": "\n"
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\nTest request\nTest request\nTest request\n"
|
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -1,130 +1,124 @@
|
|||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
"best_of_sequences": null,
|
"best_of_sequences": null,
|
||||||
"finish_reason": "length",
|
"finish_reason": "eos_token",
|
||||||
"generated_tokens": 20,
|
"generated_tokens": 19,
|
||||||
"prefill": [],
|
"prefill": [],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 415,
|
"id": 415,
|
||||||
"logprob": -0.039886475,
|
"logprob": -0.03665161,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " The"
|
"text": " The"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 12072,
|
"id": 12072,
|
||||||
"logprob": -0.1430664,
|
"logprob": -0.13549805,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " cow"
|
"text": " cow"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 349,
|
"id": 349,
|
||||||
"logprob": -0.056488037,
|
"logprob": -0.05819702,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6328,
|
"id": 6328,
|
||||||
"logprob": -0.6855469,
|
"logprob": -0.6826172,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " standing"
|
"text": " standing"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 356,
|
"id": 356,
|
||||||
"logprob": -0.1685791,
|
"logprob": -0.1607666,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " on"
|
"text": " on"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 272,
|
"id": 272,
|
||||||
"logprob": -0.50097656,
|
"logprob": -0.5073242,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " the"
|
"text": " the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 10305,
|
"id": 10305,
|
||||||
"logprob": -0.017303467,
|
"logprob": -0.016418457,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " beach"
|
"text": " beach"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 304,
|
"id": 304,
|
||||||
"logprob": -1.3564453,
|
"logprob": -1.3916016,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " and"
|
"text": " and"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 272,
|
"id": 272,
|
||||||
"logprob": -0.017868042,
|
"logprob": -0.020217896,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " the"
|
"text": " the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13088,
|
"id": 13088,
|
||||||
"logprob": -0.0027103424,
|
"logprob": -0.0028133392,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " chicken"
|
"text": " chicken"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 349,
|
"id": 349,
|
||||||
"logprob": -0.003156662,
|
"logprob": -0.003145218,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6398,
|
"id": 6398,
|
||||||
"logprob": -0.37304688,
|
"logprob": -0.37060547,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " sitting"
|
"text": " sitting"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 356,
|
"id": 356,
|
||||||
"logprob": -0.034576416,
|
"logprob": -0.034851074,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " on"
|
"text": " on"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 264,
|
"id": 264,
|
||||||
"logprob": -0.29418945,
|
"logprob": -0.2878418,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 17972,
|
"id": 17972,
|
||||||
"logprob": -0.042877197,
|
"logprob": -0.046051025,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " pile"
|
"text": " pile"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 302,
|
"id": 302,
|
||||||
"logprob": -0.00028443336,
|
"logprob": -0.00028848648,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " of"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2445,
|
"id": 2445,
|
||||||
"logprob": -0.023223877,
|
"logprob": -0.025772095,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " money"
|
"text": " money"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 28723,
|
"id": 28723,
|
||||||
"logprob": -0.018157959,
|
"logprob": -0.018127441,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "."
|
"text": "."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 32002,
|
"id": 32002,
|
||||||
"logprob": -0.00018393993,
|
"logprob": -0.00019824505,
|
||||||
"special": true,
|
"special": true,
|
||||||
"text": "<end_of_utterance>"
|
"text": "<end_of_utterance>"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2,
|
|
||||||
"logprob": -1.1920929e-07,
|
|
||||||
"special": true,
|
|
||||||
"text": "</s>"
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
|
@ -0,0 +1,89 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4321,
|
||||||
|
"logprob": -9.8359375,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -9.6171875,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.3417969,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3057,
|
||||||
|
"logprob": -1.8730469,
|
||||||
|
"special": false,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -1.2626953,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -1.7060547,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3057,
|
||||||
|
"logprob": -1.4482422,
|
||||||
|
"special": false,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -0.15246582,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.796875,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3057,
|
||||||
|
"logprob": -0.22766113,
|
||||||
|
"special": false,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -0.007045746,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.021759033,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||||
|
}
|
@ -0,0 +1,89 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4321,
|
||||||
|
"logprob": -9.7890625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -9.625,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 29899,
|
||||||
|
"logprob": -1.4980469,
|
||||||
|
"special": false,
|
||||||
|
"text": "-"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1454,
|
||||||
|
"logprob": -0.19433594,
|
||||||
|
"special": false,
|
||||||
|
"text": "for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29899,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "-"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 9342,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "comment"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29901,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 396,
|
||||||
|
"logprob": -0.27392578,
|
||||||
|
"special": false,
|
||||||
|
"text": " #"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29906,
|
||||||
|
"logprob": -0.49389648,
|
||||||
|
"special": false,
|
||||||
|
"text": "2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29900,
|
||||||
|
"logprob": -0.81103516,
|
||||||
|
"special": false,
|
||||||
|
"text": "0"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29896,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29955,
|
||||||
|
"logprob": -1.0800781,
|
||||||
|
"special": false,
|
||||||
|
"text": "7"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Test request-for-comment: #2017"
|
||||||
|
}
|
@ -0,0 +1,358 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4321,
|
||||||
|
"logprob": -9.8828125,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -9.5859375,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.3359375,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3057,
|
||||||
|
"logprob": -1.8623047,
|
||||||
|
"special": false,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -1.2451172,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -1.6923828,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3057,
|
||||||
|
"logprob": -1.4492188,
|
||||||
|
"special": false,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -0.15197754,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.8022461,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3057,
|
||||||
|
"logprob": -0.22583008,
|
||||||
|
"special": false,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -0.007095337,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.021652222,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4321,
|
||||||
|
"logprob": -9.796875,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -9.625,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.3476562,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3057,
|
||||||
|
"logprob": -1.8789062,
|
||||||
|
"special": false,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -1.2734375,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -1.703125,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3057,
|
||||||
|
"logprob": -1.4677734,
|
||||||
|
"special": false,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -0.15454102,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.7973633,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3057,
|
||||||
|
"logprob": -0.23278809,
|
||||||
|
"special": false,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -0.006980896,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.022033691,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4321,
|
||||||
|
"logprob": -9.9296875,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -9.5703125,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.3203125,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3057,
|
||||||
|
"logprob": -1.8486328,
|
||||||
|
"special": false,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -1.2480469,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -1.7060547,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3057,
|
||||||
|
"logprob": -1.4511719,
|
||||||
|
"special": false,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -0.1529541,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.81396484,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3057,
|
||||||
|
"logprob": -0.22180176,
|
||||||
|
"special": false,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -0.007133484,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.021835327,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4321,
|
||||||
|
"logprob": -9.84375,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -9.6171875,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.3261719,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3057,
|
||||||
|
"logprob": -1.8691406,
|
||||||
|
"special": false,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -1.2597656,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -1.7070312,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3057,
|
||||||
|
"logprob": -1.4550781,
|
||||||
|
"special": false,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -0.1538086,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.79345703,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3057,
|
||||||
|
"logprob": -0.22924805,
|
||||||
|
"special": false,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -0.0070266724,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.021942139,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||||
|
}
|
||||||
|
]
|
@ -3,7 +3,9 @@ import pytest
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_llama_gptq_handle(launcher):
|
def flash_llama_gptq_handle(launcher):
|
||||||
with launcher("huggingface/llama-7b-gptq", num_shard=2, quantize="gptq") as handle:
|
with launcher(
|
||||||
|
"astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="gptq"
|
||||||
|
) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,7 +57,7 @@ async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot)
|
|||||||
response.generated_text
|
response.generated_text
|
||||||
== " The cow is standing on the beach and the chicken is sitting on a pile of money."
|
== " The cow is standing on the beach and the chicken is sitting on a pile of money."
|
||||||
), f"{repr(response.generated_text)}"
|
), f"{repr(response.generated_text)}"
|
||||||
assert response.details.generated_tokens == 20
|
assert response.details.generated_tokens == 19
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ futures = "0.3.28"
|
|||||||
hf-hub = { workspace = true }
|
hf-hub = { workspace = true }
|
||||||
itertools = "0.10"
|
itertools = "0.10"
|
||||||
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
||||||
metrics = "0.21.1"
|
metrics = "0.23.0"
|
||||||
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
||||||
nohash-hasher = "0.2.0"
|
nohash-hasher = "0.2.0"
|
||||||
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
||||||
|
@ -91,14 +91,14 @@ impl Infer {
|
|||||||
.limit_concurrent_requests
|
.limit_concurrent_requests
|
||||||
.try_acquire_owned()
|
.try_acquire_owned()
|
||||||
.map_err(|err| {
|
.map_err(|err| {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "overloaded");
|
metrics::counter!("tgi_request_failure", "err" => "overloaded").increment(1);
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
err
|
err
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Validate request
|
// Validate request
|
||||||
let valid_request = self.validation.validate(request).await.map_err(|err| {
|
let valid_request = self.validation.validate(request).await.map_err(|err| {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
err
|
err
|
||||||
})?;
|
})?;
|
||||||
@ -140,7 +140,7 @@ impl Infer {
|
|||||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||||
.apply(messages, grammar_with_prompt)
|
.apply(messages, grammar_with_prompt)
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
|
||||||
tracing::error!("{e}");
|
tracing::error!("{e}");
|
||||||
e
|
e
|
||||||
})
|
})
|
||||||
@ -214,7 +214,7 @@ impl Infer {
|
|||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
let err = InferError::IncompleteGeneration;
|
let err = InferError::IncompleteGeneration;
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
|
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
Err(err)
|
Err(err)
|
||||||
}
|
}
|
||||||
|
@ -111,7 +111,7 @@ async fn queue_task(
|
|||||||
match cmd {
|
match cmd {
|
||||||
QueueCommand::Append(entry, span) => {
|
QueueCommand::Append(entry, span) => {
|
||||||
span.in_scope(|| state.append(*entry));
|
span.in_scope(|| state.append(*entry));
|
||||||
metrics::increment_gauge!("tgi_queue_size", 1.0);
|
metrics::gauge!("tgi_queue_size").increment(1.0);
|
||||||
}
|
}
|
||||||
QueueCommand::NextBatch {
|
QueueCommand::NextBatch {
|
||||||
min_size,
|
min_size,
|
||||||
@ -124,7 +124,7 @@ async fn queue_task(
|
|||||||
let next_batch =
|
let next_batch =
|
||||||
state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
|
state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
|
||||||
response_sender.send(next_batch).unwrap();
|
response_sender.send(next_batch).unwrap();
|
||||||
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
|
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -226,7 +226,7 @@ impl State {
|
|||||||
// Filter entries where the response receiver was dropped (== entries where the request
|
// Filter entries where the response receiver was dropped (== entries where the request
|
||||||
// was dropped by the client)
|
// was dropped by the client)
|
||||||
if entry.response_tx.is_closed() {
|
if entry.response_tx.is_closed() {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||||
tracing::debug!("Dropping entry");
|
tracing::debug!("Dropping entry");
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -336,7 +336,7 @@ impl State {
|
|||||||
// Increment batch id
|
// Increment batch id
|
||||||
self.next_batch_id += 1;
|
self.next_batch_id += 1;
|
||||||
|
|
||||||
metrics::histogram!("tgi_batch_next_size", batch.size as f64);
|
metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
|
||||||
|
|
||||||
Some((batch_entries, batch, next_batch_span))
|
Some((batch_entries, batch, next_batch_span))
|
||||||
}
|
}
|
||||||
|
@ -148,8 +148,8 @@ pub(crate) async fn batching_task(
|
|||||||
let batch_size = batch.size;
|
let batch_size = batch.size;
|
||||||
let batch_max_tokens = batch.max_tokens;
|
let batch_max_tokens = batch.max_tokens;
|
||||||
let mut batches = vec![batch];
|
let mut batches = vec![batch];
|
||||||
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
|
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||||
metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64);
|
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||||
|
|
||||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||||
@ -170,9 +170,11 @@ pub(crate) async fn batching_task(
|
|||||||
{
|
{
|
||||||
// Tracking metrics
|
// Tracking metrics
|
||||||
if min_size.is_some() {
|
if min_size.is_some() {
|
||||||
metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure");
|
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
||||||
|
.increment(1);
|
||||||
} else {
|
} else {
|
||||||
metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded");
|
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
||||||
|
.increment(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
entries.iter_mut().for_each(|(_, entry)| {
|
entries.iter_mut().for_each(|(_, entry)| {
|
||||||
@ -219,8 +221,8 @@ pub(crate) async fn batching_task(
|
|||||||
.await;
|
.await;
|
||||||
waiting_tokens += 1;
|
waiting_tokens += 1;
|
||||||
}
|
}
|
||||||
metrics::gauge!("tgi_batch_current_size", 0.0);
|
metrics::gauge!("tgi_batch_current_size").set(0.0);
|
||||||
metrics::gauge!("tgi_batch_current_max_tokens", 0.0);
|
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -234,7 +236,7 @@ async fn prefill(
|
|||||||
) -> Option<CachedBatch> {
|
) -> Option<CachedBatch> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_id = batch.id;
|
let batch_id = batch.id;
|
||||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
|
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
||||||
|
|
||||||
match client.prefill(batch).await {
|
match client.prefill(batch).await {
|
||||||
Ok((generations, next_batch, timings)) => {
|
Ok((generations, next_batch, timings)) => {
|
||||||
@ -248,11 +250,15 @@ async fn prefill(
|
|||||||
// Filter next batch and remove requests that were stopped
|
// Filter next batch and remove requests that were stopped
|
||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
|
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
|
||||||
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
|
.record(timings.forward.as_secs_f64());
|
||||||
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill");
|
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
|
.record(timings.decode.as_secs_f64());
|
||||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
|
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
|
||||||
|
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_inference_duration","method" => "prefill")
|
||||||
|
.record(start_time.elapsed().as_secs_f64());
|
||||||
|
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
|
||||||
next_batch
|
next_batch
|
||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
@ -261,7 +267,7 @@ async fn prefill(
|
|||||||
generation_health.store(false, Ordering::SeqCst);
|
generation_health.store(false, Ordering::SeqCst);
|
||||||
let _ = client.clear_cache(Some(batch_id)).await;
|
let _ = client.clear_cache(Some(batch_id)).await;
|
||||||
send_errors(err, entries);
|
send_errors(err, entries);
|
||||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -276,7 +282,7 @@ async fn decode(
|
|||||||
) -> Option<CachedBatch> {
|
) -> Option<CachedBatch> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
||||||
|
|
||||||
match client.decode(batches).await {
|
match client.decode(batches).await {
|
||||||
Ok((generations, next_batch, timings)) => {
|
Ok((generations, next_batch, timings)) => {
|
||||||
@ -291,13 +297,18 @@ async fn decode(
|
|||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
if let Some(concat_duration) = timings.concat {
|
if let Some(concat_duration) = timings.concat {
|
||||||
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
|
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
||||||
|
.record(concat_duration.as_secs_f64());
|
||||||
}
|
}
|
||||||
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode");
|
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
|
||||||
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode");
|
.record(timings.forward.as_secs_f64());
|
||||||
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode");
|
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
|
||||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
|
.record(timings.decode.as_secs_f64());
|
||||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
|
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
|
||||||
|
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
|
||||||
|
.record(start_time.elapsed().as_secs_f64());
|
||||||
|
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
|
||||||
next_batch
|
next_batch
|
||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
@ -307,7 +318,7 @@ async fn decode(
|
|||||||
let _ = client.clear_cache(Some(id)).await;
|
let _ = client.clear_cache(Some(id)).await;
|
||||||
}
|
}
|
||||||
send_errors(err, entries);
|
send_errors(err, entries);
|
||||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
|
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -365,7 +376,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||||||
// request and we need to stop generating hence why we unwrap_or(true)
|
// request and we need to stop generating hence why we unwrap_or(true)
|
||||||
let stopped = send_responses(generation, entry).map_err(|err| {
|
let stopped = send_responses(generation, entry).map_err(|err| {
|
||||||
tracing::error!("Entry response channel error.");
|
tracing::error!("Entry response channel error.");
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||||
err
|
err
|
||||||
}).unwrap_or(true);
|
}).unwrap_or(true);
|
||||||
if stopped {
|
if stopped {
|
||||||
@ -381,7 +392,7 @@ fn send_responses(
|
|||||||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||||
// Return directly if the channel is disconnected
|
// Return directly if the channel is disconnected
|
||||||
if entry.response_tx.is_closed() {
|
if entry.response_tx.is_closed() {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||||
return Ok(true);
|
return Ok(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -407,7 +418,7 @@ fn send_responses(
|
|||||||
// Create last Token
|
// Create last Token
|
||||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||||
let n = tokens_.ids.len();
|
let n = tokens_.ids.len();
|
||||||
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
|
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
|
||||||
let mut iterator = tokens_
|
let mut iterator = tokens_
|
||||||
.ids
|
.ids
|
||||||
.into_iter()
|
.into_iter()
|
||||||
@ -472,7 +483,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
|||||||
// Create and enter a span to link this function back to the entry
|
// Create and enter a span to link this function back to the entry
|
||||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||||
let err = InferError::GenerationError(error.to_string());
|
let err = InferError::GenerationError(error.to_string());
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
|
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
|
|
||||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
|
@ -126,7 +126,7 @@ async fn queue_task(
|
|||||||
match cmd {
|
match cmd {
|
||||||
QueueCommand::Append(entry, span) => {
|
QueueCommand::Append(entry, span) => {
|
||||||
span.in_scope(|| state.append(*entry));
|
span.in_scope(|| state.append(*entry));
|
||||||
metrics::increment_gauge!("tgi_queue_size", 1.0);
|
metrics::gauge!("tgi_queue_size").increment(1.0);
|
||||||
}
|
}
|
||||||
QueueCommand::NextBatch {
|
QueueCommand::NextBatch {
|
||||||
min_size,
|
min_size,
|
||||||
@ -141,7 +141,7 @@ async fn queue_task(
|
|||||||
.instrument(span)
|
.instrument(span)
|
||||||
.await;
|
.await;
|
||||||
response_sender.send(next_batch).unwrap();
|
response_sender.send(next_batch).unwrap();
|
||||||
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
|
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -248,7 +248,7 @@ impl State {
|
|||||||
// Filter entries where the response receiver was dropped (== entries where the request
|
// Filter entries where the response receiver was dropped (== entries where the request
|
||||||
// was dropped by the client)
|
// was dropped by the client)
|
||||||
if entry.response_tx.is_closed() {
|
if entry.response_tx.is_closed() {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||||
tracing::debug!("Dropping entry");
|
tracing::debug!("Dropping entry");
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -399,7 +399,7 @@ impl State {
|
|||||||
// Increment batch id
|
// Increment batch id
|
||||||
self.next_batch_id += 1;
|
self.next_batch_id += 1;
|
||||||
|
|
||||||
metrics::histogram!("tgi_batch_next_size", batch.size as f64);
|
metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
|
||||||
|
|
||||||
Some((batch_entries, batch, next_batch_span))
|
Some((batch_entries, batch, next_batch_span))
|
||||||
}
|
}
|
||||||
|
@ -154,8 +154,8 @@ pub(crate) async fn batching_task(
|
|||||||
let batch_size = batch.size;
|
let batch_size = batch.size;
|
||||||
let batch_max_tokens = batch.max_tokens;
|
let batch_max_tokens = batch.max_tokens;
|
||||||
let mut batches = vec![batch];
|
let mut batches = vec![batch];
|
||||||
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
|
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||||
metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64);
|
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||||
|
|
||||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||||
@ -176,9 +176,11 @@ pub(crate) async fn batching_task(
|
|||||||
{
|
{
|
||||||
// Tracking metrics
|
// Tracking metrics
|
||||||
if min_size.is_some() {
|
if min_size.is_some() {
|
||||||
metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure");
|
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
||||||
|
.increment(1);
|
||||||
} else {
|
} else {
|
||||||
metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded");
|
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
||||||
|
.increment(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
entries.iter_mut().for_each(|(_, entry)| {
|
entries.iter_mut().for_each(|(_, entry)| {
|
||||||
@ -225,8 +227,8 @@ pub(crate) async fn batching_task(
|
|||||||
.await;
|
.await;
|
||||||
waiting_tokens += 1;
|
waiting_tokens += 1;
|
||||||
}
|
}
|
||||||
metrics::gauge!("tgi_batch_current_size", 0.0);
|
metrics::gauge!("tgi_batch_current_size").set(0.0);
|
||||||
metrics::gauge!("tgi_batch_current_max_tokens", 0.0);
|
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -240,7 +242,7 @@ async fn prefill(
|
|||||||
) -> Option<CachedBatch> {
|
) -> Option<CachedBatch> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_id = batch.id;
|
let batch_id = batch.id;
|
||||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
|
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
||||||
|
|
||||||
match client.prefill(batch).await {
|
match client.prefill(batch).await {
|
||||||
Ok((generations, next_batch, timings)) => {
|
Ok((generations, next_batch, timings)) => {
|
||||||
@ -254,11 +256,15 @@ async fn prefill(
|
|||||||
// Filter next batch and remove requests that were stopped
|
// Filter next batch and remove requests that were stopped
|
||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
|
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
|
||||||
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
|
.record(timings.forward.as_secs_f64());
|
||||||
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill");
|
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
|
.record(timings.decode.as_secs_f64());
|
||||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
|
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
|
||||||
|
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
|
||||||
|
.record(start_time.elapsed().as_secs_f64());
|
||||||
|
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
|
||||||
next_batch
|
next_batch
|
||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
@ -267,7 +273,7 @@ async fn prefill(
|
|||||||
generation_health.store(false, Ordering::SeqCst);
|
generation_health.store(false, Ordering::SeqCst);
|
||||||
let _ = client.clear_cache(Some(batch_id)).await;
|
let _ = client.clear_cache(Some(batch_id)).await;
|
||||||
send_errors(err, entries);
|
send_errors(err, entries);
|
||||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -282,7 +288,7 @@ async fn decode(
|
|||||||
) -> Option<CachedBatch> {
|
) -> Option<CachedBatch> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
||||||
|
|
||||||
match client.decode(batches).await {
|
match client.decode(batches).await {
|
||||||
Ok((generations, next_batch, timings)) => {
|
Ok((generations, next_batch, timings)) => {
|
||||||
@ -297,13 +303,18 @@ async fn decode(
|
|||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
if let Some(concat_duration) = timings.concat {
|
if let Some(concat_duration) = timings.concat {
|
||||||
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
|
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
||||||
|
.record(concat_duration.as_secs_f64());
|
||||||
}
|
}
|
||||||
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode");
|
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
|
||||||
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode");
|
.record(timings.forward.as_secs_f64());
|
||||||
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode");
|
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
|
||||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
|
.record(timings.decode.as_secs_f64());
|
||||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
|
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
|
||||||
|
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
|
||||||
|
.record(start_time.elapsed().as_secs_f64());
|
||||||
|
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
|
||||||
next_batch
|
next_batch
|
||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
@ -313,7 +324,7 @@ async fn decode(
|
|||||||
let _ = client.clear_cache(Some(id)).await;
|
let _ = client.clear_cache(Some(id)).await;
|
||||||
}
|
}
|
||||||
send_errors(err, entries);
|
send_errors(err, entries);
|
||||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
|
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -371,7 +382,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||||||
// request and we need to stop generating hence why we unwrap_or(true)
|
// request and we need to stop generating hence why we unwrap_or(true)
|
||||||
let stopped = send_responses(generation, entry).map_err(|err| {
|
let stopped = send_responses(generation, entry).map_err(|err| {
|
||||||
tracing::error!("Entry response channel error.");
|
tracing::error!("Entry response channel error.");
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||||
err
|
err
|
||||||
}).unwrap_or(true);
|
}).unwrap_or(true);
|
||||||
if stopped {
|
if stopped {
|
||||||
@ -387,7 +398,7 @@ fn send_responses(
|
|||||||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||||
// Return directly if the channel is disconnected
|
// Return directly if the channel is disconnected
|
||||||
if entry.response_tx.is_closed() {
|
if entry.response_tx.is_closed() {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||||
return Ok(true);
|
return Ok(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -413,7 +424,7 @@ fn send_responses(
|
|||||||
// Create last Token
|
// Create last Token
|
||||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||||
let n = tokens_.ids.len();
|
let n = tokens_.ids.len();
|
||||||
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
|
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
|
||||||
let mut iterator = tokens_
|
let mut iterator = tokens_
|
||||||
.ids
|
.ids
|
||||||
.into_iter()
|
.into_iter()
|
||||||
@ -478,7 +489,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
|||||||
// Create and enter a span to link this function back to the entry
|
// Create and enter a span to link this function back to the entry
|
||||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||||
let err = InferError::GenerationError(error.to_string());
|
let err = InferError::GenerationError(error.to_string());
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
|
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
|
|
||||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
|
@ -386,7 +386,7 @@ pub struct CompletionRequest {
|
|||||||
/// UNUSED
|
/// UNUSED
|
||||||
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||||
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
||||||
pub model: String,
|
pub model: Option<String>,
|
||||||
|
|
||||||
/// The prompt to generate completions for.
|
/// The prompt to generate completions for.
|
||||||
#[schema(example = "What is Deep Learning?")]
|
#[schema(example = "What is Deep Learning?")]
|
||||||
@ -733,7 +733,7 @@ impl ChatCompletionChunk {
|
|||||||
pub(crate) struct ChatRequest {
|
pub(crate) struct ChatRequest {
|
||||||
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||||
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
||||||
pub model: String,
|
pub model: Option<String>,
|
||||||
|
|
||||||
/// A list of messages comprising the conversation so far.
|
/// A list of messages comprising the conversation so far.
|
||||||
#[schema(example = "[{\"role\": \"user\", \"content\": \"What is Deep Learning?\"}]")]
|
#[schema(example = "[{\"role\": \"user\", \"content\": \"What is Deep Learning?\"}]")]
|
||||||
@ -850,7 +850,7 @@ pub enum ToolType {
|
|||||||
Function { function: FunctionName },
|
Function { function: FunctionName },
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)]
|
||||||
pub struct FunctionName {
|
pub struct FunctionName {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
}
|
}
|
||||||
|
@ -11,10 +11,11 @@ use crate::kserve::{
|
|||||||
};
|
};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{
|
use crate::{
|
||||||
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters,
|
||||||
GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info,
|
GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig,
|
||||||
Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse,
|
HubTokenizerConfig, Info, Message, MessageChunk, MessageContent, OutputMessage, PrefillToken,
|
||||||
Usage, Validation,
|
SimpleToken, StreamDetails, StreamResponse, TextMessage, Token, TokenizeResponse,
|
||||||
|
ToolCallDelta, ToolCallMessage, Url, Usage, Validation,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
||||||
@ -185,7 +186,7 @@ pub(crate) async fn generate_internal(
|
|||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
|
|
||||||
// Do not long ultra long inputs, like image payloads.
|
// Do not long ultra long inputs, like image payloads.
|
||||||
tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]);
|
tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]);
|
||||||
@ -301,25 +302,15 @@ pub(crate) async fn generate_internal(
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Metrics
|
// Metrics
|
||||||
metrics::increment_counter!("tgi_request_success");
|
metrics::counter!("tgi_request_success").increment(1);
|
||||||
metrics::histogram!("tgi_request_duration", total_time.as_secs_f64());
|
metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64());
|
||||||
metrics::histogram!(
|
metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64());
|
||||||
"tgi_request_validation_duration",
|
metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64());
|
||||||
validation_time.as_secs_f64()
|
metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64());
|
||||||
);
|
metrics::histogram!("tgi_request_mean_time_per_token_duration")
|
||||||
metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64());
|
.record(time_per_token.as_secs_f64());
|
||||||
metrics::histogram!(
|
metrics::histogram!("tgi_request_generated_tokens")
|
||||||
"tgi_request_inference_duration",
|
.record(response.generated_text.generated_tokens as f64);
|
||||||
inference_time.as_secs_f64()
|
|
||||||
);
|
|
||||||
metrics::histogram!(
|
|
||||||
"tgi_request_mean_time_per_token_duration",
|
|
||||||
time_per_token.as_secs_f64()
|
|
||||||
);
|
|
||||||
metrics::histogram!(
|
|
||||||
"tgi_request_generated_tokens",
|
|
||||||
response.generated_text.generated_tokens as f64
|
|
||||||
);
|
|
||||||
|
|
||||||
// Send response
|
// Send response
|
||||||
let mut output_text = response.generated_text.text;
|
let mut output_text = response.generated_text.text;
|
||||||
@ -399,7 +390,7 @@ async fn generate_stream_internal(
|
|||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
|
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
|
|
||||||
tracing::debug!("Input: {}", req.inputs);
|
tracing::debug!("Input: {}", req.inputs);
|
||||||
|
|
||||||
@ -427,12 +418,12 @@ async fn generate_stream_internal(
|
|||||||
let best_of = req.parameters.best_of.unwrap_or(1);
|
let best_of = req.parameters.best_of.unwrap_or(1);
|
||||||
if best_of != 1 {
|
if best_of != 1 {
|
||||||
let err = InferError::from(ValidationError::BestOfStream);
|
let err = InferError::from(ValidationError::BestOfStream);
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
yield Ok(Event::from(err));
|
yield Ok(Event::from(err));
|
||||||
} else if req.parameters.decoder_input_details {
|
} else if req.parameters.decoder_input_details {
|
||||||
let err = InferError::from(ValidationError::PrefillDetailsStream);
|
let err = InferError::from(ValidationError::PrefillDetailsStream);
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
yield Ok(Event::from(err));
|
yield Ok(Event::from(err));
|
||||||
} else {
|
} else {
|
||||||
@ -500,13 +491,13 @@ async fn generate_stream_internal(
|
|||||||
span.record("seed", format!("{:?}", generated_text.seed));
|
span.record("seed", format!("{:?}", generated_text.seed));
|
||||||
|
|
||||||
// Metrics
|
// Metrics
|
||||||
metrics::increment_counter!("tgi_request_success");
|
metrics::counter!("tgi_request_success").increment(1);
|
||||||
metrics::histogram!("tgi_request_duration", total_time.as_secs_f64());
|
metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64());
|
||||||
metrics::histogram!("tgi_request_validation_duration", validation_time.as_secs_f64());
|
metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64());
|
||||||
metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64());
|
metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64());
|
||||||
metrics::histogram!("tgi_request_inference_duration", inference_time.as_secs_f64());
|
metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64());
|
||||||
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token.as_secs_f64());
|
metrics::histogram!("tgi_request_mean_time_per_token_duration").record(time_per_token.as_secs_f64());
|
||||||
metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64);
|
metrics::histogram!("tgi_request_generated_tokens").record(generated_text.generated_tokens as f64);
|
||||||
|
|
||||||
// StreamResponse
|
// StreamResponse
|
||||||
end_reached = true;
|
end_reached = true;
|
||||||
@ -553,7 +544,7 @@ async fn generate_stream_internal(
|
|||||||
// Skip if we already sent an error
|
// Skip if we already sent an error
|
||||||
if !end_reached && !error {
|
if !end_reached && !error {
|
||||||
let err = InferError::IncompleteGeneration;
|
let err = InferError::IncompleteGeneration;
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
|
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
yield Ok(Event::from(err));
|
yield Ok(Event::from(err));
|
||||||
}
|
}
|
||||||
@ -572,8 +563,8 @@ request_body = CompletionRequest,
|
|||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Generated Chat Completion",
|
(status = 200, description = "Generated Chat Completion",
|
||||||
content(
|
content(
|
||||||
("application/json" = Completion),
|
("application/json" = CompletionFinal),
|
||||||
("text/event-stream" = CompletionCompleteChunk),
|
("text/event-stream" = Chunk),
|
||||||
)),
|
)),
|
||||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Request failed during generation"})),
|
example = json ! ({"error": "Request failed during generation"})),
|
||||||
@ -604,9 +595,10 @@ async fn completions(
|
|||||||
Json(req): Json<CompletionRequest>,
|
Json(req): Json<CompletionRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
|
|
||||||
let CompletionRequest {
|
let CompletionRequest {
|
||||||
|
model,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
seed,
|
seed,
|
||||||
stop,
|
stop,
|
||||||
@ -625,7 +617,7 @@ async fn completions(
|
|||||||
|
|
||||||
// if suffix is present throw an error
|
// if suffix is present throw an error
|
||||||
if req.suffix.is_some() {
|
if req.suffix.is_some() {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||||
return Err((
|
return Err((
|
||||||
StatusCode::UNPROCESSABLE_ENTITY,
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
Json(ErrorResponse {
|
Json(ErrorResponse {
|
||||||
@ -637,7 +629,7 @@ async fn completions(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if req.prompt.0.len() > info.max_client_batch_size {
|
if req.prompt.0.len() > info.max_client_batch_size {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||||
return Err((
|
return Err((
|
||||||
StatusCode::UNPROCESSABLE_ENTITY,
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
Json(ErrorResponse {
|
Json(ErrorResponse {
|
||||||
@ -675,7 +667,7 @@ async fn completions(
|
|||||||
seed,
|
seed,
|
||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
grammar: None,
|
grammar: None,
|
||||||
..Default::default()
|
adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
@ -820,6 +812,10 @@ async fn completions(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let stream = stream.chain(futures::stream::once(async {
|
||||||
|
Ok(Event::default().data("[DONE]"))
|
||||||
|
}));
|
||||||
|
|
||||||
let sse = Sse::new(stream).keep_alive(KeepAlive::default());
|
let sse = Sse::new(stream).keep_alive(KeepAlive::default());
|
||||||
Ok((headers, sse).into_response())
|
Ok((headers, sse).into_response())
|
||||||
} else {
|
} else {
|
||||||
@ -1009,8 +1005,9 @@ async fn chat_completions(
|
|||||||
Json(req): Json<ChatRequest>,
|
Json(req): Json<ChatRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
let ChatRequest {
|
let ChatRequest {
|
||||||
|
model,
|
||||||
logprobs,
|
logprobs,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
messages,
|
messages,
|
||||||
@ -1039,7 +1036,7 @@ async fn chat_completions(
|
|||||||
|
|
||||||
// response_format and tools are mutually exclusive
|
// response_format and tools are mutually exclusive
|
||||||
if response_format.is_some() && tools.as_ref().is_some() {
|
if response_format.is_some() && tools.as_ref().is_some() {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||||
return Err((
|
return Err((
|
||||||
StatusCode::UNPROCESSABLE_ENTITY,
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
Json(ErrorResponse {
|
Json(ErrorResponse {
|
||||||
@ -1053,7 +1050,7 @@ async fn chat_completions(
|
|||||||
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
||||||
Ok(grammar) => grammar,
|
Ok(grammar) => grammar,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
return Err((
|
return Err((
|
||||||
StatusCode::UNPROCESSABLE_ENTITY,
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
@ -1082,7 +1079,7 @@ async fn chat_completions(
|
|||||||
let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) {
|
let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) {
|
||||||
Ok(inputs) => inputs,
|
Ok(inputs) => inputs,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
return Err((
|
return Err((
|
||||||
StatusCode::UNPROCESSABLE_ENTITY,
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
@ -1116,7 +1113,7 @@ async fn chat_completions(
|
|||||||
seed,
|
seed,
|
||||||
top_n_tokens: req.top_logprobs,
|
top_n_tokens: req.top_logprobs,
|
||||||
grammar,
|
grammar,
|
||||||
..Default::default()
|
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1178,6 +1175,11 @@ async fn chat_completions(
|
|||||||
span,
|
span,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
let response_stream = response_stream.chain(futures::stream::once(async {
|
||||||
|
Ok(Event::default().data("[DONE]"))
|
||||||
|
}));
|
||||||
|
|
||||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||||
Ok((headers, sse).into_response())
|
Ok((headers, sse).into_response())
|
||||||
} else {
|
} else {
|
||||||
@ -1280,7 +1282,7 @@ async fn vertex_compatibility(
|
|||||||
Json(req): Json<VertexRequest>,
|
Json(req): Json<VertexRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
|
|
||||||
// check that theres at least one instance
|
// check that theres at least one instance
|
||||||
if req.instances.is_empty() {
|
if req.instances.is_empty() {
|
||||||
@ -1454,6 +1456,14 @@ pub async fn run(
|
|||||||
GrammarType,
|
GrammarType,
|
||||||
ChatRequest,
|
ChatRequest,
|
||||||
Message,
|
Message,
|
||||||
|
MessageContent,
|
||||||
|
MessageChunk,
|
||||||
|
Url,
|
||||||
|
FunctionName,
|
||||||
|
OutputMessage,
|
||||||
|
TextMessage,
|
||||||
|
ToolCallMessage,
|
||||||
|
ToolCallDelta,
|
||||||
ChatCompletionComplete,
|
ChatCompletionComplete,
|
||||||
ChatCompletionChoice,
|
ChatCompletionChoice,
|
||||||
ChatCompletionDelta,
|
ChatCompletionDelta,
|
||||||
|
@ -157,7 +157,7 @@ impl Validation {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
metrics::histogram!("tgi_request_input_length", input_length as f64);
|
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
|
||||||
Ok((inputs, input_length, max_new_tokens))
|
Ok((inputs, input_length, max_new_tokens))
|
||||||
}
|
}
|
||||||
// Return inputs without validation
|
// Return inputs without validation
|
||||||
@ -384,7 +384,7 @@ impl Validation {
|
|||||||
ignore_eos_token: false,
|
ignore_eos_token: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);
|
metrics::histogram!("tgi_request_max_new_tokens").record(max_new_tokens as f64);
|
||||||
|
|
||||||
Ok(ValidGenerateRequest {
|
Ok(ValidGenerateRequest {
|
||||||
inputs,
|
inputs,
|
||||||
|
@ -35,5 +35,5 @@ run-dev:
|
|||||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
||||||
|
|
||||||
export-requirements:
|
export-requirements:
|
||||||
poetry export -o requirements_cuda.txt --without-hashes
|
poetry export -o requirements_cuda.txt --without-hashes --with cuda
|
||||||
poetry export -o requirements_rocm.txt --without-hashes
|
poetry export -o requirements_rocm.txt --without-hashes
|
||||||
|
@ -59,3 +59,18 @@ def marlin_gemm(
|
|||||||
Matrix multiplication using Marlin kernels.
|
Matrix multiplication using Marlin kernels.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
# fp8 marlin
|
||||||
|
def fp8_marlin_gemm(
|
||||||
|
a: torch.Tensor,
|
||||||
|
b_q_weight: torch.Tensor,
|
||||||
|
b_scales: torch.Tensor,
|
||||||
|
workspace: torch.Tensor,
|
||||||
|
num_bits: int,
|
||||||
|
size_m: int,
|
||||||
|
size_n: int,
|
||||||
|
size_k: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops._C.fp8_marlin_gemm(
|
||||||
|
a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
|
||||||
|
)
|
||||||
|
@ -9,4 +9,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
m.def("gptq_marlin_repack", &gptq_marlin_repack,
|
m.def("gptq_marlin_repack", &gptq_marlin_repack,
|
||||||
"Repack GPTQ parameters for Marlin");
|
"Repack GPTQ parameters for Marlin");
|
||||||
m.def("marlin_gemm", &marlin_gemm, "Marlin gemm");
|
m.def("marlin_gemm", &marlin_gemm, "Marlin gemm");
|
||||||
|
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
||||||
|
m.def("fp8_marlin_gemm", &fp8_marlin_gemm);
|
||||||
}
|
}
|
||||||
|
@ -27,4 +27,9 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|||||||
torch::Tensor &b_scales, torch::Tensor &workspace,
|
torch::Tensor &b_scales, torch::Tensor &workspace,
|
||||||
int64_t size_m, int64_t size_n, int64_t size_k);
|
int64_t size_m, int64_t size_n, int64_t size_k);
|
||||||
|
|
||||||
|
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
|
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||||
|
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||||
|
int64_t size_k);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
1308
server/marlin/marlin_kernels/fp8_marlin.cu
Normal file
1308
server/marlin/marlin_kernels/fp8_marlin.cu
Normal file
File diff suppressed because it is too large
Load Diff
@ -9,6 +9,7 @@ setup(
|
|||||||
CUDAExtension(
|
CUDAExtension(
|
||||||
name="marlin_kernels",
|
name="marlin_kernels",
|
||||||
sources=[
|
sources=[
|
||||||
|
"marlin_kernels/fp8_marlin.cu",
|
||||||
"marlin_kernels/gptq_marlin.cu",
|
"marlin_kernels/gptq_marlin.cu",
|
||||||
"marlin_kernels/gptq_marlin_repack.cu",
|
"marlin_kernels/gptq_marlin_repack.cu",
|
||||||
"marlin_kernels/marlin_cuda_kernel.cu",
|
"marlin_kernels/marlin_cuda_kernel.cu",
|
||||||
|
@ -8,6 +8,9 @@ from text_generation_server.pb import generate_pb2
|
|||||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||||
from text_generation_server.utils import weight_hub_files, download_weights
|
from text_generation_server.utils import weight_hub_files, download_weights
|
||||||
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded
|
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded
|
||||||
|
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
||||||
|
BloomForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@ -16,7 +19,10 @@ def default_bloom():
|
|||||||
revision = "main"
|
revision = "main"
|
||||||
filenames = weight_hub_files(model_id, revision, ".safetensors")
|
filenames = weight_hub_files(model_id, revision, ".safetensors")
|
||||||
download_weights(filenames, model_id, revision)
|
download_weights(filenames, model_id, revision)
|
||||||
return BLOOMSharded(model_id)
|
return BLOOMSharded(
|
||||||
|
model_id,
|
||||||
|
model_class=BloomForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
@ -10,7 +10,7 @@ from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def default_causal_lm():
|
def default_causal_lm():
|
||||||
return CausalLM("gpt2")
|
return CausalLM.fallback("gpt2")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLMBatch, CausalLM
|
||||||
from text_generation_server.models.santacoder import SantaCoder
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def default_santacoder():
|
def default_santacoder():
|
||||||
return SantaCoder("bigcode/santacoder")
|
return CausalLM.fallback(model_id="bigcode/santacoder")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -20,7 +20,7 @@ def mt0_small_tokenizer():
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def default_seq2seq_lm():
|
def default_seq2seq_lm():
|
||||||
return Seq2SeqLM("bigscience/mt0-small")
|
return Seq2SeqLM.fallback("bigscience/mt0-small")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -2,6 +2,7 @@ import torch
|
|||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.weights import DefaultWeightsLoader
|
||||||
|
|
||||||
|
|
||||||
class ProcessGroup:
|
class ProcessGroup:
|
||||||
@ -42,7 +43,12 @@ class Weights:
|
|||||||
def test_weight_hub_files_offline_error():
|
def test_weight_hub_files_offline_error():
|
||||||
|
|
||||||
vocab_size = 17
|
vocab_size = 17
|
||||||
weights = Weights(rank=0, world_size=1, vocab_size=vocab_size, hidden_dim=256)
|
weights = Weights(
|
||||||
|
rank=0,
|
||||||
|
world_size=1,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
hidden_dim=256,
|
||||||
|
)
|
||||||
embeddings = TensorParallelEmbedding("", weights)
|
embeddings = TensorParallelEmbedding("", weights)
|
||||||
|
|
||||||
input_ids = torch.arange(vocab_size)
|
input_ids = torch.arange(vocab_size)
|
||||||
|
@ -1,13 +1,47 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.utils.weights import Weights
|
from text_generation_server.utils.weights import (
|
||||||
from text_generation_server.layers.gptq import GPTQWeight
|
DefaultWeightsLoader,
|
||||||
from text_generation_server.layers.exl2 import Exl2Weight
|
Weights,
|
||||||
from text_generation_server.layers.marlin import MarlinWeight
|
WeightsLoader,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.gptq import GPTQWeight, GPTQWeightsLoader
|
||||||
|
from text_generation_server.layers.exl2 import Exl2Weight, Exl2WeightsLoader
|
||||||
|
from text_generation_server.layers.marlin import MarlinWeight, MarlinWeightsLoader
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import List, Optional, Dict, Union
|
from typing import List, Optional, Dict, Union
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def gptq_weights_loader():
|
||||||
|
return GPTQWeightsLoader(
|
||||||
|
bits=4,
|
||||||
|
groupsize=-1,
|
||||||
|
desc_act=False,
|
||||||
|
quant_method="gptq",
|
||||||
|
quantize="gptq",
|
||||||
|
sym=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def gptq_weights_loader_awq():
|
||||||
|
return GPTQWeightsLoader(
|
||||||
|
bits=4,
|
||||||
|
groupsize=-1,
|
||||||
|
desc_act=False,
|
||||||
|
quant_method="awq",
|
||||||
|
quantize="awq",
|
||||||
|
sym=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def marlin_weights_loader():
|
||||||
|
return MarlinWeightsLoader(bits=4, is_marlin_24=False)
|
||||||
|
|
||||||
|
|
||||||
dummy_file_system = {
|
dummy_file_system = {
|
||||||
"test_weights": {
|
"test_weights": {
|
||||||
"layer.0.weight": torch.tensor(
|
"layer.0.weight": torch.tensor(
|
||||||
@ -58,7 +92,7 @@ dummy_file_system = {
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
"test_get_multi_weights_row": {
|
"test_get_weights_row": {
|
||||||
"weight.weight": torch.tensor(
|
"weight.weight": torch.tensor(
|
||||||
[
|
[
|
||||||
[1, 2],
|
[1, 2],
|
||||||
@ -101,7 +135,7 @@ dummy_file_system = {
|
|||||||
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
||||||
"weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
|
"weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
|
||||||
},
|
},
|
||||||
"test_get_multi_weights_row_gptq": {
|
"test_get_weights_row_gptq": {
|
||||||
"weight.qweight": torch.tensor(
|
"weight.qweight": torch.tensor(
|
||||||
[
|
[
|
||||||
[1, 2],
|
[1, 2],
|
||||||
@ -200,7 +234,7 @@ dummy_file_system = {
|
|||||||
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
|
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
|
||||||
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
|
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
|
||||||
},
|
},
|
||||||
"test_get_multi_weights_row_exl2": {
|
"test_get_weights_row_exl2": {
|
||||||
"weight.q_weight": torch.tensor(
|
"weight.q_weight": torch.tensor(
|
||||||
[
|
[
|
||||||
[1, 2],
|
[1, 2],
|
||||||
@ -245,7 +279,7 @@ dummy_file_system = {
|
|||||||
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
|
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
|
||||||
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
|
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
|
||||||
},
|
},
|
||||||
"test_get_multi_weights_row_marlin": {
|
"test_get_weights_row_marlin": {
|
||||||
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
||||||
"weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
|
"weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
|
||||||
},
|
},
|
||||||
@ -308,6 +342,7 @@ class MockWeights(Weights):
|
|||||||
dummy_fs,
|
dummy_fs,
|
||||||
aliases: Optional[Dict[str, List[str]]] = None,
|
aliases: Optional[Dict[str, List[str]]] = None,
|
||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
|
weights_loader: Optional[WeightsLoader] = None,
|
||||||
):
|
):
|
||||||
routing = {}
|
routing = {}
|
||||||
self.dummy_fs = dummy_fs
|
self.dummy_fs = dummy_fs
|
||||||
@ -327,6 +362,9 @@ class MockWeights(Weights):
|
|||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
|
self.weights_loader = (
|
||||||
|
DefaultWeightsLoader() if weights_loader is None else weights_loader
|
||||||
|
)
|
||||||
self._handles = {}
|
self._handles = {}
|
||||||
|
|
||||||
def _get_handle(self, filename: Union[Path, str]):
|
def _get_handle(self, filename: Union[Path, str]):
|
||||||
@ -412,12 +450,10 @@ def test_get_weights_col_packed():
|
|||||||
)
|
)
|
||||||
|
|
||||||
prefix = "weight"
|
prefix = "weight"
|
||||||
quantize = None
|
|
||||||
block_sizes = 1
|
block_sizes = 1
|
||||||
|
|
||||||
w = weights.get_weights_col_packed(
|
w = weights.get_weights_col_packed(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
quantize=quantize,
|
|
||||||
block_sizes=block_sizes,
|
block_sizes=block_sizes,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -448,12 +484,10 @@ def test_get_weights_col_packed_block_size():
|
|||||||
)
|
)
|
||||||
|
|
||||||
prefix = "weight"
|
prefix = "weight"
|
||||||
quantize = None
|
|
||||||
block_sizes = 2
|
block_sizes = 2
|
||||||
|
|
||||||
w = weights.get_weights_col_packed(
|
w = weights.get_weights_col_packed(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
quantize=quantize,
|
|
||||||
block_sizes=block_sizes,
|
block_sizes=block_sizes,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -484,12 +518,10 @@ def test_get_weights_col_packed_block_size_arr():
|
|||||||
)
|
)
|
||||||
|
|
||||||
prefix = "weight"
|
prefix = "weight"
|
||||||
quantize = None
|
|
||||||
block_sizes = [1, 1]
|
block_sizes = [1, 1]
|
||||||
|
|
||||||
w = weights.get_weights_col_packed(
|
w = weights.get_weights_col_packed(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
quantize=quantize,
|
|
||||||
block_sizes=block_sizes,
|
block_sizes=block_sizes,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -519,11 +551,9 @@ def test_get_multi_weights_col():
|
|||||||
)
|
)
|
||||||
|
|
||||||
prefixes = ["weight", "weight"]
|
prefixes = ["weight", "weight"]
|
||||||
quantize = None
|
|
||||||
|
|
||||||
w = weights.get_multi_weights_col(
|
w = weights.get_multi_weights_col(
|
||||||
prefixes=prefixes,
|
prefixes=prefixes,
|
||||||
quantize=quantize,
|
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -545,10 +575,10 @@ def test_get_multi_weights_col():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_multi_weights_row():
|
def test_get_weights_row():
|
||||||
weights = MockWeights(
|
weights = MockWeights(
|
||||||
[
|
[
|
||||||
"test_get_multi_weights_row",
|
"test_get_weights_row",
|
||||||
],
|
],
|
||||||
device="cpu",
|
device="cpu",
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
@ -557,11 +587,9 @@ def test_get_multi_weights_row():
|
|||||||
)
|
)
|
||||||
|
|
||||||
prefix = "weight"
|
prefix = "weight"
|
||||||
quantize = None
|
|
||||||
|
|
||||||
w = weights.get_multi_weights_row(
|
w = weights.get_weights_row(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
quantize=quantize,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
@ -576,7 +604,7 @@ def test_get_multi_weights_row():
|
|||||||
# test_get_weights_col
|
# test_get_weights_col
|
||||||
|
|
||||||
|
|
||||||
def test_get_weights_col_awq():
|
def test_get_weights_col_awq(gptq_weights_loader_awq):
|
||||||
weights = MockWeights(
|
weights = MockWeights(
|
||||||
[
|
[
|
||||||
"test_get_weights_col_gptq",
|
"test_get_weights_col_gptq",
|
||||||
@ -585,14 +613,13 @@ def test_get_weights_col_awq():
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
process_group=dummy_process_group,
|
process_group=dummy_process_group,
|
||||||
dummy_fs=dummy_file_system,
|
dummy_fs=dummy_file_system,
|
||||||
|
weights_loader=gptq_weights_loader_awq,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefix = "weight"
|
prefix = "weight"
|
||||||
quantize = "awq"
|
|
||||||
|
|
||||||
w = weights.get_weights_col(
|
w = weights.get_weights_col(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
quantize=quantize,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_weight = GPTQWeight(
|
expected_weight = GPTQWeight(
|
||||||
@ -617,7 +644,7 @@ def test_get_weights_col_awq():
|
|||||||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||||
|
|
||||||
|
|
||||||
def test_get_weights_col_gtpq():
|
def test_get_weights_col_gtpq(gptq_weights_loader):
|
||||||
weights = MockWeights(
|
weights = MockWeights(
|
||||||
[
|
[
|
||||||
"test_get_weights_col_gptq",
|
"test_get_weights_col_gptq",
|
||||||
@ -626,14 +653,13 @@ def test_get_weights_col_gtpq():
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
process_group=dummy_process_group,
|
process_group=dummy_process_group,
|
||||||
dummy_fs=dummy_file_system,
|
dummy_fs=dummy_file_system,
|
||||||
|
weights_loader=gptq_weights_loader,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefix = "weight"
|
prefix = "weight"
|
||||||
quantize = "gptq"
|
|
||||||
|
|
||||||
w = weights.get_weights_col(
|
w = weights.get_weights_col(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
quantize=quantize,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_weight = GPTQWeight(
|
expected_weight = GPTQWeight(
|
||||||
@ -664,14 +690,13 @@ def test_get_weights_col_exl2():
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
process_group=dummy_process_group,
|
process_group=dummy_process_group,
|
||||||
dummy_fs=dummy_file_system,
|
dummy_fs=dummy_file_system,
|
||||||
|
weights_loader=Exl2WeightsLoader(),
|
||||||
)
|
)
|
||||||
|
|
||||||
prefix = "weight"
|
prefix = "weight"
|
||||||
quantize = "exl2"
|
|
||||||
|
|
||||||
w = weights.get_weights_col(
|
w = weights.get_weights_col(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
quantize=quantize,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
scaled_scale_max = 0.3906 * 256
|
scaled_scale_max = 0.3906 * 256
|
||||||
@ -692,7 +717,7 @@ def test_get_weights_col_exl2():
|
|||||||
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
|
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
|
||||||
|
|
||||||
|
|
||||||
def test_get_weights_col_marlin():
|
def test_get_weights_col_marlin(marlin_weights_loader):
|
||||||
weights = MockWeights(
|
weights = MockWeights(
|
||||||
[
|
[
|
||||||
"test_get_weights_col_marlin",
|
"test_get_weights_col_marlin",
|
||||||
@ -701,14 +726,13 @@ def test_get_weights_col_marlin():
|
|||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
process_group=dummy_process_group,
|
process_group=dummy_process_group,
|
||||||
dummy_fs=dummy_file_system,
|
dummy_fs=dummy_file_system,
|
||||||
|
weights_loader=marlin_weights_loader,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefix = "weight"
|
prefix = "weight"
|
||||||
quantize = "marlin"
|
|
||||||
|
|
||||||
w = weights.get_weights_col(
|
w = weights.get_weights_col(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
quantize=quantize,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_weight = MarlinWeight(
|
expected_weight = MarlinWeight(
|
||||||
@ -723,7 +747,7 @@ def test_get_weights_col_marlin():
|
|||||||
# test_get_weights_col_packed
|
# test_get_weights_col_packed
|
||||||
|
|
||||||
|
|
||||||
def test_get_weights_col_packed_awq():
|
def test_get_weights_col_packed_awq(gptq_weights_loader_awq):
|
||||||
weights = MockWeights(
|
weights = MockWeights(
|
||||||
[
|
[
|
||||||
"test_get_weights_col_packed_gptq",
|
"test_get_weights_col_packed_gptq",
|
||||||
@ -732,15 +756,14 @@ def test_get_weights_col_packed_awq():
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
process_group=dummy_process_group,
|
process_group=dummy_process_group,
|
||||||
dummy_fs=dummy_file_system,
|
dummy_fs=dummy_file_system,
|
||||||
|
weights_loader=gptq_weights_loader_awq,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefix = "weight"
|
prefix = "weight"
|
||||||
quantize = "awq"
|
|
||||||
block_sizes = 1
|
block_sizes = 1
|
||||||
|
|
||||||
w = weights.get_weights_col_packed(
|
w = weights.get_weights_col_packed(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
quantize=quantize,
|
|
||||||
block_sizes=block_sizes,
|
block_sizes=block_sizes,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -773,15 +796,14 @@ def test_get_weights_col_packed_exl2():
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
process_group=dummy_process_group,
|
process_group=dummy_process_group,
|
||||||
dummy_fs=dummy_file_system,
|
dummy_fs=dummy_file_system,
|
||||||
|
weights_loader=Exl2WeightsLoader(),
|
||||||
)
|
)
|
||||||
|
|
||||||
prefix = "weight"
|
prefix = "weight"
|
||||||
quantize = "exl2"
|
|
||||||
block_sizes = 1
|
block_sizes = 1
|
||||||
|
|
||||||
w = weights.get_weights_col_packed(
|
w = weights.get_weights_col_packed(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
quantize=quantize,
|
|
||||||
block_sizes=block_sizes,
|
block_sizes=block_sizes,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -803,7 +825,7 @@ def test_get_weights_col_packed_exl2():
|
|||||||
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
|
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
|
||||||
|
|
||||||
|
|
||||||
def test_get_weights_col_packed_gptq():
|
def test_get_weights_col_packed_gptq(gptq_weights_loader):
|
||||||
weights = MockWeights(
|
weights = MockWeights(
|
||||||
[
|
[
|
||||||
"test_get_weights_col_packed_gptq",
|
"test_get_weights_col_packed_gptq",
|
||||||
@ -812,14 +834,13 @@ def test_get_weights_col_packed_gptq():
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
process_group=dummy_process_group,
|
process_group=dummy_process_group,
|
||||||
dummy_fs=dummy_file_system,
|
dummy_fs=dummy_file_system,
|
||||||
|
weights_loader=gptq_weights_loader,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefixes = ["weight"]
|
prefixes = ["weight"]
|
||||||
quantize = "gptq"
|
|
||||||
|
|
||||||
w = weights.get_multi_weights_col(
|
w = weights.get_multi_weights_col(
|
||||||
prefixes=prefixes,
|
prefixes=prefixes,
|
||||||
quantize=quantize,
|
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -842,7 +863,7 @@ def test_get_weights_col_packed_gptq():
|
|||||||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||||
|
|
||||||
|
|
||||||
def test_get_weights_col_packed_marlin():
|
def test_get_weights_col_packed_marlin(marlin_weights_loader):
|
||||||
weights = MockWeights(
|
weights = MockWeights(
|
||||||
[
|
[
|
||||||
"test_get_weights_col_packed_marlin",
|
"test_get_weights_col_packed_marlin",
|
||||||
@ -851,14 +872,13 @@ def test_get_weights_col_packed_marlin():
|
|||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
process_group=dummy_process_group,
|
process_group=dummy_process_group,
|
||||||
dummy_fs=dummy_file_system,
|
dummy_fs=dummy_file_system,
|
||||||
|
weights_loader=marlin_weights_loader,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefix = "weight"
|
prefix = "weight"
|
||||||
quantize = "marlin"
|
|
||||||
|
|
||||||
w = weights.get_multi_weights_col(
|
w = weights.get_multi_weights_col(
|
||||||
prefixes=[prefix],
|
prefixes=[prefix],
|
||||||
quantize=quantize,
|
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -876,7 +896,7 @@ def test_get_weights_col_packed_marlin():
|
|||||||
# test_get_multi_weights_col
|
# test_get_multi_weights_col
|
||||||
|
|
||||||
|
|
||||||
def test_get_multi_weights_col_awq():
|
def test_get_multi_weights_col_awq(gptq_weights_loader_awq):
|
||||||
weights = MockWeights(
|
weights = MockWeights(
|
||||||
[
|
[
|
||||||
"test_get_multi_weights_col_gptq",
|
"test_get_multi_weights_col_gptq",
|
||||||
@ -885,14 +905,13 @@ def test_get_multi_weights_col_awq():
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
process_group=dummy_process_group,
|
process_group=dummy_process_group,
|
||||||
dummy_fs=dummy_file_system,
|
dummy_fs=dummy_file_system,
|
||||||
|
weights_loader=gptq_weights_loader_awq,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefixes = ["weight"]
|
prefixes = ["weight"]
|
||||||
quantize = "awq"
|
|
||||||
|
|
||||||
w = weights.get_multi_weights_col(
|
w = weights.get_multi_weights_col(
|
||||||
prefixes=prefixes,
|
prefixes=prefixes,
|
||||||
quantize=quantize,
|
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -924,22 +943,21 @@ def test_get_multi_weights_col_exl2():
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
process_group=dummy_process_group,
|
process_group=dummy_process_group,
|
||||||
dummy_fs=dummy_file_system,
|
dummy_fs=dummy_file_system,
|
||||||
|
weights_loader=Exl2WeightsLoader(),
|
||||||
)
|
)
|
||||||
|
|
||||||
prefix = "weight"
|
prefix = "weight"
|
||||||
quantize = "exl2"
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
w = weights.get_multi_weights_col(
|
w = weights.get_multi_weights_col(
|
||||||
prefixes=[prefix],
|
prefixes=[prefix],
|
||||||
quantize=quantize,
|
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
assert e.args[0] == "get_multi_weights_col is not supported for exl2"
|
assert e.args[0] == "get_multi_weights_col is not supported for exl2"
|
||||||
|
|
||||||
|
|
||||||
def test_get_multi_weights_col_gptq():
|
def test_get_multi_weights_col_gptq(gptq_weights_loader):
|
||||||
weights = MockWeights(
|
weights = MockWeights(
|
||||||
[
|
[
|
||||||
"test_get_multi_weights_col_gptq",
|
"test_get_multi_weights_col_gptq",
|
||||||
@ -948,14 +966,13 @@ def test_get_multi_weights_col_gptq():
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
process_group=dummy_process_group,
|
process_group=dummy_process_group,
|
||||||
dummy_fs=dummy_file_system,
|
dummy_fs=dummy_file_system,
|
||||||
|
weights_loader=gptq_weights_loader,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefixes = ["weight"]
|
prefixes = ["weight"]
|
||||||
quantize = "gptq"
|
|
||||||
|
|
||||||
w = weights.get_multi_weights_col(
|
w = weights.get_multi_weights_col(
|
||||||
prefixes=prefixes,
|
prefixes=prefixes,
|
||||||
quantize=quantize,
|
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -978,7 +995,7 @@ def test_get_multi_weights_col_gptq():
|
|||||||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||||
|
|
||||||
|
|
||||||
def test_get_multi_weights_col_marlin():
|
def test_get_multi_weights_col_marlin(marlin_weights_loader):
|
||||||
weights = MockWeights(
|
weights = MockWeights(
|
||||||
[
|
[
|
||||||
"test_get_multi_weights_col_marlin",
|
"test_get_multi_weights_col_marlin",
|
||||||
@ -987,14 +1004,13 @@ def test_get_multi_weights_col_marlin():
|
|||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
process_group=dummy_process_group,
|
process_group=dummy_process_group,
|
||||||
dummy_fs=dummy_file_system,
|
dummy_fs=dummy_file_system,
|
||||||
|
weights_loader=marlin_weights_loader,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefix = "weight"
|
prefix = "weight"
|
||||||
quantize = "marlin"
|
|
||||||
|
|
||||||
w = weights.get_multi_weights_col(
|
w = weights.get_multi_weights_col(
|
||||||
prefixes=[prefix],
|
prefixes=[prefix],
|
||||||
quantize=quantize,
|
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1007,26 +1023,25 @@ def test_get_multi_weights_col_marlin():
|
|||||||
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
|
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
|
||||||
|
|
||||||
|
|
||||||
# test_get_multi_weights_row
|
# test_get_weights_row
|
||||||
|
|
||||||
|
|
||||||
def test_get_multi_weights_row_awq():
|
def test_get_weights_row_awq(gptq_weights_loader_awq):
|
||||||
weights = MockWeights(
|
weights = MockWeights(
|
||||||
[
|
[
|
||||||
"test_get_multi_weights_row_gptq",
|
"test_get_weights_row_gptq",
|
||||||
],
|
],
|
||||||
device="cpu",
|
device="cpu",
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
process_group=dummy_process_group,
|
process_group=dummy_process_group,
|
||||||
dummy_fs=dummy_file_system,
|
dummy_fs=dummy_file_system,
|
||||||
|
weights_loader=gptq_weights_loader_awq,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefix = "weight"
|
prefix = "weight"
|
||||||
quantize = "awq"
|
|
||||||
|
|
||||||
w = weights.get_multi_weights_row(
|
w = weights.get_weights_row(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
quantize=quantize,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_weight = GPTQWeight(
|
expected_weight = GPTQWeight(
|
||||||
@ -1048,23 +1063,22 @@ def test_get_multi_weights_row_awq():
|
|||||||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||||
|
|
||||||
|
|
||||||
def test_get_multi_weights_row_exl2():
|
def test_get_weights_row_exl2():
|
||||||
weights = MockWeights(
|
weights = MockWeights(
|
||||||
[
|
[
|
||||||
"test_get_multi_weights_row_exl2",
|
"test_get_weights_row_exl2",
|
||||||
],
|
],
|
||||||
device="cpu",
|
device="cpu",
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
process_group=dummy_process_group,
|
process_group=dummy_process_group,
|
||||||
dummy_fs=dummy_file_system,
|
dummy_fs=dummy_file_system,
|
||||||
|
weights_loader=Exl2WeightsLoader(),
|
||||||
)
|
)
|
||||||
|
|
||||||
prefix = "weight"
|
prefix = "weight"
|
||||||
quantize = "exl2"
|
|
||||||
|
|
||||||
w = weights.get_multi_weights_row(
|
w = weights.get_weights_row(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
quantize=quantize,
|
|
||||||
)
|
)
|
||||||
print(w)
|
print(w)
|
||||||
|
|
||||||
@ -1086,23 +1100,22 @@ def test_get_multi_weights_row_exl2():
|
|||||||
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
|
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
|
||||||
|
|
||||||
|
|
||||||
def test_get_multi_weights_row_gptq():
|
def test_get_weights_row_gptq(gptq_weights_loader):
|
||||||
weights = MockWeights(
|
weights = MockWeights(
|
||||||
[
|
[
|
||||||
"test_get_multi_weights_row_gptq",
|
"test_get_weights_row_gptq",
|
||||||
],
|
],
|
||||||
device="cpu",
|
device="cpu",
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
process_group=dummy_process_group,
|
process_group=dummy_process_group,
|
||||||
dummy_fs=dummy_file_system,
|
dummy_fs=dummy_file_system,
|
||||||
|
weights_loader=gptq_weights_loader,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefix = "weight"
|
prefix = "weight"
|
||||||
quantize = "gptq"
|
|
||||||
|
|
||||||
w = weights.get_multi_weights_row(
|
w = weights.get_weights_row(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
quantize=quantize,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_weight = GPTQWeight(
|
expected_weight = GPTQWeight(
|
||||||
@ -1124,23 +1137,22 @@ def test_get_multi_weights_row_gptq():
|
|||||||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||||
|
|
||||||
|
|
||||||
def test_get_multi_weights_row_marlin():
|
def test_get_weights_row_marlin(marlin_weights_loader):
|
||||||
weights = MockWeights(
|
weights = MockWeights(
|
||||||
[
|
[
|
||||||
"test_get_multi_weights_row_marlin",
|
"test_get_weights_row_marlin",
|
||||||
],
|
],
|
||||||
device="cpu",
|
device="cpu",
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
process_group=dummy_process_group,
|
process_group=dummy_process_group,
|
||||||
dummy_fs=dummy_file_system,
|
dummy_fs=dummy_file_system,
|
||||||
|
weights_loader=marlin_weights_loader,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefix = "weight"
|
prefix = "weight"
|
||||||
quantize = "marlin"
|
|
||||||
|
|
||||||
w = weights.get_multi_weights_row(
|
w = weights.get_weights_row(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
quantize=quantize,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_weight = MarlinWeight(
|
expected_weight = MarlinWeight(
|
||||||
|
@ -353,6 +353,7 @@ def quantize(
|
|||||||
upload_to_model_id=upload_to_model_id,
|
upload_to_model_id=upload_to_model_id,
|
||||||
percdamp=percdamp,
|
percdamp=percdamp,
|
||||||
act_order=act_order,
|
act_order=act_order,
|
||||||
|
sym=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from typing import List, Union
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from text_generation_server.utils.weights import WeightsLoader, Weights
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Exl2Weight:
|
class Exl2Weight:
|
||||||
@ -21,3 +24,60 @@ class Exl2Weight:
|
|||||||
@property
|
@property
|
||||||
def device(self) -> torch.device:
|
def device(self) -> torch.device:
|
||||||
return self.q_weight.device
|
return self.q_weight.device
|
||||||
|
|
||||||
|
|
||||||
|
class Exl2WeightsLoader(WeightsLoader):
|
||||||
|
"""Loader for exl2-quantized weights."""
|
||||||
|
|
||||||
|
def get_weights_col_packed(
|
||||||
|
self,
|
||||||
|
weights: Weights,
|
||||||
|
prefix: str,
|
||||||
|
block_sizes: Union[int, List[int]],
|
||||||
|
):
|
||||||
|
raise RuntimeError("Column-packed weights are not supported for exl")
|
||||||
|
|
||||||
|
def get_weights_col(self, weights: Weights, prefix: str):
|
||||||
|
try:
|
||||||
|
q_weight = weights.get_tensor(f"{prefix}.q_weight")
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
|
||||||
|
)
|
||||||
|
|
||||||
|
q_scale = weights.get_tensor(f"{prefix}.q_scale")
|
||||||
|
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
|
||||||
|
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
|
||||||
|
q_groups = weights.get_tensor(f"{prefix}.q_groups")
|
||||||
|
|
||||||
|
return Exl2Weight(
|
||||||
|
q_weight=q_weight,
|
||||||
|
q_scale=q_scale,
|
||||||
|
q_invperm=q_invperm,
|
||||||
|
q_scale_max=q_scale_max,
|
||||||
|
q_groups=q_groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||||
|
raise ValueError("get_multi_weights_col is not supported for exl2")
|
||||||
|
|
||||||
|
def get_weights_row(self, weights: Weights, prefix: str):
|
||||||
|
try:
|
||||||
|
q_weight = weights.get_tensor(f"{prefix}.q_weight")
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
|
||||||
|
)
|
||||||
|
|
||||||
|
q_scale = weights.get_tensor(f"{prefix}.q_scale")
|
||||||
|
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
|
||||||
|
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
|
||||||
|
q_groups = weights.get_tensor(f"{prefix}.q_groups")
|
||||||
|
|
||||||
|
return Exl2Weight(
|
||||||
|
q_weight=q_weight,
|
||||||
|
q_scale=q_scale,
|
||||||
|
q_invperm=q_invperm,
|
||||||
|
q_scale_max=q_scale_max,
|
||||||
|
q_groups=q_groups,
|
||||||
|
)
|
||||||
|
@ -1,4 +1,23 @@
|
|||||||
|
from enum import Enum, auto
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
|
def get_fp8_linear() -> torch.nn.Module:
|
||||||
|
"""
|
||||||
|
Return an FP8 linear `Module` that is compatible with the current system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if SYSTEM == "cuda":
|
||||||
|
major, minor = torch.cuda.get_device_capability()
|
||||||
|
if major == 8 and minor < 9:
|
||||||
|
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
|
||||||
|
|
||||||
|
return GPTQMarlinFP8Linear
|
||||||
|
|
||||||
|
# On other systems let Torch decide if the hardware supports FP8.
|
||||||
|
return Fp8Linear
|
||||||
|
|
||||||
|
|
||||||
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
|
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
|
||||||
|
@ -1,20 +1,14 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from loguru import logger
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import List, Optional, Union
|
||||||
|
from safetensors import SafetensorError
|
||||||
|
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.utils.import_utils import (
|
from text_generation_server.utils.import_utils import (
|
||||||
SYSTEM,
|
SYSTEM,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.log import log_once
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GPTQParams:
|
|
||||||
bits: int
|
|
||||||
checkpoint_format: Optional[str]
|
|
||||||
groupsize: int
|
|
||||||
desc_act: bool
|
|
||||||
quant_method: str
|
|
||||||
sym: bool
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -69,3 +63,345 @@ elif CAN_EXLLAMA:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
from text_generation_server.layers.gptq.quant_linear import QuantLinear
|
from text_generation_server.layers.gptq.quant_linear import QuantLinear
|
||||||
|
|
||||||
|
|
||||||
|
class GPTQWeightsLoader(WeightsLoader):
|
||||||
|
"""
|
||||||
|
Loader for GPTQ- and AWQ-quantized weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
bits: int,
|
||||||
|
desc_act: bool,
|
||||||
|
groupsize: int,
|
||||||
|
quant_method: str,
|
||||||
|
quantize: str,
|
||||||
|
sym: bool,
|
||||||
|
):
|
||||||
|
self.bits = bits
|
||||||
|
self.desc_act = desc_act
|
||||||
|
self.groupsize = groupsize
|
||||||
|
self.quant_method = quant_method
|
||||||
|
self.quantize = quantize
|
||||||
|
self.sym = sym
|
||||||
|
|
||||||
|
def get_weights_col_packed(
|
||||||
|
self,
|
||||||
|
weights: Weights,
|
||||||
|
prefix: str,
|
||||||
|
block_sizes: Union[int, List[int]],
|
||||||
|
):
|
||||||
|
from text_generation_server.layers.marlin import (
|
||||||
|
can_use_gptq_marlin,
|
||||||
|
repack_gptq_for_marlin,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
qweight = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
|
||||||
|
)
|
||||||
|
scales = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
scales = scales.to(dtype=weights.dtype)
|
||||||
|
|
||||||
|
self._get_gptq_params(weights)
|
||||||
|
if can_use_gptq_marlin(
|
||||||
|
bits=self.bits,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
quant_method=self.quant_method,
|
||||||
|
quantize=self.quantize,
|
||||||
|
sym=self.sym,
|
||||||
|
):
|
||||||
|
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||||
|
return repack_gptq_for_marlin(
|
||||||
|
qweight=qweight,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
desc_act=self.desc_act,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
sym=self.sym,
|
||||||
|
sharded_infeatures=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
qzeros = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
if self.quantize == "gptq" and self.quant_method == "gptq":
|
||||||
|
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||||
|
elif self.quantize == "gptq" and self.quant_method == "awq":
|
||||||
|
log_once(
|
||||||
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.awq.conversion_utils import (
|
||||||
|
fast_awq_to_gptq,
|
||||||
|
)
|
||||||
|
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
g_idx = (
|
||||||
|
torch.arange(
|
||||||
|
qweight.shape[0] * (32 // self.bits),
|
||||||
|
device=qweight.device,
|
||||||
|
)
|
||||||
|
// self.groupsize
|
||||||
|
).to(dtype=torch.int32)
|
||||||
|
else:
|
||||||
|
g_idx = None
|
||||||
|
|
||||||
|
return GPTQWeight(
|
||||||
|
qweight=qweight,
|
||||||
|
qzeros=qzeros,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
use_exllama=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||||
|
from text_generation_server.layers.marlin import (
|
||||||
|
can_use_gptq_marlin,
|
||||||
|
repack_gptq_for_marlin,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
qweight = torch.cat(
|
||||||
|
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
|
||||||
|
)
|
||||||
|
|
||||||
|
scales = torch.cat(
|
||||||
|
[weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
self._get_gptq_params(weights)
|
||||||
|
if can_use_gptq_marlin(
|
||||||
|
bits=self.bits,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
quant_method=self.quant_method,
|
||||||
|
quantize=self.quantize,
|
||||||
|
sym=self.sym,
|
||||||
|
):
|
||||||
|
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||||
|
for w2 in w[1:]:
|
||||||
|
torch.testing.assert_close(w2, w[0])
|
||||||
|
g_idx = w[0]
|
||||||
|
|
||||||
|
return repack_gptq_for_marlin(
|
||||||
|
qweight=qweight,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
desc_act=self.desc_act,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
sym=self.sym,
|
||||||
|
sharded_infeatures=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
qzeros = torch.cat(
|
||||||
|
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
from text_generation_server.layers.gptq import HAS_EXLLAMA
|
||||||
|
|
||||||
|
use_exllama = (
|
||||||
|
self.bits == 4
|
||||||
|
and HAS_EXLLAMA
|
||||||
|
and self.quantize == "gptq"
|
||||||
|
and not self.desc_act
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.quantize == "gptq" and self.quant_method == "gptq":
|
||||||
|
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||||
|
for w2 in w[1:]:
|
||||||
|
torch.testing.assert_close(w2, w[0])
|
||||||
|
g_idx = w[0]
|
||||||
|
elif self.quantize == "gptq" and self.quant_method == "awq":
|
||||||
|
log_once(
|
||||||
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.awq.conversion_utils import (
|
||||||
|
fast_awq_to_gptq,
|
||||||
|
)
|
||||||
|
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
if use_exllama:
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
g_idx = (
|
||||||
|
torch.arange(
|
||||||
|
qweight.shape[0] * (32 // self.bits),
|
||||||
|
device=qweight.device,
|
||||||
|
)
|
||||||
|
// self.groupsize
|
||||||
|
).to(dtype=torch.int32)
|
||||||
|
else:
|
||||||
|
g_idx = None
|
||||||
|
|
||||||
|
return GPTQWeight(
|
||||||
|
qweight=qweight,
|
||||||
|
qzeros=qzeros,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
use_exllama=use_exllama,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weights_row(self, weights: Weights, prefix: str):
|
||||||
|
from text_generation_server.layers.marlin import (
|
||||||
|
can_use_gptq_marlin,
|
||||||
|
repack_gptq_for_marlin,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._get_gptq_params(weights)
|
||||||
|
if can_use_gptq_marlin(
|
||||||
|
bits=self.bits,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
quant_method=self.quant_method,
|
||||||
|
quantize=self.quantize,
|
||||||
|
sym=self.sym,
|
||||||
|
):
|
||||||
|
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||||
|
try:
|
||||||
|
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||||
|
)
|
||||||
|
|
||||||
|
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||||
|
if self.desc_act or self.groupsize == -1:
|
||||||
|
scales = weights.get_tensor(f"{prefix}.scales")
|
||||||
|
else:
|
||||||
|
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
||||||
|
|
||||||
|
sharded_in_features = weights.process_group.size() > 1
|
||||||
|
|
||||||
|
return repack_gptq_for_marlin(
|
||||||
|
qweight=qweight,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
desc_act=self.desc_act,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
sym=self.sym,
|
||||||
|
sharded_infeatures=sharded_in_features,
|
||||||
|
)
|
||||||
|
|
||||||
|
use_exllama = True
|
||||||
|
if self.bits != 4:
|
||||||
|
use_exllama = False
|
||||||
|
|
||||||
|
if self.desc_act:
|
||||||
|
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||||
|
use_exllama = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.quantize == "gptq" and self.quant_method == "gptq":
|
||||||
|
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||||
|
else:
|
||||||
|
g_idx = None
|
||||||
|
|
||||||
|
if weights.process_group.size() > 1:
|
||||||
|
if g_idx is not None:
|
||||||
|
if (
|
||||||
|
not torch.equal(
|
||||||
|
g_idx.cpu(),
|
||||||
|
torch.tensor(
|
||||||
|
[i // self.groupsize for i in range(g_idx.shape[0])],
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and not (g_idx == 0).all()
|
||||||
|
):
|
||||||
|
# Exllama implementation does not support row tensor parallelism with act-order, as
|
||||||
|
# it would require to reorder input activations that are split unto several GPUs
|
||||||
|
use_exllama = False
|
||||||
|
|
||||||
|
from text_generation_server.layers.gptq import (
|
||||||
|
HAS_EXLLAMA,
|
||||||
|
CAN_EXLLAMA,
|
||||||
|
GPTQWeight,
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_exllama:
|
||||||
|
if not HAS_EXLLAMA:
|
||||||
|
if CAN_EXLLAMA:
|
||||||
|
log_once(
|
||||||
|
logger.warning,
|
||||||
|
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
|
||||||
|
)
|
||||||
|
use_exllama = False
|
||||||
|
else:
|
||||||
|
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
|
||||||
|
|
||||||
|
if use_exllama and self.groupsize != -1:
|
||||||
|
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||||
|
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
||||||
|
else:
|
||||||
|
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||||
|
scales = weights.get_tensor(f"{prefix}.scales")
|
||||||
|
|
||||||
|
if use_exllama and g_idx is not None:
|
||||||
|
g_idx = g_idx - g_idx[0]
|
||||||
|
|
||||||
|
if self.quantize == "gptq" and self.quant_method == "awq":
|
||||||
|
log_once(
|
||||||
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.awq.conversion_utils import (
|
||||||
|
fast_awq_to_gptq,
|
||||||
|
)
|
||||||
|
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
if use_exllama:
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
g_idx = (
|
||||||
|
torch.arange(
|
||||||
|
qweight.shape[0] * (32 // self.bits),
|
||||||
|
device=qweight.device,
|
||||||
|
)
|
||||||
|
// self.groupsize
|
||||||
|
).to(dtype=torch.int32)
|
||||||
|
|
||||||
|
return GPTQWeight(
|
||||||
|
qweight=qweight,
|
||||||
|
qzeros=qzeros,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
use_exllama=use_exllama,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_gptq_params(self, weights: Weights):
|
||||||
|
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
|
||||||
|
self.bits = weights.get_tensor("gptq_bits").item()
|
||||||
|
self.groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||||
|
self.desc_act = False
|
||||||
|
# `server quantize` used asymmetric quantization unconditionally
|
||||||
|
# before the `gptq_sym` setting tensor was added.
|
||||||
|
self.sym = (
|
||||||
|
weights.get_tensor("gptq_sym").item()
|
||||||
|
if weights._has_tensor("gptq_sym")
|
||||||
|
else False
|
||||||
|
)
|
||||||
|
self.quant_method = "gptq"
|
||||||
|
@ -16,6 +16,8 @@ from text_generation_server.layers.gptq.quant_linear import QuantLinear
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from text_generation_server.utils.weights import DefaultWeightsLoader
|
||||||
|
|
||||||
DEV = torch.device("cuda:0")
|
DEV = torch.device("cuda:0")
|
||||||
|
|
||||||
|
|
||||||
@ -869,6 +871,7 @@ def quantize(
|
|||||||
upload_to_model_id: Optional[str],
|
upload_to_model_id: Optional[str],
|
||||||
percdamp: float,
|
percdamp: float,
|
||||||
act_order: bool,
|
act_order: bool,
|
||||||
|
sym: bool,
|
||||||
):
|
):
|
||||||
print("loading model")
|
print("loading model")
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
@ -891,6 +894,7 @@ def quantize(
|
|||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
aliases={"embed_tokens.weight": ["lm_head.weight"]},
|
aliases={"embed_tokens.weight": ["lm_head.weight"]},
|
||||||
|
weights_loader=DefaultWeightsLoader(),
|
||||||
)
|
)
|
||||||
hooks = []
|
hooks = []
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
@ -943,6 +947,7 @@ def quantize(
|
|||||||
percdamp=percdamp,
|
percdamp=percdamp,
|
||||||
act_order=act_order,
|
act_order=act_order,
|
||||||
hooks=hooks,
|
hooks=hooks,
|
||||||
|
sym=sym,
|
||||||
)
|
)
|
||||||
print(time.time() - tick)
|
print(time.time() - tick)
|
||||||
|
|
||||||
@ -954,6 +959,7 @@ def quantize(
|
|||||||
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
||||||
state_dict["gptq_bits"] = torch.LongTensor([bits])
|
state_dict["gptq_bits"] = torch.LongTensor([bits])
|
||||||
state_dict["gptq_groupsize"] = torch.LongTensor([groupsize])
|
state_dict["gptq_groupsize"] = torch.LongTensor([groupsize])
|
||||||
|
state_dict["gptq_sym"] = torch.BoolTensor([sym])
|
||||||
|
|
||||||
max_shard_size = "10GB"
|
max_shard_size = "10GB"
|
||||||
shards, index = shard_checkpoint(
|
shards, index = shard_checkpoint(
|
||||||
|
@ -106,9 +106,9 @@ def get_linear(weight, bias, quantize):
|
|||||||
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
|
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
|
||||||
)
|
)
|
||||||
elif quantize == "fp8":
|
elif quantize == "fp8":
|
||||||
from text_generation_server.layers.fp8 import Fp8Linear
|
from text_generation_server.layers.fp8 import get_fp8_linear
|
||||||
|
|
||||||
linear = Fp8Linear(weight, bias)
|
linear = get_fp8_linear()(weight, bias)
|
||||||
elif quantize == "bitsandbytes":
|
elif quantize == "bitsandbytes":
|
||||||
try:
|
try:
|
||||||
from text_generation_server.layers.bnb import (
|
from text_generation_server.layers.bnb import (
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from loguru import logger
|
||||||
from text_generation_server.layers.gptq import GPTQParams
|
from text_generation_server.layers.fp8 import fp8_quantize
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.log import log_once
|
||||||
|
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import marlin_kernels
|
import marlin_kernels
|
||||||
@ -24,16 +26,132 @@ GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
|
|||||||
MARLIN_TILE_SIZE = 16
|
MARLIN_TILE_SIZE = 16
|
||||||
|
|
||||||
|
|
||||||
def can_use_gptq_marlin(gptq_params: GPTQParams, quantize: str) -> bool:
|
class MarlinWeightsLoader(WeightsLoader):
|
||||||
|
"""Loader for Marlin-quantized weights."""
|
||||||
|
|
||||||
|
def __init__(self, *, bits: int, is_marlin_24: bool):
|
||||||
|
self.bits = bits
|
||||||
|
self.is_marlin_24 = is_marlin_24
|
||||||
|
|
||||||
|
def get_weights_col_packed(
|
||||||
|
self,
|
||||||
|
weights: Weights,
|
||||||
|
prefix: str,
|
||||||
|
block_sizes: Union[int, List[int]],
|
||||||
|
):
|
||||||
|
if self.is_marlin_24:
|
||||||
|
B = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.B_24", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
B_meta = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.B_meta", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
s = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.s", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
|
||||||
|
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
||||||
|
else:
|
||||||
|
B = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.B", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
s = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.s", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
weight = MarlinWeight(B=B, s=s)
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||||
|
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||||
|
if is_marlin_24:
|
||||||
|
try:
|
||||||
|
B = torch.cat(
|
||||||
|
[weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load `marlin` weight, make sure the model is already quantized"
|
||||||
|
)
|
||||||
|
|
||||||
|
B_meta = torch.cat(
|
||||||
|
[weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
s = torch.cat(
|
||||||
|
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
B = torch.cat(
|
||||||
|
[weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load `marlin` weight, make sure the model is already quantized"
|
||||||
|
)
|
||||||
|
s = torch.cat(
|
||||||
|
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
weight = MarlinWeight(B=B, s=s)
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def get_weights_row(self, weights: Weights, prefix: str):
|
||||||
|
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||||
|
if is_marlin_24:
|
||||||
|
try:
|
||||||
|
B = weights.get_sharded(f"{prefix}.B_24", dim=0)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
|
||||||
|
)
|
||||||
|
|
||||||
|
B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0)
|
||||||
|
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
|
||||||
|
if num_groups == 1:
|
||||||
|
# The number of groups is 1 when groupsize == -1. share
|
||||||
|
# scales between all shards in this case.
|
||||||
|
s = weights.get_tensor(f"{prefix}.s")
|
||||||
|
else:
|
||||||
|
s = weights.get_sharded(f"{prefix}.s", dim=0)
|
||||||
|
|
||||||
|
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
B = weights.get_sharded(f"{prefix}.B", dim=0)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot load `marlin` weight, make sure the model is already quantized."
|
||||||
|
)
|
||||||
|
|
||||||
|
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
|
||||||
|
if num_groups == 1:
|
||||||
|
# The number of groups is 1 when groupsize == -1. share
|
||||||
|
# scales between all shards in this case.
|
||||||
|
s = weights.get_tensor(f"{prefix}.s")
|
||||||
|
else:
|
||||||
|
s = weights.get_sharded(f"{prefix}.s", dim=0)
|
||||||
|
weight = MarlinWeight(B=B, s=s)
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
|
||||||
|
def can_use_gptq_marlin(
|
||||||
|
*, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool
|
||||||
|
) -> bool:
|
||||||
return (
|
return (
|
||||||
SYSTEM == "cuda"
|
SYSTEM == "cuda"
|
||||||
and marlin_kernels is not None
|
and marlin_kernels is not None
|
||||||
and has_sm_8_0
|
and has_sm_8_0
|
||||||
and quantize == "gptq"
|
and quantize == "gptq"
|
||||||
and gptq_params.quant_method == "gptq"
|
and quant_method == "gptq"
|
||||||
and gptq_params.bits in GPTQ_MARLIN_BITS
|
and bits in GPTQ_MARLIN_BITS
|
||||||
and gptq_params.groupsize in GPTQ_MARLIN_GROUP_SIZES
|
and groupsize in GPTQ_MARLIN_GROUP_SIZES
|
||||||
and gptq_params.sym
|
and sym
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -339,6 +457,115 @@ class GPTQMarlin24Linear(nn.Module):
|
|||||||
return C
|
return C
|
||||||
|
|
||||||
|
|
||||||
|
class GPTQMarlinFP8Linear(nn.Module):
|
||||||
|
"""
|
||||||
|
FP8 GPTQ-Marlin linear layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor],
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
_check_marlin_kernels()
|
||||||
|
assert marlin_kernels is not None
|
||||||
|
|
||||||
|
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
|
||||||
|
|
||||||
|
qweight, scale = fp8_quantize(weight)
|
||||||
|
scale = scale.to(torch.float16)
|
||||||
|
qweight, scales = repack_fp8_for_marlin(qweight, scale)
|
||||||
|
|
||||||
|
in_features = qweight.shape[0] * MARLIN_TILE_SIZE
|
||||||
|
out_features = scales.shape[1]
|
||||||
|
_check_valid_shape(in_features=in_features, out_features=out_features)
|
||||||
|
|
||||||
|
self.qweight = qweight
|
||||||
|
self.scales = scales
|
||||||
|
self.bias = bias if bias is not None else None
|
||||||
|
|
||||||
|
self.workspace = torch.zeros(
|
||||||
|
out_features // 64 * 16, dtype=torch.int, device=qweight.device
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert marlin_kernels is not None
|
||||||
|
|
||||||
|
A_flat = A.view(-1, A.shape[-1])
|
||||||
|
C = marlin_kernels.fp8_marlin_gemm(
|
||||||
|
A_flat,
|
||||||
|
self.qweight,
|
||||||
|
self.scales,
|
||||||
|
self.workspace,
|
||||||
|
8,
|
||||||
|
A_flat.shape[0],
|
||||||
|
self.scales.shape[1],
|
||||||
|
A_flat.shape[1],
|
||||||
|
)
|
||||||
|
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
|
||||||
|
|
||||||
|
if self.bias is not None:
|
||||||
|
C += self.bias
|
||||||
|
|
||||||
|
return C
|
||||||
|
|
||||||
|
|
||||||
|
def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Repack FP8 weights to gptq format (packed int32 elements).
|
||||||
|
"""
|
||||||
|
assert fp8_tensor.dtype == torch.float8_e4m3fn
|
||||||
|
|
||||||
|
if fp8_tensor.shape[0] % 4 != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape to prepare for packing
|
||||||
|
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
|
||||||
|
|
||||||
|
# Convert fp8 to uint8 (byte) representation
|
||||||
|
byte_tensor = reshaped.view(torch.uint8)
|
||||||
|
|
||||||
|
# Pack 4 uint8 values into one int32
|
||||||
|
packed = torch.zeros(
|
||||||
|
fp8_tensor.shape[0] // 4,
|
||||||
|
fp8_tensor.shape[1],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=fp8_tensor.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(4):
|
||||||
|
packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8)
|
||||||
|
|
||||||
|
return packed
|
||||||
|
|
||||||
|
|
||||||
|
def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Repack FP8 tensor for GPTQ-Marlin.
|
||||||
|
"""
|
||||||
|
|
||||||
|
out_features, in_features = weight.shape
|
||||||
|
|
||||||
|
# Torch linear layers weights with shape [out_features, in_features],
|
||||||
|
# GPTQ-quantized weights use [in_feateres/pack_factor, in_features],
|
||||||
|
# so transpose before packing.
|
||||||
|
qweight = pack_fp8_as_int32(weight.t())
|
||||||
|
|
||||||
|
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
|
||||||
|
repacked = marlin_kernels.gptq_marlin_repack(
|
||||||
|
qweight, perm, in_features, out_features, 8
|
||||||
|
)
|
||||||
|
|
||||||
|
scales = scale.reshape(1, 1).repeat(1, out_features)
|
||||||
|
scales = permute_scales(scales)
|
||||||
|
|
||||||
|
return repacked, scales
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MarlinWeight:
|
class MarlinWeight:
|
||||||
"""
|
"""
|
||||||
|
@ -102,7 +102,7 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
max_position_embeddings=rope_scaling[
|
max_position_embeddings=rope_scaling[
|
||||||
"original_max_position_embeddings"
|
"original_max_position_embeddings"
|
||||||
],
|
],
|
||||||
base=10000.0,
|
base=base,
|
||||||
device=inv_freq.device,
|
device=inv_freq.device,
|
||||||
scaling_factor=scaling_factor,
|
scaling_factor=scaling_factor,
|
||||||
extrapolation_factor=1,
|
extrapolation_factor=1,
|
||||||
@ -110,7 +110,7 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
beta_fast=32,
|
beta_fast=32,
|
||||||
beta_slow=1,
|
beta_slow=1,
|
||||||
)
|
)
|
||||||
elif rope_scaling["type"] == "su":
|
elif rope_scaling["type"] in ["su", "longrope"]:
|
||||||
short_factor = torch.tensor(
|
short_factor = torch.tensor(
|
||||||
rope_scaling["short_factor"], dtype=torch.float32, device=device
|
rope_scaling["short_factor"], dtype=torch.float32, device=device
|
||||||
)
|
)
|
||||||
|
@ -52,7 +52,7 @@ class TensorParallelHead(SuperLayer):
|
|||||||
weight = weights.get_tensor(f"{prefix}.weight")
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
except:
|
except:
|
||||||
# ...otherwise they are quantized.
|
# ...otherwise they are quantized.
|
||||||
weight = weights.get_weights_col(prefix, config.quantize)
|
weight = weights.get_weights_col(prefix)
|
||||||
should_gather = weights.process_group.size() > 1
|
should_gather = weights.process_group.size() > 1
|
||||||
elif weights.process_group.size() > 1:
|
elif weights.process_group.size() > 1:
|
||||||
try:
|
try:
|
||||||
@ -129,9 +129,7 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def load_gate_up(cls, config, prefix: str, weights, bias: bool):
|
def load_gate_up(cls, config, prefix: str, weights, bias: bool):
|
||||||
"""Specific method when the QKV was joined after the fact"""
|
"""Specific method when the QKV was joined after the fact"""
|
||||||
weight = weights.get_weights_col_packed_gate_up(
|
weight = weights.get_weights_col_packed_gate_up(prefix)
|
||||||
prefix, quantize=config.quantize
|
|
||||||
)
|
|
||||||
if bias:
|
if bias:
|
||||||
raise NotImplementedError("packed_gate_up only implemented without bias")
|
raise NotImplementedError("packed_gate_up only implemented without bias")
|
||||||
else:
|
else:
|
||||||
@ -152,7 +150,6 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||||||
"""Specific method when the QKV was joined after the fact"""
|
"""Specific method when the QKV was joined after the fact"""
|
||||||
weight = weights.get_weights_col_packed_qkv(
|
weight = weights.get_weights_col_packed_qkv(
|
||||||
prefix,
|
prefix,
|
||||||
quantize=config.quantize,
|
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
num_key_value_heads=num_key_value_heads,
|
num_key_value_heads=num_key_value_heads,
|
||||||
)
|
)
|
||||||
@ -165,7 +162,7 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, config, prefix: str, weights, bias: bool):
|
def load(cls, config, prefix: str, weights, bias: bool):
|
||||||
weight = weights.get_weights_col(prefix, config.quantize)
|
weight = weights.get_weights_col(prefix)
|
||||||
if bias:
|
if bias:
|
||||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||||
else:
|
else:
|
||||||
@ -178,14 +175,12 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||||||
if config.quantize == "exl2":
|
if config.quantize == "exl2":
|
||||||
linears = []
|
linears = []
|
||||||
for prefix in prefixes:
|
for prefix in prefixes:
|
||||||
weight = weights.get_weights_col(prefix, config.quantize)
|
weight = weights.get_weights_col(prefix)
|
||||||
b = weights.get_tensor(f"{prefix}.bias") if bias else None
|
b = weights.get_tensor(f"{prefix}.bias") if bias else None
|
||||||
linears.append(get_linear(weight, b, config.quantize))
|
linears.append(get_linear(weight, b, config.quantize))
|
||||||
linear = LayerConcat(linears)
|
linear = LayerConcat(linears)
|
||||||
else:
|
else:
|
||||||
weight = weights.get_multi_weights_col(
|
weight = weights.get_multi_weights_col(prefixes, dim=dim)
|
||||||
prefixes, quantize=config.quantize, dim=dim
|
|
||||||
)
|
|
||||||
if bias:
|
if bias:
|
||||||
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
|
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
|
||||||
bias = torch.cat(b, dim=dim)
|
bias = torch.cat(b, dim=dim)
|
||||||
@ -202,7 +197,7 @@ class TensorParallelRowLinear(SuperLayer):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, config, prefix: str, weights, bias: bool):
|
def load(cls, config, prefix: str, weights, bias: bool):
|
||||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
weight = weights.get_weights_row(prefix)
|
||||||
|
|
||||||
if bias and weights.process_group.rank() == 0:
|
if bias and weights.process_group.rank() == 0:
|
||||||
# Rank is only on the first rank process
|
# Rank is only on the first rank process
|
||||||
|
@ -11,17 +11,27 @@ from pathlib import Path
|
|||||||
|
|
||||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
from text_generation_server.models.causal_lm import CausalLM
|
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
|
||||||
from text_generation_server.models.bloom import BLOOMSharded
|
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
|
||||||
from text_generation_server.models.mpt import MPTSharded
|
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
||||||
|
MPTForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.bloom import BloomCausalLMBatch
|
||||||
|
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
||||||
|
BloomForCausalLM,
|
||||||
|
)
|
||||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||||
from text_generation_server.models.rw import RW
|
from text_generation_server.models.galactica import GalacticaCausalLMBatch
|
||||||
from text_generation_server.models.opt import OPTSharded
|
from text_generation_server.models.custom_modeling.neox_modeling import (
|
||||||
from text_generation_server.models.galactica import GalacticaSharded
|
GPTNeoxForCausalLM,
|
||||||
from text_generation_server.models.santacoder import SantaCoder
|
)
|
||||||
from text_generation_server.models.t5 import T5Sharded
|
from text_generation_server.models.custom_modeling.phi_modeling import (
|
||||||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
PhiConfig,
|
||||||
from text_generation_server.models.phi import Phi
|
PhiForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.t5_modeling import (
|
||||||
|
T5ForConditionalGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
@ -41,9 +51,6 @@ __all__ = [
|
|||||||
"CausalLM",
|
"CausalLM",
|
||||||
"GalacticaSharded",
|
"GalacticaSharded",
|
||||||
"Seq2SeqLM",
|
"Seq2SeqLM",
|
||||||
"SantaCoder",
|
|
||||||
"OPTSharded",
|
|
||||||
"T5Sharded",
|
|
||||||
"get_model",
|
"get_model",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -53,38 +60,65 @@ FLASH_ATTENTION = True
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||||
from text_generation_server.models.flash_gpt2 import FlashGPT2
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
FlashLlamaForCausalLM,
|
||||||
from text_generation_server.models.flash_llama import (
|
|
||||||
FlashLlama,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.models.flash_qwen2 import (
|
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
|
||||||
FlashQwen2,
|
FlashCohereForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.flash_cohere import (
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||||
FlashCohere,
|
FlashGemmaForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.flash_gemma import (
|
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
||||||
FlashGemma,
|
FlashGemma2ForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.flash_gemma2 import (
|
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
|
||||||
FlashGemma2,
|
FlashDbrxForCausalLM,
|
||||||
|
DbrxConfig,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
|
||||||
|
RWConfig,
|
||||||
|
FlashRWForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||||
|
FlashGPTNeoXForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.pali_gemma import (
|
from text_generation_server.models.pali_gemma import (
|
||||||
PaliGemma,
|
PaliGemmaBatch,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.flash_santacoder import (
|
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
||||||
FlashSantacoderSharded,
|
PaliGemmaForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
||||||
|
FlashPhiForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.idefics import IDEFICSSharded
|
from text_generation_server.models.idefics import IDEFICSSharded
|
||||||
from text_generation_server.models.llava_next import LlavaNext
|
from text_generation_server.models.custom_modeling.llava_next import (
|
||||||
from text_generation_server.models.idefics2 import Idefics2
|
LlavaNextForConditionalGeneration,
|
||||||
from text_generation_server.models.flash_mistral import FlashMistral
|
)
|
||||||
from text_generation_server.models.flash_mixtral import FlashMixtral
|
|
||||||
from text_generation_server.models.flash_phi import FlashPhi
|
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
|
||||||
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
|
FlashSantacoderForCausalLM,
|
||||||
from text_generation_server.models.flash_dbrx import FlashDbrx
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
|
||||||
|
FlashStarcoder2ForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||||
|
Qwen2ForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||||
|
FlashMistralForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
|
||||||
|
FlashMixtralForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
|
||||||
|
FlashGPT2ForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.idefics2 import (
|
||||||
|
Idefics2ForConditionalGeneration,
|
||||||
|
)
|
||||||
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
||||||
@ -93,21 +127,7 @@ except ImportError as e:
|
|||||||
|
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
__all__.append(FlashCausalLM)
|
__all__.append(FlashCausalLM)
|
||||||
__all__.append(FlashGPT2)
|
|
||||||
__all__.append(FlashNeoXSharded)
|
|
||||||
__all__.append(FlashRWSharded)
|
|
||||||
__all__.append(FlashSantacoderSharded)
|
|
||||||
__all__.append(FlashLlama)
|
|
||||||
__all__.append(IDEFICSSharded)
|
__all__.append(IDEFICSSharded)
|
||||||
__all__.append(FlashMistral)
|
|
||||||
__all__.append(FlashMixtral)
|
|
||||||
__all__.append(FlashDbrx)
|
|
||||||
__all__.append(FlashPhi)
|
|
||||||
__all__.append(FlashQwen2)
|
|
||||||
__all__.append(FlashStarcoder2)
|
|
||||||
__all__.append(FlashGemma)
|
|
||||||
__all__.append(FlashGemma2)
|
|
||||||
__all__.append(FlashCohere)
|
|
||||||
|
|
||||||
MAMBA_AVAILABLE = True
|
MAMBA_AVAILABLE = True
|
||||||
try:
|
try:
|
||||||
@ -148,6 +168,11 @@ class ModelType(enum.Enum):
|
|||||||
"name": "Gemma",
|
"name": "Gemma",
|
||||||
"url": "https://huggingface.co/google/gemma-7b",
|
"url": "https://huggingface.co/google/gemma-7b",
|
||||||
}
|
}
|
||||||
|
PALIGEMMA = {
|
||||||
|
"type": "paligemma",
|
||||||
|
"name": "PaliGemma",
|
||||||
|
"url": "https://huggingface.co/google/paligemma-3b-pt-224",
|
||||||
|
}
|
||||||
GEMMA2 = {
|
GEMMA2 = {
|
||||||
"type": "gemma2",
|
"type": "gemma2",
|
||||||
"name": "Gemma2",
|
"name": "Gemma2",
|
||||||
@ -445,13 +470,16 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_id.startswith("facebook/galactica"):
|
if model_id.startswith("facebook/galactica"):
|
||||||
return GalacticaSharded(
|
return CausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
# Yes galactica is just an OPT model.
|
||||||
|
model_class=OPTForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
batch_class=GalacticaCausalLMBatch,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -460,22 +488,26 @@ def get_model(
|
|||||||
and model_id.startswith("bigcode/")
|
and model_id.startswith("bigcode/")
|
||||||
):
|
):
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashSantacoderSharded(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashSantacoderForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
||||||
|
num_kv_heads=1,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return SantaCoder(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
@ -483,38 +515,44 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == BLOOM:
|
if model_type == BLOOM:
|
||||||
return BLOOMSharded(
|
return CausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=BloomForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
batch_class=BloomCausalLMBatch,
|
||||||
)
|
)
|
||||||
elif model_type == MPT:
|
elif model_type == MPT:
|
||||||
return MPTSharded(
|
return CausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=MPTForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
batch_class=CausalLMBatchKeysLast,
|
||||||
)
|
)
|
||||||
elif model_type == GPT2:
|
elif model_type == GPT2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
try:
|
try:
|
||||||
return FlashGPT2(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashGPT2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
# Lots of legacy models with various weight names.
|
# Lots of legacy models with various weight names.
|
||||||
logger.warning(f"Couldn't load flash gpt2 variant: {e}")
|
logger.warning(f"Couldn't load flash gpt2 variant: {e}")
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -525,7 +563,7 @@ def get_model(
|
|||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -535,25 +573,28 @@ def get_model(
|
|||||||
)
|
)
|
||||||
elif model_type == GPT_NEOX:
|
elif model_type == GPT_NEOX:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashNeoXSharded(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashGPTNeoXForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
return GPTNeoxSharded(
|
return CausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=GPTNeoxForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -564,16 +605,18 @@ def get_model(
|
|||||||
|
|
||||||
elif model_type == PHI:
|
elif model_type == PHI:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashPhi(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashPhiForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -588,9 +631,11 @@ def get_model(
|
|||||||
"Legacy phi-msft is not supported with Flash Attention"
|
"Legacy phi-msft is not supported with Flash Attention"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return Phi(
|
return CausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=PhiForCausalLM,
|
||||||
|
config_class=PhiConfig,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
@ -599,9 +644,10 @@ def get_model(
|
|||||||
|
|
||||||
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
|
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashLlama(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashLlamaForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
@ -611,7 +657,7 @@ def get_model(
|
|||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -621,18 +667,22 @@ def get_model(
|
|||||||
)
|
)
|
||||||
if model_type == GEMMA:
|
if model_type == GEMMA:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashGemma(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashGemmaForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
# Works better for these models
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -642,18 +692,22 @@ def get_model(
|
|||||||
)
|
)
|
||||||
elif model_type == GEMMA2:
|
elif model_type == GEMMA2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashGemma2(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashGemma2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
# Works better for these models
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -664,18 +718,20 @@ def get_model(
|
|||||||
|
|
||||||
if model_type == COHERE:
|
if model_type == COHERE:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCohere(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashCohereForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -686,18 +742,23 @@ def get_model(
|
|||||||
|
|
||||||
if model_type == DBRX:
|
if model_type == DBRX:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashDbrx(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashDbrxForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
# Dbrx works better in bfloat16.
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=DbrxConfig,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -711,27 +772,41 @@ def get_model(
|
|||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
if config_dict.get("alibi", False):
|
if config_dict.get("alibi", False):
|
||||||
raise NotImplementedError("sharded is not supported for this model")
|
raise NotImplementedError("sharded is not supported for this model")
|
||||||
return FlashRWSharded(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashRWForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
aliases={
|
||||||
|
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
||||||
|
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
||||||
|
},
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=RWConfig,
|
||||||
)
|
)
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
|
||||||
else:
|
else:
|
||||||
if FLASH_ATTENTION and not config_dict.get("alibi", False):
|
if FLASH_ATTENTION and not config_dict.get("alibi", False):
|
||||||
return FlashRWSharded(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashRWForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
aliases={
|
||||||
|
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
||||||
|
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
||||||
|
},
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=RWConfig,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return RW(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -742,18 +817,20 @@ def get_model(
|
|||||||
|
|
||||||
if model_type == MISTRAL:
|
if model_type == MISTRAL:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashMistral(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashMistralForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -764,18 +841,20 @@ def get_model(
|
|||||||
|
|
||||||
if model_type == MIXTRAL:
|
if model_type == MIXTRAL:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashMixtral(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashMixtralForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -786,19 +865,22 @@ def get_model(
|
|||||||
|
|
||||||
if model_type == STARCODER2:
|
if model_type == STARCODER2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashStarcoder2(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=FlashStarcoder2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
|
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -809,17 +891,20 @@ def get_model(
|
|||||||
|
|
||||||
if model_type == QWEN2:
|
if model_type == QWEN2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashQwen2(
|
return FlashCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=Qwen2ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -829,9 +914,10 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == OPT:
|
if model_type == OPT:
|
||||||
return OPTSharded(
|
return CausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=OPTForCausalLM,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
@ -839,13 +925,20 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == T5:
|
if model_type == T5:
|
||||||
return T5Sharded(
|
return Seq2SeqLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=T5ForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
aliases={
|
||||||
|
"shared.weight": [
|
||||||
|
"encoder.embed_tokens.weight",
|
||||||
|
"decoder.embed_tokens.weight",
|
||||||
|
]
|
||||||
|
},
|
||||||
)
|
)
|
||||||
if model_type == IDEFICS:
|
if model_type == IDEFICS:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
@ -861,34 +954,45 @@ def get_model(
|
|||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
if model_type == IDEFICS2:
|
if model_type == IDEFICS2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return Idefics2(
|
return VlmCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=Idefics2ForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
# XXX: Extremely important to cap resolution in order to limit
|
||||||
|
# VRAM usage.
|
||||||
|
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
if model_type == "paligemma":
|
if model_type == PALIGEMMA:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return PaliGemma(
|
return VlmCausalLM(
|
||||||
model_id,
|
model_id=model_id,
|
||||||
revision,
|
model_class=PaliGemmaForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
# Works better for these models
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
batch_class=PaliGemmaBatch,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
|
|
||||||
if model_type == LLAVA_NEXT:
|
if model_type == LLAVA_NEXT:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return LlavaNext(
|
return VlmCausalLM(
|
||||||
model_id,
|
model_class=LlavaNextForConditionalGeneration,
|
||||||
revision,
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
@ -912,7 +1016,7 @@ def get_model(
|
|||||||
elif quantize == "exl2":
|
elif quantize == "exl2":
|
||||||
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
|
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
|
||||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -921,7 +1025,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
||||||
return Seq2SeqLM(
|
return Seq2SeqLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -933,7 +1037,7 @@ def get_model(
|
|||||||
auto_map = config_dict.get("auto_map", None)
|
auto_map = config_dict.get("auto_map", None)
|
||||||
if trust_remote_code and auto_map is not None:
|
if trust_remote_code and auto_map is not None:
|
||||||
if "AutoModelForCausalLM" in auto_map.keys():
|
if "AutoModelForCausalLM" in auto_map.keys():
|
||||||
return CausalLM(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -942,7 +1046,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if "AutoModelForSeq2SeqLM" in auto_map.keys():
|
if "AutoModelForSeq2SeqLM" in auto_map.keys():
|
||||||
return Seq2SeqLM(
|
return Seq2SeqLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
@ -4,22 +4,12 @@ import torch.distributed
|
|||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoTokenizer,
|
|
||||||
AutoConfig,
|
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
|
||||||
BloomForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.models import CausalLM
|
from text_generation_server.models import CausalLM
|
||||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BloomCausalLMBatch(CausalLMBatch):
|
class BloomCausalLMBatch(CausalLMBatch):
|
||||||
@ -37,69 +27,6 @@ class BloomCausalLMBatch(CausalLMBatch):
|
|||||||
|
|
||||||
|
|
||||||
class BLOOMSharded(CausalLM):
|
class BLOOMSharded(CausalLM):
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
slow_but_exact=False,
|
|
||||||
tp_parallel=True,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
config.pad_token_id = 3
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(
|
|
||||||
filenames,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
process_group=self.process_group,
|
|
||||||
prefix="transformer",
|
|
||||||
)
|
|
||||||
if config.quantize in ["gptq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = BloomForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(CausalLM, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
requires_padding=True,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[CausalLMBatch]:
|
def batch_type(self) -> Type[CausalLMBatch]:
|
||||||
return BloomCausalLMBatch
|
return BloomCausalLMBatch
|
||||||
|
@ -1,13 +1,26 @@
|
|||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
|
AutoTokenizer,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
)
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
|
from text_generation_server.utils import (
|
||||||
|
initialize_torch_distributed,
|
||||||
|
weight_files,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation_server.utils.chunks import concat_text_chunks
|
from text_generation_server.utils.chunks import concat_text_chunks
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.quantization import get_loader
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
Batch,
|
Batch,
|
||||||
@ -478,10 +491,93 @@ class CausalLMBatch(Batch):
|
|||||||
return len(self.requests)
|
return len(self.requests)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CausalLMBatchKeysLast(Batch):
|
||||||
|
keys_head_dim_last: bool = False
|
||||||
|
|
||||||
|
|
||||||
class CausalLM(Model):
|
class CausalLM(Model):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
model_class,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
speculator: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
default_dtype=torch.float16,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
tokenizer_class=AutoTokenizer,
|
||||||
|
config_class=AutoConfig,
|
||||||
|
batch_class=CausalLMBatch,
|
||||||
|
):
|
||||||
|
self.batch_class = batch_class
|
||||||
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
dtype = default_dtype if dtype is None else dtype
|
||||||
|
elif SYSTEM == "ipex":
|
||||||
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = default_dtype if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
# Float16 doesn't exist on target.
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
|
||||||
|
tokenizer = tokenizer_class.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = config_class.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
config.quantize = quantize
|
||||||
|
config.speculator = speculator
|
||||||
|
if tokenizer.pad_token_id is None:
|
||||||
|
tokenizer.pad_token_id = config.pad_token_id
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
weights_loader = get_loader(
|
||||||
|
quantize=quantize, model_id=model_id, revision=revision
|
||||||
|
)
|
||||||
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
weights = Weights(
|
||||||
|
filenames,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
process_group=self.process_group,
|
||||||
|
weights_loader=weights_loader,
|
||||||
|
)
|
||||||
|
|
||||||
|
prefix = ""
|
||||||
|
model = model_class(prefix, config, weights)
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
super().__init__(
|
||||||
|
model_id=model_id,
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fallback(
|
||||||
|
cls,
|
||||||
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
speculator: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
@ -537,7 +633,12 @@ class CausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
self = cls.__new__(
|
||||||
|
cls,
|
||||||
|
)
|
||||||
|
self.batch_class = CausalLMBatch
|
||||||
|
super().__init__(
|
||||||
|
self,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -545,15 +646,11 @@ class CausalLM(Model):
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[CausalLMBatch]:
|
def batch_type(self) -> Type[CausalLMBatch]:
|
||||||
return CausalLMBatch
|
return self.batch_class
|
||||||
|
|
||||||
def decode(self, generated_ids: List[int]) -> str:
|
|
||||||
return self.tokenizer.decode(
|
|
||||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
|
@ -816,7 +816,7 @@ class BloomModel(BloomPreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class BloomForCausalLM(BloomPreTrainedModel):
|
class BloomForCausalLM(BloomPreTrainedModel):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.transformer = BloomModel(config, weights)
|
self.transformer = BloomModel(config, weights)
|
||||||
|
|
||||||
|
@ -446,7 +446,7 @@ class CLIPEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class CLIPTextTransformer(nn.Module):
|
class CLIPTextTransformer(nn.Module):
|
||||||
def __init__(self, config: CLIPTextConfig):
|
def __init__(self, prefix: str, config: CLIPTextConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
embed_dim = config.hidden_size
|
embed_dim = config.hidden_size
|
||||||
@ -536,9 +536,9 @@ class CLIPTextModel(CLIPPreTrainedModel):
|
|||||||
|
|
||||||
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
|
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
|
||||||
|
|
||||||
def __init__(self, config: CLIPTextConfig):
|
def __init__(self, prefix, config: CLIPTextConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.text_model = CLIPTextTransformer(config)
|
self.text_model = CLIPTextTransformer(prefix, config)
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
|
@ -163,7 +163,6 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
|
|
||||||
weight = weights.get_multi_weights_col(
|
weight = weights.get_multi_weights_col(
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
quantize=config.quantize,
|
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -363,9 +362,9 @@ class CohereMLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashCohereLayer(nn.Module):
|
class FlashCohereLayer(nn.Module):
|
||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, prefix: str, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"model.layers.{layer_id}"
|
prefix = f"{prefix}.layers.{layer_id}"
|
||||||
self.self_attn = FlashCohereAttention(
|
self.self_attn = FlashCohereAttention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
)
|
)
|
||||||
@ -416,18 +415,19 @@ class FlashCohereLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashCohereModel(torch.nn.Module):
|
class FlashCohereModel(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix="model.embed_tokens", weights=weights
|
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
FlashCohereLayer(
|
FlashCohereLayer(
|
||||||
|
prefix,
|
||||||
layer_id,
|
layer_id,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
@ -436,7 +436,7 @@ class FlashCohereModel(torch.nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm = FastLayerNorm.load_no_bias(
|
self.norm = FastLayerNorm.load_no_bias(
|
||||||
prefix="model.norm", weights=weights, eps=config.layer_norm_eps
|
prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_eps
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
@ -486,10 +486,15 @@ class FlashCohereModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashCohereForCausalLM(torch.nn.Module):
|
class FlashCohereForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = FlashCohereModel(config, weights)
|
if not prefix:
|
||||||
|
prefix = "model"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.model"
|
||||||
|
|
||||||
|
self.model = FlashCohereModel(prefix, config, weights)
|
||||||
try:
|
try:
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
@ -499,7 +504,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
|
|||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="model.embed_tokens",
|
prefix=f"{prefix}.embed_tokens",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.logit_scale = config.logit_scale
|
self.logit_scale = config.logit_scale
|
||||||
|
@ -105,6 +105,12 @@ class DbrxFFNConfig(PretrainedConfig):
|
|||||||
|
|
||||||
|
|
||||||
class DbrxConfig(PretrainedConfig):
|
class DbrxConfig(PretrainedConfig):
|
||||||
|
attribute_map = {
|
||||||
|
"hidden_size": "d_model",
|
||||||
|
"num_attention_heads": "n_heads",
|
||||||
|
"num_hidden_layers": "n_layers",
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
d_model: int = 2048,
|
d_model: int = 2048,
|
||||||
@ -157,6 +163,12 @@ class DbrxConfig(PretrainedConfig):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_key_value_heads(self):
|
||||||
|
# We can't use the attribute map, since this the number of KV
|
||||||
|
# heads is not top-level.
|
||||||
|
return self.attn_config.kv_n_heads
|
||||||
|
|
||||||
|
|
||||||
def promote_scalar(x: torch.Tensor) -> torch.Tensor:
|
def promote_scalar(x: torch.Tensor) -> torch.Tensor:
|
||||||
return x.view(1) if len(x.size()) == 0 else x
|
return x.view(1) if len(x.size()) == 0 else x
|
||||||
@ -593,9 +605,9 @@ class DenseMoE(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DbrxLayer(nn.Module):
|
class DbrxLayer(nn.Module):
|
||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, prefix: str, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"transformer.blocks.{layer_id}"
|
prefix = f"{prefix}.blocks.{layer_id}"
|
||||||
|
|
||||||
self.attn = DbrxNormAttentionNorm(
|
self.attn = DbrxNormAttentionNorm(
|
||||||
prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights
|
prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights
|
||||||
@ -637,16 +649,17 @@ class DbrxLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DbrxModel(torch.nn.Module):
|
class DbrxModel(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix="transformer.wte", weights=weights
|
prefix=f"{prefix}.wte", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
DbrxLayer(
|
DbrxLayer(
|
||||||
|
prefix,
|
||||||
layer_id,
|
layer_id,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
@ -655,7 +668,7 @@ class DbrxModel(torch.nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm = FastLayerNorm.load_no_bias(
|
self.norm = FastLayerNorm.load_no_bias(
|
||||||
prefix="transformer.norm_f", weights=weights, eps=1e-5
|
prefix=f"{prefix}.norm_f", weights=weights, eps=1e-5
|
||||||
)
|
)
|
||||||
|
|
||||||
self.head_size = self.layers[0].attn.self_attn.head_size
|
self.head_size = self.layers[0].attn.self_attn.head_size
|
||||||
@ -702,10 +715,15 @@ class DbrxModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashDbrxForCausalLM(torch.nn.Module):
|
class FlashDbrxForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = DbrxModel(config, weights)
|
if not prefix:
|
||||||
|
prefix = "transformer"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.transformer"
|
||||||
|
|
||||||
|
self.model = DbrxModel(prefix, config, weights)
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="lm_head",
|
prefix="lm_head",
|
||||||
|
@ -102,7 +102,7 @@ class Gemma2Config(PretrainedConfig):
|
|||||||
|
|
||||||
class Gemma2FastRMSNorm(FastRMSNorm):
|
class Gemma2FastRMSNorm(FastRMSNorm):
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, prefix, weights, eps=1e-6):
|
def load(cls, prefix: str, weights, eps=1e-6):
|
||||||
dtype = weights.dtype
|
dtype = weights.dtype
|
||||||
weights.dtype = torch.float32
|
weights.dtype = torch.float32
|
||||||
weight = weights.get_tensor(f"{prefix}.weight") + 1
|
weight = weights.get_tensor(f"{prefix}.weight") + 1
|
||||||
@ -123,7 +123,7 @@ class Gemma2FastRMSNorm(FastRMSNorm):
|
|||||||
return hidden_states.to(self.dtype), residual
|
return hidden_states.to(self.dtype), residual
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix: str, weights):
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
return _load_gqa(config, prefix, weights)
|
return _load_gqa(config, prefix, weights)
|
||||||
else:
|
else:
|
||||||
@ -141,7 +141,6 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
|
|
||||||
weight = weights.get_multi_weights_col(
|
weight = weights.get_multi_weights_col(
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
quantize=config.quantize,
|
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -305,7 +304,7 @@ class Gemma2MLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashGemma2Layer(nn.Module):
|
class FlashGemma2Layer(nn.Module):
|
||||||
def __init__(self, prefix, config, weights, causal: bool, is_sliding: bool):
|
def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = FlashGemma2Attention(
|
self.self_attn = FlashGemma2Attention(
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
@ -376,7 +375,7 @@ class FlashGemma2Layer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashGemma2Model(torch.nn.Module):
|
class FlashGemma2Model(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights, causal: bool):
|
def __init__(self, prefix: str, config, weights, causal: bool):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
@ -442,7 +441,7 @@ class FlashGemma2Model(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashGemma2ForCausalLM(torch.nn.Module):
|
class FlashGemma2ForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights, causal: bool):
|
def __init__(self, prefix: str, config, weights, *, causal: bool = True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
embed_norm = config.hidden_size**0.5
|
embed_norm = config.hidden_size**0.5
|
||||||
|
@ -102,7 +102,7 @@ class GemmaConfig(PretrainedConfig):
|
|||||||
|
|
||||||
class GemmaFastRMSNorm(FastRMSNorm):
|
class GemmaFastRMSNorm(FastRMSNorm):
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, prefix, weights, eps=1e-6):
|
def load(cls, prefix: str, weights, eps=1e-6):
|
||||||
dtype = weights.dtype
|
dtype = weights.dtype
|
||||||
weights.dtype = torch.float32
|
weights.dtype = torch.float32
|
||||||
weight = weights.get_tensor(f"{prefix}.weight") + 1
|
weight = weights.get_tensor(f"{prefix}.weight") + 1
|
||||||
@ -123,7 +123,7 @@ class GemmaFastRMSNorm(FastRMSNorm):
|
|||||||
return hidden_states.to(self.dtype), residual
|
return hidden_states.to(self.dtype), residual
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix: str, weights):
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
return _load_gqa(config, prefix, weights)
|
return _load_gqa(config, prefix, weights)
|
||||||
else:
|
else:
|
||||||
@ -141,7 +141,6 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
|
|
||||||
weight = weights.get_multi_weights_col(
|
weight = weights.get_multi_weights_col(
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
quantize=config.quantize,
|
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -261,7 +260,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class GemmaMLP(nn.Module):
|
class GemmaMLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
act = config.hidden_act
|
act = config.hidden_act
|
||||||
self.act = (
|
self.act = (
|
||||||
@ -299,7 +298,7 @@ class GemmaMLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashGemmaLayer(nn.Module):
|
class FlashGemmaLayer(nn.Module):
|
||||||
def __init__(self, prefix, config, weights, causal: bool):
|
def __init__(self, prefix: str, config, weights, causal: bool):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = FlashGemmaAttention(
|
self.self_attn = FlashGemmaAttention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal
|
prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal
|
||||||
@ -354,7 +353,7 @@ class FlashGemmaLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashGemmaModel(torch.nn.Module):
|
class FlashGemmaModel(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights, causal: bool):
|
def __init__(self, prefix: str, config, weights, causal: bool):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
@ -419,7 +418,7 @@ class FlashGemmaModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashGemmaForCausalLM(torch.nn.Module):
|
class FlashGemmaForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights, causal: bool):
|
def __init__(self, prefix: str, config, weights, *, causal: bool = True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
embed_norm = config.hidden_size**0.5
|
embed_norm = config.hidden_size**0.5
|
||||||
|
@ -61,7 +61,6 @@ def _load_qkv_gptq(config, prefix: str, weights):
|
|||||||
# Weights
|
# Weights
|
||||||
weight = weights.get_weights_col_packed_qkv(
|
weight = weights.get_weights_col_packed_qkv(
|
||||||
f"{prefix}.c_attn",
|
f"{prefix}.c_attn",
|
||||||
config.quantize,
|
|
||||||
config.num_attention_heads,
|
config.num_attention_heads,
|
||||||
config.num_attention_heads,
|
config.num_attention_heads,
|
||||||
)
|
)
|
||||||
@ -137,7 +136,7 @@ def load_row(config, prefix: str, weights, bias: bool):
|
|||||||
"""load_row, but with transposed weight matrices."""
|
"""load_row, but with transposed weight matrices."""
|
||||||
|
|
||||||
if config.quantize == "gptq":
|
if config.quantize == "gptq":
|
||||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
weight = weights.get_weights_row(prefix)
|
||||||
else:
|
else:
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
||||||
|
|
||||||
@ -155,9 +154,7 @@ def load_row(config, prefix: str, weights, bias: bool):
|
|||||||
def load_col(config, prefix: str, weights, bias: bool):
|
def load_col(config, prefix: str, weights, bias: bool):
|
||||||
"""load_col, but with transposed weight matrices."""
|
"""load_col, but with transposed weight matrices."""
|
||||||
if config.quantize == "gptq":
|
if config.quantize == "gptq":
|
||||||
weight = weights.get_multi_weights_col(
|
weight = weights.get_multi_weights_col([prefix], dim=1)
|
||||||
[prefix], quantize=config.quantize, dim=1
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
|
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
|
||||||
|
|
||||||
@ -261,7 +258,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class GPT2MLP(nn.Module):
|
class GPT2MLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
act = config.activation_function
|
act = config.activation_function
|
||||||
self.act = (
|
self.act = (
|
||||||
@ -298,7 +295,7 @@ class GPT2MLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashGPT2Layer(nn.Module):
|
class FlashGPT2Layer(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = FlashGPT2Attention(
|
self.self_attn = FlashGPT2Attention(
|
||||||
prefix=f"{prefix}.attn", config=config, weights=weights
|
prefix=f"{prefix}.attn", config=config, weights=weights
|
||||||
@ -350,7 +347,7 @@ class FlashGPT2Layer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashGPT2Model(torch.nn.Module):
|
class FlashGPT2Model(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
@ -414,7 +411,7 @@ class FlashGPT2Model(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashGPT2ForCausalLM(torch.nn.Module):
|
class FlashGPT2ForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
@ -54,7 +54,7 @@ if SYSTEM == "rocm":
|
|||||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights, layer_id):
|
def load_attention(config, prefix: str, weights, layer_id):
|
||||||
# Only defined in granite.
|
# Only defined in granite.
|
||||||
bias = getattr(config, "attention_bias", False)
|
bias = getattr(config, "attention_bias", False)
|
||||||
head_size = config.hidden_size // config.num_attention_heads
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
@ -467,7 +467,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashLlamaForCausalLM(torch.nn.Module):
|
class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
@ -248,7 +248,7 @@ class MistralAttention(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MistralMLP(nn.Module):
|
class MistralMLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights, layer_id):
|
def __init__(self, prefix: str, config, weights, layer_id):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_act = config.hidden_act
|
self.hidden_act = config.hidden_act
|
||||||
self.act = (
|
self.act = (
|
||||||
@ -328,7 +328,7 @@ class MistralMLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MistralLayer(nn.Module):
|
class MistralLayer(nn.Module):
|
||||||
def __init__(self, prefix, config, weights, layer_id):
|
def __init__(self, prefix: str, config, weights, layer_id):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = MistralAttention(
|
self.self_attn = MistralAttention(
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
@ -392,7 +392,7 @@ class MistralLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MistralModel(torch.nn.Module):
|
class MistralModel(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
@ -462,7 +462,7 @@ class MistralModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashMistralForCausalLM(torch.nn.Module):
|
class FlashMistralForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights, name=None):
|
def __init__(self, prefix: str, config, weights, name=None):
|
||||||
if name is None:
|
if name is None:
|
||||||
name = "model"
|
name = "model"
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -116,7 +116,7 @@ def promote_scalar(x: torch.Tensor) -> torch.Tensor:
|
|||||||
return x.view(1) if len(x.size()) == 0 else x
|
return x.view(1) if len(x.size()) == 0 else x
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix: str, weights):
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
return _load_gqa(config, prefix, weights)
|
return _load_gqa(config, prefix, weights)
|
||||||
else:
|
else:
|
||||||
@ -135,7 +135,6 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
|
|
||||||
weight = weights.get_multi_weights_col(
|
weight = weights.get_multi_weights_col(
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
quantize=config.quantize,
|
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -155,7 +154,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _load_experts(config, prefix, mat, weights):
|
def _load_experts(config, prefix: str, mat, weights):
|
||||||
if config.quantize is not None:
|
if config.quantize is not None:
|
||||||
raise NotImplementedError("Mixtral does not support weight quantization yet.")
|
raise NotImplementedError("Mixtral does not support weight quantization yet.")
|
||||||
|
|
||||||
@ -475,7 +474,7 @@ class DenseMoE(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MixtralLayer(nn.Module):
|
class MixtralLayer(nn.Module):
|
||||||
def __init__(self, prefix, layer_id, config, weights):
|
def __init__(self, prefix: str, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"{prefix}.layers.{layer_id}"
|
prefix = f"{prefix}.layers.{layer_id}"
|
||||||
|
|
||||||
@ -536,7 +535,7 @@ class MixtralLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MixtralModel(torch.nn.Module):
|
class MixtralModel(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
@ -610,7 +609,7 @@ class MixtralModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashMixtralForCausalLM(torch.nn.Module):
|
class FlashMixtralForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = MixtralModel(prefix, config, weights)
|
self.model = MixtralModel(prefix, config, weights)
|
||||||
|
@ -48,7 +48,7 @@ from text_generation_server.layers.rotary import (
|
|||||||
|
|
||||||
|
|
||||||
def load_row(config, prefix: str, weights, bias: bool):
|
def load_row(config, prefix: str, weights, bias: bool):
|
||||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
weight = weights.get_weights_row(prefix)
|
||||||
|
|
||||||
if bias and weights.process_group.rank() == 0:
|
if bias and weights.process_group.rank() == 0:
|
||||||
# Rank is only on the first rank process
|
# Rank is only on the first rank process
|
||||||
@ -64,7 +64,7 @@ def load_row(config, prefix: str, weights, bias: bool):
|
|||||||
|
|
||||||
|
|
||||||
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
|
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
|
||||||
weight = weights.get_multi_weights_col([prefix], quantize=config.quantize, dim=0)
|
weight = weights.get_multi_weights_col([prefix], dim=0)
|
||||||
if isinstance(weight, torch.Tensor):
|
if isinstance(weight, torch.Tensor):
|
||||||
# Only on non quantized versions
|
# Only on non quantized versions
|
||||||
weight = (
|
weight = (
|
||||||
@ -305,12 +305,12 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.embed_in = TensorParallelEmbedding(
|
self.embed_in = TensorParallelEmbedding(
|
||||||
prefix="gpt_neox.embed_in", weights=weights
|
prefix=f"{prefix}.embed_in", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
@ -320,7 +320,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.final_layer_norm = FastLayerNorm.load(
|
self.final_layer_norm = FastLayerNorm.load(
|
||||||
prefix="gpt_neox.final_layer_norm",
|
prefix=f"{prefix}.final_layer_norm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_eps,
|
eps=config.layer_norm_eps,
|
||||||
)
|
)
|
||||||
@ -370,9 +370,15 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.gpt_neox = FlashGPTNeoXModel(config, weights)
|
|
||||||
|
if not prefix:
|
||||||
|
prefix = "gpt_neox"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.gpt_neox"
|
||||||
|
|
||||||
|
self.gpt_neox = FlashGPTNeoXModel(prefix, config, weights)
|
||||||
|
|
||||||
self.embed_out = SpeculativeHead.load(
|
self.embed_out = SpeculativeHead.load(
|
||||||
config, prefix="embed_out", weights=weights
|
config, prefix="embed_out", weights=weights
|
||||||
|
@ -85,7 +85,6 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
|
|
||||||
weight = weights.get_multi_weights_col(
|
weight = weights.get_multi_weights_col(
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
quantize=config.quantize,
|
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -258,9 +257,9 @@ class PhiMLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashPhiLayer(nn.Module):
|
class FlashPhiLayer(nn.Module):
|
||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, prefix: str, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"model.layers.{layer_id}"
|
prefix = f"{prefix}.layers.{layer_id}"
|
||||||
self.self_attn = FlashPhiAttention(
|
self.self_attn = FlashPhiAttention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
)
|
)
|
||||||
@ -307,18 +306,19 @@ class FlashPhiLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashPhiModel(torch.nn.Module):
|
class FlashPhiModel(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix="model.embed_tokens", weights=weights
|
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
FlashPhiLayer(
|
FlashPhiLayer(
|
||||||
|
prefix,
|
||||||
layer_id,
|
layer_id,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
@ -378,10 +378,15 @@ class FlashPhiModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashPhiForCausalLM(torch.nn.Module):
|
class FlashPhiForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = FlashPhiModel(config, weights)
|
if not prefix:
|
||||||
|
prefix = "model"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.model"
|
||||||
|
|
||||||
|
self.model = FlashPhiModel(prefix, config, weights)
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="lm_head",
|
prefix="lm_head",
|
||||||
|
@ -203,9 +203,9 @@ class Qwen2MLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Qwen2Layer(nn.Module):
|
class Qwen2Layer(nn.Module):
|
||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, prefix, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"model.layers.{layer_id}"
|
prefix = f"{prefix}.layers.{layer_id}"
|
||||||
self.self_attn = Qwen2Attention(
|
self.self_attn = Qwen2Attention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
)
|
)
|
||||||
@ -260,17 +260,18 @@ class Qwen2Layer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Qwen2Model(torch.nn.Module):
|
class Qwen2Model(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix="model.embed_tokens", weights=weights
|
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
Qwen2Layer(
|
Qwen2Layer(
|
||||||
|
prefix,
|
||||||
layer_id,
|
layer_id,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
@ -279,7 +280,7 @@ class Qwen2Model(torch.nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm = FastRMSNorm.load(
|
self.norm = FastRMSNorm.load(
|
||||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
@ -331,10 +332,15 @@ class Qwen2Model(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Qwen2ForCausalLM(torch.nn.Module):
|
class Qwen2ForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = Qwen2Model(config, weights)
|
if not prefix:
|
||||||
|
prefix = "model"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.model"
|
||||||
|
|
||||||
|
self.model = Qwen2Model(prefix, config, weights)
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="lm_head",
|
prefix="lm_head",
|
||||||
|
@ -23,7 +23,7 @@ from text_generation_server.layers.attention import (
|
|||||||
|
|
||||||
|
|
||||||
def load_row(config, prefix: str, weights, bias: bool):
|
def load_row(config, prefix: str, weights, bias: bool):
|
||||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
weight = weights.get_weights_row(prefix)
|
||||||
|
|
||||||
if bias and weights.process_group.rank() == 0:
|
if bias and weights.process_group.rank() == 0:
|
||||||
# Rank is only on the first rank process
|
# Rank is only on the first rank process
|
||||||
@ -42,6 +42,7 @@ class RWConfig(PretrainedConfig):
|
|||||||
attribute_map = {
|
attribute_map = {
|
||||||
"num_hidden_layers": "n_layer",
|
"num_hidden_layers": "n_layer",
|
||||||
"num_attention_heads": "n_head",
|
"num_attention_heads": "n_head",
|
||||||
|
"num_key_value_heads": "n_head_kv",
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -127,7 +128,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
prefix,
|
prefix: str,
|
||||||
weights,
|
weights,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -236,7 +237,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
prefix,
|
prefix: str,
|
||||||
weights,
|
weights,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -358,7 +359,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashMLP(nn.Module):
|
class FlashMLP(nn.Module):
|
||||||
def __init__(self, config, prefix, weights):
|
def __init__(self, config, prefix: str, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.act = torch.nn.functional.gelu
|
self.act = torch.nn.functional.gelu
|
||||||
|
|
||||||
@ -380,6 +381,7 @@ class FlashRWLayer(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
layer_id,
|
layer_id,
|
||||||
|
prefix: str,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
):
|
):
|
||||||
@ -388,7 +390,7 @@ class FlashRWLayer(nn.Module):
|
|||||||
parallel_attn = config.parallel_attn
|
parallel_attn = config.parallel_attn
|
||||||
self.parallel_attn = parallel_attn
|
self.parallel_attn = parallel_attn
|
||||||
|
|
||||||
prefix = f"transformer.h.{layer_id}"
|
prefix = f"{prefix}.h.{layer_id}"
|
||||||
|
|
||||||
self.input_layernorm = FastLayerNorm.load(
|
self.input_layernorm = FastLayerNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm",
|
prefix=f"{prefix}.input_layernorm",
|
||||||
@ -479,7 +481,7 @@ class FlashRWLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashRWLayerNorm(nn.Module):
|
class FlashRWLayerNorm(nn.Module):
|
||||||
def __init__(self, config, prefix, weights):
|
def __init__(self, config, prefix: str, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_ln = config.num_ln_in_parallel_attn
|
self.num_ln = config.num_ln_in_parallel_attn
|
||||||
|
|
||||||
@ -518,9 +520,9 @@ class FlashRWLayerNorm(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashRWLargeLayer(nn.Module):
|
class FlashRWLargeLayer(nn.Module):
|
||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, layer_id, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"transformer.h.{layer_id}"
|
prefix = f"{prefix}.h.{layer_id}"
|
||||||
|
|
||||||
self.ln_layer = FlashRWLayerNorm(config, prefix, weights)
|
self.ln_layer = FlashRWLayerNorm(config, prefix, weights)
|
||||||
|
|
||||||
@ -580,18 +582,18 @@ class FlashRWPreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class FlashRWModel(FlashRWPreTrainedModel):
|
class FlashRWModel(FlashRWPreTrainedModel):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.word_embeddings = TensorParallelEmbedding(
|
self.word_embeddings = TensorParallelEmbedding(
|
||||||
prefix="transformer.word_embeddings", weights=weights
|
prefix=f"{prefix}.word_embeddings", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.new_decoder_architecture:
|
if config.new_decoder_architecture:
|
||||||
self.h = nn.ModuleList(
|
self.h = nn.ModuleList(
|
||||||
[
|
[
|
||||||
FlashRWLargeLayer(layer_id, config, weights)
|
FlashRWLargeLayer(layer_id, prefix, config, weights)
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -599,14 +601,14 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
self.h = nn.ModuleList(
|
self.h = nn.ModuleList(
|
||||||
[
|
[
|
||||||
FlashRWLayer(layer_id, config, weights)
|
FlashRWLayer(layer_id, prefix, config, weights)
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.cache_size = self.h[0].self_attention.num_heads_kv
|
self.cache_size = self.h[0].self_attention.num_heads_kv
|
||||||
|
|
||||||
self.ln_f = FastLayerNorm.load(
|
self.ln_f = FastLayerNorm.load(
|
||||||
prefix="transformer.ln_f",
|
prefix=f"{prefix}.ln_f",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_epsilon,
|
eps=config.layer_norm_epsilon,
|
||||||
)
|
)
|
||||||
@ -653,10 +655,15 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.transformer = FlashRWModel(config, weights)
|
if not prefix:
|
||||||
|
prefix = "transformer"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.transformer"
|
||||||
|
|
||||||
|
self.transformer = FlashRWModel(prefix, config, weights)
|
||||||
|
|
||||||
self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights)
|
self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights)
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
@ -81,11 +82,13 @@ def _load_multi_mqa_gptq(
|
|||||||
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||||
qzeros = qzeros.to(device=weights.device)
|
qzeros = qzeros.to(device=weights.device)
|
||||||
|
|
||||||
gptq_params = weights._get_gptq_params()
|
loader = weights.weights_loader
|
||||||
if gptq_params.quant_method == "gptq":
|
assert isinstance(loader, GPTQWeightsLoader)
|
||||||
|
loader._get_gptq_params(weights)
|
||||||
|
if loader.quant_method == "gptq":
|
||||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||||
g_idx = g_idx.to(device=weights.device)
|
g_idx = g_idx.to(device=weights.device)
|
||||||
elif gptq_params.quant_method == "awq":
|
elif loader.quant_method == "awq":
|
||||||
g_idx = None
|
g_idx = None
|
||||||
from text_generation_server.layers.awq.conversion_utils import (
|
from text_generation_server.layers.awq.conversion_utils import (
|
||||||
fast_awq_to_gptq,
|
fast_awq_to_gptq,
|
||||||
@ -100,8 +103,8 @@ def _load_multi_mqa_gptq(
|
|||||||
qzeros=qzeros,
|
qzeros=qzeros,
|
||||||
scales=scales,
|
scales=scales,
|
||||||
g_idx=g_idx,
|
g_idx=g_idx,
|
||||||
bits=gptq_params.bits,
|
bits=loader.bits,
|
||||||
groupsize=gptq_params.groupsize,
|
groupsize=loader.groupsize,
|
||||||
use_exllama=HAS_EXLLAMA,
|
use_exllama=HAS_EXLLAMA,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -197,9 +200,7 @@ def load_col(config, prefix: str, weights, bias: bool):
|
|||||||
if config.transpose:
|
if config.transpose:
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
|
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
|
||||||
else:
|
else:
|
||||||
weight = weights.get_multi_weights_col(
|
weight = weights.get_multi_weights_col([prefix], dim=0)
|
||||||
[prefix], quantize=config.quantize, dim=0
|
|
||||||
)
|
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||||
@ -212,7 +213,7 @@ def load_row(config, prefix: str, weights, bias: bool):
|
|||||||
if config.transpose:
|
if config.transpose:
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
||||||
else:
|
else:
|
||||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
weight = weights.get_weights_row(prefix)
|
||||||
|
|
||||||
if bias and weights.process_group.rank() == 0:
|
if bias and weights.process_group.rank() == 0:
|
||||||
# Rank is only on the first rank process
|
# Rank is only on the first rank process
|
||||||
@ -346,16 +347,16 @@ class MLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, prefix: str, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"transformer.h.{layer_id}"
|
prefix = f"{prefix}.h.{layer_id}"
|
||||||
self.ln_1 = FastLayerNorm.load(
|
self.ln_1 = FastLayerNorm.load(
|
||||||
prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
|
prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
|
||||||
)
|
)
|
||||||
self.ln_2 = FastLayerNorm.load(
|
self.ln_2 = FastLayerNorm.load(
|
||||||
prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon
|
prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon
|
||||||
)
|
)
|
||||||
self.attn = FlashMQAttention(
|
self.self_attn = FlashMQAttention(
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@ -378,7 +379,7 @@ class Block(nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
hidden_states, residual = self.ln_1(hidden_states, residual)
|
hidden_states, residual = self.ln_1(hidden_states, residual)
|
||||||
hidden_states = self.attn(
|
hidden_states = self.self_attn(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
@ -396,25 +397,26 @@ class Block(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashSantacoderModel(nn.Module):
|
class FlashSantacoderModel(nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
self.wte = TensorParallelEmbedding(
|
self.wte = TensorParallelEmbedding(
|
||||||
prefix="transformer.wte",
|
prefix=f"{prefix}.wte",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
reduce=False,
|
reduce=False,
|
||||||
)
|
)
|
||||||
self.wpe = TensorParallelEmbedding(
|
self.wpe = TensorParallelEmbedding(
|
||||||
prefix="transformer.wpe",
|
prefix=f"{prefix}.wpe",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
reduce=False,
|
reduce=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.h = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
Block(
|
Block(
|
||||||
|
prefix,
|
||||||
layer_id,
|
layer_id,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
@ -426,8 +428,8 @@ class FlashSantacoderModel(nn.Module):
|
|||||||
prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon
|
prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon
|
||||||
)
|
)
|
||||||
|
|
||||||
self.head_size = self.h[0].attn.head_size
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
self.num_heads = self.h[0].attn.num_heads
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -446,7 +448,7 @@ class FlashSantacoderModel(nn.Module):
|
|||||||
torch.distributed.all_reduce(hidden_states, group=self.process_group)
|
torch.distributed.all_reduce(hidden_states, group=self.process_group)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.h):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
@ -464,11 +466,18 @@ class FlashSantacoderModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashSantacoderForCausalLM(nn.Module):
|
class FlashSantacoderForCausalLM(nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.transformer = FlashSantacoderModel(config, weights)
|
|
||||||
|
if not prefix:
|
||||||
|
prefix = "transformer"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.transformer"
|
||||||
|
|
||||||
|
config.transpose = config.architectures[0].startswith("GPT2")
|
||||||
|
self.model = FlashSantacoderModel(prefix, config, weights)
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config, prefix="transformer.wte", weights=weights
|
config, prefix=f"{prefix}.wte", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -485,7 +494,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
|
@ -126,7 +126,6 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
|
|
||||||
weight = weights.get_multi_weights_col(
|
weight = weights.get_multi_weights_col(
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
quantize=config.quantize,
|
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -417,14 +416,14 @@ class Starcoder2Layer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Starcoder2Model(torch.nn.Module):
|
class Starcoder2Model(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix="model.embed_tokens", weights=weights
|
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
@ -437,7 +436,7 @@ class Starcoder2Model(torch.nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
|
self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
|
||||||
prefix="model.norm", weights=weights, eps=config.norm_epsilon
|
prefix=f"{prefix}.norm", weights=weights, eps=config.norm_epsilon
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
@ -489,10 +488,15 @@ class Starcoder2Model(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = Starcoder2Model(config, weights)
|
if not prefix:
|
||||||
|
prefix = "model"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.model"
|
||||||
|
|
||||||
|
self.model = Starcoder2Model(prefix, config, weights)
|
||||||
try:
|
try:
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
@ -502,7 +506,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
|||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="model.embed_tokens",
|
prefix=f"{prefix}.embed_tokens",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -136,7 +136,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
config.text_config.quantize = config.quantize
|
config.text_config.quantize = config.quantize
|
||||||
config.text_config.speculator = config.speculator
|
config.text_config.speculator = config.speculator
|
||||||
self.language_model = load_text_model(
|
self.text_model = load_text_model(
|
||||||
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||||
config=config.text_config,
|
config=config.text_config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@ -180,7 +180,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
image_sizes: Optional[torch.LongTensor] = None,
|
image_sizes: Optional[torch.LongTensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
inputs_embeds = self.language_model.embed_tokens(input_ids)
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
if pixel_values is not None and len(pixel_values) > 0:
|
if pixel_values is not None and len(pixel_values) > 0:
|
||||||
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||||
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
||||||
@ -269,7 +269,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
input_ids, inputs_embeds, image_features
|
input_ids, inputs_embeds, image_features
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.language_model.model(
|
hidden_states = self.text_model.model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
@ -283,5 +283,5 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits, speculative_logits = self.language_model.lm_head(hidden_states)
|
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
@ -783,7 +783,7 @@ class MPTPreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class MPTModel(MPTPreTrainedModel):
|
class MPTModel(MPTPreTrainedModel):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
# config._validate_config()
|
# config._validate_config()
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.world_size = weights.process_group.size()
|
self.world_size = weights.process_group.size()
|
||||||
@ -809,13 +809,13 @@ class MPTModel(MPTPreTrainedModel):
|
|||||||
f"Requested norm type ({config.norm_type}) is not implemented within this repo."
|
f"Requested norm type ({config.norm_type}) is not implemented within this repo."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.wte = TensorParallelEmbedding("transformer.wte", weights)
|
self.wte = TensorParallelEmbedding(f"{prefix}.wte", weights)
|
||||||
|
|
||||||
if not self.alibi:
|
if not self.alibi:
|
||||||
self.wpe = TensorParallelEmbedding("transformer.wpe", weights)
|
self.wpe = TensorParallelEmbedding(f"{prefix}.wpe", weights)
|
||||||
self.blocks = nn.ModuleList(
|
self.blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights)
|
MPTBlock(config, prefix=f"{prefix}.blocks.{i}", weights=weights)
|
||||||
for i in range(config.n_layers)
|
for i in range(config.n_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -1085,13 +1085,19 @@ class MPTModel(MPTPreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class MPTForCausalLM(MPTPreTrainedModel):
|
class MPTForCausalLM(MPTPreTrainedModel):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
if not prefix:
|
||||||
|
prefix = "transformer"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.transformer"
|
||||||
|
|
||||||
if not config.tie_word_embeddings:
|
if not config.tie_word_embeddings:
|
||||||
raise ValueError("MPTForCausalLM only supports tied word embeddings")
|
raise ValueError("MPTForCausalLM only supports tied word embeddings")
|
||||||
self.transformer = MPTModel(config, weights)
|
self.transformer = MPTModel(prefix, config, weights)
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config, prefix="transformer.wte", weights=weights
|
config, prefix=f"{prefix}.wte", weights=weights
|
||||||
)
|
)
|
||||||
self.logit_scale = None
|
self.logit_scale = None
|
||||||
if config.logit_scale is not None:
|
if config.logit_scale is not None:
|
||||||
|
@ -404,24 +404,24 @@ class GPTNeoXMLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class GPTNeoXLayer(nn.Module):
|
class GPTNeoXLayer(nn.Module):
|
||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, layer_id, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_parallel_residual = config.use_parallel_residual
|
self.use_parallel_residual = config.use_parallel_residual
|
||||||
self.input_layernorm = nn.LayerNorm.load(
|
self.input_layernorm = nn.LayerNorm.load(
|
||||||
prefix=f"gpt_neox.layers.{layer_id}.input_layernorm",
|
prefix=f"{prefix}.layers.{layer_id}.input_layernorm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_eps,
|
eps=config.layer_norm_eps,
|
||||||
)
|
)
|
||||||
self.post_attention_layernorm = nn.LayerNorm.load(
|
self.post_attention_layernorm = nn.LayerNorm.load(
|
||||||
prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm",
|
prefix=f"{prefix}.layers.{layer_id}.post_attention_layernorm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_eps,
|
eps=config.layer_norm_eps,
|
||||||
)
|
)
|
||||||
self.attention = GPTNeoXAttention(
|
self.attention = GPTNeoXAttention(
|
||||||
config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights
|
config, prefix=f"{prefix}.layers.{layer_id}.attention", weights=weights
|
||||||
)
|
)
|
||||||
self.mlp = GPTNeoXMLP(
|
self.mlp = GPTNeoXMLP(
|
||||||
config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights
|
config, prefix=f"{prefix}.layers.{layer_id}.mlp", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -472,23 +472,23 @@ class GPTNeoXLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.num_attention_heads = config.num_attention_heads
|
self.num_attention_heads = config.num_attention_heads
|
||||||
|
|
||||||
self.embed_in = TensorParallelEmbedding(
|
self.embed_in = TensorParallelEmbedding(
|
||||||
prefix="gpt_neox.embed_in", weights=weights
|
prefix=f"{prefix}.embed_in", weights=weights
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
GPTNeoXLayer(layer_id, config, weights)
|
GPTNeoXLayer(layer_id, prefix, config, weights)
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.final_layer_norm = nn.LayerNorm.load(
|
self.final_layer_norm = nn.LayerNorm.load(
|
||||||
prefix="gpt_neox.final_layer_norm",
|
prefix=f"{prefix}.final_layer_norm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_eps,
|
eps=config.layer_norm_eps,
|
||||||
)
|
)
|
||||||
@ -640,9 +640,15 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
|||||||
class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
|
class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
|
||||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||||
|
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.gpt_neox = GPTNeoXModel(config, weights)
|
|
||||||
|
if not prefix:
|
||||||
|
prefix = "gpt_neox"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.gpt_neox"
|
||||||
|
|
||||||
|
self.gpt_neox = GPTNeoXModel(prefix, config, weights)
|
||||||
self.embed_out = SpeculativeHead.load(
|
self.embed_out = SpeculativeHead.load(
|
||||||
config, prefix="embed_out", weights=weights
|
config, prefix="embed_out", weights=weights
|
||||||
)
|
)
|
||||||
|
@ -94,11 +94,11 @@ class OPTLearnedPositionalEmbedding(nn.Module):
|
|||||||
This module learns positional embeddings up to a fixed maximum size.
|
This module learns positional embeddings up to a fixed maximum size.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, weights):
|
def __init__(self, prefix: str, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.offset = 2
|
self.offset = 2
|
||||||
self.weight = nn.Parameter(
|
self.weight = nn.Parameter(
|
||||||
weights.get_tensor("model.decoder.embed_positions.weight")
|
weights.get_tensor(f"{prefix}.decoder.embed_positions.weight")
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -311,11 +311,11 @@ class OPTAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class OPTDecoderLayer(nn.Module):
|
class OPTDecoderLayer(nn.Module):
|
||||||
def __init__(self, layer_id: int, config: OPTConfig, weights):
|
def __init__(self, layer_id: int, prefix: str, config: OPTConfig, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
prefix = f"model.decoder.layers.{layer_id}"
|
prefix = f"{prefix}.decoder.layers.{layer_id}"
|
||||||
self.self_attn = OPTAttention(
|
self.self_attn = OPTAttention(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
@ -429,7 +429,7 @@ class OPTPreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class OPTDecoder(OPTPreTrainedModel):
|
class OPTDecoder(OPTPreTrainedModel):
|
||||||
def __init__(self, config: OPTConfig, weights):
|
def __init__(self, prefix: str, config: OPTConfig, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.dropout = config.dropout
|
self.dropout = config.dropout
|
||||||
self.layerdrop = config.layerdrop
|
self.layerdrop = config.layerdrop
|
||||||
@ -438,20 +438,26 @@ class OPTDecoder(OPTPreTrainedModel):
|
|||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix="model.decoder.embed_tokens", weights=weights
|
prefix=f"{prefix}.decoder.embed_tokens", weights=weights
|
||||||
)
|
)
|
||||||
self.embed_positions = OPTLearnedPositionalEmbedding(weights)
|
self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights)
|
||||||
|
|
||||||
if config.word_embed_proj_dim != config.hidden_size:
|
if config.word_embed_proj_dim != config.hidden_size:
|
||||||
self.project_out = FastLinear.load(
|
self.project_out = FastLinear.load(
|
||||||
config, prefix="model.decoder.project_out", weights=weights, bias=False
|
config,
|
||||||
|
prefix=f"{prefix}.decoder.project_out",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.project_out = None
|
self.project_out = None
|
||||||
|
|
||||||
if config.word_embed_proj_dim != config.hidden_size:
|
if config.word_embed_proj_dim != config.hidden_size:
|
||||||
self.project_in = FastLinear.load(
|
self.project_in = FastLinear.load(
|
||||||
config, prefix="model.decoder.project_in", weights=weights, bias=False
|
config,
|
||||||
|
prefix=f"{prefix}.decoder.project_in",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.project_in = None
|
self.project_in = None
|
||||||
@ -461,14 +467,14 @@ class OPTDecoder(OPTPreTrainedModel):
|
|||||||
# see https://github.com/facebookresearch/metaseq/pull/164
|
# see https://github.com/facebookresearch/metaseq/pull/164
|
||||||
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
||||||
self.final_layer_norm = nn.LayerNorm.load(
|
self.final_layer_norm = nn.LayerNorm.load(
|
||||||
prefix="model.decoder.final_layer_norm", weights=weights, eps=EPS
|
prefix=f"{prefix}.decoder.final_layer_norm", weights=weights, eps=EPS
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.final_layer_norm = None
|
self.final_layer_norm = None
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
OPTDecoderLayer(layer_id, config, weights)
|
OPTDecoderLayer(layer_id, prefix, config, weights)
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -686,9 +692,9 @@ class OPTDecoder(OPTPreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class OPTModel(OPTPreTrainedModel):
|
class OPTModel(OPTPreTrainedModel):
|
||||||
def __init__(self, config: OPTConfig, weights):
|
def __init__(self, prefix: str, config: OPTConfig, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.decoder = OPTDecoder(config, weights)
|
self.decoder = OPTDecoder(prefix, config, weights)
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -743,13 +749,18 @@ class OPTModel(OPTPreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class OPTForCausalLM(OPTPreTrainedModel):
|
class OPTForCausalLM(OPTPreTrainedModel):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.model = OPTModel(config, weights)
|
if not prefix:
|
||||||
|
prefix = "model"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.model"
|
||||||
|
|
||||||
|
self.model = OPTModel(prefix, config, weights)
|
||||||
|
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config, prefix="model.decoder.embed_tokens", weights=weights
|
config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -248,16 +248,16 @@ class PhiBlock(nn.Module):
|
|||||||
|
|
||||||
# PhiModel implements the embedding layer and the transformer blocks.
|
# PhiModel implements the embedding layer and the transformer blocks.
|
||||||
class PhiModel(nn.Module):
|
class PhiModel(nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_rank = weights.process_group.rank()
|
self.tp_rank = weights.process_group.rank()
|
||||||
self.tp_world_size = weights.process_group.size()
|
self.tp_world_size = weights.process_group.size()
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix="transformer.embd.wte", weights=weights
|
prefix=f"{prefix}.embd.wte", weights=weights
|
||||||
)
|
)
|
||||||
self.blocks = nn.ModuleList(
|
self.blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
PhiBlock(f"transformer.h.{layer_id}", config, weights)
|
PhiBlock(f"{prefix}.h.{layer_id}", config, weights)
|
||||||
for layer_id in range(config.n_layer)
|
for layer_id in range(config.n_layer)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -289,9 +289,15 @@ class PhiModel(nn.Module):
|
|||||||
|
|
||||||
# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.
|
# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.
|
||||||
class PhiForCausalLM(torch.nn.Module):
|
class PhiForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = PhiModel(config, weights)
|
|
||||||
|
if not prefix:
|
||||||
|
prefix = "transformer"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.transformer"
|
||||||
|
|
||||||
|
self.model = PhiModel(prefix, config, weights)
|
||||||
self.lm_head = PhiCausalLMHead(config, weights)
|
self.lm_head = PhiCausalLMHead(config, weights)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -10,7 +10,12 @@ import numpy as np
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import (
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
AutoConfig,
|
||||||
|
AutoTokenizer,
|
||||||
|
GenerationConfig,
|
||||||
|
)
|
||||||
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
|
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
|
||||||
@ -21,6 +26,12 @@ from text_generation_server.models import Model
|
|||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
from text_generation_server.utils.dist import RANK
|
from text_generation_server.utils.dist import RANK
|
||||||
from text_generation_server.utils.speculate import get_speculate
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
|
from text_generation_server.utils import (
|
||||||
|
initialize_torch_distributed,
|
||||||
|
weight_files,
|
||||||
|
Weights,
|
||||||
|
hub,
|
||||||
|
)
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
Batch,
|
Batch,
|
||||||
Tokens,
|
Tokens,
|
||||||
@ -39,6 +50,7 @@ from text_generation_server.models.globals import (
|
|||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||||
|
from text_generation_server.utils.quantization import get_loader
|
||||||
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
|
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import (
|
from text_generation_server.utils.import_utils import (
|
||||||
@ -799,29 +811,123 @@ class FlashCausalLMBatch(Batch):
|
|||||||
return len(self.requests)
|
return len(self.requests)
|
||||||
|
|
||||||
|
|
||||||
|
ADAPTER_LAYERS = [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
"o_proj",
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
"down_proj",
|
||||||
|
]
|
||||||
|
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
|
||||||
|
|
||||||
|
|
||||||
class FlashCausalLM(Model):
|
class FlashCausalLM(Model):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
model: torch.nn.Module,
|
model_class,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
revision: Optional[str] = None,
|
||||||
num_layers: int,
|
quantize: Optional[str] = None,
|
||||||
num_kv_heads: int,
|
speculator: Optional[str] = None,
|
||||||
head_size: int,
|
dtype: Optional[torch.dtype] = None,
|
||||||
dtype: torch.dtype,
|
trust_remote_code: bool = False,
|
||||||
device: torch.device,
|
lora_adapter_ids: Optional[list] = [],
|
||||||
rank: int = 0,
|
tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
|
||||||
world_size: int = 1,
|
config_class: PreTrainedTokenizerBase = AutoConfig,
|
||||||
sliding_window: Optional[int] = None,
|
default_dtype=torch.float16,
|
||||||
|
aliases=None,
|
||||||
|
# Used for Santacoder override of config
|
||||||
|
num_kv_heads=None,
|
||||||
|
skip_special_tokens: bool = True,
|
||||||
):
|
):
|
||||||
self.num_layers = num_layers
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
self.num_kv_heads = num_kv_heads
|
if torch.cuda.is_available():
|
||||||
self.head_size = head_size
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
dtype = default_dtype if dtype is None else dtype
|
||||||
|
elif SYSTEM == "ipex":
|
||||||
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = default_dtype if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
# Float16 doesn't exist on target.
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"{model_class} is only available on GPU")
|
||||||
|
|
||||||
|
tokenizer = tokenizer_class.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
generation_config = GenerationConfig.from_pretrained(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
if isinstance(generation_config.eos_token_id, (list, set)):
|
||||||
|
# TODO Huge hack
|
||||||
|
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
config = config_class.from_pretrained(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
config.quantize = quantize
|
||||||
|
config.speculator = speculator
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
weights_loader = get_loader(quantize, model_id, revision)
|
||||||
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
weights = Weights(
|
||||||
|
filenames,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
process_group=self.process_group,
|
||||||
|
aliases=aliases,
|
||||||
|
weights_loader=weights_loader,
|
||||||
|
)
|
||||||
|
|
||||||
|
prefix = ""
|
||||||
|
model = model_class(prefix, config, weights)
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
# VLM models define the config we care about in their text_config
|
||||||
|
text_config = getattr(config, "text_config", None)
|
||||||
|
if text_config is not None:
|
||||||
|
config = text_config
|
||||||
|
|
||||||
|
if getattr(config, "sliding_window", None) is not None:
|
||||||
|
set_sliding_window(config.sliding_window)
|
||||||
|
else:
|
||||||
|
config.sliding_window = None
|
||||||
|
|
||||||
|
self.num_layers = config.num_hidden_layers
|
||||||
|
# Validation is done in the model itself
|
||||||
|
if num_kv_heads is None:
|
||||||
|
num_kv_heads = getattr(config, "num_key_value_heads", None)
|
||||||
|
# GPT-2 workaround
|
||||||
|
if num_kv_heads is None:
|
||||||
|
num_kv_heads = getattr(config, "n_head", None)
|
||||||
|
if num_kv_heads is None:
|
||||||
|
raise ValueError("Cannot get the number of key/value heads")
|
||||||
|
self.num_kv_heads = (
|
||||||
|
num_kv_heads // self.process_group.size()
|
||||||
|
if num_kv_heads > 1
|
||||||
|
else num_kv_heads
|
||||||
|
)
|
||||||
|
assert self.num_kv_heads > 0
|
||||||
|
self.head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
|
||||||
self.cuda_graphs = {}
|
self.cuda_graphs = {}
|
||||||
self.kv_cache = []
|
self.kv_cache = []
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super().__init__(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -830,7 +936,7 @@ class FlashCausalLM(Model):
|
|||||||
device=device,
|
device=device,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
sliding_window=sliding_window,
|
sliding_window=config.sliding_window,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -1578,3 +1684,72 @@ class FlashCausalLM(Model):
|
|||||||
forward_ns = start_decode - start
|
forward_ns = start_decode - start
|
||||||
decode_ns = time.time_ns() - start_decode
|
decode_ns = time.time_ns() - start_decode
|
||||||
return generations, batch, (forward_ns, decode_ns)
|
return generations, batch, (forward_ns, decode_ns)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_adapter_loading(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
||||||
|
layer_weights = {}
|
||||||
|
|
||||||
|
prefix = "model.layers"
|
||||||
|
|
||||||
|
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
||||||
|
# that have a language_model inside of the larger model.
|
||||||
|
if hasattr(self.model, "language_model"):
|
||||||
|
_model = self.model.language_model
|
||||||
|
elif hasattr(self.model, "text_model"):
|
||||||
|
_model = self.model.text_model
|
||||||
|
else:
|
||||||
|
_model = self.model
|
||||||
|
|
||||||
|
for i, layer in enumerate(_model.model.layers):
|
||||||
|
layer_weights[(i, "q_proj")] = (
|
||||||
|
f"{prefix}.{i}.self_attn.q_proj",
|
||||||
|
layer.self_attn.query_key_value,
|
||||||
|
)
|
||||||
|
layer_weights[(i, "k_proj")] = (
|
||||||
|
f"{prefix}.{i}.self_attn.k_proj",
|
||||||
|
layer.self_attn.query_key_value,
|
||||||
|
)
|
||||||
|
layer_weights[(i, "v_proj")] = (
|
||||||
|
f"{prefix}.{i}.self_attn.v_proj",
|
||||||
|
layer.self_attn.query_key_value,
|
||||||
|
)
|
||||||
|
layer_weights[(i, "o_proj")] = (
|
||||||
|
f"{prefix}.{i}.self_attn.o_proj",
|
||||||
|
layer.self_attn.o_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: this is a hack to avoid the gate_proj for
|
||||||
|
# FlashStarcoder2 that doesnt have these layers
|
||||||
|
if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
|
||||||
|
layer_weights[(i, "gate_proj")] = (
|
||||||
|
f"{prefix}.{i}.mlp.gate_proj",
|
||||||
|
layer.mlp.gate_up_proj,
|
||||||
|
)
|
||||||
|
layer_weights[(i, "up_proj")] = (
|
||||||
|
f"{prefix}.{i}.mlp.up_proj",
|
||||||
|
layer.mlp.gate_up_proj,
|
||||||
|
)
|
||||||
|
layer_weights[(i, "down_proj")] = (
|
||||||
|
f"{prefix}.{i}.mlp.down_proj",
|
||||||
|
layer.mlp.down_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
|
||||||
|
return layer_weights
|
||||||
|
|
||||||
|
@property
|
||||||
|
def adapter_layers(self) -> List[str]:
|
||||||
|
return ADAPTER_LAYERS
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_traced_adapter_layers(self) -> List[str]:
|
||||||
|
return ["q_proj", "v_proj"]
|
||||||
|
|
||||||
|
def get_num_layers_for_type(self, layer_type: str) -> int:
|
||||||
|
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
|
||||||
|
|
||||||
|
def is_row_parallel(self, layer_type: str) -> bool:
|
||||||
|
return layer_type in ROW_PARALLEL
|
||||||
|
@ -1,75 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from typing import Optional
|
|
||||||
from transformers import AutoTokenizer, AutoConfig
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
|
|
||||||
FlashCohereForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashCohere(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashCohere is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
use_fast=True,
|
|
||||||
from_slow=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = FlashCohereForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashCohere, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.model.layers),
|
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
|
||||||
head_size=model.model.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
@ -1,100 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from typing import Optional
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
from transformers.models.gpt2 import GPT2TokenizerFast
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
|
|
||||||
FlashDbrxForCausalLM,
|
|
||||||
DbrxConfig,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashDbrx(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashDBRX is only available on GPU")
|
|
||||||
|
|
||||||
try:
|
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
use_fast=True,
|
|
||||||
from_slow=False,
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
try:
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
use_fast=True,
|
|
||||||
from_slow=False,
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
# FIXME: change back to model id once the tokenizer.json is merged
|
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained(
|
|
||||||
"Xenova/dbrx-instruct-tokenizer",
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
use_fast=True,
|
|
||||||
from_slow=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = DbrxConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = FlashDbrxForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashDbrx, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.model.layers),
|
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
|
||||||
head_size=model.model.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
@ -1,83 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from typing import Optional
|
|
||||||
from transformers import AutoConfig, AutoTokenizer
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
|
||||||
FlashGemmaForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashGemma(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashGemma is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
# TODO hardcoded
|
|
||||||
prefix = ""
|
|
||||||
model = FlashGemmaForCausalLM(prefix, config, weights, causal=True)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashGemma, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.model.layers),
|
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
|
||||||
head_size=model.model.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
@ -1,83 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from typing import Optional
|
|
||||||
from transformers import PretrainedConfig, AutoTokenizer
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
|
||||||
FlashGemma2ForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashGemma2(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashGemma2 is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = PretrainedConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
# TODO hardcoded
|
|
||||||
prefix = ""
|
|
||||||
model = FlashGemma2ForCausalLM(prefix, config, weights, causal=True)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashGemma2, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.model.layers),
|
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
|
||||||
head_size=model.model.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
@ -1,82 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
|
|
||||||
from transformers.models.gpt2 import GPT2Tokenizer
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
|
|
||||||
FlashGPT2ForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashGPT2(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashGPT2 is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
prefix = ""
|
|
||||||
model = FlashGPT2ForCausalLM(prefix, config, weights)
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashGPT2, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.model.layers),
|
|
||||||
num_kv_heads=model.model.num_heads,
|
|
||||||
head_size=model.model.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
@ -1,171 +0,0 @@
|
|||||||
import os
|
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
|
|
||||||
from typing import Optional, Tuple, Dict, List
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
|
||||||
FlashLlamaForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
hub,
|
|
||||||
)
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
ADAPTER_LAYERS = [
|
|
||||||
"q_proj",
|
|
||||||
"k_proj",
|
|
||||||
"v_proj",
|
|
||||||
"o_proj",
|
|
||||||
"gate_proj",
|
|
||||||
"up_proj",
|
|
||||||
"down_proj",
|
|
||||||
]
|
|
||||||
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
|
|
||||||
|
|
||||||
|
|
||||||
class FlashLlama(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
lora_adapter_ids: Optional[list] = [],
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
generation_config = GenerationConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
if isinstance(generation_config.eos_token_id, (list, set)):
|
|
||||||
# TODO Huge hack
|
|
||||||
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
prefix = ""
|
|
||||||
model = FlashLlamaForCausalLM(prefix, config, weights)
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashLlama, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.model.layers),
|
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
|
||||||
head_size=model.model.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_adapter_loading(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
|
||||||
layer_weights = {}
|
|
||||||
|
|
||||||
prefix = "model.layers"
|
|
||||||
|
|
||||||
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
|
||||||
# that have a language_model inside of the larger model.
|
|
||||||
if hasattr(self.model, "language_model"):
|
|
||||||
_model = self.model.language_model
|
|
||||||
elif hasattr(self.model, "text_model"):
|
|
||||||
_model = self.model.text_model
|
|
||||||
else:
|
|
||||||
_model = self.model
|
|
||||||
|
|
||||||
for i, layer in enumerate(_model.model.layers):
|
|
||||||
layer_weights[(i, "q_proj")] = (
|
|
||||||
f"{prefix}.{i}.self_attn.q_proj",
|
|
||||||
layer.self_attn.query_key_value,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "k_proj")] = (
|
|
||||||
f"{prefix}.{i}.self_attn.k_proj",
|
|
||||||
layer.self_attn.query_key_value,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "v_proj")] = (
|
|
||||||
f"{prefix}.{i}.self_attn.v_proj",
|
|
||||||
layer.self_attn.query_key_value,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "o_proj")] = (
|
|
||||||
f"{prefix}.{i}.self_attn.o_proj",
|
|
||||||
layer.self_attn.o_proj,
|
|
||||||
)
|
|
||||||
|
|
||||||
layer_weights[(i, "gate_proj")] = (
|
|
||||||
f"{prefix}.{i}.mlp.gate_proj",
|
|
||||||
layer.mlp.gate_up_proj,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "up_proj")] = (
|
|
||||||
f"{prefix}.{i}.mlp.up_proj",
|
|
||||||
layer.mlp.gate_up_proj,
|
|
||||||
)
|
|
||||||
layer_weights[(i, "down_proj")] = (
|
|
||||||
f"{prefix}.{i}.mlp.down_proj",
|
|
||||||
layer.mlp.down_proj,
|
|
||||||
)
|
|
||||||
|
|
||||||
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
|
|
||||||
return layer_weights
|
|
||||||
|
|
||||||
@property
|
|
||||||
def adapter_layers(self) -> List[str]:
|
|
||||||
return ADAPTER_LAYERS
|
|
||||||
|
|
||||||
@property
|
|
||||||
def default_traced_adapter_layers(self) -> List[str]:
|
|
||||||
return ["q_proj", "v_proj"]
|
|
||||||
|
|
||||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
|
||||||
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
|
|
||||||
|
|
||||||
def is_row_parallel(self, layer_type: str) -> bool:
|
|
||||||
return layer_type in ROW_PARALLEL
|
|
@ -1,24 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from transformers import AutoTokenizer, AutoConfig
|
|
||||||
from typing import Optional, Tuple, Dict, List
|
from typing import Optional, Tuple, Dict, List
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models import FlashCausalLM
|
||||||
from text_generation_server.models.flash_causal_lm import set_sliding_window
|
|
||||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
|
||||||
FlashMistralForCausalLM,
|
|
||||||
MistralConfig,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
ADAPTER_LAYERS = [
|
ADAPTER_LAYERS = [
|
||||||
@ -33,88 +16,7 @@ ADAPTER_LAYERS = [
|
|||||||
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
|
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
|
||||||
|
|
||||||
|
|
||||||
class BaseFlashMistral(FlashCausalLM):
|
class FlashMistral(FlashCausalLM):
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_cls,
|
|
||||||
model_id: str,
|
|
||||||
config_cls=AutoConfig,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
tokenizer_class=AutoTokenizer,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashMistral is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = tokenizer_class.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = config_cls.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
# Set context windows
|
|
||||||
if getattr(config, "sliding_window", None) is not None:
|
|
||||||
set_sliding_window(config.sliding_window)
|
|
||||||
else:
|
|
||||||
config.sliding_window = None
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
prefix = ""
|
|
||||||
model = model_cls(prefix, config, weights)
|
|
||||||
|
|
||||||
self.cuda_graphs = {}
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
num_layers, num_kv_heads, head_size = self.get_layer_config(model)
|
|
||||||
super().__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=num_layers,
|
|
||||||
num_kv_heads=num_kv_heads,
|
|
||||||
head_size=head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
sliding_window=config.sliding_window,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
|
||||||
return (
|
|
||||||
len(model.model.layers),
|
|
||||||
model.model.num_key_value_heads,
|
|
||||||
model.model.head_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supports_adapter_loading(self) -> bool:
|
def supports_adapter_loading(self) -> bool:
|
||||||
return True
|
return True
|
||||||
@ -126,9 +28,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
|
|
||||||
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
||||||
# that have a language_model inside of the larger model.
|
# that have a language_model inside of the larger model.
|
||||||
if hasattr(self.model, "language_model"):
|
if hasattr(self.model, "text_model"):
|
||||||
_model = self.model.language_model
|
|
||||||
elif hasattr(self.model, "text_model"):
|
|
||||||
_model = self.model.text_model
|
_model = self.model.text_model
|
||||||
else:
|
else:
|
||||||
_model = self.model
|
_model = self.model
|
||||||
@ -183,25 +83,3 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
|
|
||||||
def is_row_parallel(self, layer_type: str) -> bool:
|
def is_row_parallel(self, layer_type: str) -> bool:
|
||||||
return layer_type in ROW_PARALLEL
|
return layer_type in ROW_PARALLEL
|
||||||
|
|
||||||
|
|
||||||
class FlashMistral(BaseFlashMistral):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
super(FlashMistral, self).__init__(
|
|
||||||
config_cls=MistralConfig,
|
|
||||||
model_cls=FlashMistralForCausalLM,
|
|
||||||
model_id=model_id,
|
|
||||||
revision=revision,
|
|
||||||
quantize=quantize,
|
|
||||||
speculator=speculator,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
@ -1,31 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from text_generation_server.models.flash_mistral import BaseFlashMistral
|
|
||||||
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
|
|
||||||
MixtralConfig,
|
|
||||||
FlashMixtralForCausalLM,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashMixtral(BaseFlashMistral):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
super(FlashMixtral, self).__init__(
|
|
||||||
config_cls=MixtralConfig,
|
|
||||||
model_cls=FlashMixtralForCausalLM,
|
|
||||||
model_id=model_id,
|
|
||||||
revision=revision,
|
|
||||||
quantize=quantize,
|
|
||||||
speculator=speculator,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
@ -1,82 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from transformers import AutoTokenizer, AutoConfig
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
|
||||||
FlashGPTNeoXForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashNeoXSharded(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashNeoX is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(
|
|
||||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
|
||||||
)
|
|
||||||
if config.quantize in ["gptq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = FlashGPTNeoXForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashNeoXSharded, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model.to(device),
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.gpt_neox.layers),
|
|
||||||
num_kv_heads=model.gpt_neox.num_heads,
|
|
||||||
head_size=model.gpt_neox.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
@ -1,111 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from transformers import AutoConfig, AutoTokenizer
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
|
||||||
FlashPhiForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashPhi(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashPhi is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = FlashPhiForCausalLM(config, weights)
|
|
||||||
if speculator:
|
|
||||||
from text_generation_server.utils.medusa import MedusaModel
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
is_local_model = (
|
|
||||||
Path(speculator).exists() and Path(speculator).is_dir()
|
|
||||||
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
|
||||||
|
|
||||||
if not is_local_model:
|
|
||||||
medusa_config = hf_hub_download(
|
|
||||||
speculator, revision=revision, filename="config.json"
|
|
||||||
)
|
|
||||||
medusa_head = hf_hub_download(
|
|
||||||
speculator, revision=revision, filename="medusa_lm_head.pt"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
medusa_config = str(Path(speculator) / "config.json")
|
|
||||||
medusa_head = str(Path(speculator) / "medusa_lm_head.pt")
|
|
||||||
|
|
||||||
with open(medusa_config, "r") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
|
||||||
weights = Weights(
|
|
||||||
[medusa_sf], device, dtype, process_group=self.process_group
|
|
||||||
)
|
|
||||||
lm_head = model.lm_head
|
|
||||||
model.lm_head = MedusaModel(config, weights, lm_head)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashPhi, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.model.layers),
|
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
|
||||||
head_size=model.model.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
@ -1,93 +0,0 @@
|
|||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from transformers import AutoTokenizer, AutoConfig
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from text_generation_server.models.flash_mistral import (
|
|
||||||
BaseFlashMistral,
|
|
||||||
set_sliding_window,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
|
||||||
Qwen2ForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashQwen2(BaseFlashMistral):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashQwen2 is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
# Set context windows
|
|
||||||
if config.sliding_window is not None:
|
|
||||||
set_sliding_window(config.sliding_window)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = Qwen2ForCausalLM(config, weights)
|
|
||||||
|
|
||||||
self.cuda_graphs = {}
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(BaseFlashMistral, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.model.layers),
|
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
|
||||||
head_size=model.model.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
sliding_window=config.sliding_window,
|
|
||||||
)
|
|
@ -1,91 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
|
|
||||||
RWConfig,
|
|
||||||
FlashRWForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashRWSharded(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashRW is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = RWConfig.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(
|
|
||||||
filenames,
|
|
||||||
device,
|
|
||||||
dtype,
|
|
||||||
process_group=self.process_group,
|
|
||||||
aliases={
|
|
||||||
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
|
||||||
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
if config.quantize in ["gptq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = FlashRWForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashRWSharded, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model.to(device),
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.transformer.h),
|
|
||||||
num_kv_heads=model.transformer.cache_size,
|
|
||||||
head_size=model.transformer.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
@ -1,99 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
|
||||||
from transformers import AutoTokenizer, AutoConfig
|
|
||||||
from typing import Optional, List
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
|
|
||||||
FlashSantacoderForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashSantacoderSharded(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
config.transpose = config.architectures[0].startswith("GPT2")
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(
|
|
||||||
filenames,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
process_group=self.process_group,
|
|
||||||
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
|
||||||
)
|
|
||||||
if config.quantize in ["gptq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = FlashSantacoderForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(FlashSantacoderSharded, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model.to(device),
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.transformer.h),
|
|
||||||
num_kv_heads=1,
|
|
||||||
head_size=model.transformer.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def decode(self, generated_ids: List[int]) -> str:
|
|
||||||
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
|
||||||
return self.tokenizer.decode(
|
|
||||||
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
|
||||||
)
|
|
@ -1,84 +0,0 @@
|
|||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from transformers.models.gpt2 import GPT2TokenizerFast
|
|
||||||
|
|
||||||
from text_generation_server.models.flash_mistral import (
|
|
||||||
BaseFlashMistral,
|
|
||||||
set_sliding_window,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
|
|
||||||
Starcoder2Config,
|
|
||||||
FlashStarcoder2ForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Starcoder2 has the same base as Mistral
|
|
||||||
class FlashStarcoder2(BaseFlashMistral):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("FlashStarcoder2 is only available on GPU")
|
|
||||||
|
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = Starcoder2Config.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
# Set context windows
|
|
||||||
if config.sliding_window is not None:
|
|
||||||
set_sliding_window(config.sliding_window)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = FlashStarcoder2ForCausalLM(config, weights)
|
|
||||||
|
|
||||||
self.cuda_graphs = {}
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(BaseFlashMistral, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
num_layers=len(model.model.layers),
|
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
|
||||||
head_size=model.model.head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
sliding_window=config.sliding_window,
|
|
||||||
)
|
|
@ -162,83 +162,3 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class GalacticaSharded(CausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
tp_parallel=True,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
tokenizer.pad_token_id = config.pad_token_id
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(
|
|
||||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
|
||||||
)
|
|
||||||
if config.quantize in ["gptq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = OPTForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(CausalLM, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
requires_padding=True,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def batch_type(self) -> Type[CausalLMBatch]:
|
|
||||||
return GalacticaCausalLMBatch
|
|
||||||
|
|
||||||
def decode(self, generated_ids: List[int]) -> str:
|
|
||||||
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
|
||||||
return self.tokenizer.decode(
|
|
||||||
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
|
||||||
):
|
|
||||||
outputs, speculative_logits = self.model.forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
|
||||||
|
@ -1,89 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from transformers import (
|
|
||||||
AutoTokenizer,
|
|
||||||
AutoConfig,
|
|
||||||
)
|
|
||||||
from text_generation_server.models import CausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.neox_modeling import (
|
|
||||||
GPTNeoxForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GPTNeoxSharded(CausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(
|
|
||||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
|
||||||
)
|
|
||||||
if config.quantize in ["gptq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = GPTNeoxForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(CausalLM, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
requires_padding=True,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
|
||||||
):
|
|
||||||
outputs, speculative_logits = self.model.forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
|
@ -23,6 +23,7 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.quantization import get_loader
|
||||||
|
|
||||||
|
|
||||||
class IDEFICSSharded(IdeficsCausalLM):
|
class IDEFICSSharded(IdeficsCausalLM):
|
||||||
@ -70,6 +71,9 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
weights_loader = get_loader(
|
||||||
|
quantize=quantize, model_id=model_id, revision=revision
|
||||||
|
)
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(
|
weights = Weights(
|
||||||
@ -77,6 +81,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
process_group=self.process_group,
|
process_group=self.process_group,
|
||||||
|
weights_loader=weights_loader,
|
||||||
)
|
)
|
||||||
|
|
||||||
model = IdeficsForVisionText2Text(config, weights)
|
model = IdeficsForVisionText2Text(config, weights)
|
||||||
|
@ -1,51 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
from transformers import (
|
|
||||||
AutoProcessor,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.custom_modeling.idefics2 import (
|
|
||||||
Idefics2ForConditionalGeneration,
|
|
||||||
)
|
|
||||||
|
|
||||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
|
||||||
|
|
||||||
|
|
||||||
class Idefics2(VlmCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
# XXX: Extremely important to cap resolution in order to limit
|
|
||||||
# VRAM usage.
|
|
||||||
size={"longest_edge": 448, "shortest_edge": 378},
|
|
||||||
)
|
|
||||||
super().__init__(
|
|
||||||
model_cls=Idefics2ForConditionalGeneration,
|
|
||||||
model_id=model_id,
|
|
||||||
revision=revision,
|
|
||||||
quantize=quantize,
|
|
||||||
speculator=speculator,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
|
||||||
return (
|
|
||||||
len(model.text_model.model.layers),
|
|
||||||
model.text_model.model.num_key_value_heads,
|
|
||||||
model.text_model.model.head_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def max_past(self) -> Optional[int]:
|
|
||||||
return getattr(self.model.text_model, "max_past", None)
|
|
@ -1,46 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
from transformers import (
|
|
||||||
AutoProcessor,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.custom_modeling.llava_next import (
|
|
||||||
LlavaNextForConditionalGeneration,
|
|
||||||
)
|
|
||||||
|
|
||||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
|
||||||
|
|
||||||
|
|
||||||
class LlavaNext(VlmCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
super().__init__(
|
|
||||||
model_cls=LlavaNextForConditionalGeneration,
|
|
||||||
model_id=model_id,
|
|
||||||
revision=revision,
|
|
||||||
quantize=quantize,
|
|
||||||
speculator=speculator,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
|
||||||
return (
|
|
||||||
len(model.language_model.model.layers),
|
|
||||||
model.language_model.model.num_key_value_heads,
|
|
||||||
model.language_model.model.head_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def max_past(self) -> Optional[int]:
|
|
||||||
return getattr(self.model.language_model, "max_past", None)
|
|
@ -28,6 +28,7 @@ from text_generation_server.models.types import (
|
|||||||
GeneratedText,
|
GeneratedText,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.chunks import concat_text_chunks
|
from text_generation_server.utils.chunks import concat_text_chunks
|
||||||
|
from text_generation_server.utils.quantization import get_loader
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
@ -448,8 +449,17 @@ class Mamba(Model):
|
|||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.speculator = speculator
|
config.speculator = speculator
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
weights_loader = get_loader(
|
||||||
|
quantize=quantize, model_id=model_id, revision=revision
|
||||||
|
)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
weights = Weights(
|
||||||
|
filenames,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
process_group=self.process_group,
|
||||||
|
weights_loader=weights_loader,
|
||||||
|
)
|
||||||
model = MambaModel(config, weights)
|
model = MambaModel(config, weights)
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(Mamba, self).__init__(
|
super(Mamba, self).__init__(
|
||||||
|
@ -60,7 +60,7 @@ class Model(ABC):
|
|||||||
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
|
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
|
||||||
LayerAdapterWeights
|
LayerAdapterWeights
|
||||||
)
|
)
|
||||||
self.target_to_layer = self.adapter_target_to_layer()
|
self.target_to_layer = None
|
||||||
self.loaded_adapters = set()
|
self.loaded_adapters = set()
|
||||||
self.static_adapter_id = adapter_id
|
self.static_adapter_id = adapter_id
|
||||||
|
|
||||||
@ -187,6 +187,8 @@ class Model(ABC):
|
|||||||
into model. Otherwise, the adapter weights are applied during the forward
|
into model. Otherwise, the adapter weights are applied during the forward
|
||||||
pass and stored separately from the base model parameters.
|
pass and stored separately from the base model parameters.
|
||||||
"""
|
"""
|
||||||
|
if self.target_to_layer is None:
|
||||||
|
self.target_to_layer = self.adapter_target_to_layer()
|
||||||
if adapter_index in self.loaded_adapters:
|
if adapter_index in self.loaded_adapters:
|
||||||
# Adapter already loaded
|
# Adapter already loaded
|
||||||
return
|
return
|
||||||
|
@ -1,105 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Type
|
|
||||||
from opentelemetry import trace
|
|
||||||
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
import json
|
|
||||||
|
|
||||||
from text_generation_server.models import CausalLM
|
|
||||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
|
||||||
from text_generation_server.pb import generate_pb2
|
|
||||||
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
|
||||||
MPTForCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class MPTCausalLMBatch(CausalLMBatch):
|
|
||||||
@classmethod
|
|
||||||
def from_pb(
|
|
||||||
cls,
|
|
||||||
pb: generate_pb2.Batch,
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
) -> "CausalLMBatch":
|
|
||||||
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
|
|
||||||
batch.keys_head_dim_last = False
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
class MPTSharded(CausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
|
|
||||||
# If model_id is a local path, load the file directly
|
|
||||||
local_path = Path(model_id, "config.json")
|
|
||||||
if local_path.exists():
|
|
||||||
filename = str(local_path.resolve())
|
|
||||||
else:
|
|
||||||
filename = hf_hub_download(
|
|
||||||
model_id, revision=revision, filename="config.json"
|
|
||||||
)
|
|
||||||
with open(filename, "r") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
config = PretrainedConfig(**config)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
||||||
if config.quantize in ["gptq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
config.quantize = quantize
|
|
||||||
model = MPTForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(CausalLM, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
requires_padding=False,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def batch_type(self) -> Type[CausalLMBatch]:
|
|
||||||
return MPTCausalLMBatch
|
|
@ -1,86 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from transformers import (
|
|
||||||
AutoTokenizer,
|
|
||||||
AutoConfig,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
|
|
||||||
from text_generation_server.models import CausalLM
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OPTSharded(CausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
tokenizer.pad_token_id = config.pad_token_id
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(
|
|
||||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
|
||||||
)
|
|
||||||
if config.quantize in ["gptq", "marlin"]:
|
|
||||||
weights._set_gptq_params(model_id, revision)
|
|
||||||
|
|
||||||
model = OPTForCausalLM(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(CausalLM, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
requires_padding=True,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
|
||||||
):
|
|
||||||
outputs, speculative_logits = self.model.forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
|
@ -74,45 +74,3 @@ class PaliGemmaBatch(VlmCausalLMBatch):
|
|||||||
else:
|
else:
|
||||||
image_inputs = None
|
image_inputs = None
|
||||||
return batch_tokenized_inputs, image_inputs
|
return batch_tokenized_inputs, image_inputs
|
||||||
|
|
||||||
|
|
||||||
class PaliGemma(VlmCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
config_cls=AutoConfig,
|
|
||||||
model_cls=PaliGemmaForConditionalGeneration,
|
|
||||||
model_id=model_id,
|
|
||||||
revision=revision,
|
|
||||||
quantize=quantize,
|
|
||||||
speculator=speculator,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def batch_type(self):
|
|
||||||
return PaliGemmaBatch
|
|
||||||
|
|
||||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
|
||||||
return (
|
|
||||||
len(model.text_model.model.layers),
|
|
||||||
model.text_model.model.num_key_value_heads,
|
|
||||||
model.text_model.model.head_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def max_past(self) -> Optional[int]:
|
|
||||||
return getattr(self.model.text_model, "max_past", None)
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user