Merge branch 'main' into ci_amd3

This commit is contained in:
fxmarty 2024-07-08 13:06:39 +02:00
commit 8c590be463
84 changed files with 2061 additions and 3022 deletions

View File

@ -11,10 +11,30 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v2
- name: Set up Rust
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
- name: Install Protocol Buffers compiler
run: |
sudo apt-get update
sudo apt-get install -y protobuf-compiler libprotobuf-dev
- name: Install Launcher - name: Install Launcher
id: install-launcher id: install-launcher
run: cargo install --path launcher/ run: cargo install --path launcher/
- name: Check launcher Docs are up-to-date
- name: Install router
id: install-router
run: cargo install --path router/
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.x'
- name: Check that documentation is up-to-date
run: | run: |
echo text-generation-launcher --help
python update_doc.py --check python update_doc.py --check

View File

@ -11,6 +11,11 @@ on:
# - rocm # - rocm
# - xpu # - xpu
required: true required: true
release-tests:
description: "Run release integration tests"
required: true
default: false
type: boolean
jobs: jobs:
build-and-push: build-and-push:
@ -195,7 +200,7 @@ jobs:
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"] runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"]
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest' if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
env: env:
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main') && '--release' || '' }} PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '' }}
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4

View File

@ -20,7 +20,14 @@ on:
- "Dockerfile_amd" - "Dockerfile_amd"
- "Dockerfile_intel" - "Dockerfile_intel"
branches: branches:
- 'main' - "main"
workflow_dispatch:
inputs:
release-tests:
description: "Run release integration tests"
required: true
default: false
type: boolean
jobs: jobs:
build: build:
@ -33,4 +40,6 @@ jobs:
uses: ./.github/workflows/build.yaml # calls the one above ^ uses: ./.github/workflows/build.yaml # calls the one above ^
with: with:
hardware: ${{ matrix.hardware }} hardware: ${{ matrix.hardware }}
# https://github.com/actions/runner/issues/2206
release-tests: ${{ inputs.release-tests == true }}
secrets: inherit secrets: inherit

View File

@ -9,7 +9,7 @@ members = [
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
version = "2.1.1-dev0" version = "2.1.2-dev0"
edition = "2021" edition = "2021"
authors = ["Olivier Dehaene"] authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference" homepage = "https://github.com/huggingface/text-generation-inference"

View File

@ -4,7 +4,7 @@ WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
FROM chef as planner FROM chef AS planner
COPY Cargo.lock Cargo.lock COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml COPY rust-toolchain.toml rust-toolchain.toml
@ -38,7 +38,7 @@ RUN cargo build --profile release-opt
# Python builder # Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 as pytorch-install FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS pytorch-install
ARG PYTORCH_VERSION=2.3.0 ARG PYTORCH_VERSION=2.3.0
ARG PYTHON_VERSION=3.10 ARG PYTHON_VERSION=3.10
@ -81,7 +81,7 @@ RUN case ${TARGETPLATFORM} in \
/opt/conda/bin/conda clean -ya /opt/conda/bin/conda clean -ya
# CUDA kernels builder image # CUDA kernels builder image
FROM pytorch-install as kernel-builder FROM pytorch-install AS kernel-builder
ARG MAX_JOBS=8 ARG MAX_JOBS=8
@ -90,7 +90,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Build Flash Attention CUDA kernels # Build Flash Attention CUDA kernels
FROM kernel-builder as flash-att-builder FROM kernel-builder AS flash-att-builder
WORKDIR /usr/src WORKDIR /usr/src
@ -100,7 +100,7 @@ COPY server/Makefile-flash-att Makefile
RUN make build-flash-attention RUN make build-flash-attention
# Build Flash Attention v2 CUDA kernels # Build Flash Attention v2 CUDA kernels
FROM kernel-builder as flash-att-v2-builder FROM kernel-builder AS flash-att-v2-builder
WORKDIR /usr/src WORKDIR /usr/src
@ -110,14 +110,14 @@ COPY server/Makefile-flash-att-v2 Makefile
RUN make build-flash-attention-v2-cuda RUN make build-flash-attention-v2-cuda
# Build Transformers exllama kernels # Build Transformers exllama kernels
FROM kernel-builder as exllama-kernels-builder FROM kernel-builder AS exllama-kernels-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/exllama_kernels/ . COPY server/exllama_kernels/ .
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
# Build Transformers exllama kernels # Build Transformers exllama kernels
FROM kernel-builder as exllamav2-kernels-builder FROM kernel-builder AS exllamav2-kernels-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/exllamav2_kernels/ . COPY server/exllamav2_kernels/ .
@ -125,42 +125,42 @@ COPY server/exllamav2_kernels/ .
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
# Build Transformers awq kernels # Build Transformers awq kernels
FROM kernel-builder as awq-kernels-builder FROM kernel-builder AS awq-kernels-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/Makefile-awq Makefile COPY server/Makefile-awq Makefile
# Build specific version of transformers # Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-awq RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-awq
# Build eetq kernels # Build eetq kernels
FROM kernel-builder as eetq-kernels-builder FROM kernel-builder AS eetq-kernels-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/Makefile-eetq Makefile COPY server/Makefile-eetq Makefile
# Build specific version of transformers # Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq
# Build marlin kernels # Build marlin kernels
FROM kernel-builder as marlin-kernels-builder FROM kernel-builder AS marlin-kernels-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/marlin/ . COPY server/marlin/ .
# Build specific version of transformers # Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
# Build Lorax Punica kernels # Build Lorax Punica kernels
FROM kernel-builder as lorax-punica-builder FROM kernel-builder AS lorax-punica-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/Makefile-lorax-punica Makefile COPY server/Makefile-lorax-punica Makefile
# Build specific version of transformers # Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica
# Build Transformers CUDA kernels # Build Transformers CUDA kernels
FROM kernel-builder as custom-kernels-builder FROM kernel-builder AS custom-kernels-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/custom_kernels/ . COPY server/custom_kernels/ .
# Build specific version of transformers # Build specific version of transformers
RUN python setup.py build RUN python setup.py build
# Build vllm CUDA kernels # Build vllm CUDA kernels
FROM kernel-builder as vllm-builder FROM kernel-builder AS vllm-builder
WORKDIR /usr/src WORKDIR /usr/src
@ -172,13 +172,13 @@ COPY server/Makefile-vllm Makefile
RUN make build-vllm-cuda RUN make build-vllm-cuda
# Build mamba kernels # Build mamba kernels
FROM kernel-builder as mamba-builder FROM kernel-builder AS mamba-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/Makefile-selective-scan Makefile COPY server/Makefile-selective-scan Makefile
RUN make build-all RUN make build-all
# Text Generation Inference base image # Text Generation Inference base image
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 as base FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS base
# Conda env # Conda env
ENV PATH=/opt/conda/bin:$PATH \ ENV PATH=/opt/conda/bin:$PATH \
@ -260,7 +260,7 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/lo
# AWS Sagemaker compatible image # AWS Sagemaker compatible image
FROM base as sagemaker FROM base AS sagemaker
COPY sagemaker-entrypoint.sh entrypoint.sh COPY sagemaker-entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh RUN chmod +x entrypoint.sh

View File

@ -4,7 +4,7 @@ WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
FROM chef as planner FROM chef AS planner
COPY Cargo.lock Cargo.lock COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml COPY rust-toolchain.toml rust-toolchain.toml
@ -37,7 +37,7 @@ COPY launcher launcher
RUN cargo build --profile release-opt RUN cargo build --profile release-opt
# Text Generation Inference base image for RoCm # Text Generation Inference base image for RoCm
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update AS base
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \ build-essential \
@ -118,7 +118,7 @@ ARG BUILD_CAFFE2="0" \
RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install
# Set as recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm # Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
ENV HIP_FORCE_DEV_KERNARG=1 ENV HIP_FORCE_DEV_KERNARG=1
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK. # On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
@ -150,26 +150,26 @@ COPY server/Makefile-flash-att-v2 Makefile
RUN make build-flash-attention-v2-rocm RUN make build-flash-attention-v2-rocm
# Build Transformers CUDA kernels (gpt-neox and bloom) # Build Transformers CUDA kernels (gpt-neox and bloom)
FROM kernel-builder as custom-kernels-builder FROM kernel-builder AS custom-kernels-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/custom_kernels/ . COPY server/custom_kernels/ .
RUN python setup.py build RUN python setup.py build
# Build exllama kernels # Build exllama kernels
FROM kernel-builder as exllama-kernels-builder FROM kernel-builder AS exllama-kernels-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/exllama_kernels/ . COPY server/exllama_kernels/ .
RUN python setup.py build RUN python setup.py build
# Build exllama v2 kernels # Build exllama v2 kernels
FROM kernel-builder as exllamav2-kernels-builder FROM kernel-builder AS exllamav2-kernels-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/exllamav2_kernels/ . COPY server/exllamav2_kernels/ .
RUN python setup.py build RUN python setup.py build
FROM base as base-copy FROM base AS base-copy
# Text Generation Inference base env # Text Generation Inference base env
ENV HUGGINGFACE_HUB_CACHE=/data \ ENV HUGGINGFACE_HUB_CACHE=/data \
@ -208,7 +208,7 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
# AWS Sagemaker compatible image # AWS Sagemaker compatible image
FROM base as sagemaker FROM base AS sagemaker
COPY sagemaker-entrypoint.sh entrypoint.sh COPY sagemaker-entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh RUN chmod +x entrypoint.sh

View File

@ -5,7 +5,7 @@ WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
FROM chef as planner FROM chef AS planner
COPY Cargo.lock Cargo.lock COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml COPY rust-toolchain.toml rust-toolchain.toml
@ -40,7 +40,7 @@ RUN cargo build --profile release-opt
# Text Generation Inference base image for Intel # Text Generation Inference base image for Intel
FROM intel/intel-extension-for-pytorch:2.1.30-xpu as xpu FROM intel/intel-extension-for-pytorch:2.1.30-xpu AS xpu
USER root USER root
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it # libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
@ -95,7 +95,7 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/lo
# Text Generation Inference base image for Intel-cpu # Text Generation Inference base image for Intel-cpu
FROM ubuntu:22.04 as cpu FROM ubuntu:22.04 AS cpu
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
curl \ curl \
@ -172,6 +172,6 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
# Install launcher # Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
FROM ${PLATFORM} as final FROM ${PLATFORM} AS final
ENTRYPOINT ["text-generation-launcher"] ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"] CMD ["--json-output"]

View File

@ -79,7 +79,7 @@ model=HuggingFaceH4/zephyr-7b-beta
volume=$PWD/data volume=$PWD/data
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.1.0 --model-id $model ghcr.io/huggingface/text-generation-inference:2.1.1 --model-id $model
``` ```
And then you can make requests like And then you can make requests like
@ -93,7 +93,7 @@ curl 127.0.0.1:8080/generate_stream \
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. **Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.1.0-rocm --model-id $model` instead of the command above. **Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.1.1-rocm --model-id $model` instead of the command above.
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli): To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
``` ```

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0", "name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0" "url": "https://www.apache.org/licenses/LICENSE-2.0"
}, },
"version": "2.0.1" "version": "2.1.2-dev0"
}, },
"paths": { "paths": {
"/": { "/": {
@ -19,7 +19,6 @@
"Text Generation Inference" "Text Generation Inference"
], ],
"summary": "Generate tokens if `stream == false` or a stream of token if `stream == true`", "summary": "Generate tokens if `stream == false` or a stream of token if `stream == true`",
"description": "Generate tokens if `stream == false` or a stream of token if `stream == true`",
"operationId": "compat_generate", "operationId": "compat_generate",
"requestBody": { "requestBody": {
"content": { "content": {
@ -108,7 +107,6 @@
"Text Generation Inference" "Text Generation Inference"
], ],
"summary": "Generate tokens", "summary": "Generate tokens",
"description": "Generate tokens",
"operationId": "generate", "operationId": "generate",
"requestBody": { "requestBody": {
"content": { "content": {
@ -192,7 +190,6 @@
"Text Generation Inference" "Text Generation Inference"
], ],
"summary": "Generate a stream of token using Server-Sent Events", "summary": "Generate a stream of token using Server-Sent Events",
"description": "Generate a stream of token using Server-Sent Events",
"operationId": "generate_stream", "operationId": "generate_stream",
"requestBody": { "requestBody": {
"content": { "content": {
@ -276,7 +273,6 @@
"Text Generation Inference" "Text Generation Inference"
], ],
"summary": "Health check method", "summary": "Health check method",
"description": "Health check method",
"operationId": "health", "operationId": "health",
"responses": { "responses": {
"200": { "200": {
@ -305,7 +301,6 @@
"Text Generation Inference" "Text Generation Inference"
], ],
"summary": "Text Generation Inference endpoint info", "summary": "Text Generation Inference endpoint info",
"description": "Text Generation Inference endpoint info",
"operationId": "get_model_info", "operationId": "get_model_info",
"responses": { "responses": {
"200": { "200": {
@ -327,7 +322,6 @@
"Text Generation Inference" "Text Generation Inference"
], ],
"summary": "Prometheus metrics scrape endpoint", "summary": "Prometheus metrics scrape endpoint",
"description": "Prometheus metrics scrape endpoint",
"operationId": "metrics", "operationId": "metrics",
"responses": { "responses": {
"200": { "200": {
@ -349,7 +343,6 @@
"Text Generation Inference" "Text Generation Inference"
], ],
"summary": "Tokenize inputs", "summary": "Tokenize inputs",
"description": "Tokenize inputs",
"operationId": "tokenize", "operationId": "tokenize",
"requestBody": { "requestBody": {
"content": { "content": {
@ -394,7 +387,6 @@
"Text Generation Inference" "Text Generation Inference"
], ],
"summary": "Generate tokens", "summary": "Generate tokens",
"description": "Generate tokens",
"operationId": "chat_completions", "operationId": "chat_completions",
"requestBody": { "requestBody": {
"content": { "content": {
@ -483,7 +475,6 @@
"Text Generation Inference" "Text Generation Inference"
], ],
"summary": "Generate tokens", "summary": "Generate tokens",
"description": "Generate tokens",
"operationId": "completions", "operationId": "completions",
"requestBody": { "requestBody": {
"content": { "content": {
@ -626,7 +617,6 @@
"type": "object", "type": "object",
"required": [ "required": [
"id", "id",
"object",
"created", "created",
"model", "model",
"system_fingerprint", "system_fingerprint",
@ -653,9 +643,6 @@
"type": "string", "type": "string",
"example": "mistralai/Mistral-7B-Instruct-v0.2" "example": "mistralai/Mistral-7B-Instruct-v0.2"
}, },
"object": {
"type": "string"
},
"system_fingerprint": { "system_fingerprint": {
"type": "string" "type": "string"
}, },
@ -697,7 +684,6 @@
"type": "object", "type": "object",
"required": [ "required": [
"id", "id",
"object",
"created", "created",
"model", "model",
"system_fingerprint", "system_fingerprint",
@ -723,9 +709,6 @@
"type": "string", "type": "string",
"example": "mistralai/Mistral-7B-Instruct-v0.2" "example": "mistralai/Mistral-7B-Instruct-v0.2"
}, },
"object": {
"type": "string"
},
"system_fingerprint": { "system_fingerprint": {
"type": "string" "type": "string"
} }
@ -756,34 +739,19 @@
"nullable": true "nullable": true
}, },
"message": { "message": {
"$ref": "#/components/schemas/Message" "$ref": "#/components/schemas/OutputMessage"
} }
} }
}, },
"ChatCompletionDelta": { "ChatCompletionDelta": {
"type": "object", "oneOf": [
"required": [
"role"
],
"properties": {
"content": {
"type": "string",
"example": "What is Deep Learning?",
"nullable": true
},
"role": {
"type": "string",
"example": "user"
},
"tool_calls": {
"allOf": [
{ {
"$ref": "#/components/schemas/DeltaToolCall" "$ref": "#/components/schemas/TextMessage"
} },
], {
"nullable": true "$ref": "#/components/schemas/ToolCallDelta"
}
} }
]
}, },
"ChatCompletionLogprob": { "ChatCompletionLogprob": {
"type": "object", "type": "object",
@ -903,6 +871,15 @@
"example": 0.1, "example": 0.1,
"nullable": true "nullable": true
}, },
"response_format": {
"allOf": [
{
"$ref": "#/components/schemas/GrammarType"
}
],
"default": "null",
"nullable": true
},
"seed": { "seed": {
"type": "integer", "type": "integer",
"format": "int64", "format": "int64",
@ -969,6 +946,38 @@
} }
} }
}, },
"Chunk": {
"type": "object",
"required": [
"id",
"created",
"choices",
"model",
"system_fingerprint"
],
"properties": {
"choices": {
"type": "array",
"items": {
"$ref": "#/components/schemas/CompletionComplete"
}
},
"created": {
"type": "integer",
"format": "int64",
"minimum": 0
},
"id": {
"type": "string"
},
"model": {
"type": "string"
},
"system_fingerprint": {
"type": "string"
}
}
},
"CompatGenerateRequest": { "CompatGenerateRequest": {
"type": "object", "type": "object",
"required": [ "required": [
@ -988,6 +997,55 @@
} }
} }
}, },
"Completion": {
"oneOf": [
{
"allOf": [
{
"$ref": "#/components/schemas/Chunk"
},
{
"type": "object",
"required": [
"object"
],
"properties": {
"object": {
"type": "string",
"enum": [
"text_completion"
]
}
}
}
]
},
{
"allOf": [
{
"$ref": "#/components/schemas/CompletionFinal"
},
{
"type": "object",
"required": [
"object"
],
"properties": {
"object": {
"type": "string",
"enum": [
"text_completion"
]
}
}
}
]
}
],
"discriminator": {
"propertyName": "object"
}
},
"CompletionComplete": { "CompletionComplete": {
"type": "object", "type": "object",
"required": [ "required": [
@ -1017,15 +1075,15 @@
} }
} }
}, },
"CompletionCompleteChunk": { "CompletionFinal": {
"type": "object", "type": "object",
"required": [ "required": [
"id", "id",
"object",
"created", "created",
"choices",
"model", "model",
"system_fingerprint" "system_fingerprint",
"choices",
"usage"
], ],
"properties": { "properties": {
"choices": { "choices": {
@ -1037,19 +1095,21 @@
"created": { "created": {
"type": "integer", "type": "integer",
"format": "int64", "format": "int64",
"example": "1706270835",
"minimum": 0 "minimum": 0
}, },
"id": { "id": {
"type": "string" "type": "string"
}, },
"model": { "model": {
"type": "string" "type": "string",
}, "example": "mistralai/Mistral-7B-Instruct-v0.2"
"object": {
"type": "string"
}, },
"system_fingerprint": { "system_fingerprint": {
"type": "string" "type": "string"
},
"usage": {
"$ref": "#/components/schemas/Usage"
} }
} }
}, },
@ -1081,12 +1141,7 @@
"example": "mistralai/Mistral-7B-Instruct-v0.2" "example": "mistralai/Mistral-7B-Instruct-v0.2"
}, },
"prompt": { "prompt": {
"type": "array", "$ref": "#/components/schemas/Prompt"
"items": {
"type": "string"
},
"description": "The prompt to generate completions for.",
"example": "What is Deep Learning?"
}, },
"repetition_penalty": { "repetition_penalty": {
"type": "number", "type": "number",
@ -1100,6 +1155,15 @@
"nullable": true, "nullable": true,
"minimum": 0 "minimum": 0
}, },
"stop": {
"type": "array",
"items": {
"type": "string"
},
"description": "Up to 4 sequences where the API will stop generating further tokens.",
"example": "null",
"nullable": true
},
"stream": { "stream": {
"type": "boolean" "type": "boolean"
}, },
@ -1121,15 +1185,6 @@
"description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.", "description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.",
"example": 0.95, "example": 0.95,
"nullable": true "nullable": true
},
"stop": {
"type": "array",
"items": {
"type": "string"
},
"description": "Up to 4 sequences where the API will stop generating further tokens.",
"example": "null",
"nullable": true
} }
} }
}, },
@ -1272,8 +1327,16 @@
"GenerateParameters": { "GenerateParameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"adapter_id": {
"type": "string",
"description": "Lora adapter id",
"default": "null",
"example": "null",
"nullable": true
},
"best_of": { "best_of": {
"type": "integer", "type": "integer",
"description": "Generate best_of sequences and return the one if the highest token logprobs.",
"default": "null", "default": "null",
"example": 1, "example": 1,
"nullable": true, "nullable": true,
@ -1282,20 +1345,24 @@
}, },
"decoder_input_details": { "decoder_input_details": {
"type": "boolean", "type": "boolean",
"description": "Whether to return decoder input token logprobs and ids.",
"default": "false" "default": "false"
}, },
"details": { "details": {
"type": "boolean", "type": "boolean",
"description": "Whether to return generation details.",
"default": "true" "default": "true"
}, },
"do_sample": { "do_sample": {
"type": "boolean", "type": "boolean",
"description": "Activate logits sampling.",
"default": "false", "default": "false",
"example": true "example": true
}, },
"frequency_penalty": { "frequency_penalty": {
"type": "number", "type": "number",
"format": "float", "format": "float",
"description": "The parameter for frequency penalty. 1.0 means no penalty\nPenalize new tokens based on their existing frequency in the text so far,\ndecreasing the model's likelihood to repeat the same line verbatim.",
"default": "null", "default": "null",
"example": 0.1, "example": 0.1,
"nullable": true, "nullable": true,
@ -1313,6 +1380,7 @@
"max_new_tokens": { "max_new_tokens": {
"type": "integer", "type": "integer",
"format": "int32", "format": "int32",
"description": "Maximum number of tokens to generate.",
"default": "100", "default": "100",
"example": "20", "example": "20",
"nullable": true, "nullable": true,
@ -1321,6 +1389,7 @@
"repetition_penalty": { "repetition_penalty": {
"type": "number", "type": "number",
"format": "float", "format": "float",
"description": "The parameter for repetition penalty. 1.0 means no penalty.\nSee [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.",
"default": "null", "default": "null",
"example": 1.03, "example": 1.03,
"nullable": true, "nullable": true,
@ -1328,6 +1397,7 @@
}, },
"return_full_text": { "return_full_text": {
"type": "boolean", "type": "boolean",
"description": "Whether to prepend the prompt to the generated text",
"default": "null", "default": "null",
"example": false, "example": false,
"nullable": true "nullable": true
@ -1335,6 +1405,7 @@
"seed": { "seed": {
"type": "integer", "type": "integer",
"format": "int64", "format": "int64",
"description": "Random sampling seed.",
"default": "null", "default": "null",
"example": "null", "example": "null",
"nullable": true, "nullable": true,
@ -1346,6 +1417,7 @@
"items": { "items": {
"type": "string" "type": "string"
}, },
"description": "Stop generating tokens if a member of `stop` is generated.",
"example": [ "example": [
"photographer" "photographer"
], ],
@ -1354,6 +1426,7 @@
"temperature": { "temperature": {
"type": "number", "type": "number",
"format": "float", "format": "float",
"description": "The value used to module the logits distribution.",
"default": "null", "default": "null",
"example": 0.5, "example": 0.5,
"nullable": true, "nullable": true,
@ -1362,6 +1435,7 @@
"top_k": { "top_k": {
"type": "integer", "type": "integer",
"format": "int32", "format": "int32",
"description": "The number of highest probability vocabulary tokens to keep for top-k-filtering.",
"default": "null", "default": "null",
"example": 10, "example": 10,
"nullable": true, "nullable": true,
@ -1370,6 +1444,7 @@
"top_n_tokens": { "top_n_tokens": {
"type": "integer", "type": "integer",
"format": "int32", "format": "int32",
"description": "The number of highest probability vocabulary tokens to keep for top-n-filtering.",
"default": "null", "default": "null",
"example": 5, "example": 5,
"nullable": true, "nullable": true,
@ -1379,6 +1454,7 @@
"top_p": { "top_p": {
"type": "number", "type": "number",
"format": "float", "format": "float",
"description": "Top-p value for nucleus sampling.",
"default": "null", "default": "null",
"example": 0.95, "example": 0.95,
"nullable": true, "nullable": true,
@ -1387,6 +1463,7 @@
}, },
"truncate": { "truncate": {
"type": "integer", "type": "integer",
"description": "Truncate inputs tokens to the given size.",
"default": "null", "default": "null",
"example": "null", "example": "null",
"nullable": true, "nullable": true,
@ -1395,6 +1472,7 @@
"typical_p": { "typical_p": {
"type": "number", "type": "number",
"format": "float", "format": "float",
"description": "Typical Decoding mass\nSee [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.",
"default": "null", "default": "null",
"example": 0.95, "example": 0.95,
"nullable": true, "nullable": true,
@ -1403,6 +1481,7 @@
}, },
"watermark": { "watermark": {
"type": "boolean", "type": "boolean",
"description": "Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226).",
"default": "false", "default": "false",
"example": true "example": true
} }
@ -1495,13 +1574,14 @@
"max_concurrent_requests", "max_concurrent_requests",
"max_best_of", "max_best_of",
"max_stop_sequences", "max_stop_sequences",
"max_input_length", "max_input_tokens",
"max_total_tokens", "max_total_tokens",
"waiting_served_ratio", "waiting_served_ratio",
"max_batch_total_tokens", "max_batch_total_tokens",
"max_waiting_tokens", "max_waiting_tokens",
"validation_workers", "validation_workers",
"max_client_batch_size", "max_client_batch_size",
"router",
"version" "version"
], ],
"properties": { "properties": {
@ -1538,7 +1618,7 @@
"example": "128", "example": "128",
"minimum": 0 "minimum": 0
}, },
"max_input_length": { "max_input_tokens": {
"type": "integer", "type": "integer",
"example": "1024", "example": "1024",
"minimum": 0 "minimum": 0
@ -1581,6 +1661,11 @@
"example": "e985a63cdc139290c5f700ff1929f0b5942cced2", "example": "e985a63cdc139290c5f700ff1929f0b5942cced2",
"nullable": true "nullable": true
}, },
"router": {
"type": "string",
"description": "Router Info",
"example": "text-generation-router"
},
"sha": { "sha": {
"type": "string", "type": "string",
"example": "null", "example": "null",
@ -1593,7 +1678,6 @@
}, },
"version": { "version": {
"type": "string", "type": "string",
"description": "Router Info",
"example": "0.5.0" "example": "0.5.0"
}, },
"waiting_served_ratio": { "waiting_served_ratio": {
@ -1606,13 +1690,12 @@
"Message": { "Message": {
"type": "object", "type": "object",
"required": [ "required": [
"role" "role",
"content"
], ],
"properties": { "properties": {
"content": { "content": {
"type": "string", "$ref": "#/components/schemas/MessageContent"
"example": "My name is David and I",
"nullable": true
}, },
"name": { "name": {
"type": "string", "type": "string",
@ -1622,13 +1705,6 @@
"role": { "role": {
"type": "string", "type": "string",
"example": "user" "example": "user"
},
"tool_calls": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolCall"
},
"nullable": true
} }
} }
}, },
@ -1658,6 +1734,12 @@
} }
} }
}, },
"Prompt": {
"type": "array",
"items": {
"type": "string"
}
},
"SimpleToken": { "SimpleToken": {
"type": "object", "type": "object",
"required": [ "required": [
@ -1817,9 +1899,7 @@
"$ref": "#/components/schemas/FunctionDefinition" "$ref": "#/components/schemas/FunctionDefinition"
}, },
"id": { "id": {
"type": "integer", "type": "string"
"format": "int32",
"minimum": 0
}, },
"type": { "type": {
"type": "string" "type": "string"
@ -1830,20 +1910,22 @@
"oneOf": [ "oneOf": [
{ {
"type": "object", "type": "object",
"required": [ "default": null,
"FunctionName" "nullable": true
],
"properties": {
"FunctionName": {
"type": "string"
}
}
}, },
{ {
"type": "string", "type": "string"
"enum": [ },
"OneOf" {
] "type": "object",
"required": [
"function"
],
"properties": {
"function": {
"$ref": "#/components/schemas/FunctionName"
}
}
} }
] ]
}, },

View File

@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--device=/dev/kfd --device=/dev/dri --group-add video \ --device=/dev/kfd --device=/dev/dri --group-add video \
--ipc=host --shm-size 256g --net host -v $volume:/data \ --ipc=host --shm-size 256g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.1.0-rocm \ ghcr.io/huggingface/text-generation-inference:2.1.1-rocm \
--model-id $model --model-id $model
``` ```

View File

@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \ docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.1.0 \ ghcr.io/huggingface/text-generation-inference:2.1.1 \
--model-id $model --model-id $model
``` ```

View File

@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.1.0 \ ghcr.io/huggingface/text-generation-inference:2.1.1 \
--model-id $model --model-id $model
``` ```
@ -88,7 +88,7 @@ curl 127.0.0.1:8080/generate \
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more. To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
```bash ```bash
docker run ghcr.io/huggingface/text-generation-inference:2.1.0 --help docker run ghcr.io/huggingface/text-generation-inference:2.1.1 --help
``` ```
</Tip> </Tip>

View File

@ -10,6 +10,7 @@ Text Generation Inference enables serving optimized models on specific hardware
- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) - [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) - [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
- [Gemma](https://huggingface.co/google/gemma-7b) - [Gemma](https://huggingface.co/google/gemma-7b)
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
- [Gemma2](https://huggingface.co/google/gemma2-9b) - [Gemma2](https://huggingface.co/google/gemma2-9b)
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus) - [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct) - [Dbrx](https://huggingface.co/databricks/dbrx-instruct)

View File

@ -5,85 +5,80 @@
"generated_tokens": 10, "generated_tokens": 10,
"prefill": [ "prefill": [
{ {
"id": 1, "id": 2323,
"logprob": null, "logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.7890625,
"text": "Test" "text": "Test"
}, },
{ {
"id": 2009, "id": 1715,
"logprob": -9.625, "logprob": -11.34375,
"text": " request" "text": " request"
} }
], ],
"seed": null, "seed": null,
"tokens": [ "tokens": [
{ {
"id": 13, "id": 198,
"logprob": -2.3359375, "logprob": -2.5742188,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 3057, "id": 262,
"logprob": -1.8779297, "logprob": -1.6230469,
"special": false, "special": false,
"text": "Test" "text": " "
}, },
{ {
"id": 2009, "id": 3270,
"logprob": -1.2744141, "logprob": -2.046875,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1425781,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.9238281,
"special": false, "special": false,
"text": " request" "text": " request"
}, },
{ {
"id": 13, "id": 13204,
"logprob": -1.6933594, "logprob": -0.076660156,
"special": false, "special": false,
"text": "\n" "text": ".method"
}, },
{ {
"id": 3057, "id": 624,
"logprob": -1.4648438, "logprob": -0.021987915,
"special": false, "special": false,
"text": "Test" "text": " =="
}, },
{ {
"id": 2009, "id": 364,
"logprob": -0.15600586, "logprob": -0.39208984,
"special": false, "special": false,
"text": " request" "text": " '"
}, },
{ {
"id": 13, "id": 3019,
"logprob": -0.8027344, "logprob": -0.10821533,
"special": false, "special": false,
"text": "\n" "text": "POST"
},
{
"id": 3057,
"logprob": -0.23022461,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.0069885254,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.02218628,
"special": false,
"text": "\n"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\nTest request\nTest request\nTest request\n" "generated_text": "\n \"\"\"\n if request.method == 'POST"
} }

View File

@ -5,85 +5,80 @@
"generated_tokens": 10, "generated_tokens": 10,
"prefill": [ "prefill": [
{ {
"id": 1, "id": 2323,
"logprob": null, "logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.84375,
"text": "Test" "text": "Test"
}, },
{ {
"id": 2009, "id": 1715,
"logprob": -9.6015625, "logprob": -11.34375,
"text": " request" "text": " request"
} }
], ],
"seed": 0, "seed": 0,
"tokens": [ "tokens": [
{ {
"id": 29899, "id": 13,
"logprob": -1.5625, "logprob": -2.2539062,
"special": false, "special": false,
"text": "-" "text": "."
}, },
{ {
"id": 1454, "id": 578,
"logprob": -0.20410156, "logprob": -0.15563965,
"special": false, "special": false,
"text": "for" "text": " The"
}, },
{ {
"id": 29899, "id": 3622,
"logprob": -0.8203125,
"special": false,
"text": " server"
},
{
"id": 706,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": "-" "text": " has"
}, },
{ {
"id": 9342, "id": 539,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": "comment" "text": " not"
}, },
{ {
"id": 29901, "id": 3686,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": ":" "text": " yet"
}, },
{ {
"id": 396, "id": 3288,
"logprob": -0.27685547,
"special": false,
"text": " #"
},
{
"id": 29906,
"logprob": -0.4970703,
"special": false,
"text": "2"
},
{
"id": 29900,
"logprob": -0.80615234,
"special": false,
"text": "0"
},
{
"id": 29896,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": "1" "text": " sent"
}, },
{ {
"id": 29955, "id": 904,
"logprob": -1.0751953, "logprob": 0.0,
"special": false, "special": false,
"text": "7" "text": " any"
},
{
"id": 828,
"logprob": 0.0,
"special": false,
"text": " data"
},
{
"id": 382,
"logprob": -1.5517578,
"special": false,
"text": ".\n\n"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "Test request-for-comment: #2017" "generated_text": "Test request. The server has not yet sent any data.\n\n"
} }

View File

@ -6,87 +6,82 @@
"generated_tokens": 10, "generated_tokens": 10,
"prefill": [ "prefill": [
{ {
"id": 1, "id": 2323,
"logprob": null, "logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.828125,
"text": "Test" "text": "Test"
}, },
{ {
"id": 2009, "id": 1715,
"logprob": -9.609375, "logprob": -11.34375,
"text": " request" "text": " request"
} }
], ],
"seed": null, "seed": null,
"tokens": [ "tokens": [
{ {
"id": 13, "id": 198,
"logprob": -2.3300781, "logprob": -2.5742188,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 3057, "id": 262,
"logprob": -1.8740234, "logprob": -1.6220703,
"special": false, "special": false,
"text": "Test" "text": " "
}, },
{ {
"id": 2009, "id": 3270,
"logprob": -1.2646484, "logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false, "special": false,
"text": " request" "text": " request"
}, },
{ {
"id": 13, "id": 13204,
"logprob": -1.7158203, "logprob": -0.07672119,
"special": false, "special": false,
"text": "\n" "text": ".method"
}, },
{ {
"id": 3057, "id": 624,
"logprob": -1.4667969, "logprob": -0.021987915,
"special": false, "special": false,
"text": "Test" "text": " =="
}, },
{ {
"id": 2009, "id": 364,
"logprob": -0.15344238, "logprob": -0.39208984,
"special": false, "special": false,
"text": " request" "text": " '"
}, },
{ {
"id": 13, "id": 3019,
"logprob": -0.81591797, "logprob": -0.10638428,
"special": false, "special": false,
"text": "\n" "text": "POST"
},
{
"id": 3057,
"logprob": -0.22973633,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.007045746,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.021957397,
"special": false,
"text": "\n"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\nTest request\nTest request\nTest request\n" "generated_text": "\n \"\"\"\n if request.method == 'POST"
}, },
{ {
"details": { "details": {
@ -95,87 +90,82 @@
"generated_tokens": 10, "generated_tokens": 10,
"prefill": [ "prefill": [
{ {
"id": 1, "id": 2323,
"logprob": null, "logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.84375,
"text": "Test" "text": "Test"
}, },
{ {
"id": 2009, "id": 1715,
"logprob": -9.59375, "logprob": -11.34375,
"text": " request" "text": " request"
} }
], ],
"seed": null, "seed": null,
"tokens": [ "tokens": [
{ {
"id": 13, "id": 198,
"logprob": -2.3378906, "logprob": -2.5742188,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 3057, "id": 262,
"logprob": -1.8779297, "logprob": -1.6220703,
"special": false, "special": false,
"text": "Test" "text": " "
}, },
{ {
"id": 2009, "id": 3270,
"logprob": -1.2636719, "logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false, "special": false,
"text": " request" "text": " request"
}, },
{ {
"id": 13, "id": 13204,
"logprob": -1.6992188, "logprob": -0.07672119,
"special": false, "special": false,
"text": "\n" "text": ".method"
}, },
{ {
"id": 3057, "id": 624,
"logprob": -1.4589844, "logprob": -0.021987915,
"special": false, "special": false,
"text": "Test" "text": " =="
}, },
{ {
"id": 2009, "id": 364,
"logprob": -0.15344238, "logprob": -0.39208984,
"special": false, "special": false,
"text": " request" "text": " '"
}, },
{ {
"id": 13, "id": 3019,
"logprob": -0.79052734, "logprob": -0.10638428,
"special": false, "special": false,
"text": "\n" "text": "POST"
},
{
"id": 3057,
"logprob": -0.22937012,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.007041931,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.022140503,
"special": false,
"text": "\n"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\nTest request\nTest request\nTest request\n" "generated_text": "\n \"\"\"\n if request.method == 'POST"
}, },
{ {
"details": { "details": {
@ -184,87 +174,82 @@
"generated_tokens": 10, "generated_tokens": 10,
"prefill": [ "prefill": [
{ {
"id": 1, "id": 2323,
"logprob": null, "logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.84375,
"text": "Test" "text": "Test"
}, },
{ {
"id": 2009, "id": 1715,
"logprob": -9.609375, "logprob": -11.34375,
"text": " request" "text": " request"
} }
], ],
"seed": null, "seed": null,
"tokens": [ "tokens": [
{ {
"id": 13, "id": 198,
"logprob": -2.3261719, "logprob": -2.5742188,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 3057, "id": 262,
"logprob": -1.8730469, "logprob": -1.6220703,
"special": false, "special": false,
"text": "Test" "text": " "
}, },
{ {
"id": 2009, "id": 3270,
"logprob": -1.2587891, "logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false, "special": false,
"text": " request" "text": " request"
}, },
{ {
"id": 13, "id": 13204,
"logprob": -1.6894531, "logprob": -0.07672119,
"special": false, "special": false,
"text": "\n" "text": ".method"
}, },
{ {
"id": 3057, "id": 624,
"logprob": -1.46875, "logprob": -0.021987915,
"special": false, "special": false,
"text": "Test" "text": " =="
}, },
{ {
"id": 2009, "id": 364,
"logprob": -0.1541748, "logprob": -0.39208984,
"special": false, "special": false,
"text": " request" "text": " '"
}, },
{ {
"id": 13, "id": 3019,
"logprob": -0.80322266, "logprob": -0.10638428,
"special": false, "special": false,
"text": "\n" "text": "POST"
},
{
"id": 3057,
"logprob": -0.22912598,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.0070495605,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.021606445,
"special": false,
"text": "\n"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\nTest request\nTest request\nTest request\n" "generated_text": "\n \"\"\"\n if request.method == 'POST"
}, },
{ {
"details": { "details": {
@ -273,86 +258,81 @@
"generated_tokens": 10, "generated_tokens": 10,
"prefill": [ "prefill": [
{ {
"id": 1, "id": 2323,
"logprob": null, "logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.84375,
"text": "Test" "text": "Test"
}, },
{ {
"id": 2009, "id": 1715,
"logprob": -9.6015625, "logprob": -11.34375,
"text": " request" "text": " request"
} }
], ],
"seed": null, "seed": null,
"tokens": [ "tokens": [
{ {
"id": 13, "id": 198,
"logprob": -2.3320312, "logprob": -2.5742188,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 3057, "id": 262,
"logprob": -1.875, "logprob": -1.6220703,
"special": false, "special": false,
"text": "Test" "text": " "
}, },
{ {
"id": 2009, "id": 3270,
"logprob": -1.2646484, "logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false, "special": false,
"text": " request" "text": " request"
}, },
{ {
"id": 13, "id": 13204,
"logprob": -1.6884766, "logprob": -0.07672119,
"special": false, "special": false,
"text": "\n" "text": ".method"
}, },
{ {
"id": 3057, "id": 624,
"logprob": -1.4589844, "logprob": -0.021987915,
"special": false, "special": false,
"text": "Test" "text": " =="
}, },
{ {
"id": 2009, "id": 364,
"logprob": -0.15185547, "logprob": -0.39208984,
"special": false, "special": false,
"text": " request" "text": " '"
}, },
{ {
"id": 13, "id": 3019,
"logprob": -0.79833984, "logprob": -0.10638428,
"special": false, "special": false,
"text": "\n" "text": "POST"
},
{
"id": 3057,
"logprob": -0.22827148,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.006996155,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.021560669,
"special": false,
"text": "\n"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\nTest request\nTest request\nTest request\n" "generated_text": "\n \"\"\"\n if request.method == 'POST"
} }
] ]

View File

@ -1,130 +1,124 @@
{ {
"details": { "details": {
"best_of_sequences": null, "best_of_sequences": null,
"finish_reason": "length", "finish_reason": "eos_token",
"generated_tokens": 20, "generated_tokens": 19,
"prefill": [], "prefill": [],
"seed": null, "seed": null,
"tokens": [ "tokens": [
{ {
"id": 415, "id": 415,
"logprob": -0.039886475, "logprob": -0.03665161,
"special": false, "special": false,
"text": " The" "text": " The"
}, },
{ {
"id": 12072, "id": 12072,
"logprob": -0.1430664, "logprob": -0.13549805,
"special": false, "special": false,
"text": " cow" "text": " cow"
}, },
{ {
"id": 349, "id": 349,
"logprob": -0.056488037, "logprob": -0.05819702,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 6328, "id": 6328,
"logprob": -0.6855469, "logprob": -0.6826172,
"special": false, "special": false,
"text": " standing" "text": " standing"
}, },
{ {
"id": 356, "id": 356,
"logprob": -0.1685791, "logprob": -0.1607666,
"special": false, "special": false,
"text": " on" "text": " on"
}, },
{ {
"id": 272, "id": 272,
"logprob": -0.50097656, "logprob": -0.5073242,
"special": false, "special": false,
"text": " the" "text": " the"
}, },
{ {
"id": 10305, "id": 10305,
"logprob": -0.017303467, "logprob": -0.016418457,
"special": false, "special": false,
"text": " beach" "text": " beach"
}, },
{ {
"id": 304, "id": 304,
"logprob": -1.3564453, "logprob": -1.3916016,
"special": false, "special": false,
"text": " and" "text": " and"
}, },
{ {
"id": 272, "id": 272,
"logprob": -0.017868042, "logprob": -0.020217896,
"special": false, "special": false,
"text": " the" "text": " the"
}, },
{ {
"id": 13088, "id": 13088,
"logprob": -0.0027103424, "logprob": -0.0028133392,
"special": false, "special": false,
"text": " chicken" "text": " chicken"
}, },
{ {
"id": 349, "id": 349,
"logprob": -0.003156662, "logprob": -0.003145218,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 6398, "id": 6398,
"logprob": -0.37304688, "logprob": -0.37060547,
"special": false, "special": false,
"text": " sitting" "text": " sitting"
}, },
{ {
"id": 356, "id": 356,
"logprob": -0.034576416, "logprob": -0.034851074,
"special": false, "special": false,
"text": " on" "text": " on"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.29418945, "logprob": -0.2878418,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 17972, "id": 17972,
"logprob": -0.042877197, "logprob": -0.046051025,
"special": false, "special": false,
"text": " pile" "text": " pile"
}, },
{ {
"id": 302, "id": 302,
"logprob": -0.00028443336, "logprob": -0.00028848648,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 2445, "id": 2445,
"logprob": -0.023223877, "logprob": -0.025772095,
"special": false, "special": false,
"text": " money" "text": " money"
}, },
{ {
"id": 28723, "id": 28723,
"logprob": -0.018157959, "logprob": -0.018127441,
"special": false, "special": false,
"text": "." "text": "."
}, },
{ {
"id": 32002, "id": 32002,
"logprob": -0.00018393993, "logprob": -0.00019824505,
"special": true, "special": true,
"text": "<end_of_utterance>" "text": "<end_of_utterance>"
},
{
"id": 2,
"logprob": -1.1920929e-07,
"special": true,
"text": "</s>"
} }
], ],
"top_tokens": null "top_tokens": null

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.8359375,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.6171875,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.3417969,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.8730469,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -1.2626953,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -1.7060547,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.4482422,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.15246582,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.796875,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.22766113,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.007045746,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.021759033,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": "\nTest request\nTest request\nTest request\n"
}

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.7890625,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.625,
"text": "request"
}
],
"seed": 0,
"tokens": [
{
"id": 29899,
"logprob": -1.4980469,
"special": false,
"text": "-"
},
{
"id": 1454,
"logprob": -0.19433594,
"special": false,
"text": "for"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 9342,
"logprob": 0.0,
"special": false,
"text": "comment"
},
{
"id": 29901,
"logprob": 0.0,
"special": false,
"text": ":"
},
{
"id": 396,
"logprob": -0.27392578,
"special": false,
"text": " #"
},
{
"id": 29906,
"logprob": -0.49389648,
"special": false,
"text": "2"
},
{
"id": 29900,
"logprob": -0.81103516,
"special": false,
"text": "0"
},
{
"id": 29896,
"logprob": 0.0,
"special": false,
"text": "1"
},
{
"id": 29955,
"logprob": -1.0800781,
"special": false,
"text": "7"
}
],
"top_tokens": null
},
"generated_text": "Test request-for-comment: #2017"
}

View File

@ -0,0 +1,358 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.8828125,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.5859375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.3359375,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.8623047,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -1.2451172,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -1.6923828,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.4492188,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.15197754,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.8022461,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.22583008,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.007095337,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.021652222,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": "\nTest request\nTest request\nTest request\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.796875,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.625,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.3476562,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.8789062,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -1.2734375,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -1.703125,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.4677734,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.15454102,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.7973633,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.23278809,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.006980896,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.022033691,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": "\nTest request\nTest request\nTest request\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.9296875,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.5703125,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.3203125,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.8486328,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -1.2480469,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -1.7060547,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.4511719,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.1529541,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.81396484,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.22180176,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.007133484,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.021835327,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": "\nTest request\nTest request\nTest request\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.84375,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.6171875,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.3261719,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.8691406,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -1.2597656,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -1.7070312,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.4550781,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.1538086,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.79345703,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.22924805,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.0070266724,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.021942139,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": "\nTest request\nTest request\nTest request\n"
}
]

View File

@ -5,7 +5,9 @@ from testing_utils import is_flaky_async, SYSTEM, require_backend_async
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_llama_gptq_handle(launcher): def flash_llama_gptq_handle(launcher):
with launcher("huggingface/llama-7b-gptq", num_shard=2, quantize="gptq") as handle: with launcher(
"astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="gptq"
) as handle:
yield handle yield handle

View File

@ -62,7 +62,7 @@ async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot)
response.generated_text response.generated_text
== " The cow is standing on the beach and the chicken is sitting on a pile of money." == " The cow is standing on the beach and the chicken is sitting on a pile of money."
), f"{repr(response.generated_text)}" ), f"{repr(response.generated_text)}"
assert response.details.generated_tokens == 20 assert response.details.generated_tokens == 19
assert response == response_snapshot assert response == response_snapshot

View File

@ -433,8 +433,17 @@ pub struct CompletionRequest {
pub stop: Option<Vec<String>>, pub stop: Option<Vec<String>>,
} }
#[derive(Clone, Serialize, ToSchema)]
#[serde(tag = "object")]
enum Completion {
#[serde(rename = "text_completion")]
Chunk(Chunk),
#[serde(rename = "text_completion")]
Final(CompletionFinal),
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)] #[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
pub(crate) struct Completion { pub(crate) struct CompletionFinal {
pub id: String, pub id: String,
#[schema(example = "1706270835")] #[schema(example = "1706270835")]
pub created: u64, pub created: u64,
@ -453,6 +462,15 @@ pub(crate) struct CompletionComplete {
pub finish_reason: String, pub finish_reason: String,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct Chunk {
pub id: String,
pub created: u64,
pub choices: Vec<CompletionComplete>,
pub model: String,
pub system_fingerprint: String,
}
#[derive(Clone, Deserialize, Serialize, ToSchema)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletion { pub(crate) struct ChatCompletion {
pub id: String, pub id: String,
@ -614,15 +632,6 @@ impl ChatCompletion {
} }
} }
} }
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct CompletionCompleteChunk {
pub id: String,
pub created: u64,
pub choices: Vec<CompletionComplete>,
pub model: String,
pub system_fingerprint: String,
}
#[derive(Clone, Serialize, ToSchema)] #[derive(Clone, Serialize, ToSchema)]
pub(crate) struct ChatCompletionChunk { pub(crate) struct ChatCompletionChunk {
pub id: String, pub id: String,

View File

@ -1,5 +1,6 @@
use axum::http::HeaderValue; use axum::http::HeaderValue;
use clap::Parser; use clap::Parser;
use clap::Subcommand;
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Cache, Repo, RepoType}; use hf_hub::{Cache, Repo, RepoType};
use opentelemetry::sdk::propagation::TraceContextPropagator; use opentelemetry::sdk::propagation::TraceContextPropagator;
@ -27,6 +28,9 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)] #[clap(author, version, about, long_about = None)]
struct Args { struct Args {
#[command(subcommand)]
command: Option<Commands>,
#[clap(default_value = "128", long, env)] #[clap(default_value = "128", long, env)]
max_concurrent_requests: usize, max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)] #[clap(default_value = "2", long, env)]
@ -85,10 +89,15 @@ struct Args {
max_client_batch_size: usize, max_client_batch_size: usize,
} }
#[derive(Debug, Subcommand)]
enum Commands {
PrintSchema,
}
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), RouterError> { async fn main() -> Result<(), RouterError> {
// Get args
let args = Args::parse(); let args = Args::parse();
// Pattern match configuration // Pattern match configuration
let Args { let Args {
max_concurrent_requests, max_concurrent_requests,
@ -119,10 +128,17 @@ async fn main() -> Result<(), RouterError> {
messages_api_enabled, messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
command,
} = args; } = args;
// Launch Tokio runtime let print_schema_command = match command {
Some(Commands::PrintSchema) => true,
None => {
// only init logging if we are not running the print schema command
init_logging(otlp_endpoint, otlp_service_name, json_output); init_logging(otlp_endpoint, otlp_service_name, json_output);
false
}
};
// Validate args // Validate args
if max_input_tokens >= max_total_tokens { if max_input_tokens >= max_total_tokens {
@ -388,6 +404,7 @@ async fn main() -> Result<(), RouterError> {
messages_api_enabled, messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
print_schema_command,
) )
.await?; .await?;
Ok(()) Ok(())

View File

@ -19,8 +19,8 @@ use crate::{
use crate::{ use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
CompletionRequest, CompletionType, DeltaToolCall, Function, Tool, VertexRequest, CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
VertexResponse, VertexResponse,
}; };
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolType}; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolType};
@ -705,7 +705,7 @@ async fn completions(
.as_secs(); .as_secs();
event event
.json_data(CompletionCompleteChunk { .json_data(Completion::Chunk(Chunk {
id: "".to_string(), id: "".to_string(),
created: current_time, created: current_time,
@ -718,7 +718,7 @@ async fn completions(
model: model_id.clone(), model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(), system_fingerprint: system_fingerprint.clone(),
}) }))
.unwrap_or_else(|_e| Event::default()) .unwrap_or_else(|_e| Event::default())
}; };
@ -931,7 +931,7 @@ async fn completions(
.collect::<Result<Vec<_>, _>>() .collect::<Result<Vec<_>, _>>()
.map_err(|(status, Json(err))| (status, Json(err)))?; .map_err(|(status, Json(err))| (status, Json(err)))?;
let response = Completion { let response = Completion::Final(CompletionFinal {
id: "".to_string(), id: "".to_string(),
created: current_time, created: current_time,
model: info.model_id.clone(), model: info.model_id.clone(),
@ -946,7 +946,7 @@ async fn completions(
completion_tokens, completion_tokens,
total_tokens, total_tokens,
}, },
}; });
// headers similar to `generate` but aggregated // headers similar to `generate` but aggregated
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
@ -1430,6 +1430,7 @@ pub async fn run(
messages_api_enabled: bool, messages_api_enabled: bool,
grammar_support: bool, grammar_support: bool,
max_client_batch_size: usize, max_client_batch_size: usize,
print_schema_command: bool,
) -> Result<(), WebServerError> { ) -> Result<(), WebServerError> {
// OpenAPI documentation // OpenAPI documentation
#[derive(OpenApi)] #[derive(OpenApi)]
@ -1463,7 +1464,10 @@ pub async fn run(
ChatCompletion, ChatCompletion,
CompletionRequest, CompletionRequest,
CompletionComplete, CompletionComplete,
CompletionCompleteChunk, Chunk,
Completion,
CompletionFinal,
Prompt,
GenerateParameters, GenerateParameters,
PrefillToken, PrefillToken,
Token, Token,
@ -1500,6 +1504,12 @@ pub async fn run(
struct ApiDoc; struct ApiDoc;
// Create state // Create state
if print_schema_command {
let api_doc = ApiDoc::openapi();
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
println!("{}", api_doc);
std::process::exit(0);
}
// Open connection, get model info and warmup // Open connection, get model info and warmup
let (scheduler, health_ext, shard_info, max_batch_total_tokens): ( let (scheduler, health_ext, shard_info, max_batch_total_tokens): (

View File

@ -8,6 +8,9 @@ from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.utils import weight_hub_files, download_weights from text_generation_server.utils import weight_hub_files, download_weights
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -16,7 +19,10 @@ def default_bloom():
revision = "main" revision = "main"
filenames = weight_hub_files(model_id, revision, ".safetensors") filenames = weight_hub_files(model_id, revision, ".safetensors")
download_weights(filenames, model_id, revision) download_weights(filenames, model_id, revision)
return BLOOMSharded(model_id) return BLOOMSharded(
model_id,
model_class=BloomForCausalLM,
)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View File

@ -10,7 +10,7 @@ from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def default_causal_lm(): def default_causal_lm():
return CausalLM("gpt2") return CausalLM.fallback("gpt2")
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View File

@ -1,13 +1,12 @@
import pytest import pytest
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch, CausalLM
from text_generation_server.models.santacoder import SantaCoder
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def default_santacoder(): def default_santacoder():
return SantaCoder("bigcode/santacoder") return CausalLM.fallback(model_id="bigcode/santacoder")
@pytest.fixture @pytest.fixture

View File

@ -20,7 +20,7 @@ def mt0_small_tokenizer():
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def default_seq2seq_lm(): def default_seq2seq_lm():
return Seq2SeqLM("bigscience/mt0-small") return Seq2SeqLM.fallback("bigscience/mt0-small")
@pytest.fixture @pytest.fixture

View File

@ -110,7 +110,7 @@ class PositionRotaryEmbedding(nn.Module):
beta_fast=32, beta_fast=32,
beta_slow=1, beta_slow=1,
) )
elif rope_scaling["type"] == "su": elif rope_scaling["type"] in ["su", "longrope"]:
short_factor = torch.tensor( short_factor = torch.tensor(
rope_scaling["short_factor"], dtype=torch.float32, device=device rope_scaling["short_factor"], dtype=torch.float32, device=device
) )

View File

@ -11,17 +11,27 @@ from pathlib import Path
from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.bloom import BLOOMSharded from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.mpt import MPTSharded from text_generation_server.models.custom_modeling.mpt_modeling import (
MPTForCausalLM,
)
from text_generation_server.models.bloom import BloomCausalLMBatch
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.rw import RW from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.opt import OPTSharded from text_generation_server.models.custom_modeling.neox_modeling import (
from text_generation_server.models.galactica import GalacticaSharded GPTNeoxForCausalLM,
from text_generation_server.models.santacoder import SantaCoder )
from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.custom_modeling.phi_modeling import (
from text_generation_server.models.gpt_neox import GPTNeoxSharded PhiConfig,
from text_generation_server.models.phi import Phi PhiForCausalLM,
)
from text_generation_server.models.custom_modeling.t5_modeling import (
T5ForConditionalGeneration,
)
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
@ -41,9 +51,6 @@ __all__ = [
"CausalLM", "CausalLM",
"GalacticaSharded", "GalacticaSharded",
"Seq2SeqLM", "Seq2SeqLM",
"SantaCoder",
"OPTSharded",
"T5Sharded",
"get_model", "get_model",
] ]
@ -53,38 +60,65 @@ FLASH_ATTENTION = True
try: try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.flash_gpt2 import FlashGPT2 from text_generation_server.models.custom_modeling.flash_llama_modeling import (
from text_generation_server.models.flash_neox import FlashNeoXSharded FlashLlamaForCausalLM,
from text_generation_server.models.flash_llama import (
FlashLlama,
) )
from text_generation_server.models.flash_qwen2 import ( from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
FlashQwen2, FlashCohereForCausalLM,
) )
from text_generation_server.models.flash_cohere import ( from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashCohere, FlashGemmaForCausalLM,
) )
from text_generation_server.models.flash_gemma import ( from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma, FlashGemma2ForCausalLM,
) )
from text_generation_server.models.flash_gemma2 import ( from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
FlashGemma2, FlashDbrxForCausalLM,
DbrxConfig,
)
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
RWConfig,
FlashRWForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
FlashGPTNeoXForCausalLM,
) )
from text_generation_server.models.pali_gemma import ( from text_generation_server.models.pali_gemma import (
PaliGemma, PaliGemmaBatch,
) )
from text_generation_server.models.flash_santacoder import ( from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
FlashSantacoderSharded, PaliGemmaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM,
) )
from text_generation_server.models.idefics import IDEFICSSharded from text_generation_server.models.idefics import IDEFICSSharded
from text_generation_server.models.llava_next import LlavaNext from text_generation_server.models.custom_modeling.llava_next import (
from text_generation_server.models.idefics2 import Idefics2 LlavaNextForConditionalGeneration,
from text_generation_server.models.flash_mistral import FlashMistral )
from text_generation_server.models.flash_mixtral import FlashMixtral
from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 FlashSantacoderForCausalLM,
from text_generation_server.models.flash_dbrx import FlashDbrx )
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
FlashStarcoder2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
FlashMixtralForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
FlashGPT2ForCausalLM,
)
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.layers.attention import SUPPORTS_WINDOWING from text_generation_server.layers.attention import SUPPORTS_WINDOWING
except ImportError as e: except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}") logger.warning(f"Could not import Flash Attention enabled models: {e}")
@ -93,21 +127,7 @@ except ImportError as e:
if FLASH_ATTENTION: if FLASH_ATTENTION:
__all__.append(FlashCausalLM) __all__.append(FlashCausalLM)
__all__.append(FlashGPT2)
__all__.append(FlashNeoXSharded)
__all__.append(FlashRWSharded)
__all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama)
__all__.append(IDEFICSSharded) __all__.append(IDEFICSSharded)
__all__.append(FlashMistral)
__all__.append(FlashMixtral)
__all__.append(FlashDbrx)
__all__.append(FlashPhi)
__all__.append(FlashQwen2)
__all__.append(FlashStarcoder2)
__all__.append(FlashGemma)
__all__.append(FlashGemma2)
__all__.append(FlashCohere)
MAMBA_AVAILABLE = True MAMBA_AVAILABLE = True
MAMBA_IMPORT_ERROR = None MAMBA_IMPORT_ERROR = None
@ -150,6 +170,11 @@ class ModelType(enum.Enum):
"name": "Gemma", "name": "Gemma",
"url": "https://huggingface.co/google/gemma-7b", "url": "https://huggingface.co/google/gemma-7b",
} }
PALIGEMMA = {
"type": "paligemma",
"name": "PaliGemma",
"url": "https://huggingface.co/google/paligemma-3b-pt-224",
}
GEMMA2 = { GEMMA2 = {
"type": "gemma2", "type": "gemma2",
"name": "Gemma2", "name": "Gemma2",
@ -452,13 +477,16 @@ def get_model(
) )
if model_id.startswith("facebook/galactica"): if model_id.startswith("facebook/galactica"):
return GalacticaSharded( return CausalLM(
model_id, model_id=model_id,
revision, # Yes galactica is just an OPT model.
model_class=OPTForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
batch_class=GalacticaCausalLMBatch,
) )
if ( if (
@ -467,22 +495,26 @@ def get_model(
and model_id.startswith("bigcode/") and model_id.startswith("bigcode/")
): ):
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashSantacoderSharded( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashSantacoderForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
num_kv_heads=1,
) )
elif sharded: elif sharded:
raise NotImplementedError( raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
) )
else: else:
return SantaCoder( return CausalLM.fallback(
model_id, model_id=model_id,
revision, revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
@ -490,38 +522,44 @@ def get_model(
) )
if model_type == BLOOM: if model_type == BLOOM:
return BLOOMSharded( return CausalLM(
model_id, model_id=model_id,
revision, model_class=BloomForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
batch_class=BloomCausalLMBatch,
) )
elif model_type == MPT: elif model_type == MPT:
return MPTSharded( return CausalLM(
model_id, model_id=model_id,
revision, model_class=MPTForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
batch_class=CausalLMBatchKeysLast,
) )
elif model_type == GPT2: elif model_type == GPT2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
try: try:
return FlashGPT2( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashGPT2ForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
except RuntimeError as e: except RuntimeError as e:
# Lots of legacy models with various weight names. # Lots of legacy models with various weight names.
logger.warning(f"Couldn't load flash gpt2 variant: {e}") logger.warning(f"Couldn't load flash gpt2 variant: {e}")
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -532,7 +570,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -542,25 +580,28 @@ def get_model(
) )
elif model_type == GPT_NEOX: elif model_type == GPT_NEOX:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashNeoXSharded( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashGPTNeoXForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
return GPTNeoxSharded( return CausalLM(
model_id, model_id=model_id,
revision, model_class=GPTNeoxForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -571,16 +612,18 @@ def get_model(
elif model_type == PHI: elif model_type == PHI:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashPhi( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashPhiForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -595,9 +638,11 @@ def get_model(
"Legacy phi-msft is not supported with Flash Attention" "Legacy phi-msft is not supported with Flash Attention"
) )
else: else:
return Phi( return CausalLM(
model_id, model_id=model_id,
revision, model_class=PhiForCausalLM,
config_class=PhiConfig,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
@ -606,9 +651,10 @@ def get_model(
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashLlama( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashLlamaForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
@ -618,7 +664,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -628,18 +674,22 @@ def get_model(
) )
if model_type == GEMMA: if model_type == GEMMA:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashGemma( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashGemmaForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -649,18 +699,22 @@ def get_model(
) )
elif model_type == GEMMA2: elif model_type == GEMMA2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashGemma2( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashGemma2ForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -671,18 +725,20 @@ def get_model(
if model_type == COHERE: if model_type == COHERE:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCohere( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashCohereForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -693,18 +749,23 @@ def get_model(
if model_type == DBRX: if model_type == DBRX:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashDbrx( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashDbrxForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
# Dbrx works better in bfloat16.
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DbrxConfig,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -718,27 +779,37 @@ def get_model(
if FLASH_ATTENTION: if FLASH_ATTENTION:
if config_dict.get("alibi", False): if config_dict.get("alibi", False):
raise NotImplementedError("sharded is not supported for this model") raise NotImplementedError("sharded is not supported for this model")
return FlashRWSharded( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashRWForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
aliases={
"lm_head.weight": ["transformer.word_embeddings.weight"],
"transformer.word_embeddings.weight": ["lm_head.weight"],
},
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=RWConfig,
) )
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
else: else:
if FLASH_ATTENTION and not config_dict.get("alibi", False): if FLASH_ATTENTION and not config_dict.get("alibi", False):
return FlashRWSharded( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashRWForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=RWConfig,
) )
else: else:
return RW( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -749,18 +820,20 @@ def get_model(
if model_type == MISTRAL: if model_type == MISTRAL:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashMistral( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashMistralForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -771,18 +844,20 @@ def get_model(
if model_type == MIXTRAL: if model_type == MIXTRAL:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashMixtral( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashMixtralForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -793,19 +868,22 @@ def get_model(
if model_type == STARCODER2: if model_type == STARCODER2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashStarcoder2( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=FlashStarcoder2ForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError( raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2") FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
) )
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -816,17 +894,20 @@ def get_model(
if model_type == QWEN2: if model_type == QWEN2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashQwen2( return FlashCausalLM(
model_id, model_id=model_id,
revision, model_class=Qwen2ForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
else: else:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -836,9 +917,10 @@ def get_model(
) )
if model_type == OPT: if model_type == OPT:
return OPTSharded( return CausalLM(
model_id, model_id=model_id,
revision, model_class=OPTForCausalLM,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
@ -846,13 +928,20 @@ def get_model(
) )
if model_type == T5: if model_type == T5:
return T5Sharded( return Seq2SeqLM(
model_id, model_id=model_id,
revision, model_class=T5ForConditionalGeneration,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
aliases={
"shared.weight": [
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
},
) )
if model_type == IDEFICS: if model_type == IDEFICS:
if FLASH_ATTENTION: if FLASH_ATTENTION:
@ -868,34 +957,45 @@ def get_model(
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == IDEFICS2: if model_type == IDEFICS2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return Idefics2( return VlmCausalLM(
model_id, model_id=model_id,
revision, model_class=Idefics2ForConditionalGeneration,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
) )
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "paligemma": if model_type == PALIGEMMA:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return PaliGemma( return VlmCausalLM(
model_id, model_id=model_id,
revision, model_class=PaliGemmaForConditionalGeneration,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
batch_class=PaliGemmaBatch,
) )
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == LLAVA_NEXT: if model_type == LLAVA_NEXT:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return LlavaNext( return VlmCausalLM(
model_id, model_class=LlavaNextForConditionalGeneration,
revision, model_id=model_id,
revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
@ -919,7 +1019,7 @@ def get_model(
elif quantize == "exl2": elif quantize == "exl2":
raise NotImplementedError("exl2 quantization is not supported for AutoModel") raise NotImplementedError("exl2 quantization is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -928,7 +1028,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
return Seq2SeqLM( return Seq2SeqLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -940,7 +1040,7 @@ def get_model(
auto_map = config_dict.get("auto_map", None) auto_map = config_dict.get("auto_map", None)
if trust_remote_code and auto_map is not None: if trust_remote_code and auto_map is not None:
if "AutoModelForCausalLM" in auto_map.keys(): if "AutoModelForCausalLM" in auto_map.keys():
return CausalLM( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -949,7 +1049,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if "AutoModelForSeq2SeqLM" in auto_map.keys(): if "AutoModelForSeq2SeqLM" in auto_map.keys():
return Seq2SeqLM( return Seq2SeqLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,

View File

@ -4,22 +4,12 @@ import torch.distributed
from typing import Optional, Type from typing import Optional, Type
from transformers import ( from transformers import (
AutoTokenizer,
AutoConfig,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
from text_generation_server.models import CausalLM from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class BloomCausalLMBatch(CausalLMBatch): class BloomCausalLMBatch(CausalLMBatch):
@ -37,69 +27,6 @@ class BloomCausalLMBatch(CausalLMBatch):
class BLOOMSharded(CausalLM): class BLOOMSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
slow_but_exact=False,
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
config.pad_token_id = 3
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
prefix="transformer",
)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = BloomForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@property @property
def batch_type(self) -> Type[CausalLMBatch]: def batch_type(self) -> Type[CausalLMBatch]:
return BloomCausalLMBatch return BloomCausalLMBatch

View File

@ -1,13 +1,25 @@
import torch import torch
import time import time
import torch.distributed
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase from transformers import (
AutoConfig,
AutoTokenizer,
AutoModelForCausalLM,
PreTrainedTokenizerBase,
)
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
@ -478,10 +490,88 @@ class CausalLMBatch(Batch):
return len(self.requests) return len(self.requests)
@dataclass
class CausalLMBatchKeysLast(Batch):
keys_head_dim_last: bool = False
class CausalLM(Model): class CausalLM(Model):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
model_class,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
default_dtype=torch.float16,
trust_remote_code: bool = False,
tokenizer_class=AutoTokenizer,
config_class=AutoConfig,
batch_class=CausalLMBatch,
):
self.batch_class = batch_class
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = config_class.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = config.pad_token_id
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
model = model_class(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@classmethod
def fallback(
cls,
model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
speculator: Optional[str] = None, speculator: Optional[str] = None,
@ -537,7 +627,12 @@ class CausalLM(Model):
else: else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super(CausalLM, self).__init__( self = cls.__new__(
cls,
)
self.batch_class = CausalLMBatch
super().__init__(
self,
model_id=model_id, model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -545,15 +640,11 @@ class CausalLM(Model):
dtype=dtype, dtype=dtype,
device=device, device=device,
) )
return self
@property @property
def batch_type(self) -> Type[CausalLMBatch]: def batch_type(self) -> Type[CausalLMBatch]:
return CausalLMBatch return self.batch_class
def decode(self, generated_ids: List[int]) -> str:
return self.tokenizer.decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None

View File

@ -815,7 +815,7 @@ class BloomModel(BloomPreTrainedModel):
class BloomForCausalLM(BloomPreTrainedModel): class BloomForCausalLM(BloomPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__(config) super().__init__(config)
self.transformer = BloomModel(config, weights) self.transformer = BloomModel(config, weights)

View File

@ -446,7 +446,7 @@ class CLIPEncoder(nn.Module):
class CLIPTextTransformer(nn.Module): class CLIPTextTransformer(nn.Module):
def __init__(self, config: CLIPTextConfig): def __init__(self, prefix: str, config: CLIPTextConfig):
super().__init__() super().__init__()
self.config = config self.config = config
embed_dim = config.hidden_size embed_dim = config.hidden_size
@ -536,9 +536,9 @@ class CLIPTextModel(CLIPPreTrainedModel):
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"] _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
def __init__(self, config: CLIPTextConfig): def __init__(self, prefix, config: CLIPTextConfig):
super().__init__(config) super().__init__(config)
self.text_model = CLIPTextTransformer(config) self.text_model = CLIPTextTransformer(prefix, config)
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()

View File

@ -363,9 +363,9 @@ class CohereMLP(nn.Module):
class FlashCohereLayer(nn.Module): class FlashCohereLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix: str, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}" prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = FlashCohereAttention( self.self_attn = FlashCohereAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
@ -416,18 +416,19 @@ class FlashCohereLayer(nn.Module):
class FlashCohereModel(torch.nn.Module): class FlashCohereModel(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights prefix=f"{prefix}.embed_tokens", weights=weights
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
FlashCohereLayer( FlashCohereLayer(
prefix,
layer_id, layer_id,
config, config,
weights, weights,
@ -436,7 +437,7 @@ class FlashCohereModel(torch.nn.Module):
] ]
) )
self.norm = FastLayerNorm.load_no_bias( self.norm = FastLayerNorm.load_no_bias(
prefix="model.norm", weights=weights, eps=config.layer_norm_eps prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_eps
) )
self.gradient_checkpointing = False self.gradient_checkpointing = False
@ -486,10 +487,15 @@ class FlashCohereModel(torch.nn.Module):
class FlashCohereForCausalLM(torch.nn.Module): class FlashCohereForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.model = FlashCohereModel(config, weights) if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = FlashCohereModel(prefix, config, weights)
try: try:
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
@ -499,7 +505,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
except RuntimeError: except RuntimeError:
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="model.embed_tokens", prefix=f"{prefix}.embed_tokens",
weights=weights, weights=weights,
) )
self.logit_scale = config.logit_scale self.logit_scale = config.logit_scale

View File

@ -593,9 +593,9 @@ class DenseMoE(nn.Module):
class DbrxLayer(nn.Module): class DbrxLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix: str, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"transformer.blocks.{layer_id}" prefix = f"{prefix}.blocks.{layer_id}"
self.attn = DbrxNormAttentionNorm( self.attn = DbrxNormAttentionNorm(
prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights
@ -637,16 +637,17 @@ class DbrxLayer(nn.Module):
class DbrxModel(torch.nn.Module): class DbrxModel(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="transformer.wte", weights=weights prefix=f"{prefix}.wte", weights=weights
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
DbrxLayer( DbrxLayer(
prefix,
layer_id, layer_id,
config, config,
weights, weights,
@ -655,7 +656,7 @@ class DbrxModel(torch.nn.Module):
] ]
) )
self.norm = FastLayerNorm.load_no_bias( self.norm = FastLayerNorm.load_no_bias(
prefix="transformer.norm_f", weights=weights, eps=1e-5 prefix=f"{prefix}.norm_f", weights=weights, eps=1e-5
) )
self.head_size = self.layers[0].attn.self_attn.head_size self.head_size = self.layers[0].attn.self_attn.head_size
@ -702,10 +703,15 @@ class DbrxModel(torch.nn.Module):
class FlashDbrxForCausalLM(torch.nn.Module): class FlashDbrxForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.model = DbrxModel(config, weights) if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
self.model = DbrxModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head",

View File

@ -102,7 +102,7 @@ class Gemma2Config(PretrainedConfig):
class Gemma2FastRMSNorm(FastRMSNorm): class Gemma2FastRMSNorm(FastRMSNorm):
@classmethod @classmethod
def load(cls, prefix, weights, eps=1e-6): def load(cls, prefix: str, weights, eps=1e-6):
dtype = weights.dtype dtype = weights.dtype
weights.dtype = torch.float32 weights.dtype = torch.float32
weight = weights.get_tensor(f"{prefix}.weight") + 1 weight = weights.get_tensor(f"{prefix}.weight") + 1
@ -123,7 +123,7 @@ class Gemma2FastRMSNorm(FastRMSNorm):
return hidden_states.to(self.dtype), residual return hidden_states.to(self.dtype), residual
def load_attention(config, prefix, weights): def load_attention(config, prefix: str, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) return _load_gqa(config, prefix, weights)
else: else:
@ -305,7 +305,7 @@ class Gemma2MLP(nn.Module):
class FlashGemma2Layer(nn.Module): class FlashGemma2Layer(nn.Module):
def __init__(self, prefix, config, weights, causal: bool, is_sliding: bool): def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool):
super().__init__() super().__init__()
self.self_attn = FlashGemma2Attention( self.self_attn = FlashGemma2Attention(
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
@ -376,7 +376,7 @@ class FlashGemma2Layer(nn.Module):
class FlashGemma2Model(torch.nn.Module): class FlashGemma2Model(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool): def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
@ -442,7 +442,7 @@ class FlashGemma2Model(torch.nn.Module):
class FlashGemma2ForCausalLM(torch.nn.Module): class FlashGemma2ForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool): def __init__(self, prefix: str, config, weights, *, causal: bool = True):
super().__init__() super().__init__()
embed_norm = config.hidden_size**0.5 embed_norm = config.hidden_size**0.5

View File

@ -102,7 +102,7 @@ class GemmaConfig(PretrainedConfig):
class GemmaFastRMSNorm(FastRMSNorm): class GemmaFastRMSNorm(FastRMSNorm):
@classmethod @classmethod
def load(cls, prefix, weights, eps=1e-6): def load(cls, prefix: str, weights, eps=1e-6):
dtype = weights.dtype dtype = weights.dtype
weights.dtype = torch.float32 weights.dtype = torch.float32
weight = weights.get_tensor(f"{prefix}.weight") + 1 weight = weights.get_tensor(f"{prefix}.weight") + 1
@ -123,7 +123,7 @@ class GemmaFastRMSNorm(FastRMSNorm):
return hidden_states.to(self.dtype), residual return hidden_states.to(self.dtype), residual
def load_attention(config, prefix, weights): def load_attention(config, prefix: str, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) return _load_gqa(config, prefix, weights)
else: else:
@ -261,7 +261,7 @@ class FlashGemmaAttention(torch.nn.Module):
class GemmaMLP(nn.Module): class GemmaMLP(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
act = config.hidden_act act = config.hidden_act
self.act = ( self.act = (
@ -299,7 +299,7 @@ class GemmaMLP(nn.Module):
class FlashGemmaLayer(nn.Module): class FlashGemmaLayer(nn.Module):
def __init__(self, prefix, config, weights, causal: bool): def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__() super().__init__()
self.self_attn = FlashGemmaAttention( self.self_attn = FlashGemmaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal
@ -354,7 +354,7 @@ class FlashGemmaLayer(nn.Module):
class FlashGemmaModel(torch.nn.Module): class FlashGemmaModel(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool): def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
@ -419,7 +419,7 @@ class FlashGemmaModel(torch.nn.Module):
class FlashGemmaForCausalLM(torch.nn.Module): class FlashGemmaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool): def __init__(self, prefix: str, config, weights, *, causal: bool = True):
super().__init__() super().__init__()
embed_norm = config.hidden_size**0.5 embed_norm = config.hidden_size**0.5

View File

@ -261,7 +261,7 @@ class FlashGPT2Attention(torch.nn.Module):
class GPT2MLP(nn.Module): class GPT2MLP(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
act = config.activation_function act = config.activation_function
self.act = ( self.act = (
@ -298,7 +298,7 @@ class GPT2MLP(nn.Module):
class FlashGPT2Layer(nn.Module): class FlashGPT2Layer(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.self_attn = FlashGPT2Attention( self.self_attn = FlashGPT2Attention(
prefix=f"{prefix}.attn", config=config, weights=weights prefix=f"{prefix}.attn", config=config, weights=weights
@ -350,7 +350,7 @@ class FlashGPT2Layer(nn.Module):
class FlashGPT2Model(torch.nn.Module): class FlashGPT2Model(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
@ -414,7 +414,7 @@ class FlashGPT2Model(torch.nn.Module):
class FlashGPT2ForCausalLM(torch.nn.Module): class FlashGPT2ForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(

View File

@ -54,7 +54,7 @@ if SYSTEM == "rocm":
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
def load_attention(config, prefix, weights, layer_id): def load_attention(config, prefix: str, weights, layer_id):
# Only defined in granite. # Only defined in granite.
bias = getattr(config, "attention_bias", False) bias = getattr(config, "attention_bias", False)
head_size = config.hidden_size // config.num_attention_heads head_size = config.hidden_size // config.num_attention_heads
@ -467,7 +467,7 @@ class FlashLlamaModel(torch.nn.Module):
class FlashLlamaForCausalLM(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(

View File

@ -248,7 +248,7 @@ class MistralAttention(torch.nn.Module):
class MistralMLP(nn.Module): class MistralMLP(nn.Module):
def __init__(self, prefix, config, weights, layer_id): def __init__(self, prefix: str, config, weights, layer_id):
super().__init__() super().__init__()
self.hidden_act = config.hidden_act self.hidden_act = config.hidden_act
self.act = ( self.act = (
@ -328,7 +328,7 @@ class MistralMLP(nn.Module):
class MistralLayer(nn.Module): class MistralLayer(nn.Module):
def __init__(self, prefix, config, weights, layer_id): def __init__(self, prefix: str, config, weights, layer_id):
super().__init__() super().__init__()
self.self_attn = MistralAttention( self.self_attn = MistralAttention(
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
@ -392,7 +392,7 @@ class MistralLayer(nn.Module):
class MistralModel(torch.nn.Module): class MistralModel(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
@ -462,7 +462,7 @@ class MistralModel(torch.nn.Module):
class FlashMistralForCausalLM(torch.nn.Module): class FlashMistralForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, name=None): def __init__(self, prefix: str, config, weights, name=None):
if name is None: if name is None:
name = "model" name = "model"
super().__init__() super().__init__()

View File

@ -116,7 +116,7 @@ def promote_scalar(x: torch.Tensor) -> torch.Tensor:
return x.view(1) if len(x.size()) == 0 else x return x.view(1) if len(x.size()) == 0 else x
def load_attention(config, prefix, weights): def load_attention(config, prefix: str, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) return _load_gqa(config, prefix, weights)
else: else:
@ -155,7 +155,7 @@ def _load_gqa(config, prefix: str, weights):
) )
def _load_experts(config, prefix, mat, weights): def _load_experts(config, prefix: str, mat, weights):
if config.quantize is not None: if config.quantize is not None:
raise NotImplementedError("Mixtral does not support weight quantization yet.") raise NotImplementedError("Mixtral does not support weight quantization yet.")
@ -475,7 +475,7 @@ class DenseMoE(nn.Module):
class MixtralLayer(nn.Module): class MixtralLayer(nn.Module):
def __init__(self, prefix, layer_id, config, weights): def __init__(self, prefix: str, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"{prefix}.layers.{layer_id}" prefix = f"{prefix}.layers.{layer_id}"
@ -536,7 +536,7 @@ class MixtralLayer(nn.Module):
class MixtralModel(torch.nn.Module): class MixtralModel(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
@ -610,7 +610,7 @@ class MixtralModel(torch.nn.Module):
class FlashMixtralForCausalLM(torch.nn.Module): class FlashMixtralForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.model = MixtralModel(prefix, config, weights) self.model = MixtralModel(prefix, config, weights)

View File

@ -305,12 +305,12 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel):
class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.embed_in = TensorParallelEmbedding( self.embed_in = TensorParallelEmbedding(
prefix="gpt_neox.embed_in", weights=weights prefix=f"{prefix}.embed_in", weights=weights
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
@ -320,7 +320,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
] ]
) )
self.final_layer_norm = FastLayerNorm.load( self.final_layer_norm = FastLayerNorm.load(
prefix="gpt_neox.final_layer_norm", prefix=f"{prefix}.final_layer_norm",
weights=weights, weights=weights,
eps=config.layer_norm_eps, eps=config.layer_norm_eps,
) )
@ -370,9 +370,15 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__(config) super().__init__(config)
self.gpt_neox = FlashGPTNeoXModel(config, weights)
if not prefix:
prefix = "gpt_neox"
else:
prefix = f"{prefix}.gpt_neox"
self.gpt_neox = FlashGPTNeoXModel(prefix, config, weights)
self.embed_out = SpeculativeHead.load( self.embed_out = SpeculativeHead.load(
config, prefix="embed_out", weights=weights config, prefix="embed_out", weights=weights

View File

@ -258,9 +258,9 @@ class PhiMLP(nn.Module):
class FlashPhiLayer(nn.Module): class FlashPhiLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix: str, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}" prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = FlashPhiAttention( self.self_attn = FlashPhiAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
@ -307,18 +307,19 @@ class FlashPhiLayer(nn.Module):
class FlashPhiModel(torch.nn.Module): class FlashPhiModel(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights prefix=f"{prefix}.embed_tokens", weights=weights
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
FlashPhiLayer( FlashPhiLayer(
prefix,
layer_id, layer_id,
config, config,
weights, weights,
@ -378,10 +379,15 @@ class FlashPhiModel(torch.nn.Module):
class FlashPhiForCausalLM(torch.nn.Module): class FlashPhiForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.model = FlashPhiModel(config, weights) if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = FlashPhiModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head",

View File

@ -203,9 +203,9 @@ class Qwen2MLP(nn.Module):
class Qwen2Layer(nn.Module): class Qwen2Layer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}" prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = Qwen2Attention( self.self_attn = Qwen2Attention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
@ -260,17 +260,18 @@ class Qwen2Layer(nn.Module):
class Qwen2Model(torch.nn.Module): class Qwen2Model(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights prefix=f"{prefix}.embed_tokens", weights=weights
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
Qwen2Layer( Qwen2Layer(
prefix,
layer_id, layer_id,
config, config,
weights, weights,
@ -279,7 +280,7 @@ class Qwen2Model(torch.nn.Module):
] ]
) )
self.norm = FastRMSNorm.load( self.norm = FastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
) )
self.gradient_checkpointing = False self.gradient_checkpointing = False
@ -331,10 +332,15 @@ class Qwen2Model(torch.nn.Module):
class Qwen2ForCausalLM(torch.nn.Module): class Qwen2ForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.model = Qwen2Model(config, weights) if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = Qwen2Model(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head",

View File

@ -127,7 +127,7 @@ class FlashRWAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
config, config,
prefix, prefix: str,
weights, weights,
): ):
super().__init__() super().__init__()
@ -236,7 +236,7 @@ class FlashRWLargeAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
config, config,
prefix, prefix: str,
weights, weights,
): ):
super().__init__() super().__init__()
@ -358,7 +358,7 @@ class FlashRWLargeAttention(torch.nn.Module):
class FlashMLP(nn.Module): class FlashMLP(nn.Module):
def __init__(self, config, prefix, weights): def __init__(self, config, prefix: str, weights):
super().__init__() super().__init__()
self.act = torch.nn.functional.gelu self.act = torch.nn.functional.gelu
@ -380,6 +380,7 @@ class FlashRWLayer(nn.Module):
def __init__( def __init__(
self, self,
layer_id, layer_id,
prefix: str,
config, config,
weights, weights,
): ):
@ -388,7 +389,7 @@ class FlashRWLayer(nn.Module):
parallel_attn = config.parallel_attn parallel_attn = config.parallel_attn
self.parallel_attn = parallel_attn self.parallel_attn = parallel_attn
prefix = f"transformer.h.{layer_id}" prefix = f"{prefix}.h.{layer_id}"
self.input_layernorm = FastLayerNorm.load( self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm", prefix=f"{prefix}.input_layernorm",
@ -479,7 +480,7 @@ class FlashRWLayer(nn.Module):
class FlashRWLayerNorm(nn.Module): class FlashRWLayerNorm(nn.Module):
def __init__(self, config, prefix, weights): def __init__(self, config, prefix: str, weights):
super().__init__() super().__init__()
self.num_ln = config.num_ln_in_parallel_attn self.num_ln = config.num_ln_in_parallel_attn
@ -518,9 +519,9 @@ class FlashRWLayerNorm(nn.Module):
class FlashRWLargeLayer(nn.Module): class FlashRWLargeLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, layer_id, prefix: str, config, weights):
super().__init__() super().__init__()
prefix = f"transformer.h.{layer_id}" prefix = f"{prefix}.h.{layer_id}"
self.ln_layer = FlashRWLayerNorm(config, prefix, weights) self.ln_layer = FlashRWLayerNorm(config, prefix, weights)
@ -580,18 +581,18 @@ class FlashRWPreTrainedModel(PreTrainedModel):
class FlashRWModel(FlashRWPreTrainedModel): class FlashRWModel(FlashRWPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.word_embeddings = TensorParallelEmbedding( self.word_embeddings = TensorParallelEmbedding(
prefix="transformer.word_embeddings", weights=weights prefix=f"{prefix}.word_embeddings", weights=weights
) )
if config.new_decoder_architecture: if config.new_decoder_architecture:
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
FlashRWLargeLayer(layer_id, config, weights) FlashRWLargeLayer(layer_id, prefix, config, weights)
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
@ -599,14 +600,14 @@ class FlashRWModel(FlashRWPreTrainedModel):
else: else:
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
FlashRWLayer(layer_id, config, weights) FlashRWLayer(layer_id, prefix, config, weights)
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.cache_size = self.h[0].self_attention.num_heads_kv self.cache_size = self.h[0].self_attention.num_heads_kv
self.ln_f = FastLayerNorm.load( self.ln_f = FastLayerNorm.load(
prefix="transformer.ln_f", prefix=f"{prefix}.ln_f",
weights=weights, weights=weights,
eps=config.layer_norm_epsilon, eps=config.layer_norm_epsilon,
) )
@ -653,10 +654,15 @@ class FlashRWModel(FlashRWPreTrainedModel):
class FlashRWForCausalLM(FlashRWPreTrainedModel): class FlashRWForCausalLM(FlashRWPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__(config) super().__init__(config)
self.transformer = FlashRWModel(config, weights) if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
self.transformer = FlashRWModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights) self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights)

View File

@ -346,16 +346,16 @@ class MLP(nn.Module):
class Block(nn.Module): class Block(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix: str, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"transformer.h.{layer_id}" prefix = f"{prefix}.h.{layer_id}"
self.ln_1 = FastLayerNorm.load( self.ln_1 = FastLayerNorm.load(
prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
) )
self.ln_2 = FastLayerNorm.load( self.ln_2 = FastLayerNorm.load(
prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon
) )
self.attn = FlashMQAttention( self.self_attn = FlashMQAttention(
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
config=config, config=config,
weights=weights, weights=weights,
@ -378,7 +378,7 @@ class Block(nn.Module):
max_s, max_s,
): ):
hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn( hidden_states = self.self_attn(
hidden_states, hidden_states,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
@ -396,25 +396,26 @@ class Block(nn.Module):
class FlashSantacoderModel(nn.Module): class FlashSantacoderModel(nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.config = config self.config = config
self.process_group = weights.process_group self.process_group = weights.process_group
self.wte = TensorParallelEmbedding( self.wte = TensorParallelEmbedding(
prefix="transformer.wte", prefix=f"{prefix}.wte",
weights=weights, weights=weights,
reduce=False, reduce=False,
) )
self.wpe = TensorParallelEmbedding( self.wpe = TensorParallelEmbedding(
prefix="transformer.wpe", prefix=f"{prefix}.wpe",
weights=weights, weights=weights,
reduce=False, reduce=False,
) )
self.h = nn.ModuleList( self.layers = nn.ModuleList(
[ [
Block( Block(
prefix,
layer_id, layer_id,
config, config,
weights, weights,
@ -426,8 +427,8 @@ class FlashSantacoderModel(nn.Module):
prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon
) )
self.head_size = self.h[0].attn.head_size self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.h[0].attn.num_heads self.num_heads = self.layers[0].self_attn.num_heads
def forward( def forward(
self, self,
@ -446,7 +447,7 @@ class FlashSantacoderModel(nn.Module):
torch.distributed.all_reduce(hidden_states, group=self.process_group) torch.distributed.all_reduce(hidden_states, group=self.process_group)
residual = None residual = None
for i, layer in enumerate(self.h): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
residual, residual,
@ -464,11 +465,18 @@ class FlashSantacoderModel(nn.Module):
class FlashSantacoderForCausalLM(nn.Module): class FlashSantacoderForCausalLM(nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.transformer = FlashSantacoderModel(config, weights)
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
config.transpose = config.architectures[0].startswith("GPT2")
self.model = FlashSantacoderModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="transformer.wte", weights=weights config, prefix=f"{prefix}.wte", weights=weights
) )
def forward( def forward(
@ -485,7 +493,7 @@ class FlashSantacoderForCausalLM(nn.Module):
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer( hidden_states = self.model(
input_ids, input_ids,
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,

View File

@ -417,14 +417,14 @@ class Starcoder2Layer(nn.Module):
class Starcoder2Model(torch.nn.Module): class Starcoder2Model(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights prefix=f"{prefix}.embed_tokens", weights=weights
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
@ -437,7 +437,7 @@ class Starcoder2Model(torch.nn.Module):
] ]
) )
self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load( self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
prefix="model.norm", weights=weights, eps=config.norm_epsilon prefix=f"{prefix}.norm", weights=weights, eps=config.norm_epsilon
) )
self.gradient_checkpointing = False self.gradient_checkpointing = False
@ -489,10 +489,15 @@ class Starcoder2Model(torch.nn.Module):
class FlashStarcoder2ForCausalLM(torch.nn.Module): class FlashStarcoder2ForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.model = Starcoder2Model(config, weights) if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = Starcoder2Model(prefix, config, weights)
try: try:
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
@ -502,7 +507,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
except RuntimeError: except RuntimeError:
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="model.embed_tokens", prefix=f"{prefix}.embed_tokens",
weights=weights, weights=weights,
) )

View File

@ -136,7 +136,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
self.config = config self.config = config
config.text_config.quantize = config.quantize config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator config.text_config.speculator = config.speculator
self.language_model = load_text_model( self.text_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model", prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config, config=config.text_config,
weights=weights, weights=weights,
@ -180,7 +180,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
image_sizes: Optional[torch.LongTensor] = None, image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
): ):
inputs_embeds = self.language_model.embed_tokens(input_ids) inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0: if pixel_values is not None and len(pixel_values) > 0:
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum() # num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
@ -269,7 +269,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
input_ids, inputs_embeds, image_features input_ids, inputs_embeds, image_features
) )
hidden_states = self.language_model.model( hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
@ -283,5 +283,5 @@ class LlavaNextForConditionalGeneration(nn.Module):
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.language_model.lm_head(hidden_states) logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits return logits, speculative_logits

View File

@ -783,7 +783,7 @@ class MPTPreTrainedModel(PreTrainedModel):
class MPTModel(MPTPreTrainedModel): class MPTModel(MPTPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
# config._validate_config() # config._validate_config()
super().__init__(config) super().__init__(config)
self.world_size = weights.process_group.size() self.world_size = weights.process_group.size()
@ -809,13 +809,13 @@ class MPTModel(MPTPreTrainedModel):
f"Requested norm type ({config.norm_type}) is not implemented within this repo." f"Requested norm type ({config.norm_type}) is not implemented within this repo."
) )
self.wte = TensorParallelEmbedding("transformer.wte", weights) self.wte = TensorParallelEmbedding(f"{prefix}.wte", weights)
if not self.alibi: if not self.alibi:
self.wpe = TensorParallelEmbedding("transformer.wpe", weights) self.wpe = TensorParallelEmbedding(f"{prefix}.wpe", weights)
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [
MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights) MPTBlock(config, prefix=f"{prefix}.blocks.{i}", weights=weights)
for i in range(config.n_layers) for i in range(config.n_layers)
] ]
) )
@ -1085,13 +1085,19 @@ class MPTModel(MPTPreTrainedModel):
class MPTForCausalLM(MPTPreTrainedModel): class MPTForCausalLM(MPTPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__(config) super().__init__(config)
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
if not config.tie_word_embeddings: if not config.tie_word_embeddings:
raise ValueError("MPTForCausalLM only supports tied word embeddings") raise ValueError("MPTForCausalLM only supports tied word embeddings")
self.transformer = MPTModel(config, weights) self.transformer = MPTModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="transformer.wte", weights=weights config, prefix=f"{prefix}.wte", weights=weights
) )
self.logit_scale = None self.logit_scale = None
if config.logit_scale is not None: if config.logit_scale is not None:

View File

@ -404,24 +404,24 @@ class GPTNeoXMLP(nn.Module):
class GPTNeoXLayer(nn.Module): class GPTNeoXLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, layer_id, prefix: str, config, weights):
super().__init__() super().__init__()
self.use_parallel_residual = config.use_parallel_residual self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm.load( self.input_layernorm = nn.LayerNorm.load(
prefix=f"gpt_neox.layers.{layer_id}.input_layernorm", prefix=f"{prefix}.layers.{layer_id}.input_layernorm",
weights=weights, weights=weights,
eps=config.layer_norm_eps, eps=config.layer_norm_eps,
) )
self.post_attention_layernorm = nn.LayerNorm.load( self.post_attention_layernorm = nn.LayerNorm.load(
prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm", prefix=f"{prefix}.layers.{layer_id}.post_attention_layernorm",
weights=weights, weights=weights,
eps=config.layer_norm_eps, eps=config.layer_norm_eps,
) )
self.attention = GPTNeoXAttention( self.attention = GPTNeoXAttention(
config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights config, prefix=f"{prefix}.layers.{layer_id}.attention", weights=weights
) )
self.mlp = GPTNeoXMLP( self.mlp = GPTNeoXMLP(
config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights config, prefix=f"{prefix}.layers.{layer_id}.mlp", weights=weights
) )
def forward( def forward(
@ -472,23 +472,23 @@ class GPTNeoXLayer(nn.Module):
class GPTNeoXModel(GPTNeoXPreTrainedModel): class GPTNeoXModel(GPTNeoXPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.embed_in = TensorParallelEmbedding( self.embed_in = TensorParallelEmbedding(
prefix="gpt_neox.embed_in", weights=weights prefix=f"{prefix}.embed_in", weights=weights
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
GPTNeoXLayer(layer_id, config, weights) GPTNeoXLayer(layer_id, prefix, config, weights)
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.final_layer_norm = nn.LayerNorm.load( self.final_layer_norm = nn.LayerNorm.load(
prefix="gpt_neox.final_layer_norm", prefix=f"{prefix}.final_layer_norm",
weights=weights, weights=weights,
eps=config.layer_norm_eps, eps=config.layer_norm_eps,
) )
@ -640,9 +640,15 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__(config) super().__init__(config)
self.gpt_neox = GPTNeoXModel(config, weights)
if not prefix:
prefix = "gpt_neox"
else:
prefix = f"{prefix}.gpt_neox"
self.gpt_neox = GPTNeoXModel(prefix, config, weights)
self.embed_out = SpeculativeHead.load( self.embed_out = SpeculativeHead.load(
config, prefix="embed_out", weights=weights config, prefix="embed_out", weights=weights
) )

View File

@ -94,11 +94,11 @@ class OPTLearnedPositionalEmbedding(nn.Module):
This module learns positional embeddings up to a fixed maximum size. This module learns positional embeddings up to a fixed maximum size.
""" """
def __init__(self, weights): def __init__(self, prefix: str, weights):
super().__init__() super().__init__()
self.offset = 2 self.offset = 2
self.weight = nn.Parameter( self.weight = nn.Parameter(
weights.get_tensor("model.decoder.embed_positions.weight") weights.get_tensor(f"{prefix}.decoder.embed_positions.weight")
) )
def forward( def forward(
@ -311,11 +311,11 @@ class OPTAttention(nn.Module):
class OPTDecoderLayer(nn.Module): class OPTDecoderLayer(nn.Module):
def __init__(self, layer_id: int, config: OPTConfig, weights): def __init__(self, layer_id: int, prefix: str, config: OPTConfig, weights):
super().__init__() super().__init__()
self.process_group = weights.process_group self.process_group = weights.process_group
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
prefix = f"model.decoder.layers.{layer_id}" prefix = f"{prefix}.decoder.layers.{layer_id}"
self.self_attn = OPTAttention( self.self_attn = OPTAttention(
config, config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
@ -429,7 +429,7 @@ class OPTPreTrainedModel(PreTrainedModel):
class OPTDecoder(OPTPreTrainedModel): class OPTDecoder(OPTPreTrainedModel):
def __init__(self, config: OPTConfig, weights): def __init__(self, prefix: str, config: OPTConfig, weights):
super().__init__(config) super().__init__(config)
self.dropout = config.dropout self.dropout = config.dropout
self.layerdrop = config.layerdrop self.layerdrop = config.layerdrop
@ -438,20 +438,26 @@ class OPTDecoder(OPTPreTrainedModel):
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="model.decoder.embed_tokens", weights=weights prefix=f"{prefix}.decoder.embed_tokens", weights=weights
) )
self.embed_positions = OPTLearnedPositionalEmbedding(weights) self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights)
if config.word_embed_proj_dim != config.hidden_size: if config.word_embed_proj_dim != config.hidden_size:
self.project_out = FastLinear.load( self.project_out = FastLinear.load(
config, prefix="model.decoder.project_out", weights=weights, bias=False config,
prefix=f"{prefix}.decoder.project_out",
weights=weights,
bias=False,
) )
else: else:
self.project_out = None self.project_out = None
if config.word_embed_proj_dim != config.hidden_size: if config.word_embed_proj_dim != config.hidden_size:
self.project_in = FastLinear.load( self.project_in = FastLinear.load(
config, prefix="model.decoder.project_in", weights=weights, bias=False config,
prefix=f"{prefix}.decoder.project_in",
weights=weights,
bias=False,
) )
else: else:
self.project_in = None self.project_in = None
@ -461,14 +467,14 @@ class OPTDecoder(OPTPreTrainedModel):
# see https://github.com/facebookresearch/metaseq/pull/164 # see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm: if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm.load( self.final_layer_norm = nn.LayerNorm.load(
prefix="model.decoder.final_layer_norm", weights=weights, eps=EPS prefix=f"{prefix}.decoder.final_layer_norm", weights=weights, eps=EPS
) )
else: else:
self.final_layer_norm = None self.final_layer_norm = None
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
OPTDecoderLayer(layer_id, config, weights) OPTDecoderLayer(layer_id, prefix, config, weights)
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
@ -686,9 +692,9 @@ class OPTDecoder(OPTPreTrainedModel):
class OPTModel(OPTPreTrainedModel): class OPTModel(OPTPreTrainedModel):
def __init__(self, config: OPTConfig, weights): def __init__(self, prefix: str, config: OPTConfig, weights):
super().__init__(config) super().__init__(config)
self.decoder = OPTDecoder(config, weights) self.decoder = OPTDecoder(prefix, config, weights)
# Initialize weights and apply final processing # Initialize weights and apply final processing
def forward( def forward(
@ -743,13 +749,18 @@ class OPTModel(OPTPreTrainedModel):
class OPTForCausalLM(OPTPreTrainedModel): class OPTForCausalLM(OPTPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__(config) super().__init__(config)
self.model = OPTModel(config, weights) if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = OPTModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="model.decoder.embed_tokens", weights=weights config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights
) )
def forward( def forward(

View File

@ -248,16 +248,16 @@ class PhiBlock(nn.Module):
# PhiModel implements the embedding layer and the transformer blocks. # PhiModel implements the embedding layer and the transformer blocks.
class PhiModel(nn.Module): class PhiModel(nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.tp_rank = weights.process_group.rank() self.tp_rank = weights.process_group.rank()
self.tp_world_size = weights.process_group.size() self.tp_world_size = weights.process_group.size()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="transformer.embd.wte", weights=weights prefix=f"{prefix}.embd.wte", weights=weights
) )
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [
PhiBlock(f"transformer.h.{layer_id}", config, weights) PhiBlock(f"{prefix}.h.{layer_id}", config, weights)
for layer_id in range(config.n_layer) for layer_id in range(config.n_layer)
] ]
) )
@ -289,9 +289,15 @@ class PhiModel(nn.Module):
# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object. # PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.
class PhiForCausalLM(torch.nn.Module): class PhiForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.model = PhiModel(config, weights)
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
self.model = PhiModel(prefix, config, weights)
self.lm_head = PhiCausalLMHead(config, weights) self.lm_head = PhiCausalLMHead(config, weights)
def forward( def forward(

View File

@ -10,7 +10,12 @@ import numpy as np
from loguru import logger from loguru import logger
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import (
PreTrainedTokenizerBase,
AutoConfig,
AutoTokenizer,
GenerationConfig,
)
from typing import Iterable, Optional, Tuple, List, Type, Dict from typing import Iterable, Optional, Tuple, List, Type, Dict
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
@ -21,6 +26,12 @@ from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.dist import RANK from text_generation_server.utils.dist import RANK
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
hub,
)
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
Tokens, Tokens,
@ -798,29 +809,120 @@ class FlashCausalLMBatch(Batch):
return len(self.requests) return len(self.requests)
ADAPTER_LAYERS = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class FlashCausalLM(Model): class FlashCausalLM(Model):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
model: torch.nn.Module, model_class,
tokenizer: PreTrainedTokenizerBase, revision: Optional[str] = None,
num_layers: int, quantize: Optional[str] = None,
num_kv_heads: int, speculator: Optional[str] = None,
head_size: int, dtype: Optional[torch.dtype] = None,
dtype: torch.dtype, trust_remote_code: bool = False,
device: torch.device, lora_adapter_ids: Optional[list] = [],
rank: int = 0, tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
world_size: int = 1, config_class: PreTrainedTokenizerBase = AutoConfig,
sliding_window: Optional[int] = None, default_dtype=torch.float16,
aliases=None,
# Used for Santacoder override of config
num_kv_heads=None,
skip_special_tokens: bool = True,
): ):
self.num_layers = num_layers self.process_group, rank, world_size = initialize_torch_distributed()
self.num_kv_heads = num_kv_heads if torch.cuda.is_available():
self.head_size = head_size device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError(f"{model_class} is only available on GPU")
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
try:
generation_config = GenerationConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
if isinstance(generation_config.eos_token_id, (list, set)):
# TODO Huge hack
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
except Exception:
pass
config = config_class.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device, dtype, process_group=self.process_group, aliases=aliases
)
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
model = model_class(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
# VLM models define the config we care about in their text_config
text_config = getattr(config, "text_config", None)
if text_config is not None:
config = text_config
if getattr(config, "sliding_window", None) is not None:
set_sliding_window(config.sliding_window)
else:
config.sliding_window = None
self.num_layers = config.num_hidden_layers
# Validation is done in the model itself
if num_kv_heads is None:
# Order is important here.
for attr in ["num_key_value_heads", "num_attention_heads", "n_head"]:
num_kv_heads = getattr(config, attr, None)
if num_kv_heads is not None:
break
if num_kv_heads is None:
raise ValueError("Cannot get the number of key/value heads")
self.num_kv_heads = (
num_kv_heads // self.process_group.size()
if num_kv_heads > 1
else num_kv_heads
)
assert self.num_kv_heads > 0
self.head_size = config.hidden_size // config.num_attention_heads
self.cuda_graphs = {} self.cuda_graphs = {}
self.kv_cache = [] self.kv_cache = []
super(FlashCausalLM, self).__init__( super().__init__(
model_id=model_id, model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -829,7 +931,7 @@ class FlashCausalLM(Model):
device=device, device=device,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
sliding_window=sliding_window, sliding_window=config.sliding_window,
) )
@property @property
@ -1577,3 +1679,72 @@ class FlashCausalLM(Model):
forward_ns = start_decode - start forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns) return generations, batch, (forward_ns, decode_ns)
@property
def supports_adapter_loading(self) -> bool:
return True
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}
prefix = "model.layers"
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model.
if hasattr(self.model, "language_model"):
_model = self.model.language_model
elif hasattr(self.model, "text_model"):
_model = self.model.text_model
else:
_model = self.model
for i, layer in enumerate(_model.model.layers):
layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "k_proj")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "v_proj")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "o_proj")] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
# TODO: this is a hack to avoid the gate_proj for
# FlashStarcoder2 that doesnt have these layers
if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
return layer_weights
@property
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS
@property
def default_traced_adapter_layers(self) -> List[str]:
return ["q_proj", "v_proj"]
def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL

View File

@ -1,75 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional
from transformers import AutoTokenizer, AutoConfig
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
FlashCohereForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashCohere(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashCohere is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashCohereForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashCohere, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,100 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional
from transformers import AutoTokenizer
from transformers.models.gpt2 import GPT2TokenizerFast
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
FlashDbrxForCausalLM,
DbrxConfig,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashDbrx(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashDBRX is only available on GPU")
try:
tokenizer = GPT2TokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
)
except:
try:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
)
except:
# FIXME: change back to model id once the tokenizer.json is merged
tokenizer = GPT2TokenizerFast.from_pretrained(
"Xenova/dbrx-instruct-tokenizer",
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
)
config = DbrxConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashDbrxForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashDbrx, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,83 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional
from transformers import AutoConfig, AutoTokenizer
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashGemma(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashGemma is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
# TODO hardcoded
prefix = ""
model = FlashGemmaForCausalLM(prefix, config, weights, causal=True)
torch.distributed.barrier(group=self.process_group)
super(FlashGemma, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,83 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional
from transformers import PretrainedConfig, AutoTokenizer
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma2ForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashGemma2(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashGemma2 is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = PretrainedConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
# TODO hardcoded
prefix = ""
model = FlashGemma2ForCausalLM(prefix, config, weights, causal=True)
torch.distributed.barrier(group=self.process_group)
super(FlashGemma2, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,82 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
from transformers.models.gpt2 import GPT2Tokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
FlashGPT2ForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashGPT2(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashGPT2 is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
model = FlashGPT2ForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashGPT2, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,171 +0,0 @@
import os
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
from typing import Optional, Tuple, Dict, List
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
hub,
)
tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import SYSTEM
ADAPTER_LAYERS = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class FlashLlama(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
lora_adapter_ids: Optional[list] = [],
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashLlama is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
try:
generation_config = GenerationConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
if isinstance(generation_config.eos_token_id, (list, set)):
# TODO Huge hack
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
except Exception:
pass
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
model = FlashLlamaForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@property
def supports_adapter_loading(self) -> bool:
return True
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}
prefix = "model.layers"
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model.
if hasattr(self.model, "language_model"):
_model = self.model.language_model
elif hasattr(self.model, "text_model"):
_model = self.model.text_model
else:
_model = self.model
for i, layer in enumerate(_model.model.layers):
layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "k_proj")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "v_proj")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "o_proj")] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
return layer_weights
@property
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS
@property
def default_traced_adapter_layers(self) -> List[str]:
return ["q_proj", "v_proj"]
def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL

View File

@ -1,24 +1,7 @@
import torch import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, Tuple, Dict, List from typing import Optional, Tuple, Dict, List
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.flash_causal_lm import set_sliding_window
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
MistralConfig,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
ADAPTER_LAYERS = [ ADAPTER_LAYERS = [
@ -33,88 +16,7 @@ ADAPTER_LAYERS = [
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class BaseFlashMistral(FlashCausalLM): class FlashMistral(FlashCausalLM):
def __init__(
self,
model_cls,
model_id: str,
config_cls=AutoConfig,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
tokenizer_class=AutoTokenizer,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashMistral is only available on GPU")
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = config_cls.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
# Set context windows
if getattr(config, "sliding_window", None) is not None:
set_sliding_window(config.sliding_window)
else:
config.sliding_window = None
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
model = model_cls(prefix, config, weights)
self.cuda_graphs = {}
torch.distributed.barrier(group=self.process_group)
num_layers, num_kv_heads, head_size = self.get_layer_config(model)
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
sliding_window=config.sliding_window,
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.model.layers),
model.model.num_key_value_heads,
model.model.head_size,
)
@property @property
def supports_adapter_loading(self) -> bool: def supports_adapter_loading(self) -> bool:
return True return True
@ -126,9 +28,7 @@ class BaseFlashMistral(FlashCausalLM):
# This accounts for VLMs (e.g. LlavaNext, Idefics2) # This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model. # that have a language_model inside of the larger model.
if hasattr(self.model, "language_model"): if hasattr(self.model, "text_model"):
_model = self.model.language_model
elif hasattr(self.model, "text_model"):
_model = self.model.text_model _model = self.model.text_model
else: else:
_model = self.model _model = self.model
@ -183,25 +83,3 @@ class BaseFlashMistral(FlashCausalLM):
def is_row_parallel(self, layer_type: str) -> bool: def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL return layer_type in ROW_PARALLEL
class FlashMistral(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(FlashMistral, self).__init__(
config_cls=MistralConfig,
model_cls=FlashMistralForCausalLM,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

View File

@ -1,31 +0,0 @@
import torch
from typing import Optional
from text_generation_server.models.flash_mistral import BaseFlashMistral
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
MixtralConfig,
FlashMixtralForCausalLM,
)
class FlashMixtral(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(FlashMixtral, self).__init__(
config_cls=MixtralConfig,
model_cls=FlashMixtralForCausalLM,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

View File

@ -1,82 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
FlashGPTNeoXForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashNeoXSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashNeoX is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashGPTNeoXForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashNeoXSharded, self).__init__(
model_id=model_id,
model=model.to(device),
tokenizer=tokenizer,
num_layers=len(model.gpt_neox.layers),
num_kv_heads=model.gpt_neox.num_heads,
head_size=model.gpt_neox.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,111 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashPhi(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashPhi is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashPhiForCausalLM(config, weights)
if speculator:
from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download
import json
import os
from pathlib import Path
is_local_model = (
Path(speculator).exists() and Path(speculator).is_dir()
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model:
medusa_config = hf_hub_download(
speculator, revision=revision, filename="config.json"
)
medusa_head = hf_hub_download(
speculator, revision=revision, filename="medusa_lm_head.pt"
)
else:
medusa_config = str(Path(speculator) / "config.json")
medusa_head = str(Path(speculator) / "medusa_lm_head.pt")
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
weights = Weights(
[medusa_sf], device, dtype, process_group=self.process_group
)
lm_head = model.lm_head
model.lm_head = MedusaModel(config, weights, lm_head)
torch.distributed.barrier(group=self.process_group)
super(FlashPhi, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,93 +0,0 @@
import math
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig
from typing import Optional
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
set_sliding_window,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2ForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashQwen2(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashQwen2 is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
# Set context windows
if config.sliding_window is not None:
set_sliding_window(config.sliding_window)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = Qwen2ForCausalLM(config, weights)
self.cuda_graphs = {}
torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
sliding_window=config.sliding_window,
)

View File

@ -1,91 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
RWConfig,
FlashRWForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashRWSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashRW is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = RWConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
aliases={
"lm_head.weight": ["transformer.word_embeddings.weight"],
"transformer.word_embeddings.weight": ["lm_head.weight"],
},
)
config.quantize = quantize
config.speculator = speculator
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashRWForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashRWSharded, self).__init__(
model_id=model_id,
model=model.to(device),
tokenizer=tokenizer,
num_layers=len(model.transformer.h),
num_kv_heads=model.transformer.cache_size,
head_size=model.transformer.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,99 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, List
import json
import os
from huggingface_hub import hf_hub_download
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
FlashSantacoderForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashSantacoderSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=True,
)
config.quantize = quantize
config.speculator = speculator
config.transpose = config.architectures[0].startswith("GPT2")
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashSantacoderForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashSantacoderSharded, self).__init__(
model_id=model_id,
model=model.to(device),
tokenizer=tokenizer,
num_layers=len(model.transformer.h),
num_kv_heads=1,
head_size=model.transformer.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)

View File

@ -1,84 +0,0 @@
import math
import torch
from typing import Optional
from transformers.models.gpt2 import GPT2TokenizerFast
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
set_sliding_window,
)
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
Starcoder2Config,
FlashStarcoder2ForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
# Starcoder2 has the same base as Mistral
class FlashStarcoder2(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashStarcoder2 is only available on GPU")
tokenizer = GPT2TokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = Starcoder2Config.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
# Set context windows
if config.sliding_window is not None:
set_sliding_window(config.sliding_window)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashStarcoder2ForCausalLM(config, weights)
self.cuda_graphs = {}
torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
sliding_window=config.sliding_window,
)

View File

@ -162,83 +162,3 @@ class GalacticaCausalLMBatch(CausalLMBatch):
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
max_tokens=max_tokens, max_tokens=max_tokens,
) )
class GalacticaSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
tokenizer.pad_token_id = config.pad_token_id
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = OPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return GalacticaCausalLMBatch
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, speculative_logits, outputs.past_key_values

View File

@ -1,89 +0,0 @@
import torch
import torch.distributed
from typing import Optional
from transformers import (
AutoTokenizer,
AutoConfig,
)
from text_generation_server.models import CausalLM
from text_generation_server.models.custom_modeling.neox_modeling import (
GPTNeoxForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class GPTNeoxSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.pad_token = tokenizer.eos_token
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = GPTNeoxForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, speculative_logits, outputs.past_key_values

View File

@ -1,51 +0,0 @@
import torch
from typing import Optional, Tuple
from transformers import (
AutoProcessor,
)
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
class Idefics2(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
size={"longest_edge": 448, "shortest_edge": 378},
)
super().__init__(
model_cls=Idefics2ForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.text_model.model.layers),
model.text_model.model.num_key_value_heads,
model.text_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)

View File

@ -1,46 +0,0 @@
import torch
from typing import Optional, Tuple
from transformers import (
AutoProcessor,
)
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
class LlavaNext(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
super().__init__(
model_cls=LlavaNextForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.language_model.model.layers),
model.language_model.model.num_key_value_heads,
model.language_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.language_model, "max_past", None)

View File

@ -63,7 +63,7 @@ class Model(ABC):
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict( self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
LayerAdapterWeights LayerAdapterWeights
) )
self.target_to_layer = self.adapter_target_to_layer() self.target_to_layer = None
self.loaded_adapters = set() self.loaded_adapters = set()
self.static_adapter_id = adapter_id self.static_adapter_id = adapter_id
@ -206,6 +206,8 @@ class Model(ABC):
into model. Otherwise, the adapter weights are applied during the forward into model. Otherwise, the adapter weights are applied during the forward
pass and stored separately from the base model parameters. pass and stored separately from the base model parameters.
""" """
if self.target_to_layer is None:
self.target_to_layer = self.adapter_target_to_layer()
if adapter_index in self.loaded_adapters: if adapter_index in self.loaded_adapters:
# Adapter already loaded # Adapter already loaded
return return

View File

@ -1,105 +0,0 @@
import torch
import torch.distributed
from pathlib import Path
from typing import Optional, Type
from opentelemetry import trace
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
from huggingface_hub import hf_hub_download
import json
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.models.custom_modeling.mpt_modeling import (
MPTForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class MPTCausalLMBatch(CausalLMBatch):
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
batch.keys_head_dim_last = False
return batch
class MPTSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.pad_token = tokenizer.eos_token
# If model_id is a local path, load the file directly
local_path = Path(model_id, "config.json")
if local_path.exists():
filename = str(local_path.resolve())
else:
filename = hf_hub_download(
model_id, revision=revision, filename="config.json"
)
with open(filename, "r") as f:
config = json.load(f)
config = PretrainedConfig(**config)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
config.quantize = quantize
model = MPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return MPTCausalLMBatch

View File

@ -1,86 +0,0 @@
import torch
import torch.distributed
from typing import Optional
from transformers import (
AutoTokenizer,
AutoConfig,
)
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models import CausalLM
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class OPTSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
tokenizer.pad_token_id = config.pad_token_id
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = OPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, speculative_logits, outputs.past_key_values

View File

@ -74,45 +74,3 @@ class PaliGemmaBatch(VlmCausalLMBatch):
else: else:
image_inputs = None image_inputs = None
return batch_tokenized_inputs, image_inputs return batch_tokenized_inputs, image_inputs
class PaliGemma(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
super().__init__(
config_cls=AutoConfig,
model_cls=PaliGemmaForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@property
def batch_type(self):
return PaliGemmaBatch
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.text_model.model.layers),
model.text_model.model.num_key_value_heads,
model.text_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)

View File

@ -1,69 +0,0 @@
import torch
import torch.distributed
from transformers import AutoConfig, AutoTokenizer
from typing import Optional, List, Tuple
from text_generation_server.models import CausalLM
from text_generation_server.models.custom_modeling.phi_modeling import (
PhiConfig,
PhiForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class Phi(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, _rank, _world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = PhiConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
tokenizer.bos_token_id = config.bos_token_id
tokenizer.eos_token_id = config.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
model = PhiForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)

View File

@ -1,84 +0,0 @@
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Optional, Tuple
from text_generation_server.models import CausalLM
class RW(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
if speculator:
raise RuntimeError("Medusa decoding is not enabled for AutoModel")
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
elif model.config.eos_token_id is not None:
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
# Model Forward
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, speculative_logits, outputs.past_key_values

View File

@ -1,77 +0,0 @@
import torch
import torch.distributed
from typing import Optional, List
from transformers import AutoTokenizer, AutoModelForCausalLM
from text_generation_server.models import CausalLM
FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>"
FIM_SUFFIX = "<fim-suffix>"
FIM_PAD = "<fim-pad>"
EOD = "<|endoftext|>"
class SantaCoder(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.add_special_tokens(
{
"additional_special_tokens": [
EOD,
FIM_PREFIX,
FIM_MIDDLE,
FIM_SUFFIX,
FIM_PAD,
],
"pad_token": EOD,
}
)
with device:
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)

View File

@ -1,11 +1,22 @@
import torch import torch
import torch.distributed
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
PreTrainedTokenizerBase,
AutoConfig,
)
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models import Model from text_generation_server.models import Model
@ -531,6 +542,80 @@ class Seq2SeqLM(Model):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
model_class,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
default_dtype=torch.float16,
trust_remote_code: bool = False,
config_class=AutoConfig,
tokenizer_class=AutoTokenizer,
aliases=None,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
config = config_class.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.bos_token_id = config.decoder_start_token_id
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
aliases=aliases,
)
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = model_class(config, weights)
torch.distributed.barrier(group=self.process_group)
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@classmethod
def fallback(
cls,
model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
speculator: Optional[str] = None, speculator: Optional[str] = None,
@ -574,7 +659,11 @@ class Seq2SeqLM(Model):
) )
tokenizer.bos_token_id = model.config.decoder_start_token_id tokenizer.bos_token_id = model.config.decoder_start_token_id
super(Seq2SeqLM, self).__init__( self = cls.__new__(
cls,
)
super().__init__(
self,
model_id=model_id, model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -582,16 +671,12 @@ class Seq2SeqLM(Model):
dtype=dtype, dtype=dtype,
device=device, device=device,
) )
return self
@property @property
def batch_type(self) -> Type[Seq2SeqLMBatch]: def batch_type(self) -> Type[Seq2SeqLMBatch]:
return Seq2SeqLMBatch return Seq2SeqLMBatch
def decode(self, decoder_ids: List[int]) -> str:
return self.tokenizer.decode(
decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
def forward( def forward(
self, self,
input_ids, input_ids,

View File

@ -1,115 +0,0 @@
import torch
import torch.distributed
from typing import List, Optional, Tuple
from transformers import (
AutoTokenizer,
AutoConfig,
)
from text_generation_server.models import Seq2SeqLM
from text_generation_server.models.custom_modeling.t5_modeling import (
T5ForConditionalGeneration,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class T5Sharded(Seq2SeqLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.bos_token_id = config.decoder_start_token_id
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
aliases={
"shared.weight": [
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
},
)
model = T5ForConditionalGeneration(config, weights)
torch.distributed.barrier(group=self.process_group)
super(Seq2SeqLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
def forward(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask: Optional,
encoder_last_hidden_state: Optional,
past_key_values: Optional = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
]:
# Model Forward
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_last_hidden_state,
past_key_values=past_key_values,
use_cache=True,
)
return (
outputs.logits,
speculative_logits,
outputs.encoder_last_hidden_state,
outputs.past_key_values,
)

View File

@ -9,10 +9,11 @@ from typing import Iterable, Optional, Tuple, List, Type, Dict
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import select_best_resolution from transformers.image_processing_utils import select_best_resolution
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch from text_generation_server.models.flash_causal_lm import (
from text_generation_server.models.flash_mistral import ( FlashCausalLMBatch,
BaseFlashMistral, FlashCausalLM,
) )
from transformers import AutoProcessor
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -239,10 +240,35 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
return batch return batch
class VlmCausalLM(BaseFlashMistral): class VlmCausalLM(FlashCausalLM):
def __init__(
self,
model_id: str,
*,
processor_class=AutoProcessor,
processor_kwargs=None,
batch_class=VlmCausalLMBatch,
revision,
trust_remote_code: bool,
**kwargs,
):
if processor_kwargs is None:
processor_kwargs = {}
self.processor = processor_class.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
**processor_kwargs,
)
self.batch_class = batch_class
super().__init__(model_id=model_id, **kwargs)
@property @property
def batch_type(self) -> Type[VlmCausalLMBatch]: def batch_type(self) -> Type[VlmCausalLMBatch]:
return VlmCausalLMBatch return self.batch_class
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)
def forward( def forward(
self, self,

View File

@ -1,6 +1,8 @@
import subprocess import subprocess
import argparse import argparse
import ast import ast
import json
import os
TEMPLATE = """ TEMPLATE = """
# Supported Models and Hardware # Supported Models and Hardware
@ -122,6 +124,53 @@ def check_supported_models(check: bool):
f.write(final_doc) f.write(final_doc)
def get_openapi_schema():
try:
output = subprocess.check_output(["text-generation-router", "print-schema"])
return json.loads(output)
except subprocess.CalledProcessError as e:
print(f"Error running text-generation-router print-schema: {e}")
raise SystemExit(1)
except json.JSONDecodeError:
print("Error: Invalid JSON received from text-generation-router print-schema")
raise SystemExit(1)
def check_openapi(check: bool):
new_openapi_data = get_openapi_schema()
filename = "docs/openapi.json"
tmp_filename = "openapi_tmp.json"
with open(tmp_filename, "w") as f:
json.dump(new_openapi_data, f, indent=2)
if check:
diff = subprocess.run(
[
"diff",
# allow for trailing whitespace since it's not significant
# and the precommit hook will remove it
"--ignore-trailing-space",
tmp_filename,
filename,
],
capture_output=True,
).stdout.decode()
os.remove(tmp_filename)
if diff:
print(diff)
raise Exception(
"OpenAPI documentation is not up-to-date, run `python update_doc.py` in order to update it"
)
return True
else:
os.rename(tmp_filename, filename)
print("OpenAPI documentation updated.")
return True
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--check", action="store_true") parser.add_argument("--check", action="store_true")
@ -130,6 +179,7 @@ def main():
check_cli(args.check) check_cli(args.check)
check_supported_models(args.check) check_supported_models(args.check)
check_openapi(args.check)
if __name__ == "__main__": if __name__ == "__main__":