Merge branch 'huggingface:main' into tylertitsworth/numba-cache-fix

This commit is contained in:
Tyler Titsworth 2024-09-17 09:00:09 -07:00 committed by GitHub
commit fe920d5c76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
109 changed files with 10344 additions and 2604 deletions

View File

@ -32,10 +32,6 @@ jobs:
permissions:
contents: write
packages: write
# This is used to complete the identity challenge
# with sigstore/fulcio when running outside of PRs.
id-token: write
security-events: write
steps:
- name: Checkout repository
uses: actions/checkout@v4

View File

@ -39,6 +39,9 @@ jobs:
matrix:
hardware: ["cuda", "rocm", "intel-xpu", "intel-cpu"]
uses: ./.github/workflows/build.yaml # calls the one above ^
permissions:
contents: write
packages: write
with:
hardware: ${{ matrix.hardware }}
# https://github.com/actions/runner/issues/2206

41
.github/workflows/nix_tests.yaml vendored Normal file
View File

@ -0,0 +1,41 @@
name: "Nix Tests"
on:
pull_request:
paths:
- ".github/workflows/nix_tests.yaml"
- "server/**"
- "proto/**"
- "router/**"
- "launcher/**"
- "Cargo.lock"
- "rust-toolchain.toml"
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
tests:
runs-on:
group: aws-highmemory-32-plus-priv
steps:
- uses: actions/checkout@v4
- uses: cachix/install-nix-action@v27
with:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
with:
name: text-generation-inference
# If you chose signing key for write access
authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
env:
USER: github_runner
- name: Build
run: nix develop .#test --command echo "Ok"
- name: Pre-commit tests.
run: nix develop .#test --command pre-commit run --all-files
- name: Python tests.
run: nix develop .#test --command python -m pytest server/tests/
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
- name: Rust tests.
run: nix build .#checks.$(nix eval --impure --raw --expr 'builtins.currentSystem').rust -L

View File

@ -17,25 +17,21 @@ concurrency:
jobs:
run_tests:
runs-on: ubuntu-latest
env:
SCCACHE_GHA_ENABLED: "on"
RUSTC_WRAPPER: /usr/local/bin/sccache
SCCACHE: 0.3.3
runs-on:
group: aws-highmemory-32-plus-priv
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v1
uses: actions/setup-python@v4
id: python
with:
python-version: 3.9
python-version: 3.11
- name: Install Rust
uses: actions-rs/toolchain@v1
with:
# Released on: 02 May, 2024
# https://releases.rs/docs/1.78.0/
toolchain: 1.79.0
toolchain: 1.80.0
override: true
components: rustfmt, clippy
- name: Install Protoc
@ -44,30 +40,9 @@ jobs:
run: |
sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android
sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET
- name: Install sccache
run: |
curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache
chmod +x /usr/local/bin/sccache
- name: configure sccache
uses: actions/github-script@v6
with:
script: |
core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || '');
core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
core.exportVariable('SCCACHE_GHA_CACHE_TO', 'sccache-${{runner.os}}-${{github.ref_name}}');
core.exportVariable('SCCACHE_GHA_CACHE_FROM', 'sccache-${{runner.os}}-main,sccache-${{runner.os}}-');
- name: cargo registry cache
uses: actions/cache@v3
with:
key: cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}-${{ github.sha }}
restore-keys: |
cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}-
cargo-${{ runner.os }}-
path: |
~/.cargo/registry
~/.cargo/git
- name: Install
run: |
sudo apt install python3.11-dev -y
make install-cpu
- name: Run server tests
run: |
@ -82,6 +57,3 @@ jobs:
- name: Run Rust tests
run: |
cargo test
- name: sccache stats
run: |
/usr/local/bin/sccache --show-stats

3
.gitignore vendored
View File

@ -19,3 +19,6 @@ server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp
data/
load_tests/*.json
server/fbgemmm
.direnv/
.venv/

View File

@ -77,3 +77,4 @@ docs/openapi.json:
- '#/paths/~1tokenize/post'
- '#/paths/~1v1~1chat~1completions/post'
- '#/paths/~1v1~1completions/post'
- '#/paths/~1v1~1models/get'

787
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -25,10 +25,12 @@ homepage = "https://github.com/huggingface/text-generation-inference"
[workspace.dependencies]
base64 = "0.22.0"
tokenizers = { version = "0.19.1", features = ["http"] }
tokenizers = { version = "0.20.0", features = ["http"] }
hf-hub = { version = "0.3.1", features = ["tokio"] }
metrics = { version = "0.23.0" }
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
minijinja = { version = "2.2.0", features = ["json"] }
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
[profile.release]
incremental = true

View File

@ -1,5 +1,5 @@
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -13,10 +13,13 @@ COPY benchmark benchmark
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3.11-dev
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
@ -37,6 +40,7 @@ COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo build --profile release-opt
RUN cargo build --profile release-opt
# Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
@ -45,7 +49,7 @@ FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
ARG PYTORCH_VERSION=2.4.0
ARG PYTHON_VERSION=3.10
ARG PYTHON_VERSION=3.11
# Keep in sync with `server/pyproject.toml
ARG CUDA_VERSION=12.4
ARG MAMBA_VERSION=24.3.0-0
@ -184,6 +188,12 @@ WORKDIR /usr/src
COPY server/Makefile-selective-scan Makefile
RUN make build-all
# Build flashinfer
FROM kernel-builder AS flashinfer-builder
WORKDIR /usr/src
COPY server/Makefile-flashinfer Makefile
RUN make install-flashinfer
# Text Generation Inference base image
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS base
@ -210,32 +220,33 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
COPY --from=pytorch-install /opt/conda /opt/conda
# Copy build artifacts from flash attention builder
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /opt/conda/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so /opt/conda/lib/python3.10/site-packages
COPY --from=flash-att-v2-builder /opt/conda/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from exllama kernels builder
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from awq kernels builder
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from eetq kernels builder
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from lorax punica kernels builder
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from fbgemm builder
COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages
COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.11/cmake-install /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from mamba builder
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /opt/conda/lib/python3.11/site-packages
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /opt/conda/lib/python3.11/site-packages
COPY --from=flashinfer-builder /opt/conda/lib/python3.11/site-packages/flashinfer/ /opt/conda/lib/python3.11/site-packages/flashinfer/
# Install flash-attention dependencies
RUN pip install einops --no-cache-dir
@ -250,7 +261,9 @@ RUN cd server && \
pip install ".[bnb, accelerate, marlin, quantize, peft, outlines]" --no-cache-dir && \
pip install nvidia-nccl-cu12==2.22.3
ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
# Required to find libpython within the rust binaries
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
# This is needed because exl2 tries to load flash-attn
# And fails with our builds.
ENV EXLLAMA_NO_FLASH_ATTN=1

View File

@ -1,5 +1,5 @@
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -17,6 +17,8 @@ RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3.11-dev
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
@ -64,14 +66,14 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
hipsolver-dev \
rccl-dev \
cmake \
python3-dev && \
python3.11-dev && \
rm -rf /var/lib/apt/lists/*
# Keep in sync with `server/pyproject.toml
ARG MAMBA_VERSION=23.1.0-1
ARG PYTORCH_VERSION='2.3.0'
ARG ROCM_VERSION='6.0.2'
ARG PYTHON_VERSION='3.10.10'
ARG PYTHON_VERSION='3.11.10'
# Automatically set by buildx
ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH
@ -89,10 +91,18 @@ RUN chmod +x ~/mambaforge.sh && \
mamba init && \
rm ~/mambaforge.sh
# RUN conda install intel::mkl-static intel::mkl-include
# Install pytorch
# On arm64 we exit with an error code
RUN case ${TARGETPLATFORM} in \
"linux/arm64") exit 1 ;; \
*) /opt/conda/bin/conda update -y conda && \
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
esac && \
/opt/conda/bin/conda clean -ya
# Install flash-attention, torch dependencies
RUN pip install numpy einops ninja --no-cache-dir
RUN conda install intel::mkl-static intel::mkl-include
RUN pip uninstall -y triton && \
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
cd triton/python && \
@ -172,19 +182,19 @@ ENV HF_HOME=/data \
PORT=80
# Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from exllama kernels builder
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Install server
COPY proto proto
@ -201,6 +211,7 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/l
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
# AWS Sagemaker compatible image
FROM base AS sagemaker

View File

@ -1,6 +1,6 @@
ARG PLATFORM=xpu
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -18,6 +18,8 @@ RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3.11-dev
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
@ -42,9 +44,35 @@ RUN cargo build --profile release-opt
# Text Generation Inference base image for Intel
FROM intel/intel-extension-for-pytorch:2.1.30-xpu AS xpu
FROM intel/intel-extension-for-pytorch:2.3.110-xpu AS xpu
USER root
ARG MAMBA_VERSION=23.1.0-1
ARG PYTHON_VERSION='3.11.10'
# Automatically set by buildx
ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba
# translating Docker's TARGETPLATFORM into mamba arches
RUN case ${TARGETPLATFORM} in \
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
*) MAMBA_ARCH=x86_64 ;; \
esac && \
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
RUN chmod +x ~/mambaforge.sh && \
bash ~/mambaforge.sh -b -p /opt/conda && \
rm ~/mambaforge.sh
RUN case ${TARGETPLATFORM} in \
"linux/arm64") exit 1 ;; \
*) /opt/conda/bin/conda update -y conda && \
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
esac && \
/opt/conda/bin/conda clean -ya
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
@ -54,7 +82,7 @@ RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dea
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build pciutils
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y intel-basekit xpu-smi cmake ninja-build pciutils
# Text Generation Inference base env
ENV HF_HOME=/data \
@ -64,9 +92,7 @@ ENV HF_HOME=/data \
WORKDIR /usr/src
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
RUN pip install https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed
RUN pip install torch==2.3.1+cxx11.abi torchvision==0.18.1+cxx11.abi torchaudio==2.3.1+cxx11.abi intel-extension-for-pytorch==2.3.110+xpu oneccl_bind_pt==2.3.100+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --no-cache-dir
# Install server
COPY proto proto
@ -81,14 +107,12 @@ ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest
ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest
ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric
ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:
ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:/opt/conda/lib
ENV PATH=/opt/conda/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
ENV CCL_ZE_IPC_EXCHANGE=sockets
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
RUN pip uninstall -y intel-extension-for-pytorch && cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
@ -116,7 +140,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
NUMBA_CACHE_DIR=/data/numba_cache
ARG MAMBA_VERSION=23.1.0-1
ARG PYTHON_VERSION='3.10.10'
ARG PYTHON_VERSION='3.11.10'
# Automatically set by buildx
ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH
@ -133,12 +157,19 @@ RUN chmod +x ~/mambaforge.sh && \
bash ~/mambaforge.sh -b -p /opt/conda && \
rm ~/mambaforge.sh
RUN case ${TARGETPLATFORM} in \
"linux/arm64") exit 1 ;; \
*) /opt/conda/bin/conda update -y conda && \
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
esac && \
/opt/conda/bin/conda clean -ya
RUN conda install -c conda-forge gperftools mkl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
RUN pip install triton numa
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install triton py-libnuma
WORKDIR /usr/src
@ -151,10 +182,11 @@ RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update
RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so
ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib
ENV CCL_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch
ENV I_MPI_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/lib
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
# Install server
COPY proto proto
@ -173,5 +205,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
FROM ${PLATFORM} AS final
ENV ATTENTION=paged
ENV USE_PREFIX_CACHING=0
ENV CUDA_GRAPHS=0
ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]

View File

@ -189,6 +189,8 @@ overridden with the `--otlp-service-name` argument
![TGI architecture](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/TGI.png)
Detailed blogpost by Adyen on TGI inner workings: [LLM inference at scale with TGI (Martin Iglesias Goyanes - Adyen, 2024)](https://www.adyen.com/knowledge-hub/llm-inference-at-scale-with-tgi)
### Local install
You can also opt to install `text-generation-inference` locally.

View File

@ -153,6 +153,8 @@ impl Client {
}),
// We truncate the input on the server side to be sure that it has the correct size
truncate,
// Most request will have that
add_special_tokens: true,
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
slots: vec![],

View File

@ -221,6 +221,7 @@ impl Health for ShardedClient {
chunks: vec![Chunk::Text("liveness".into()).into()],
}),
truncate: 10,
add_special_tokens: true,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,

View File

@ -53,8 +53,8 @@ utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
init-tracing-opentelemetry = { version = "0.14.1", features = [
"opentelemetry-otlp",
] }
minijinja = { version = "2.0.2" }
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
minijinja = { workspace = true }
minijinja-contrib = { workspace = true }
futures-util = "0.3.30"
regex = "1.10.3"
once_cell = "1.19.0"

View File

@ -35,27 +35,15 @@ impl BackendV3 {
window_size: Option<u32>,
speculate: u32,
) -> Self {
let prefix_caching = if let Ok(prefix_caching) = std::env::var("USE_PREFIX_CACHING") {
matches!(prefix_caching.as_str(), "true" | "1")
} else {
false
};
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
attention
let prefix_caching =
std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var");
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
let attention: String = std::env::var("ATTENTION").expect("attention env var");
let attention: Attention = attention
.parse()
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
} else if prefix_caching {
Attention::FlashInfer
} else {
Attention::Paged
};
let block_size = if attention == Attention::FlashDecoding {
256
} else if attention == Attention::FlashInfer {
1
} else {
16
};
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
let block_size = attention.block_size();
let queue = Queue::new(
requires_padding,
@ -180,6 +168,8 @@ pub(crate) async fn batching_task(
None
} else {
// Minimum batch size
// TODO: temporarily disable to avoid incorrect deallocation +
// reallocation when using prefix caching.
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
};
@ -386,10 +376,9 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
// Send generation responses back to the infer task
// If the receive an error from the Flume channel, it means that the client dropped the
// request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| {
let stopped = send_responses(generation, entry).inspect_err(|_err| {
tracing::error!("Entry response channel error.");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
err
}).unwrap_or(true);
if stopped {
entries.remove(&id).expect("ID not found in entries. This is a bug.");

View File

@ -1,4 +1,4 @@
use std::{cmp::min, sync::Arc};
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};
use crate::radix::RadixAllocator;
@ -137,7 +137,6 @@ pub trait Allocator {
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
}
pub struct SimpleAllocator {
free_blocks: Vec<u32>,
block_size: u32,
@ -167,7 +166,7 @@ impl Allocator for SimpleAllocator {
None => (tokens, 1),
Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size;
let tokens = min(tokens, window_size);
let tokens = core::cmp::min(tokens, window_size);
(tokens, repeats as usize)
}
};

View File

@ -149,6 +149,7 @@ impl Client {
requests.push(Request {
id: 0,
inputs,
add_special_tokens: true,
input_chunks: Some(Input {
chunks: input_chunks,
}),

View File

@ -222,6 +222,7 @@ impl Health for ShardedClient {
chunks: vec![Chunk::Text("liveness".into()).into()],
}),
truncate: 10,
add_special_tokens: true,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,

View File

@ -252,17 +252,14 @@ impl State {
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
next_batch_span.follows_from(Span::current());
let mut batch_requests = Vec::with_capacity(self.entries.len());
let mut batch_entries =
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
let mut batch = Vec::with_capacity(self.entries.len());
let mut max_input_length = 0;
let mut prefill_tokens: u32 = 0;
let mut decode_tokens: u32 = 0;
let mut max_blocks = 0;
// Pop entries starting from the front of the queue
'entry_loop: while let Some((id, mut entry)) = self.entries.pop_front() {
'entry_loop: while let Some((id, entry)) = self.entries.pop_front() {
// Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client)
if entry.response_tx.is_closed() {
@ -276,7 +273,7 @@ impl State {
// We pad to max input length in the Python shards
// We need to take these padding tokens into the equation
max_input_length = max_input_length.max(entry.request.input_length);
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length;
prefill_tokens = (batch.len() + 1) as u32 * max_input_length;
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
@ -290,7 +287,7 @@ impl State {
}
None
}
Some(block_allocator) => {
Some(_block_allocator) => {
prefill_tokens += entry.request.input_length;
let max_new_tokens = match self.window_size {
None => entry.request.stopping_parameters.max_new_tokens,
@ -324,13 +321,50 @@ impl State {
entry.request.input_ids.clone()
};
Some((tokens, input_ids))
}
};
batch.push((id, entry, block_allocation));
if Some(batch.len()) == max_size {
break;
}
}
// Empty batch
if batch.is_empty() {
tracing::debug!("Filterered out all entries");
return None;
}
// XXX We haven't allocated yet, so we're allowed to ditch the results.
// Check if our batch is big enough
if let Some(min_size) = min_size {
// Batch is too small
if batch.len() < min_size {
// Add back entries to the queue in the correct order
for (id, entry, _) in batch.into_iter().rev() {
self.entries.push_front((id, entry));
}
return None;
}
}
let mut batch_requests = Vec::with_capacity(self.entries.len());
let mut batch_entries =
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
for (id, mut entry, block_allocation) in batch {
let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) =
(block_allocation, &self.block_allocator)
{
tracing::debug!("Allocating {tokens} with {input_ids:?}");
match block_allocator.allocate(tokens, input_ids).await {
None => {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: not enough free blocks");
self.entries.push_front((id, entry));
break 'entry_loop;
continue;
}
Some(block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}");
@ -338,9 +372,9 @@ impl State {
Some(block_allocation)
}
}
}
} else {
None
};
tracing::debug!("Accepting entry");
// Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer");
@ -383,6 +417,7 @@ impl State {
}),
inputs: entry.request.inputs.chunks_to_string(),
truncate: entry.request.truncate,
add_special_tokens: entry.request.add_special_tokens,
parameters: Some(NextTokenChooserParameters::from(
entry.request.parameters.clone(),
)),
@ -399,11 +434,6 @@ impl State {
entry.batch_time = Some(Instant::now());
// Insert in batch_entries IntMap
batch_entries.insert(id, entry);
// Check if max_size
if Some(batch_requests.len()) == max_size {
break;
}
}
// Empty batch
@ -412,21 +442,6 @@ impl State {
return None;
}
// Check if our batch is big enough
if let Some(min_size) = min_size {
// Batch is too small
if batch_requests.len() < min_size {
// Add back entries to the queue in the correct order
for r in batch_requests.into_iter().rev() {
let id = r.id;
let entry = batch_entries.remove(&id).unwrap();
self.entries.push_front((id, entry));
}
return None;
}
}
// Final batch size
let size = batch_requests.len() as u32;
next_batch_span.record("batch_size", size);
@ -517,6 +532,7 @@ mod tests {
inputs: vec![],
input_ids: Some(Arc::new(vec![])),
input_length: 0,
add_special_tokens: true,
truncate: 0,
decoder_input_details: false,
parameters: ValidParameters {

View File

@ -1,11 +1,21 @@
use crate::block_allocator::{Allocator, BlockAllocation};
use slotmap::{DefaultKey, SlotMap};
use std::hash::{Hash, Hasher};
use std::{
collections::{BTreeSet, HashMap},
sync::Arc,
};
use slotmap::{DefaultKey, SlotMap};
use crate::block_allocator::{Allocator, BlockAllocation};
fn hash(slice: &[u32]) -> u64 {
assert!(!slice.is_empty());
if slice.len() == 1 {
slice[0] as u64
} else {
let mut s = std::hash::DefaultHasher::new();
slice.hash(&mut s);
s.finish()
}
}
pub struct RadixAllocator {
allocation_id: u64,
@ -16,26 +26,26 @@ pub struct RadixAllocator {
/// Blocks that are immediately available for allocation.
free_blocks: Vec<u32>,
#[allow(dead_code)]
// This isn't used because the prefix need to match without the windowing
// mecanism. This at worst is overallocating, not necessarily being wrong.
window_size: Option<u32>,
block_size: u32,
}
impl RadixAllocator {
pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self {
assert_eq!(
block_size, 1,
"Radix tree allocator only works with block_size=1, was: {}",
block_size
);
if window_size.is_some() {
unimplemented!("Window size not supported in the prefix-caching block allocator yet");
}
RadixAllocator {
allocation_id: 0,
allocations: HashMap::new(),
cache_blocks: RadixTrie::new(),
cache_blocks: RadixTrie::new(block_size as usize),
// Block 0 is reserved for health checks.
free_blocks: (1..n_blocks).collect(),
window_size,
block_size,
}
}
@ -46,6 +56,10 @@ impl RadixAllocator {
// the free list if we cannot allocate enough blocks. This is only
// temporary, the trie needs to be able to report whether it can
// allocate the requested amount. Just not implemented yet.
tracing::debug!(
"Free blocks {} need {n_blocks_needed}",
self.free_blocks.len()
);
self.free_blocks.extend(
self.cache_blocks
.evict(n_blocks_needed - self.free_blocks.len()),
@ -63,6 +77,7 @@ impl RadixAllocator {
}
}
// Allocator trait
impl Allocator for RadixAllocator {
fn allocate(
&mut self,
@ -74,24 +89,30 @@ impl Allocator for RadixAllocator {
let node_id = self
.cache_blocks
.find(prefill_tokens.as_slice(), &mut blocks);
// Even if this allocation fails below, we need to increase he
// refcount to ensure that the prefix that was found is not evicted.
node_id
} else {
self.cache_blocks.root_id()
};
// Even if this allocation fails below, we need to increase he
// refcount to ensure that the prefix that was found is not evicted.
self.cache_blocks
.incref(prefix_node)
.expect("Failed to increment refcount");
let prefix_len = blocks.len();
let prefix_len = blocks.len() * self.block_size as usize;
let suffix_len = tokens - prefix_len as u32;
match self.alloc_or_reclaim(suffix_len as usize) {
let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size;
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");
match self.alloc_or_reclaim(suffix_blocks as usize) {
Some(suffix_blocks) => blocks.extend(suffix_blocks),
None => {
tracing::debug!("Cannot allocate {:?}", self.cache_blocks);
tracing::debug!("Found {prefix_len} prefix tokens need {suffix_blocks} suffix blocks for {tokens} tokens");
tracing::debug!("Block size {}", self.block_size);
self.cache_blocks
.decref(prefix_node)
.expect("Failed to decrement refcount");
@ -100,7 +121,20 @@ impl Allocator for RadixAllocator {
}
// 1:1 mapping of blocks and slots.
let slots = blocks.clone();
let slots = if self.block_size == 1 {
blocks.clone()
} else {
let mut slots = Vec::with_capacity(blocks.len() * self.block_size as usize);
'slots: for block_id in &blocks {
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
slots.push(s);
if slots.len() as u32 == tokens {
break 'slots;
}
}
}
slots
};
let allocation = RadixAllocation {
prefix_node,
@ -136,12 +170,17 @@ impl Allocator for RadixAllocator {
// If there are prefill tokens that did not come from the cache,
// add them to the cache.
if prefill_tokens.len() > allocation.cached_prefix_len {
let aligned =
(prefill_tokens.len() / self.block_size as usize) * self.block_size as usize;
if aligned > 0 {
let prefix_len = self
.cache_blocks
.insert(prefill_tokens, &blocks[..prefill_tokens.len()])
.insert(
&prefill_tokens[..aligned],
&blocks[..aligned / self.block_size as usize],
)
// Unwrap, failing is a programming error.
.expect("Failed to store prefill tokens");
// We can have a prefill with the following structure:
//
// |---| From the prefix cache.
@ -151,12 +190,18 @@ impl Allocator for RadixAllocator {
// This means that while processing this request there was a
// partially overlapping request that had A..=E in its
// prefill. In this case we need to free the blocks D E.
self.free_blocks
.extend(&blocks[allocation.cached_prefix_len..prefix_len]);
if prefix_len > allocation.cached_prefix_len {
self.free_blocks.extend(
&blocks[allocation.cached_prefix_len / self.block_size as usize
..prefix_len / self.block_size as usize],
);
}
}
}
// Free non-prefill blocks.
self.free_blocks.extend(&blocks[prefill_tokens.len()..]);
self.free_blocks
.extend(&blocks[prefill_tokens.len() / self.block_size as usize..]);
} else {
self.free_blocks.extend(blocks);
}
@ -185,7 +230,6 @@ struct RadixAllocation {
pub enum TrieError {
InvalidNodeId,
RefCountUnderflow,
BlockTokenCountMismatch,
}
pub type NodeId = DefaultKey;
@ -204,17 +248,14 @@ pub struct RadixTrie {
/// Time as a monotonically increating counter to avoid the system
/// call that a real time lookup would require.
time: u64,
}
impl Default for RadixTrie {
fn default() -> Self {
Self::new()
}
/// All blocks need to be aligned with this
block_size: usize,
}
impl RadixTrie {
/// Construct a new radix trie.
pub fn new() -> Self {
pub fn new(block_size: usize) -> Self {
let root = TrieNode::new(vec![], vec![], 0, None);
let mut nodes = SlotMap::new();
let root = nodes.insert(root);
@ -223,13 +264,14 @@ impl RadixTrie {
nodes,
root,
time: 0,
block_size,
}
}
/// Find the prefix of the given tokens.
///
/// The blocks corresponding to the part of the prefix that could be found
/// are writteng to `blocks`. The number of blocks is in `0..=tokens.len()`.
/// are written to `blocks`. The number of blocks is in `0..=tokens.len()`.
/// Returns the identifier of the trie node that contains the longest
/// prefix. The node identifier can be used by callers to e.g. increase its
/// reference count.
@ -244,17 +286,21 @@ impl RadixTrie {
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
let node = &self.nodes[node_id];
if let Some(&child_id) = node.children.get(&key[0]) {
if key.len() >= self.block_size {
let node_key = hash(&key[..self.block_size]);
if let Some(&child_id) = node.children.get(&node_key) {
self.update_access_time(child_id);
let child = self.nodes.get(child_id).expect("Invalid child identifier");
let shared_prefix_len = child.key.shared_prefix_len(key);
blocks.extend(&child.blocks[..shared_prefix_len]);
let shared_prefix_len = shared_prefix(&child.key, key, self.block_size);
assert_eq!(shared_prefix_len % self.block_size, 0);
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
let key = &key[shared_prefix_len..];
if !key.is_empty() {
node_id = self.find_(child_id, key, blocks);
}
}
}
node_id
}
@ -277,6 +323,11 @@ impl RadixTrie {
node.ref_count -= 1;
if node.ref_count == 0 {
assert!(
node.children.is_empty(),
"Nodes with children must have refcount > 0"
);
self.leaves.insert((node.last_accessed, node_id));
}
@ -304,7 +355,7 @@ impl RadixTrie {
/// Evict `n_blocks` from the trie.
///
/// Returns the evicted blocks. When the length is less than `n_blocks`,
/// not enough blocks could beevicted.
/// not enough blocks could be evicted.
pub fn evict(&mut self, n_blocks: usize) -> Vec<u32> {
// NOTE: we don't return Result here. If any of the unwrapping fails,
// it's a programming error in the trie implementation, not a user
@ -314,11 +365,19 @@ impl RadixTrie {
// evict n_blocks and return `None` if we can't. We are now needlessly
// evicting prefixes from the cache in such a case.
let mut evicted = Vec::new();
tracing::debug!("Evicting in search of {n_blocks}");
while let Some((last_access, node_id)) = self.leaves.pop_first() {
let blocks_needed = n_blocks - evicted.len();
let blocks_needed = n_blocks.saturating_sub(evicted.len());
tracing::debug!("Evicting node {node_id:?} ");
let node = self.nodes.get(node_id).expect("Leave does not exist");
assert_eq!(
node.ref_count, 0,
"Leaf must have refcount of 0, got {}",
node.ref_count
);
if blocks_needed >= node.blocks.len() {
// We need to evict the whole node if we need more blocks than it has.
let node = self.remove_node(node_id);
@ -332,8 +391,11 @@ impl RadixTrie {
// the required number of blocks and leave the remaining blocks
// untouched.
let node = self.nodes.get_mut(node_id).expect("Leave does not exist");
node.key.truncate(node.blocks.len() - blocks_needed);
evicted.extend(node.blocks.split_off(node.blocks.len() - blocks_needed));
let truncate_blocks = node.blocks.len() - blocks_needed;
let truncate_tokens = truncate_blocks * self.block_size;
node.key.truncate(truncate_tokens);
evicted.extend(node.blocks.split_off(truncate_blocks));
self.leaves.insert((last_access, node_id));
break;
}
@ -349,7 +411,8 @@ impl RadixTrie {
/// the first 10 elements of the tree **the blocks are not updated**.
pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result<usize, TrieError> {
self.time += 1;
self.insert_(self.root, tokens, blocks)
let common = self.insert_(self.root, tokens, blocks)?;
Ok(common)
}
/// Insertion worker.
@ -363,21 +426,20 @@ impl RadixTrie {
// the part of the prefix that is already in the trie to detect
// mismatches.
if tokens.len() != blocks.len() {
return Err(TrieError::BlockTokenCountMismatch);
}
assert_eq!(tokens.len(), blocks.len() * self.block_size);
if let Some(&child_id) = self.nodes[node_id].children.get(&tokens[0]) {
let node_key = hash(&tokens[..self.block_size]);
if let Some(&child_id) = self.nodes[node_id].children.get(&node_key) {
self.update_access_time(child_id);
let child = self
.nodes
.get_mut(child_id)
// Unwrap here, since failure is a bug.
.expect("Child node does not exist");
let shared_prefix_len = child.key.shared_prefix_len(tokens);
let shared_prefix_len = shared_prefix(&child.key, tokens, self.block_size);
// We are done, the prefix is already in the trie.
if shared_prefix_len == tokens.len() {
if shared_prefix_len == tokens.len() || shared_prefix_len == 0 {
return Ok(shared_prefix_len);
}
@ -387,7 +449,7 @@ impl RadixTrie {
+ self.insert_(
child_id,
&tokens[shared_prefix_len..],
&blocks[shared_prefix_len..],
&blocks[shared_prefix_len / self.block_size..],
)?);
}
@ -396,7 +458,7 @@ impl RadixTrie {
// remainder of the prefix into the node again
let child_id = self.split_node(child_id, shared_prefix_len);
let key = &tokens[shared_prefix_len..];
let blocks = &blocks[shared_prefix_len..];
let blocks = &blocks[shared_prefix_len / self.block_size..];
Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?)
} else {
self.add_node(node_id, tokens, blocks);
@ -415,14 +477,15 @@ impl RadixTrie {
.get_mut(node_id)
.expect("Node to-be split does not exist");
let mut parent_key = node.key.split_off(prefix_len);
let mut parent_blocks = node.blocks.split_off(prefix_len);
let prefix_blocks = prefix_len / self.block_size;
let mut parent_blocks = node.blocks.split_off(prefix_blocks);
// Move first part of the prefix to the parent. We swap to avoid
// an allocation + copy for both splits of the key/blocks.
std::mem::swap(&mut node.key, &mut parent_key);
std::mem::swap(&mut node.blocks, &mut parent_blocks);
let node_key = node.key[0];
let node_key = hash(&node.key[..self.block_size]);
let grandparent_id = node.parent.expect("Node does not have a parent");
let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks);
@ -447,7 +510,7 @@ impl RadixTrie {
) -> NodeId {
let key = key.into();
let blocks = blocks.into();
let first = key[0];
let first = hash(&key[..self.block_size]);
let child = TrieNode::new(key, blocks, self.time, Some(parent_id));
let child_id = self.nodes.insert(child);
@ -459,10 +522,10 @@ impl RadixTrie {
}
/// Add a node to the parent.
fn add_node_to_parent(&mut self, parent_id: NodeId, first: u32, child_id: NodeId) {
fn add_node_to_parent(&mut self, parent_id: NodeId, hash: u64, child_id: NodeId) {
// Unwrap here, passing in an unknown id is a programming error.
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
if parent.children.insert(first, child_id).is_none() {
if parent.children.insert(hash, child_id).is_none() {
// Only increase reference count if child does not replace another child.
self.incref(parent_id)
.expect("Failed to increase parent refcount");
@ -473,12 +536,18 @@ impl RadixTrie {
fn remove_node(&mut self, node_id: NodeId) -> TrieNode {
// Unwrap here, passing in an unknown id is a programming error.
let node = self.nodes.remove(node_id).expect("Unknown node");
assert!(
node.children.is_empty(),
"Tried to remove a node with {} children",
node.children.len()
);
let parent_id = node.parent.expect("Attempted to remove root node");
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
parent.children.remove(&node.key[0]);
let node_key = hash(&node.key[..self.block_size]);
parent.children.remove(&node_key);
self.decref(parent_id)
.expect("Failed to decrease parent refcount");
self.nodes.remove(node_id);
node
}
@ -530,7 +599,7 @@ impl RadixTrie {
#[derive(Debug)]
struct TrieNode {
blocks: Vec<u32>,
children: HashMap<u32, NodeId>,
children: HashMap<u64, NodeId>,
key: Vec<u32>,
last_accessed: u64,
parent: Option<NodeId>,
@ -550,34 +619,56 @@ impl TrieNode {
}
}
/// Helper trait to get the length of the shared prefix of two sequences.
trait SharedPrefixLen {
fn shared_prefix_len(&self, other: &Self) -> usize;
}
impl<T> SharedPrefixLen for [T]
where
T: PartialEq,
{
fn shared_prefix_len(&self, other: &Self) -> usize {
self.iter().zip(other).take_while(|(a, b)| a == b).count()
}
fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize {
let full = left.iter().zip(right).take_while(|(a, b)| a == b).count();
// NOTE: this is the case because the child node was chosen based on
// matching the first character of the key/prefix.
assert!(full > 0, "Prefixes must at least share 1 token");
(full / block_size) * block_size
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::block_allocator::Allocator;
use super::*;
use super::RadixAllocator;
#[test]
fn allocator_block_size() {
let mut cache = RadixAllocator::new(2, 12, None);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
assert_eq!(allocation.prefix_len, 0);
cache.free(allocation.blocks.clone(), allocation.allocation_id);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
assert_eq!(allocation.prefix_len, 4);
}
#[test]
fn allocator_block_size_non_aligned() {
let mut cache = RadixAllocator::new(2, 12, None);
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
assert_eq!(allocation.prefix_len, 0);
cache.free(allocation.blocks.clone(), allocation.allocation_id);
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
assert_eq!(allocation.prefix_len, 2);
}
#[test]
fn allocator_reuses_prefixes() {
let mut cache = RadixAllocator::new(1, 12, None);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
assert_eq!(allocation.slots, allocation.slots);
assert_eq!(allocation.blocks, allocation.slots);
assert_eq!(allocation.prefix_len, 0);
cache.free(allocation.blocks.clone(), allocation.allocation_id);
@ -666,7 +757,7 @@ mod tests {
#[test]
fn trie_insertions_have_correct_prefix_len() {
let mut trie = super::RadixTrie::new();
let mut trie = RadixTrie::new(1);
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0);
@ -687,9 +778,33 @@ mod tests {
);
}
#[test]
fn trie_insertions_block_size() {
let mut trie = RadixTrie::new(2);
assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 0);
// Already exists.
// But needs to be block_size aligned
assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 4);
// Completely new at root-level
assert_eq!(trie.insert(&[1, 2, 3, 4], &[1, 2]).unwrap(), 0);
// Contains full prefix, but longer.
assert_eq!(trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2]).unwrap(), 4);
// Shares partial prefix, we need a split.
assert_eq!(
trie.insert(&[0, 1, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3])
.unwrap(),
2
);
}
#[test]
fn trie_get_returns_correct_blocks() {
let mut trie = super::RadixTrie::new();
let mut trie = RadixTrie::new(1);
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
@ -723,7 +838,7 @@ mod tests {
#[test]
fn trie_evict_removes_correct_blocks() {
let mut trie = super::RadixTrie::new();
let mut trie = RadixTrie::new(1);
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
.unwrap();

View File

@ -16,7 +16,7 @@ path = "src/main.rs"
[dependencies]
average = "0.14"
clap = { version = "4.4.5", features = ["derive", "env"] }
crossterm = "0.27"
crossterm = "0.28.1"
float-ord = "0.3.2"
serde = {version = "1.0.188", features = ["derive"]}
serde_json = "1.0"
@ -25,7 +25,7 @@ text-generation-client = { path = "../backends/client" }
thiserror = "1.0.48"
tokenizers = { workspace = true }
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }
tui = {package = "ratatui", version = "0.23", default-features = false, features = ["crossterm"]}
ratatui = { version = "0.28.1", default-features = false, features = ["crossterm"] }
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
hf-hub = { workspace = true }

View File

@ -1,16 +1,15 @@
/// Inspired by https://github.com/hatoo/oha/blob/bb989ea3cd77727e7743e7daa60a19894bb5e901/src/monitor.rs
use crate::generation::{Decode, Message, Prefill};
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
use text_generation_client::ClientError;
use tokio::sync::mpsc;
use tui::backend::Backend;
use tui::layout::{Alignment, Constraint, Direction, Layout};
use tui::style::{Color, Modifier, Style};
use tui::text::{Line, Span};
use tui::widgets::{
use ratatui::layout::{Alignment, Constraint, Direction, Layout};
use ratatui::style::{Color, Modifier, Style};
use ratatui::text::{Line, Span};
use ratatui::widgets::{
Axis, BarChart, Block, Borders, Chart, Dataset, Gauge, GraphType, Paragraph, Tabs,
};
use tui::{symbols, Frame};
use ratatui::{symbols, Frame};
use text_generation_client::ClientError;
use tokio::sync::mpsc;
/// TUI powered App
pub(crate) struct App {
@ -153,7 +152,7 @@ impl App {
}
/// Render frame
pub fn render<B: Backend>(&mut self, f: &mut Frame<'_, B>) {
pub fn render(&mut self, f: &mut Frame) {
let batch_progress =
(self.completed_batch as f64 / self.data.batch_size.len() as f64).clamp(0.0, 1.0);
let run_progress =
@ -172,7 +171,7 @@ impl App {
]
.as_ref(),
)
.split(f.size());
.split(f.area());
// Top row horizontal layout
let top = Layout::default()
@ -239,7 +238,7 @@ impl App {
f.render_widget(helper, row5[0]);
// Batch tabs
let titles = self
let titles: Vec<Line> = self
.data
.batch_size
.iter()

View File

@ -148,6 +148,7 @@ async fn prefill(
}),
inputs: sequence.clone(),
truncate: sequence_length,
add_special_tokens: true,
parameters: Some(parameters.clone()),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: decode_length,

View File

@ -7,12 +7,12 @@ mod utils;
use crate::app::App;
use crate::event::Event;
use crossterm::ExecutableCommand;
use ratatui::backend::CrosstermBackend;
use ratatui::Terminal;
use std::io;
use text_generation_client::v3::{GrammarType, NextTokenChooserParameters, ShardedClient};
use tokenizers::Tokenizer;
use tokio::sync::{broadcast, mpsc};
use tui::backend::CrosstermBackend;
use tui::Terminal;
/// Run benchmarking app
#[allow(clippy::too_many_arguments)]

View File

@ -757,7 +757,12 @@ class AsyncClient:
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
payload_data = (
payload.lstrip("data:").rstrip("\n").removeprefix(" ")
)
if payload_data == "[DONE]":
break
json_payload = json.loads(payload_data)
try:
response = ChatCompletionChunk(**json_payload)
yield response

View File

@ -556,6 +556,37 @@
}
}
}
},
"/v1/models": {
"get": {
"tags": [
"Text Generation Inference"
],
"summary": "Get model info",
"operationId": "openai_get_model_info",
"responses": {
"200": {
"description": "Served model info",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ModelInfo"
}
}
}
},
"404": {
"description": "Model not found",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
}
}
}
}
}
}
}
},
"components": {
@ -924,7 +955,7 @@
"tool_prompt": {
"type": "string",
"description": "A prompt to be appended before the tools",
"example": "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"",
"example": "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.",
"nullable": true
},
"tools": {
@ -1747,6 +1778,35 @@
}
]
},
"ModelInfo": {
"type": "object",
"required": [
"id",
"object",
"created",
"owned_by"
],
"properties": {
"created": {
"type": "integer",
"format": "int64",
"example": 1686935002,
"minimum": 0
},
"id": {
"type": "string",
"example": "gpt2"
},
"object": {
"type": "string",
"example": "model"
},
"owned_by": {
"type": "string",
"example": "openai"
}
}
},
"OutputMessage": {
"oneOf": [
{

View File

@ -71,6 +71,8 @@
title: How Guidance Works (via outlines)
- local: conceptual/lora
title: LoRA (Low-Rank Adaptation)
- local: conceptual/external
title: External Resources
title: Conceptual Guides

View File

@ -157,7 +157,12 @@ from huggingface_hub import InferenceClient
client = InferenceClient("http://localhost:3000")
regexp = "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)"
section_regex = "(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"
regexp = f"HELLO\.{section_regex}\.WORLD\.{section_regex}"
# This is a more realistic example of an ip address regex
# regexp = f"{section_regex}\.{section_regex}\.{section_regex}\.{section_regex}"
resp = client.text_generation(
f"Whats Googles DNS? Please use the following regex: {regexp}",
@ -170,7 +175,7 @@ resp = client.text_generation(
print(resp)
# 7.1.1.1
# HELLO.255.WORLD.255
```

View File

@ -0,0 +1,4 @@
# External Resources
- Adyen wrote a detailed article about the interplay between TGI's main components: router and server.
[LLM inference at scale with TGI (Martin Iglesias Goyanes - Adyen, 2024)](https://www.adyen.com/knowledge-hub/llm-inference-at-scale-with-tgi)

View File

@ -1,5 +1,6 @@
# Streaming
## What is Streaming?
Token streaming is the mode in which the server returns the tokens one by one as the model generates them. This enables showing progressive generations to the user rather than waiting for the whole generation. Streaming is an essential aspect of the end-user experience as it reduces latency, one of the most critical aspects of a smooth experience.

View File

@ -12,7 +12,24 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.2.0-intel \
ghcr.io/huggingface/text-generation-inference:2.2.0-intel-xpu \
--model-id $model --cuda-graphs 0
```
# Using TGI with Intel CPUs
Intel® Extension for PyTorch (IPEX) also provides further optimizations for Intel CPUs. The IPEX provides optimization operations such as flash attention, page attention, Add + LayerNorm, ROPE and more.
On a server powered by Intel CPU, TGI can be launched with the following command:
```bash
model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.2.0-intel-cpu \
--model-id $model --cuda-graphs 0
```

View File

@ -492,6 +492,24 @@
"type": "github"
}
},
"flake-utils_7": {
"inputs": {
"systems": "systems_7"
},
"locked": {
"lastModified": 1710146030,
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"gitignore": {
"inputs": {
"nixpkgs": [
@ -700,16 +718,16 @@
},
"nixpkgs_6": {
"locked": {
"lastModified": 1723912943,
"narHash": "sha256-39F9GzyhxYcY3wTeKuEFWRJWcrGBosO4nf4xzMTWZX8=",
"owner": "danieldk",
"lastModified": 1724915739,
"narHash": "sha256-7PgRge4mn5akFvhPwefuaLQGbF5BnmxlwZJEf7CgbrE=",
"owner": "nixos",
"repo": "nixpkgs",
"rev": "b82cdca86dbb30013b76c4b55d48806476820a5c",
"rev": "85be051bb60943d3328d91aaf2598798f87e19af",
"type": "github"
},
"original": {
"owner": "danieldk",
"ref": "cuda-12.4",
"owner": "nixos",
"ref": "nixos-unstable-small",
"repo": "nixpkgs",
"type": "github"
}
@ -835,11 +853,11 @@
]
},
"locked": {
"lastModified": 1724206841,
"narHash": "sha256-L8dKaX4T3k+TR2fEHCfGbH4UXdspovz/pj87iai9qmc=",
"lastModified": 1726280639,
"narHash": "sha256-YfLRPlFZWrT2oRLNAoqf7G3+NnUTDdlIJk6tmBU7kXM=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "45e98fbd62c32e5927e952d2833fa1ba4fb35a61",
"rev": "e9f8641c92f26fd1e076e705edb12147c384171d",
"type": "github"
},
"original": {
@ -938,17 +956,33 @@
"type": "github"
}
},
"systems_7": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"tgi-nix": {
"inputs": {
"flake-compat": "flake-compat_4",
"flake-utils": "flake-utils_7",
"nixpkgs": "nixpkgs_6"
},
"locked": {
"lastModified": 1724270760,
"narHash": "sha256-KX566x0+3HZcB20HPdvdwyMm7ZJg21M+iqVrs/HCimA=",
"lastModified": 1726229792,
"narHash": "sha256-9xsLmjc9nr7a4PTddKv2DOi82ompTtJNyjO6R67y5tE=",
"owner": "danieldk",
"repo": "tgi-nix",
"rev": "12cbaa76ff258351741d3b5afb7161f617fe7b4c",
"rev": "1a902f4818e94c3f8d95f6000db17bc3fadd0ce7",
"type": "github"
},
"original": {

View File

@ -31,15 +31,12 @@
src = ./.;
additionalCargoNixArgs = [ "--all-features" ];
};
config = {
allowUnfree = true;
cudaSupport = true;
};
pkgs = import nixpkgs {
inherit config system;
inherit system;
inherit (tgi-nix.lib) config;
overlays = [
rust-overlay.overlays.default
tgi-nix.overlay
tgi-nix.overlays.default
];
};
crateOverrides = import ./nix/crate-overrides.nix { inherit pkgs nix-filter; };
@ -49,12 +46,52 @@
launcher = cargoNix.workspaceMembers.text-generation-launcher.build.override {
inherit crateOverrides;
};
router = cargoNix.workspaceMembers.text-generation-router-v3.build.override {
router =
let
routerUnwrapped = cargoNix.workspaceMembers.text-generation-router-v3.build.override {
inherit crateOverrides;
};
packagePath =
with pkgs.python3.pkgs;
makePythonPath [
protobuf
sentencepiece
torch
transformers
];
in
pkgs.writeShellApplication {
name = "text-generation-router";
text = ''
PYTHONPATH="${packagePath}" ${routerUnwrapped}/bin/text-generation-router "$@"
'';
};
server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; };
in
{
checks = {
rust = with pkgs; rustPlatform.buildRustPackage {
name = "rust-checks";
src = ./.;
cargoLock = {
lockFile = ./Cargo.lock;
};
buildInputs = [ openssl.dev ];
nativeBuildInputs = [ clippy pkg-config protobuf python3 rustfmt ];
buildPhase = ''
cargo check
'';
checkPhase = ''
cargo fmt -- --check
cargo test -j $NIX_BUILD_CORES
cargo clippy
'';
installPhase = "touch $out";
} ;
};
formatter = pkgs.nixfmt-rfc-style;
devShells = with pkgs; rec {
default = pure;
@ -66,6 +103,29 @@
server
];
};
test = mkShell {
buildInputs =
[
# benchmark
# launcher
# router
server
openssl.dev
pkg-config
cargo
rustfmt
clippy
]
++ (with python3.pkgs; [
docker
pytest
pytest-asyncio
syrupy
pre-commit
ruff
]);
};
impure = mkShell {
buildInputs =
@ -82,16 +142,25 @@
]
++ (with python3.pkgs; [
venvShellHook
docker
pip
ipdb
click
pyright
pytest
pytest-asyncio
ruff
syrupy
]);
inputsFrom = [ server ];
venvDir = "./.venv";
postVenv = ''
postVenvCreation = ''
unset SOURCE_DATE_EPOCH
( cd server ; python -m pip install --no-dependencies -e . )
( cd clients/python ; python -m pip install --no-dependencies -e . )
'';
postShellHook = ''
unset SOURCE_DATE_EPOCH
@ -99,6 +168,17 @@
'';
};
};
packages.default = pkgs.writeShellApplication {
name = "text-generation-inference";
runtimeInputs = [
server
router
];
text = ''
${launcher}/bin/text-generation-launcher "$@"
'';
};
}
);
}

View File

@ -19,6 +19,7 @@ from syrupy.extensions.json import JSONSnapshotExtension
from text_generation import AsyncClient
from text_generation.types import (
BestOfSequence,
Message,
ChatComplete,
ChatCompletionChunk,
ChatCompletionComplete,
@ -64,6 +65,7 @@ class ResponseComparator(JSONSnapshotExtension):
self,
data,
*,
include=None,
exclude=None,
matcher=None,
):
@ -79,7 +81,12 @@ class ResponseComparator(JSONSnapshotExtension):
data = [d.model_dump() for d in data]
data = self._filter(
data=data, depth=0, path=(), exclude=exclude, matcher=matcher
data=data,
depth=0,
path=(),
exclude=exclude,
include=include,
matcher=matcher,
)
return json.dumps(data, indent=2, ensure_ascii=False, sort_keys=False) + "\n"
@ -91,7 +98,14 @@ class ResponseComparator(JSONSnapshotExtension):
) -> bool:
def convert_data(data):
data = json.loads(data)
if isinstance(data, Dict) and "choices" in data:
return _convert_data(data)
def _convert_data(data):
if isinstance(data, Dict):
if "choices" in data:
data["choices"] = list(
sorted(data["choices"], key=lambda x: x["index"])
)
choices = data["choices"]
if isinstance(choices, List) and len(choices) >= 1:
if "delta" in choices[0]:
@ -99,17 +113,10 @@ class ResponseComparator(JSONSnapshotExtension):
if "text" in choices[0]:
return Completion(**data)
return ChatComplete(**data)
if isinstance(data, Dict):
else:
return Response(**data)
if isinstance(data, List):
if (
len(data) > 0
and "object" in data[0]
and data[0]["object"] == "text_completion"
):
return [Completion(**d) for d in data]
return [Response(**d) for d in data]
return [_convert_data(d) for d in data]
raise NotImplementedError
def eq_token(token: Token, other: Token) -> bool:
@ -257,7 +264,7 @@ class IgnoreLogProbResponseComparator(ResponseComparator):
class LauncherHandle:
def __init__(self, port: int):
self.client = AsyncClient(f"http://localhost:{port}")
self.client = AsyncClient(f"http://localhost:{port}", timeout=30)
def _inner_health(self):
raise NotImplementedError
@ -335,6 +342,7 @@ def launcher(event_loop):
max_total_tokens: Optional[int] = None,
lora_adapters: Optional[List[str]] = None,
cuda_graphs: Optional[List[int]] = None,
attention: Optional[str] = None,
):
port = random.randint(8000, 10_000)
master_port = random.randint(10_000, 20_000)
@ -394,6 +402,8 @@ def launcher(event_loop):
if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false"
if attention is not None:
env["ATTENTION"] = attention
with tempfile.TemporaryFile("w+") as tmp:
# We'll output stdout/stderr to a temporary file. Using a pipe
@ -430,6 +440,7 @@ def launcher(event_loop):
max_total_tokens: Optional[int] = None,
lora_adapters: Optional[List[str]] = None,
cuda_graphs: Optional[List[int]] = None,
attention: Optional[str] = None,
):
port = random.randint(8000, 10_000)
@ -484,6 +495,8 @@ def launcher(event_loop):
}
if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false"
if attention is not None:
env["ATTENTION"] = attention
if HF_TOKEN is not None:
env["HF_TOKEN"] = HF_TOKEN
@ -515,6 +528,7 @@ def launcher(event_loop):
devices=devices,
volumes=volumes,
ports={"80/tcp": port},
healthcheck={"timeout": int(10 * 1e9)},
shm_size="1G",
)
@ -565,3 +579,38 @@ def generate_load():
return await asyncio.gather(*futures)
return generate_load_inner
@pytest.fixture(scope="module")
def generate_multi():
async def generate_load_inner(
client: AsyncClient,
prompts: List[str],
max_new_tokens: int,
seed: Optional[int] = None,
) -> List[Response]:
import numpy as np
arange = np.arange(len(prompts))
perm = np.random.permutation(arange)
rperm = [-1] * len(perm)
for i, p in enumerate(perm):
rperm[p] = i
shuffled_prompts = [prompts[p] for p in perm]
futures = [
client.chat(
messages=[Message(role="user", content=prompt)],
max_tokens=max_new_tokens,
temperature=0,
seed=seed,
)
for prompt in shuffled_prompts
]
shuffled_responses = await asyncio.gather(*futures)
responses = [shuffled_responses[p] for p in rperm]
return responses
return generate_load_inner

View File

@ -5,7 +5,7 @@
"index": 0,
"logprobs": null,
"message": {
"content": "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to explore in the middle of urban confines. In fact, typical times for humidity levels in Brooklyn include:\n\n- Early morning: 80-85% humidity, with occas",
"content": "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to appreciate nature.\n\nIn terms of temperature, the warmest times of the year are from June to August, when average high temperatures typically range from around 73°F or 23°C",
"name": null,
"role": "assistant",
"tool_calls": null
@ -13,14 +13,14 @@
"usage": null
}
],
"created": 1716553098,
"created": 1724792495,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.5-dev0-native",
"object": "chat.completion",
"system_fingerprint": "2.2.1-dev0-native",
"usage": {
"completion_tokens": 100,
"prompt_tokens": 62,
"total_tokens": 162
"prompt_tokens": 61,
"total_tokens": 161
}
}

View File

@ -1,38 +1,38 @@
{
"choices": [
{
"finish_reason": "stop",
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": " A Beginners Guide\nDeep learning is a subset"
},
{
"finish_reason": "length",
"index": 1,
"logprobs": null,
"text": " PR for more information?"
"text": " This is a question that has puzzled many people for"
},
{
"finish_reason": "length",
"index": 3,
"logprobs": null,
"text": "hd20220811-"
},
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": "le Business Incubator is providing a workspace"
"text": "usculas_minusculas(s):\n \"\"\"\n"
},
{
"finish_reason": "length",
"index": 2,
"logprobs": null,
"text": " severely flawed and often has a substandard"
"text": " Paris\nWhat is the capital of France?\nThe"
}
],
"created": 1722014725,
"created": 1725877154,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native",
"usage": {
"completion_tokens": 36,
"prompt_tokens": 8,
"total_tokens": 44
"completion_tokens": 40,
"prompt_tokens": 22,
"total_tokens": 62
}
}

View File

@ -5,14 +5,14 @@
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "\n"
"text": " A"
}
],
"created": 1713284431,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
@ -20,14 +20,74 @@
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "\n"
"text": " This"
}
],
"created": 1713284431,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " Paris"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "us"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " Beginner"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " is"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
@ -38,11 +98,11 @@
"text": "\n"
}
],
"created": 1713284431,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
@ -50,14 +110,14 @@
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "hd"
"text": "cul"
}
],
"created": 1713284431,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
@ -65,14 +125,14 @@
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "\n"
"text": "s"
}
],
"created": 1713284431,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
@ -80,14 +140,14 @@
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "\n"
"text": " a"
}
],
"created": 1713284431,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
@ -95,14 +155,14 @@
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "\n"
"text": "What"
}
],
"created": 1713284431,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
@ -110,14 +170,14 @@
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "aho"
"text": "as"
}
],
"created": 1713284431,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
@ -125,14 +185,14 @@
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "2"
"text": " Guide"
}
],
"created": 1713284431,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
@ -140,254 +200,14 @@
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "2"
"text": " question"
}
],
"created": 1713284431,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "2"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "ima"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "."
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "."
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "."
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "\n"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " Sarah"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " Yes"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " And"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "i"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "'"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": ","
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " what"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "'"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "s"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " Moh"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
@ -398,11 +218,11 @@
"text": " is"
}
],
"created": 1713284431,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
@ -410,14 +230,14 @@
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "m"
"text": "_minus"
}
],
"created": 1713284431,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
@ -425,14 +245,14 @@
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " Room"
"text": "\n"
}
],
"created": 1713284431,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
@ -440,14 +260,14 @@
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "s"
"text": " that"
}
],
"created": 1713284431,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
@ -458,11 +278,11 @@
"text": " the"
}
],
"created": 1713284431,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
@ -470,14 +290,14 @@
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": " tired"
"text": "cul"
}
],
"created": 1713284431,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
@ -485,14 +305,14 @@
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": ":"
"text": "Deep"
}
],
"created": 1713284431,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
@ -500,14 +320,14 @@
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "'"
"text": " has"
}
],
"created": 1713284431,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
@ -518,11 +338,11 @@
"text": " capital"
}
],
"created": 1713284431,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
@ -530,14 +350,14 @@
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": " of"
"text": "as"
}
],
"created": 1713284431,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
@ -545,14 +365,14 @@
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " She"
"text": " learning"
}
],
"created": 1713284431,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
@ -560,14 +380,14 @@
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " scale"
"text": " puzzled"
}
],
"created": 1713284431,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
@ -578,11 +398,11 @@
"text": " of"
}
],
"created": 1713284431,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
@ -590,13 +410,193 @@
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": " being"
"text": "(s"
}
],
"created": 1713284431,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native"
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " is"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " many"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " France"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "):\n"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " a"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " people"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "?\n"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": " "
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": " subset"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "length",
"index": 1,
"logprobs": null,
"text": " for"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "length",
"index": 2,
"logprobs": null,
"text": "The"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "length",
"index": 3,
"logprobs": null,
"text": " \"\"\"\n"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
}
]

View File

@ -4,17 +4,17 @@
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": " PR for flake8"
"text": " A Beginners Guide\nDeep learning is a subset"
}
],
"created": 1713284454,
"created": 1725876621,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native",
"system_fingerprint": "2.2.1-dev0-native",
"usage": {
"completion_tokens": 5,
"completion_tokens": 10,
"prompt_tokens": 6,
"total_tokens": 11
"total_tokens": 16
}
}

View File

@ -16,7 +16,7 @@
},
{
"id": 3102,
"logprob": -11.1875,
"logprob": -11.25,
"text": " request"
}
],
@ -24,66 +24,66 @@
"tokens": [
{
"id": 185,
"logprob": -1.5546875,
"logprob": -1.546875,
"special": false,
"text": "\n"
},
{
"id": 549,
"logprob": -2.84375,
"logprob": -2.859375,
"special": false,
"text": "The"
},
{
"id": 1727,
"logprob": -2.34375,
"logprob": -2.484375,
"special": false,
"text": " test"
},
{
"id": 3102,
"logprob": -0.8359375,
"logprob": -0.83203125,
"special": false,
"text": " request"
},
{
"id": 317,
"logprob": -1.0859375,
"logprob": -1.1484375,
"special": false,
"text": " is"
},
{
"id": 254,
"logprob": -1.5390625,
"id": 245,
"logprob": -1.578125,
"special": false,
"text": " the"
"text": " a"
},
{
"id": 1022,
"logprob": -1.1875,
"id": 3412,
"logprob": -2.578125,
"special": false,
"text": " first"
"text": " document"
},
{
"id": 3458,
"logprob": -0.35546875,
"id": 344,
"logprob": -1.125,
"special": false,
"text": " step"
"text": " that"
},
{
"id": 279,
"logprob": -0.8828125,
"id": 317,
"logprob": -1.6953125,
"special": false,
"text": " in"
"text": " is"
},
{
"id": 254,
"logprob": -0.71484375,
"id": 1222,
"logprob": -1.71875,
"special": false,
"text": " the"
"text": " used"
}
],
"top_tokens": null
},
"generated_text": "\nThe test request is the first step in the"
"generated_text": "\nThe test request is a document that is used"
}

View File

@ -37,56 +37,56 @@
},
{
"id": 1727,
"logprob": -2.359375,
"logprob": -2.4375,
"special": false,
"text": " test"
},
{
"id": 3102,
"logprob": -0.83203125,
"logprob": -0.83984375,
"special": false,
"text": " request"
},
{
"id": 317,
"logprob": -1.125,
"logprob": -1.1328125,
"special": false,
"text": " is"
},
{
"id": 245,
"logprob": -1.5703125,
"id": 254,
"logprob": -1.515625,
"special": false,
"text": " a"
"text": " the"
},
{
"id": 3412,
"logprob": -2.578125,
"id": 1022,
"logprob": -1.15625,
"special": false,
"text": " document"
"text": " first"
},
{
"id": 344,
"logprob": -1.125,
"id": 3458,
"logprob": -0.3671875,
"special": false,
"text": " that"
"text": " step"
},
{
"id": 317,
"logprob": -1.6953125,
"id": 279,
"logprob": -0.88671875,
"special": false,
"text": " is"
"text": " in"
},
{
"id": 1222,
"logprob": -1.75,
"id": 254,
"logprob": -0.69140625,
"special": false,
"text": " used"
"text": " the"
}
],
"top_tokens": null
},
"generated_text": "\nThe test request is a document that is used"
"generated_text": "\nThe test request is the first step in the"
},
{
"details": {
@ -126,56 +126,56 @@
},
{
"id": 1727,
"logprob": -2.359375,
"logprob": -2.4375,
"special": false,
"text": " test"
},
{
"id": 3102,
"logprob": -0.83203125,
"logprob": -0.83984375,
"special": false,
"text": " request"
},
{
"id": 317,
"logprob": -1.125,
"logprob": -1.1328125,
"special": false,
"text": " is"
},
{
"id": 245,
"logprob": -1.5703125,
"id": 254,
"logprob": -1.515625,
"special": false,
"text": " a"
"text": " the"
},
{
"id": 3412,
"logprob": -2.578125,
"id": 1022,
"logprob": -1.15625,
"special": false,
"text": " document"
"text": " first"
},
{
"id": 344,
"logprob": -1.125,
"id": 3458,
"logprob": -0.3671875,
"special": false,
"text": " that"
"text": " step"
},
{
"id": 317,
"logprob": -1.6953125,
"id": 279,
"logprob": -0.88671875,
"special": false,
"text": " is"
"text": " in"
},
{
"id": 1222,
"logprob": -1.75,
"id": 254,
"logprob": -0.69140625,
"special": false,
"text": " used"
"text": " the"
}
],
"top_tokens": null
},
"generated_text": "\nThe test request is a document that is used"
"generated_text": "\nThe test request is the first step in the"
},
{
"details": {
@ -215,56 +215,56 @@
},
{
"id": 1727,
"logprob": -2.359375,
"logprob": -2.4375,
"special": false,
"text": " test"
},
{
"id": 3102,
"logprob": -0.83203125,
"logprob": -0.83984375,
"special": false,
"text": " request"
},
{
"id": 317,
"logprob": -1.125,
"logprob": -1.1328125,
"special": false,
"text": " is"
},
{
"id": 245,
"logprob": -1.5703125,
"id": 254,
"logprob": -1.515625,
"special": false,
"text": " a"
"text": " the"
},
{
"id": 3412,
"logprob": -2.578125,
"id": 1022,
"logprob": -1.15625,
"special": false,
"text": " document"
"text": " first"
},
{
"id": 344,
"logprob": -1.125,
"id": 3458,
"logprob": -0.3671875,
"special": false,
"text": " that"
"text": " step"
},
{
"id": 317,
"logprob": -1.6953125,
"id": 279,
"logprob": -0.88671875,
"special": false,
"text": " is"
"text": " in"
},
{
"id": 1222,
"logprob": -1.75,
"id": 254,
"logprob": -0.69140625,
"special": false,
"text": " used"
"text": " the"
}
],
"top_tokens": null
},
"generated_text": "\nThe test request is a document that is used"
"generated_text": "\nThe test request is the first step in the"
},
{
"details": {
@ -304,55 +304,55 @@
},
{
"id": 1727,
"logprob": -2.359375,
"logprob": -2.4375,
"special": false,
"text": " test"
},
{
"id": 3102,
"logprob": -0.83203125,
"logprob": -0.83984375,
"special": false,
"text": " request"
},
{
"id": 317,
"logprob": -1.125,
"logprob": -1.1328125,
"special": false,
"text": " is"
},
{
"id": 245,
"logprob": -1.5703125,
"id": 254,
"logprob": -1.515625,
"special": false,
"text": " a"
"text": " the"
},
{
"id": 3412,
"logprob": -2.578125,
"id": 1022,
"logprob": -1.15625,
"special": false,
"text": " document"
"text": " first"
},
{
"id": 344,
"logprob": -1.125,
"id": 3458,
"logprob": -0.3671875,
"special": false,
"text": " that"
"text": " step"
},
{
"id": 317,
"logprob": -1.6953125,
"id": 279,
"logprob": -0.88671875,
"special": false,
"text": " is"
"text": " in"
},
{
"id": 1222,
"logprob": -1.75,
"id": 254,
"logprob": -0.69140625,
"special": false,
"text": " used"
"text": " the"
}
],
"top_tokens": null
},
"generated_text": "\nThe test request is a document that is used"
"generated_text": "\nThe test request is the first step in the"
}
]

View File

@ -1,8 +1,8 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"finish_reason": "stop_sequence",
"generated_tokens": 5,
"prefill": [
{
"id": 128000,
@ -16,7 +16,7 @@
},
{
"id": 1715,
"logprob": -10.375,
"logprob": -10.4375,
"text": " request"
}
],
@ -29,61 +29,31 @@
"text": ":"
},
{
"id": 2209,
"logprob": -2.78125,
"id": 923,
"logprob": -2.84375,
"special": false,
"text": " Is"
"text": " add"
},
{
"id": 279,
"logprob": -0.6328125,
"id": 264,
"logprob": 0.0,
"special": false,
"text": " the"
},
{
"id": 734,
"logprob": -2.703125,
"special": false,
"text": " function"
"text": " a"
},
{
"id": 330,
"logprob": -0.34179688,
"logprob": -0.31640625,
"special": false,
"text": " \""
},
{
"id": 4110,
"logprob": -2.359375,
"id": 1985,
"logprob": 0.0,
"special": false,
"text": "Create"
},
{
"id": 7575,
"logprob": -2.1875,
"special": false,
"text": "Process"
},
{
"id": 1,
"logprob": -0.07910156,
"special": false,
"text": "\""
},
{
"id": 304,
"logprob": -0.83203125,
"special": false,
"text": " in"
},
{
"id": 12468,
"logprob": -1.8203125,
"special": false,
"text": " Win"
"text": "test"
}
],
"top_tokens": null
},
"generated_text": "Test request: Is the function \"CreateProcess\" in Win"
"generated_text": "Test request: add a \"test"
}

View File

@ -0,0 +1,114 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1824,
"logprob": -6.1445312,
"text": "What"
},
{
"id": 349,
"logprob": -1.4648438,
"text": "is"
},
{
"id": 21135,
"logprob": -13.6875,
"text": "gradient"
},
{
"id": 24871,
"logprob": -1.6005859,
"text": "descent"
},
{
"id": 28804,
"logprob": -0.39526367,
"text": "?"
},
{
"id": 13,
"logprob": -0.640625,
"text": "\n"
},
{
"id": 13,
"logprob": -0.18774414,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 20910,
"logprob": -0.96484375,
"special": false,
"text": "Grad"
},
{
"id": 722,
"logprob": -0.003168106,
"special": false,
"text": "ient"
},
{
"id": 24871,
"logprob": -0.16540527,
"special": false,
"text": " descent"
},
{
"id": 349,
"logprob": -0.08886719,
"special": false,
"text": " is"
},
{
"id": 396,
"logprob": -0.75878906,
"special": false,
"text": " an"
},
{
"id": 18586,
"logprob": -0.5703125,
"special": false,
"text": " optimization"
},
{
"id": 9464,
"logprob": -0.11242676,
"special": false,
"text": " algorithm"
},
{
"id": 1307,
"logprob": -0.7939453,
"special": false,
"text": " used"
},
{
"id": 298,
"logprob": -0.17102051,
"special": false,
"text": " to"
},
{
"id": 26518,
"logprob": -0.34326172,
"special": false,
"text": " minimize"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is an optimization algorithm used to minimize"
}

View File

@ -0,0 +1,99 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 24871,
"logprob": -17.234375,
"text": "descent"
},
{
"id": 28804,
"logprob": -7.4335938,
"text": "?"
},
{
"id": 13,
"logprob": -0.8017578,
"text": "\n"
},
{
"id": 13,
"logprob": -0.32958984,
"text": "\n"
}
],
"seed": 0,
"tokens": [
{
"id": 1313,
"logprob": -2.3613281,
"special": false,
"text": "It"
},
{
"id": 3969,
"logprob": -0.7285156,
"special": false,
"text": " seems"
},
{
"id": 298,
"logprob": -1.3466797,
"special": false,
"text": " to"
},
{
"id": 528,
"logprob": 0.0,
"special": false,
"text": " me"
},
{
"id": 28725,
"logprob": -1.6757812,
"special": false,
"text": ","
},
{
"id": 369,
"logprob": -0.06585693,
"special": false,
"text": " that"
},
{
"id": 513,
"logprob": -1.1269531,
"special": false,
"text": " if"
},
{
"id": 368,
"logprob": 0.0,
"special": false,
"text": " you"
},
{
"id": 28742,
"logprob": -2.4921875,
"special": false,
"text": "'"
},
{
"id": 267,
"logprob": 0.0,
"special": false,
"text": "re"
}
],
"top_tokens": null
},
"generated_text": "What is gradient descent?\n\nIt seems to me, that if you're"
}

View File

@ -0,0 +1,458 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1824,
"logprob": -6.1445312,
"text": "What"
},
{
"id": 349,
"logprob": -1.4648438,
"text": "is"
},
{
"id": 21135,
"logprob": -13.6875,
"text": "gradient"
},
{
"id": 24871,
"logprob": -1.6005859,
"text": "descent"
},
{
"id": 28804,
"logprob": -0.39526367,
"text": "?"
},
{
"id": 13,
"logprob": -0.640625,
"text": "\n"
},
{
"id": 13,
"logprob": -0.18774414,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 20910,
"logprob": -0.96484375,
"special": false,
"text": "Grad"
},
{
"id": 722,
"logprob": -0.003168106,
"special": false,
"text": "ient"
},
{
"id": 24871,
"logprob": -0.16369629,
"special": false,
"text": " descent"
},
{
"id": 349,
"logprob": -0.0881958,
"special": false,
"text": " is"
},
{
"id": 396,
"logprob": -0.76708984,
"special": false,
"text": " an"
},
{
"id": 18586,
"logprob": -0.57373047,
"special": false,
"text": " optimization"
},
{
"id": 9464,
"logprob": -0.11291504,
"special": false,
"text": " algorithm"
},
{
"id": 1307,
"logprob": -0.79589844,
"special": false,
"text": " used"
},
{
"id": 298,
"logprob": -0.1694336,
"special": false,
"text": " to"
},
{
"id": 26518,
"logprob": -0.34350586,
"special": false,
"text": " minimize"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is an optimization algorithm used to minimize"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1824,
"logprob": -6.1445312,
"text": "What"
},
{
"id": 349,
"logprob": -1.4677734,
"text": "is"
},
{
"id": 21135,
"logprob": -13.6875,
"text": "gradient"
},
{
"id": 24871,
"logprob": -1.6015625,
"text": "descent"
},
{
"id": 28804,
"logprob": -0.39453125,
"text": "?"
},
{
"id": 13,
"logprob": -0.6435547,
"text": "\n"
},
{
"id": 13,
"logprob": -0.18713379,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 20910,
"logprob": -0.9628906,
"special": false,
"text": "Grad"
},
{
"id": 722,
"logprob": -0.0032176971,
"special": false,
"text": "ient"
},
{
"id": 24871,
"logprob": -0.16540527,
"special": false,
"text": " descent"
},
{
"id": 349,
"logprob": -0.08898926,
"special": false,
"text": " is"
},
{
"id": 396,
"logprob": -0.765625,
"special": false,
"text": " an"
},
{
"id": 18586,
"logprob": -0.5708008,
"special": false,
"text": " optimization"
},
{
"id": 9464,
"logprob": -0.11401367,
"special": false,
"text": " algorithm"
},
{
"id": 1307,
"logprob": -0.7963867,
"special": false,
"text": " used"
},
{
"id": 298,
"logprob": -0.17028809,
"special": false,
"text": " to"
},
{
"id": 26518,
"logprob": -0.34326172,
"special": false,
"text": " minimize"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is an optimization algorithm used to minimize"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1824,
"logprob": -6.140625,
"text": "What"
},
{
"id": 349,
"logprob": -1.4658203,
"text": "is"
},
{
"id": 21135,
"logprob": -13.6796875,
"text": "gradient"
},
{
"id": 24871,
"logprob": -1.5898438,
"text": "descent"
},
{
"id": 28804,
"logprob": -0.3955078,
"text": "?"
},
{
"id": 13,
"logprob": -0.64501953,
"text": "\n"
},
{
"id": 13,
"logprob": -0.18493652,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 20910,
"logprob": -0.9580078,
"special": false,
"text": "Grad"
},
{
"id": 722,
"logprob": -0.0032176971,
"special": false,
"text": "ient"
},
{
"id": 24871,
"logprob": -0.16552734,
"special": false,
"text": " descent"
},
{
"id": 349,
"logprob": -0.08874512,
"special": false,
"text": " is"
},
{
"id": 396,
"logprob": -0.75878906,
"special": false,
"text": " an"
},
{
"id": 18586,
"logprob": -0.5703125,
"special": false,
"text": " optimization"
},
{
"id": 9464,
"logprob": -0.11236572,
"special": false,
"text": " algorithm"
},
{
"id": 1307,
"logprob": -0.79541016,
"special": false,
"text": " used"
},
{
"id": 298,
"logprob": -0.17102051,
"special": false,
"text": " to"
},
{
"id": 26518,
"logprob": -0.34326172,
"special": false,
"text": " minimize"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is an optimization algorithm used to minimize"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1824,
"logprob": -6.1328125,
"text": "What"
},
{
"id": 349,
"logprob": -1.4658203,
"text": "is"
},
{
"id": 21135,
"logprob": -13.6796875,
"text": "gradient"
},
{
"id": 24871,
"logprob": -1.5947266,
"text": "descent"
},
{
"id": 28804,
"logprob": -0.39648438,
"text": "?"
},
{
"id": 13,
"logprob": -0.6464844,
"text": "\n"
},
{
"id": 13,
"logprob": -0.18688965,
"text": "\n"
}
],
"seed": null,
"tokens": [
{
"id": 20910,
"logprob": -0.9609375,
"special": false,
"text": "Grad"
},
{
"id": 722,
"logprob": -0.003168106,
"special": false,
"text": "ient"
},
{
"id": 24871,
"logprob": -0.16601562,
"special": false,
"text": " descent"
},
{
"id": 349,
"logprob": -0.088134766,
"special": false,
"text": " is"
},
{
"id": 396,
"logprob": -0.7597656,
"special": false,
"text": " an"
},
{
"id": 18586,
"logprob": -0.5708008,
"special": false,
"text": " optimization"
},
{
"id": 9464,
"logprob": -0.11291504,
"special": false,
"text": " algorithm"
},
{
"id": 1307,
"logprob": -0.7944336,
"special": false,
"text": " used"
},
{
"id": 298,
"logprob": -0.17102051,
"special": false,
"text": " to"
},
{
"id": 26518,
"logprob": -0.34399414,
"special": false,
"text": " minimize"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is an optimization algorithm used to minimize"
}
]

View File

@ -16,7 +16,7 @@
},
{
"id": 100,
"logprob": -0.38549805,
"logprob": -0.38305664,
"text": "_"
},
{
@ -29,7 +29,7 @@
"tokens": [
{
"id": 2284,
"logprob": -0.31323242,
"logprob": -0.296875,
"special": false,
"text": "():"
},
@ -59,19 +59,19 @@
},
{
"id": 10914,
"logprob": -0.7817383,
"logprob": -0.7734375,
"special": false,
"text": " World"
},
{
"id": 16013,
"logprob": -0.6328125,
"logprob": -0.61816406,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": -0.0619812,
"logprob": -0.054870605,
"special": false,
"text": "\n"
},
@ -83,7 +83,7 @@
},
{
"id": 610,
"logprob": -0.4086914,
"logprob": -0.4152832,
"special": false,
"text": "def"
},
@ -113,7 +113,7 @@
},
{
"id": 444,
"logprob": -0.21826172,
"logprob": -0.21618652,
"special": false,
"text": "name"
},
@ -173,7 +173,7 @@
},
{
"id": 11571,
"logprob": -0.10021973,
"logprob": -0.08892822,
"special": false,
"text": "!\""
},

View File

@ -30,19 +30,19 @@
},
{
"id": 264,
"logprob": -0.37573242,
"logprob": -0.38061523,
"special": false,
"text": " a"
},
{
"id": 633,
"logprob": -0.09161377,
"logprob": -0.09301758,
"special": false,
"text": " new"
},
{
"id": 4480,
"logprob": -0.26171875,
"logprob": -0.26782227,
"special": false,
"text": " feature"
},
@ -78,7 +78,7 @@
},
{
"id": 13,
"logprob": 0.0,
"logprob": -0.10632324,
"special": false,
"text": "\n"
}

View File

@ -35,6 +35,6 @@ async def test_flash_llama_simple(flash_llama_chat, response_snapshot):
print(repr(response.choices[0].message.content))
assert (
response.choices[0].message.content
== "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to explore in the middle of urban confines. In fact, typical times for humidity levels in Brooklyn include:\n\n- Early morning: 80-85% humidity, with occas"
== "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to appreciate nature.\n\nIn terms of temperature, the warmest times of the year are from June to August, when average high temperatures typically range from around 73°F or 23°C"
)
assert response == response_snapshot

View File

@ -11,7 +11,7 @@ from text_generation.types import (
@pytest.fixture(scope="module")
def flash_llama_completion_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
) as handle:
yield handle
@ -34,16 +34,19 @@ def test_flash_llama_completion_single_prompt(
f"{flash_llama_completion.base_url}/v1/completions",
json={
"model": "tgi",
"prompt": "Say this is a test",
"max_tokens": 5,
"seed": 0,
"prompt": "What is Deep Learning?",
"max_tokens": 10,
"temperature": 0.0,
},
headers=flash_llama_completion.headers,
stream=False,
)
response = response.json()
assert len(response["choices"]) == 1
assert (
response["choices"][0]["text"]
== " A Beginners Guide\nDeep learning is a subset"
)
assert response == response_snapshot
@ -53,9 +56,15 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
f"{flash_llama_completion.base_url}/v1/completions",
json={
"model": "tgi",
"prompt": ["Say", "this", "is", "a"],
"prompt": [
"What is Deep Learning?",
"Is water wet?",
"What is the capital of France?",
"def mai",
],
"max_tokens": 10,
"seed": 0,
"temperature": 0.0,
},
headers=flash_llama_completion.headers,
stream=False,
@ -63,9 +72,16 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
response = response.json()
assert len(response["choices"]) == 4
all_indexes = [choice["index"] for choice in response["choices"]]
all_indexes = [(choice["index"], choice["text"]) for choice in response["choices"]]
all_indexes.sort()
assert all_indexes == [0, 1, 2, 3]
all_indices, all_strings = zip(*all_indexes)
assert list(all_indices) == [0, 1, 2, 3]
assert list(all_strings) == [
" A Beginners Guide\nDeep learning is a subset",
" This is a question that has puzzled many people for",
" Paris\nWhat is the capital of France?\nThe",
'usculas_minusculas(s):\n """\n',
]
assert response == response_snapshot
@ -77,19 +93,21 @@ async def test_flash_llama_completion_many_prompts_stream(
request = {
"model": "tgi",
"prompt": [
"What color is the sky?",
"What is Deep Learning?",
"Is water wet?",
"What is the capital of France?",
"def mai",
],
"max_tokens": 10,
"seed": 0,
"temperature": 0.0,
"stream": True,
}
url = f"{flash_llama_completion.base_url}/v1/completions"
chunks = []
strings = [""] * 4
async with ClientSession(headers=flash_llama_completion.headers) as session:
async with session.post(url, json=request) as response:
# iterate over the stream
@ -108,7 +126,15 @@ async def test_flash_llama_completion_many_prompts_stream(
for c in chunk:
chunks.append(Completion(**c))
assert "choices" in c
assert 0 <= c["choices"][0]["index"] <= 4
index = c["choices"][0]["index"]
assert 0 <= index <= 4
strings[index] += c["choices"][0]["text"]
assert response.status == 200
assert list(strings) == [
" A Beginners Guide\nDeep learning is a subset",
" This is a question that has puzzled many people for",
" Paris\nWhat is the capital of France?\nThe",
'usculas_minusculas(s):\n """\n',
]
assert chunks == response_snapshot

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,75 @@
import pytest
@pytest.fixture(scope="module")
def flash_mixtral_handle(launcher):
with launcher("mistralai/Mixtral-8x7B-v0.1", num_shard=8) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_mixtral(flash_mixtral_handle):
await flash_mixtral_handle.health(300)
return flash_mixtral_handle.client
@pytest.mark.skip(reason="requires > 4 shards")
@pytest.mark.asyncio
async def test_flash_mixtral(flash_mixtral, response_snapshot):
response = await flash_mixtral.generate(
"What is gradient descent?\n\n", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "Gradient descent is an optimization algorithm used to minimize"
)
assert response == response_snapshot
@pytest.mark.skip(reason="requires > 4 shards")
@pytest.mark.asyncio
async def test_flash_mixtral_all_params(flash_mixtral, response_snapshot):
response = await flash_mixtral.generate(
"What is gradient descent?\n\n",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "What is gradient descent?\n\nIt seems to me, that if you're"
)
assert response == response_snapshot
@pytest.mark.skip(reason="requires > 4 shards")
@pytest.mark.asyncio
async def test_flash_mixtral_load(flash_mixtral, generate_load, response_snapshot):
responses = await generate_load(
flash_mixtral, "What is gradient descent?\n\n", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert responses[0].details.generated_tokens == 10
assert (
responses[0].generated_text
== "Gradient descent is an optimization algorithm used to minimize"
)
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert responses == response_snapshot

View File

@ -36,6 +36,7 @@ tools = [
},
},
"required": ["location", "format"],
"additionalProperties": False,
},
},
},
@ -62,13 +63,13 @@ tools = [
},
},
"required": ["location", "format", "num_days"],
"additionalProperties": False,
},
},
},
]
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
@ -76,7 +77,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
max_tokens=100,
seed=1,
tools=tools,
presence_penalty=-1.1,
temperature=0.0,
messages=[
{
"role": "system",
@ -91,19 +92,18 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [
{
"id": 0,
"id": "0",
"type": "function",
"function": {
"description": None,
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"},
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
},
}
]
assert response == response_snapshot
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_auto(
@ -113,8 +113,8 @@ async def test_flash_llama_grammar_tools_auto(
max_tokens=100,
seed=1,
tools=tools,
temperature=0.0,
tool_choice="auto",
presence_penalty=-1.1,
messages=[
{
"role": "system",
@ -129,12 +129,12 @@ async def test_flash_llama_grammar_tools_auto(
assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [
{
"id": 0,
"id": "0",
"type": "function",
"function": {
"description": None,
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"},
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
},
}
]
@ -142,7 +142,6 @@ async def test_flash_llama_grammar_tools_auto(
assert response == response_snapshot
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_choice(
@ -152,8 +151,8 @@ async def test_flash_llama_grammar_tools_choice(
max_tokens=100,
seed=1,
tools=tools,
temperature=0.0,
tool_choice="get_current_weather",
presence_penalty=-1.1,
messages=[
{
"role": "system",
@ -168,12 +167,12 @@ async def test_flash_llama_grammar_tools_choice(
assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [
{
"id": 0,
"id": "0",
"type": "function",
"function": {
"description": None,
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"},
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
},
}
]
@ -181,7 +180,6 @@ async def test_flash_llama_grammar_tools_choice(
assert response == response_snapshot
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_stream(
@ -191,8 +189,8 @@ async def test_flash_llama_grammar_tools_stream(
max_tokens=100,
seed=1,
tools=tools,
temperature=0.0,
tool_choice="get_current_weather",
presence_penalty=-1.1,
messages=[
{
"role": "system",
@ -210,11 +208,10 @@ async def test_flash_llama_grammar_tools_stream(
async for response in responses:
count += 1
assert count == 38
assert count == 48
assert response == response_snapshot
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information(
@ -222,13 +219,13 @@ async def test_flash_llama_grammar_tools_insufficient_information(
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=8,
seed=24,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
},
{
"role": "user",
@ -239,18 +236,7 @@ async def test_flash_llama_grammar_tools_insufficient_information(
)
assert responses.choices[0].message.content is None
assert responses.choices[0].message.tool_calls == [
{
"function": {
"arguments": {
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
},
"description": None,
"name": "notify_error",
},
"id": 0,
"type": "function",
}
]
assert (
responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error"
)
assert responses == response_snapshot

File diff suppressed because it is too large Load Diff

View File

@ -6,9 +6,10 @@ authors = ["Nicolas Patry <nicolas@huggingface.co>"]
[tool.poetry.dependencies]
pydantic = "> 2, < 3"
python = ">=3.9,<3.13"
syrupy = "4.0.1"
python = ">=3.10,<3.13"
syrupy = "^4.7.1"
text-generation = "^0.6.0"
pytest = "^7.4.0"
pytest-asyncio = "^0.21.1"
docker = "^6.1.3"
docker = "^7"
numpy = "^1.20"

View File

@ -1,35 +1,35 @@
aiohttp==3.8.5 ; python_version >= "3.9" and python_version < "3.13"
aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13"
annotated-types==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.13"
attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
colored==1.4.4 ; python_version >= "3.9" and python_version < "3.13"
docker==6.1.3 ; python_version >= "3.9" and python_version < "3.13"
exceptiongroup==1.1.3 ; python_version >= "3.9" and python_version < "3.11"
filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13"
frozenlist==1.4.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
iniconfig==2.0.0 ; python_version >= "3.9" and python_version < "3.13"
multidict==6.0.4 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
pluggy==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
pydantic-core==2.16.3 ; python_version >= "3.9" and python_version < "3.13"
pydantic==2.6.4 ; python_version >= "3.9" and python_version < "3.13"
pytest-asyncio==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
pytest==7.4.0 ; python_version >= "3.9" and python_version < "3.13"
pywin32==306 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
syrupy==4.0.1 ; python_version >= "3.9" and python_version < "3.13"
text-generation==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
websocket-client==1.6.2 ; python_version >= "3.9" and python_version < "3.13"
yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13"
aiohappyeyeballs==2.4.0 ; python_version >= "3.10" and python_version < "3.13"
aiohttp==3.10.5 ; python_version >= "3.10" and python_version < "3.13"
aiosignal==1.3.1 ; python_version >= "3.10" and python_version < "3.13"
annotated-types==0.7.0 ; python_version >= "3.10" and python_version < "3.13"
async-timeout==4.0.3 ; python_version >= "3.10" and python_version < "3.11"
attrs==24.2.0 ; python_version >= "3.10" and python_version < "3.13"
certifi==2024.8.30 ; python_version >= "3.10" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.10" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.10" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
docker==7.1.0 ; python_version >= "3.10" and python_version < "3.13"
exceptiongroup==1.2.2 ; python_version >= "3.10" and python_version < "3.11"
filelock==3.16.0 ; python_version >= "3.10" and python_version < "3.13"
frozenlist==1.4.1 ; python_version >= "3.10" and python_version < "3.13"
fsspec==2024.9.0 ; python_version >= "3.10" and python_version < "3.13"
huggingface-hub==0.24.6 ; python_version >= "3.10" and python_version < "3.13"
idna==3.8 ; python_version >= "3.10" and python_version < "3.13"
iniconfig==2.0.0 ; python_version >= "3.10" and python_version < "3.13"
multidict==6.1.0 ; python_version >= "3.10" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.10" and python_version < "3.13"
packaging==24.1 ; python_version >= "3.10" and python_version < "3.13"
pluggy==1.5.0 ; python_version >= "3.10" and python_version < "3.13"
pydantic-core==2.23.3 ; python_version >= "3.10" and python_version < "3.13"
pydantic==2.9.1 ; python_version >= "3.10" and python_version < "3.13"
pytest-asyncio==0.21.2 ; python_version >= "3.10" and python_version < "3.13"
pytest==7.4.4 ; python_version >= "3.10" and python_version < "3.13"
pywin32==306 ; python_version >= "3.10" and python_version < "3.13" and sys_platform == "win32"
pyyaml==6.0.2 ; python_version >= "3.10" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.10" and python_version < "3.13"
syrupy==4.7.1 ; python_version >= "3.10" and python_version < "3.13"
text-generation==0.6.1 ; python_version >= "3.10" and python_version < "3.13"
tomli==2.0.1 ; python_version >= "3.10" and python_version < "3.11"
tqdm==4.66.5 ; python_version >= "3.10" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.10" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.10" and python_version < "3.13"
yarl==1.11.1 ; python_version >= "3.10" and python_version < "3.13"

View File

@ -8,7 +8,7 @@ use nix::unistd::Pid;
use serde::Deserialize;
use std::env;
use std::ffi::OsString;
use std::io::{BufRead, BufReader, Lines};
use std::io::{BufRead, BufReader};
use std::os::unix::process::{CommandExt, ExitStatusExt};
use std::path::Path;
use std::process::{Child, Command, ExitStatus, Stdio};
@ -18,12 +18,103 @@ use std::sync::{mpsc, Arc};
use std::thread;
use std::thread::sleep;
use std::time::{Duration, Instant};
use std::{fs, io};
use std::{
fs, io,
io::{Read, Write},
};
use thiserror::Error;
use tracing_subscriber::{filter::LevelFilter, EnvFilter};
mod env_runtime;
fn get_config(
model_id: &str,
revision: &Option<String>,
) -> Result<Config, Box<dyn std::error::Error>> {
let mut path = std::path::Path::new(model_id).to_path_buf();
let model_id = model_id.to_string();
let filename = if !path.exists() {
// Assume it's a hub id
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
// env variable has precedence over on file token.
ApiBuilder::new().with_token(Some(token)).build()?
} else {
Api::new()?
};
let repo = if let Some(ref revision) = revision {
api.repo(Repo::with_revision(
model_id,
RepoType::Model,
revision.to_string(),
))
} else {
api.model(model_id)
};
repo.get("config.json")?
} else {
path.push("config.json");
path
};
let content = std::fs::read_to_string(filename)?;
let config: RawConfig = serde_json::from_str(&content)?;
let config: Config = config.into();
Ok(config)
}
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
if let Some(config) = config {
if prefix_caching.is_none() {
if config.vision_config.is_some() {
tracing::info!("Disabling prefix caching because of VLM model");
prefix_caching = Some("0".to_string());
} else if config.is_encoder_decoder {
tracing::info!("Disabling prefix caching because of seq2seq model");
prefix_caching = Some("0".to_string());
}
}
match config.head_dim {
Some(h) if h == 64 || h == 128 || h == 256 => {
if lora_adapters.is_some() && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching because of lora adapters");
prefix_caching = Some("0".to_string());
}
match config.model_type.as_deref() {
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
// Required because gemma2 needs bfloat16 which is not supported by
// flashinfer ?
if attention.is_none() {
tracing::info!(
"Forcing flash decoding because model {} requires it",
config.model_type.as_ref().unwrap()
);
attention = Some("flashdecoding".to_string());
}
}
Some("t5") => {}
_ => {}
}
}
_ => {
if attention.is_none() {
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
attention = Some("flashdecoding".to_string());
}
if prefix_caching.is_none() {
prefix_caching = Some("0".to_string());
}
}
}
}
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
let attention = attention.unwrap_or("flashinfer".to_string());
(prefix_caching, attention)
}
#[derive(Deserialize)]
struct RawConfig {
max_position_embeddings: Option<usize>,
@ -31,6 +122,12 @@ struct RawConfig {
model_type: Option<String>,
max_seq_len: Option<usize>,
quantization_config: Option<QuantizationConfig>,
n_embd: Option<usize>,
hidden_size: Option<usize>,
num_attention_heads: Option<usize>,
head_dim: Option<usize>,
vision_config: Option<VisionConfig>,
is_encoder_decoder: Option<bool>,
}
#[derive(Deserialize)]
@ -38,10 +135,17 @@ struct QuantizationConfig {
quant_method: Option<Quantization>,
}
#[derive(Deserialize)]
struct VisionConfig {}
#[derive(Deserialize)]
struct Config {
max_position_embeddings: Option<usize>,
quantize: Option<Quantization>,
head_dim: Option<usize>,
model_type: Option<String>,
vision_config: Option<VisionConfig>,
is_encoder_decoder: bool,
}
impl From<RawConfig> for Config {
@ -51,9 +155,32 @@ impl From<RawConfig> for Config {
.or(other.max_seq_len)
.or(other.n_positions);
let quantize = other.quantization_config.and_then(|q| q.quant_method);
let head_dim = other.head_dim.or_else(|| {
match (other.hidden_size, other.n_embd, other.num_attention_heads) {
(Some(hidden_size), _, Some(num_attention_heads))
if hidden_size % num_attention_heads == 0 =>
{
Some(hidden_size / num_attention_heads)
}
// Legacy
(_, Some(hidden_size), Some(num_attention_heads))
if hidden_size % num_attention_heads == 0 =>
{
Some(hidden_size / num_attention_heads)
}
_ => None,
}
});
let model_type = other.model_type;
let vision_config = other.vision_config;
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
Config {
max_position_embeddings,
quantize,
head_dim,
model_type,
vision_config,
is_encoder_decoder,
}
}
}
@ -731,6 +858,7 @@ fn shard_manager(
.args(shard_args)
.env_clear()
.envs(envs)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.process_group(0)
@ -752,12 +880,13 @@ fn shard_manager(
};
// Redirect STDOUT to the console
let mut pstdin = p.stdin.take().unwrap();
let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
//stdout tracing thread
thread::spawn(move || {
log_lines(shard_stdout_reader.lines());
log_lines(shard_stdout_reader);
});
// We read stderr in another thread as it seems that lines() can block in some cases
let (err_sender, err_receiver) = mpsc::channel();
@ -766,6 +895,18 @@ fn shard_manager(
err_sender.send(line).unwrap_or(());
}
});
// We read stdin in another thread as it seems that lines() can block in some cases
thread::spawn(move || {
let mut stdin = io::stdin(); // We get `Stdin` here.
loop {
let mut buffer = vec![0; 4096];
if let Ok(n) = stdin.read(&mut buffer) {
if n > 0 {
let _ = pstdin.write_all(&buffer[..n]);
}
}
}
});
let mut ready = false;
let start_time = Instant::now();
@ -872,19 +1013,36 @@ impl PythonLogMessage {
}
}
impl TryFrom<&String> for PythonLogMessage {
impl TryFrom<&[u8]> for PythonLogMessage {
type Error = serde_json::Error;
fn try_from(value: &String) -> Result<Self, Self::Error> {
serde_json::from_str::<Self>(value)
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
serde_json::from_slice::<Self>(value)
}
}
fn log_lines<S: Sized + BufRead>(lines: Lines<S>) {
for line in lines.map_while(Result::ok) {
match PythonLogMessage::try_from(&line) {
fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
let mut buffer = vec![0u8; 8 * 4096];
let mut stdout = std::io::stdout();
loop {
let n = bufread.read(&mut buffer);
if let Ok(n) = n {
if n > 0 {
let mut lines = buffer[..n].split(|i| *i == b'\n').peekable();
while let Some(line) = lines.next() {
match PythonLogMessage::try_from(line) {
Ok(log) => log.trace(),
Err(_) => tracing::debug!("{line}"),
// For interactive debugging ?
Err(_) => {
stdout.write_all(line).unwrap();
if lines.peek().is_some() {
stdout.write_all(b"\n").unwrap();
}
stdout.flush().unwrap();
}
}
}
}
}
}
}
@ -1044,7 +1202,7 @@ fn download_convert_model(
let download_stdout = BufReader::new(download_process.stdout.take().unwrap());
thread::spawn(move || {
log_lines(download_stdout.lines());
log_lines(download_stdout);
});
let download_stderr = BufReader::new(download_process.stderr.take().unwrap());
@ -1439,47 +1597,12 @@ fn main() -> Result<(), LauncherError> {
tracing::info!("{:#?}", args);
let get_max_positions_quantize =
|| -> Result<(usize, Option<Quantization>), Box<dyn std::error::Error>> {
let model_id = args.model_id.clone();
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
let filename = if !path.exists() {
// Assume it's a hub id
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
// env variable has precedence over on file token.
ApiBuilder::new().with_token(Some(token)).build()?
} else {
Api::new()?
};
let repo = if let Some(ref revision) = args.revision {
api.repo(Repo::with_revision(
model_id,
RepoType::Model,
revision.to_string(),
))
} else {
api.model(model_id)
};
repo.get("config.json")?
} else {
path.push("config.json");
path
};
let content = std::fs::read_to_string(filename)?;
let config: RawConfig = serde_json::from_str(&content)?;
if config.model_type == Some("gemma2".to_string()) {
tracing::info!("Forcing flash decoding because of softcap usage");
std::env::set_var("ATTENTION", "flashdecoding");
}
let config: Config = config.into();
let quantize = config.quantize;
let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();
let quantize = config.as_ref().and_then(|c| c.quantize);
// Quantization usually means you're even more RAM constrained.
let max_default = 4096;
let max_position_embeddings = if let Some(config) = &config {
if let Some(max_position_embeddings) = config.max_position_embeddings {
if max_position_embeddings > max_default {
let max = max_position_embeddings;
@ -1489,18 +1612,20 @@ fn main() -> Result<(), LauncherError> {
{
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
}
Ok((max_default, quantize))
max_default
} else {
Ok((max_position_embeddings, quantize))
max_position_embeddings
}
} else {
Err(Box::new(LauncherError::ArgumentValidation(
"no max defined".to_string(),
)))
max_default
}
} else {
max_default
};
let (max_position_embeddings, quantize): (usize, Option<Quantization>) =
get_max_positions_quantize().unwrap_or((4096, None));
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
std::env::set_var("USE_PREFIX_CACHING", prefix_caching);
std::env::set_var("ATTENTION", attention);
let max_input_tokens = {
match (args.max_input_tokens, args.max_input_length) {
@ -1718,9 +1843,8 @@ fn main() -> Result<(), LauncherError> {
shutdown.clone(),
&shutdown_receiver,
)
.map_err(|err| {
.inspect_err(|_| {
shutdown_shards(shutdown.clone(), &shutdown_receiver);
err
})?;
// Default exit code

View File

@ -33,13 +33,13 @@ export function get_options() {
// rate: 20,
// timeUnit: '1s',
// },
load_test: {
executor: 'constant-arrival-rate',
duration: '60s',
preAllocatedVUs: 100,
rate: 1,
timeUnit: '1s',
},
// load_test: {
// executor: 'constant-arrival-rate',
// duration: '60s',
// preAllocatedVUs: 100,
// rate: 1,
// timeUnit: '1s',
// },
// breakpoint: {
// executor: 'ramping-arrival-rate', //Assure load increase if the system slows
// preAllocatedVUs: 300,
@ -47,12 +47,12 @@ export function get_options() {
// { duration: '60s', target: 100 }, // just slowly ramp-up to a HUGE load
// ],
// },
// throughput: {
// executor: 'shared-iterations',
// vus: 100,
// iterations: 200,
// maxDuration: '40s',
// },
throughput: {
executor: 'shared-iterations',
vus: 100,
iterations: 200,
maxDuration: '40s',
},
},
};
}

View File

@ -28,6 +28,9 @@ defaultCrateOverrides
];
};
};
pyo3-build-config = attrs: {
buildInputs = [ python3 ];
};
text-generation-benchmark = attrs: {
src = filter {
root = ../benchmark;

View File

@ -26,8 +26,10 @@
opentelemetry-instrumentation-grpc,
opentelemetry-semantic-conventions,
peft,
punica-kernels,
safetensors,
tokenizers,
torch,
sentencepiece,
transformers,
typer,
@ -91,6 +93,7 @@ buildPythonPackage {
opentelemetry-instrumentation-grpc
opentelemetry-semantic-conventions
peft
punica-kernels
safetensors
sentencepiece
tokenizers

View File

@ -137,6 +137,8 @@ message Request {
optional string adapter_id = 11;
/// Prefix length that can be retrieved from the KV cache.
uint32 prefix_len = 12;
/// Context truncation
bool add_special_tokens = 13;
}
message Batch {

View File

@ -46,8 +46,8 @@ ngrok = { version = "0.13.1", features = ["axum"], optional = true }
init-tracing-opentelemetry = { version = "0.14.1", features = [
"opentelemetry-otlp",
] }
minijinja = { version = "2.0.2" }
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
minijinja = { workspace = true }
minijinja-contrib = { workspace = true }
futures-util = "0.3.30"
regex = "1.10.3"
once_cell = "1.19.0"
@ -61,6 +61,7 @@ uuid = { version = "1.9.1", default-features = false, features = [
] }
csv = "1.3.0"
ureq = "=2.9"
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
[build-dependencies]

View File

@ -1,11 +1,8 @@
use std::collections::HashSet;
use crate::infer::InferError;
use crate::{
ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken,
};
use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool};
use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat;
use std::collections::HashSet;
/// Raise a exception (custom function) used in the chat templates
pub(crate) fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
@ -32,6 +29,7 @@ impl ChatTemplate {
env.set_unknown_method_callback(pycompat::unknown_method_callback);
let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception);
tracing::debug!("Loading template: {:#?}", template_str);
// leaking env and template_str as read-only, static resources for performance.
let template = Box::leak(env)
@ -42,6 +40,7 @@ impl ChatTemplate {
let variables = template.undeclared_variables(true);
// check if the `tools` variable is used in the template
let use_default_tool_template = !variables.contains("tools");
tracing::debug!("Use default tool template: {}", use_default_tool_template);
Self {
template,
@ -56,25 +55,36 @@ impl ChatTemplate {
&self,
guideline: Option<&str>,
mut messages: Vec<Message>,
grammar_with_prompt: Option<(GrammarType, String)>,
tools_and_prompt: Option<(Vec<Tool>, String)>,
) -> Result<String, InferError> {
if self.use_default_tool_template {
if let Some(last_message) = messages.last_mut() {
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
last_message.content.push(MessageChunk::Text {
text: format!("\n---\n{}\n{}", tool_prompt, tools),
});
}
}
}
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
// check if guideline is expected but not provided
if self.variables.contains("guideline") && guideline.is_none() {
return Err(InferError::MissingTemplateVariable("guideline".to_string()));
}
let tools = match tools_and_prompt {
Some((tools, tool_prompt)) => {
// check if the `tools` variable is used in the template
// if not, we need to append the tools to the last message
let text = if self.use_default_tool_template {
match serde_json::to_string(&tools) {
Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt),
Err(e) => return Err(InferError::ToolError(e.to_string())),
}
} else {
// if the `tools` variable is used in the template, we just append the tool_prompt
format!("\n---\n{}", tool_prompt)
};
if let Some(last_message) = messages.last_mut() {
last_message.content.push(MessageChunk::Text { text });
}
Some(tools)
}
None => None,
};
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
self.template
.render(ChatTemplateInputs {
guideline,
@ -82,8 +92,7 @@ impl ChatTemplate {
bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(),
add_generation_prompt: true,
tools: None,
tools_prompt: None,
tools,
})
.map_err(InferError::TemplateError)
}
@ -95,7 +104,7 @@ mod tests {
use crate::infer::chat_template::raise_exception;
use crate::infer::ChatTemplate;
use crate::{
ChatTemplateInputs, GrammarType, Message, MessageContent, TextMessage, TokenizerConfigToken,
ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool,
};
use minijinja::Environment;
@ -854,11 +863,46 @@ mod tests {
content: MessageContent::SingleText("Just testing".to_string()),
},
];
let tools = serde_json::json!("[]");
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
let tool_prompt = "This default prompt will be used".to_string();
let grammer_with_prompt = (GrammarType::Json(tools), tool_prompt);
let result = ct.apply(None, msgs, Some(grammer_with_prompt));
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\nThis default prompt will be used\n\"[]\" [/INST]".to_string();
let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(None, msgs, tools_and_prompt);
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string();
assert_eq!(result.unwrap(), expected);
}
#[test]
fn test_chat_template_with_custom_tool_template() {
// chat template from meta-llama/Meta-Llama-3.1-8B-Instruct
let ct = ChatTemplate::new(
"{{- bos_token }}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n".to_string(),
Some(TokenizerConfigToken::String("<s>".to_string())),
Some(TokenizerConfigToken::String("</s>".to_string())),
);
let msgs: Vec<Message> = vec![
Message {
name: None,
role: "system".to_string(),
content: MessageContent::SingleText(
"Youre a helpful assistant! Answer the users question best you can."
.to_string(),
),
},
Message {
name: None,
role: "user".to_string(),
content: MessageContent::SingleText(
"What is the weather like in Brooklyn, New York?".to_string(),
),
},
];
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(None, msgs, tools_and_prompt);
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n },\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
assert_eq!(result.unwrap(), expected);
}
}

View File

@ -3,7 +3,7 @@ mod chat_template;
pub mod tool_grammar;
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
use crate::GrammarType;
use crate::Tool;
use crate::{
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
Message, PrefillToken, Token,
@ -120,10 +120,11 @@ impl Infer {
) -> Result<Option<tokenizers::Encoding>, InferError> {
// Tokenize request
let inputs = request.inputs;
let add_special_tokens = request.add_special_tokens;
let truncate = request.parameters.truncate;
let encoding = self
.validation
.tokenize(inputs, truncate)
.tokenize(inputs, add_special_tokens, truncate)
.await
.map_err(|err| {
tracing::error!("Tokenization {err}");
@ -140,12 +141,12 @@ impl Infer {
&self,
guideline: Option<String>,
messages: Vec<Message>,
grammar_with_prompt: Option<(GrammarType, String)>,
tools_and_prompt: Option<(Vec<Tool>, String)>,
) -> Result<String, InferError> {
self.chat_template
.as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.apply(guideline.as_deref(), messages, grammar_with_prompt)
.apply(guideline.as_deref(), messages, tools_and_prompt)
.map_err(|e| {
metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
tracing::error!("{e}");
@ -335,6 +336,8 @@ pub enum InferError {
ValidationError(#[from] ValidationError),
#[error("Incomplete generation")]
IncompleteGeneration,
#[error("Incomplete generation stream")]
IncompleteGenerationStream,
#[error("Template error: {0}")]
TemplateError(#[from] minijinja::Error),
#[error("Missing template vatiable: {0}")]
@ -350,6 +353,7 @@ impl InferError {
InferError::Overloaded(_) => "overloaded",
InferError::ValidationError(_) => "validation",
InferError::IncompleteGeneration => "incomplete_generation",
InferError::IncompleteGenerationStream => "incomplete_generation_stream",
InferError::TemplateError(_) => "template_error",
InferError::MissingTemplateVariable(_) => "missing_template_variable",
InferError::ToolError(_) => "tool_error",

View File

@ -1,5 +1,8 @@
use crate::infer::InferError;
use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools};
use crate::{
FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice,
ToolType,
};
use serde_json::{json, Map, Value};
use std::collections::HashMap;
@ -16,17 +19,38 @@ impl ToolGrammar {
}
pub fn apply(
tools: Option<Vec<Tool>>,
tools: Vec<Tool>,
tool_choice: ToolChoice,
) -> Result<Option<Tools>, InferError> {
) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> {
// if no tools are provided, we return None
let tools = match tools {
Some(tools) if !tools.is_empty() => tools,
_ => return Ok(None),
};
if tools.is_empty() {
return Ok((tools, None));
}
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
let mut tools = tools.clone();
// add the notify_error function to the tools
let notify_error = Tool {
r#type: "function".to_string(),
function: FunctionDefinition {
name: "notify_error".to_string(),
description: Some("Notify an error or issue".to_string()),
arguments: json!({
"type": "object",
"properties": {
"error": {
"type": "string",
"description": "The error or issue to notify"
}
},
"required": ["error"]
}),
},
};
tools.push(notify_error);
// if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
@ -35,87 +59,57 @@ impl ToolGrammar {
ToolType::Function { function } => {
vec![Self::find_tool_by_name(&tools, &function.name)?]
}
ToolType::OneOf => tools,
ToolType::NoTool => return Ok(None),
ToolType::OneOf => tools.clone(),
ToolType::NoTool => return Ok((tools, None)),
};
// adds the error notification function for LLM feedback if required
let mut text_response_properties = Map::new();
text_response_properties.insert(
"error".to_string(),
serde_json::json!({
"type": "string",
"description": "The error or issue to notify"
}),
);
text_response_properties.insert(
"_name".to_string(),
serde_json::json!({
"type": "string",
"const": "notify_error"
}),
);
let functions: HashMap<String, serde_json::Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();
// Clone the existing parameters, which are expected to be a JSON object
let mut params = if let Value::Object(params) = &func.arguments {
params.clone()
} else {
Map::new()
};
let mut params = Map::new();
// Insert the function's description at the top level, outside of properties
params.insert(
"description".to_string(),
Value::String(func.description.clone().unwrap_or_default()),
Value::String(func.description.unwrap_or_default()),
);
// Ensure 'properties' exists and is an object
let properties = params
.entry("properties".to_string())
.or_insert_with(|| json!({}))
.as_object_mut()
.unwrap();
let mut properties = Map::new();
let mut required = vec![Value::String("_name".to_string())];
// Insert the constant for the function name inside 'properties'
properties.insert(
"_name".to_string(),
json!({
"type": "string",
"const": func.name.clone(),
// "description": "The name of the function"
}),
);
// Check if 'required' exists, and it is an array. If not, create an empty array.
let required = params
.entry("required".to_string())
.or_insert_with(|| json!([]))
.as_array_mut()
.unwrap();
// Add 'name' to the 'required' array if it is not already present
if !required.iter().any(|r| r == "_name") {
required.push(json!("_name"));
if let Value::Object(args) = func.arguments {
if let Some(Value::Object(props)) = args.get("properties") {
properties.extend(props.clone());
}
if let Some(Value::Array(reqs)) = args.get("required") {
required.extend(reqs.clone());
}
params.insert(
"additionalProperties".to_string(),
Value::Bool(
args.get("additionalProperties").and_then(|v| v.as_str())
== Some("true"),
),
);
}
params.insert("properties".to_string(), Value::Object(properties));
params.insert("required".to_string(), Value::Array(required));
(func.name, Value::Object(params))
})
.chain([(
"notify_error".to_string(),
serde_json::json!({
"properties": text_response_properties,
"required": ["error", "_name"],
"type": "object"
}),
)])
.collect();
let tools = Tools {
let tool_schema = JsonSchemaTool {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
@ -123,13 +117,10 @@ impl ToolGrammar {
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.chain(std::iter::once(FunctionRef {
ref_path: "#/$functions/notify_error".to_string(),
}))
.collect(),
},
};
Ok(Some(tools))
Ok((tools, Some(tool_schema)))
}
}

View File

@ -22,6 +22,16 @@ pub enum Attention {
FlashInfer,
}
impl Attention {
pub fn block_size(&self) -> u32 {
match self {
Attention::FlashDecoding => 256,
Attention::FlashInfer => 1,
Attention::Paged => 16,
}
}
}
#[derive(Debug)]
pub struct ParseError;
@ -45,13 +55,20 @@ impl std::str::FromStr for Attention {
}
#[derive(Clone, Deserialize, ToSchema)]
pub(crate) struct VertexInstance {
pub(crate) struct GenerateVertexInstance {
#[schema(example = "What is Deep Learning?")]
pub inputs: String,
#[schema(nullable = true, default = "null", example = "null")]
pub parameters: Option<GenerateParameters>,
}
#[derive(Clone, Deserialize, ToSchema)]
#[serde(untagged)]
enum VertexInstance {
Generate(GenerateVertexInstance),
Chat(ChatRequest),
}
#[derive(Deserialize, ToSchema)]
pub(crate) struct VertexRequest {
#[serde(rename = "instances")]
@ -840,10 +857,10 @@ pub(crate) struct ChatRequest {
pub tools: Option<Vec<Tool>>,
/// A prompt to be appended before the tools
#[serde(default = "default_tool_prompt")]
#[serde(default)]
#[schema(
nullable = true,
example = "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\""
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
)]
pub tool_prompt: Option<String>,
@ -865,10 +882,8 @@ pub(crate) struct ChatRequest {
pub guideline: Option<String>,
}
fn default_tool_prompt() -> Option<String> {
Some(
"\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(),
)
pub fn default_tool_prompt() -> String {
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
@ -910,7 +925,7 @@ impl From<ToolTypeDeserializer> for ToolChoice {
}
#[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)]
pub struct Tools {
pub struct JsonSchemaTool {
#[serde(flatten)]
functions_map: FunctionsMap,
properties: Properties,
@ -968,8 +983,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
bos_token: Option<&'a str>,
eos_token: Option<&'a str>,
add_generation_prompt: bool,
tools: Option<&'a str>,
tools_prompt: Option<&'a str>,
tools: Option<Vec<Tool>>,
guideline: Option<&'a str>,
}
@ -1075,6 +1089,16 @@ pub(crate) struct GenerateRequest {
pub inputs: String,
#[serde(default = "default_parameters")]
pub parameters: GenerateParameters,
/// This is used internally because some requests
/// already contain the templated input therefore
/// we shouldn't add the special tokens.
#[serde(default = "default_true", skip)]
pub add_special_tokens: bool,
}
fn default_true() -> bool {
true
}
#[derive(Clone, Debug, Deserialize, ToSchema)]
@ -1092,6 +1116,7 @@ impl From<CompatGenerateRequest> for GenerateRequest {
fn from(req: CompatGenerateRequest) -> Self {
Self {
inputs: req.inputs,
add_special_tokens: true,
parameters: req.parameters,
}
}
@ -1243,6 +1268,34 @@ pub(crate) struct ErrorResponse {
pub error_type: String,
}
#[derive(Serialize, Deserialize, ToSchema)]
pub(crate) struct ModelInfo {
#[schema(example = "gpt2")]
pub id: String,
#[schema(example = "model")]
pub object: String,
#[schema(example = 1686935002)]
pub created: u64,
#[schema(example = "openai")]
pub owned_by: String,
}
#[derive(Serialize, Deserialize, ToSchema)]
pub(crate) struct ModelsInfo {
#[schema(example = "list")]
pub object: String,
pub data: Vec<ModelInfo>,
}
impl Default for ModelsInfo {
fn default() -> Self {
ModelsInfo {
object: "list".to_string(),
data: Vec::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -8,7 +8,7 @@ use crate::kserve::{
kserve_model_metadata, kserve_model_metadata_ready,
};
use crate::validation::ValidationError;
use crate::ChatTokenizeResponse;
use crate::{default_tool_prompt, ChatTokenizeResponse, VertexInstance};
use crate::{
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
@ -23,7 +23,8 @@ use crate::{
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
VertexResponse,
};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
use crate::{ModelInfo, ModelsInfo};
use async_stream::__private::AsyncStream;
use axum::extract::Extension;
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
@ -40,6 +41,7 @@ use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Cache, Repo, RepoType};
use http::header::AUTHORIZATION;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use pyo3::types::IntoPyDict;
use serde_json::Value;
use std::convert::Infallible;
use std::fs::File;
@ -47,7 +49,6 @@ use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf};
use thiserror::Error;
use tokenizers::processors::template::TemplateProcessing;
use tokenizers::Tokenizer;
use tokio::select;
use tokio::signal;
@ -116,6 +117,29 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
Json(info.0)
}
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/v1/models",
responses(
(status = 200, description = "Served model info", body = ModelInfo),
(status = 404, description = "Model not found", body = ErrorResponse),
)
)]
#[instrument(skip(info))]
/// Get model info
async fn openai_get_model_info(info: Extension<Info>) -> Json<ModelsInfo> {
Json(ModelsInfo {
data: vec![ModelInfo {
id: info.0.model_id.clone(),
object: "model".to_string(),
created: 0, // TODO: determine how to get this
owned_by: info.0.model_id.clone(),
}],
..Default::default()
})
}
#[utoipa::path(
post,
tag = "Text Generation Inference",
@ -146,7 +170,7 @@ async fn get_chat_tokenize(
} = req;
let tool_prompt = tool_prompt.unwrap_or_default();
let (inputs, _grammar, _tool_grammar) = prepare_chat_input(
let (inputs, _grammar, _using_tools) = prepare_chat_input(
&infer,
response_format,
tools,
@ -158,6 +182,7 @@ async fn get_chat_tokenize(
let generate_request = GenerateRequest {
inputs,
add_special_tokens: false,
parameters: GenerateParameters {
best_of: None,
temperature,
@ -293,7 +318,10 @@ pub(crate) async fn generate_internal(
metrics::counter!("tgi_request_count").increment(1);
// Do not long ultra long inputs, like image payloads.
tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]);
tracing::debug!(
"Input: {}",
&req.inputs.chars().take(1000).collect::<String>()
);
let compute_characters = req.inputs.chars().count();
let mut add_prompt = None;
@ -649,7 +677,7 @@ async fn generate_stream_internal(
// Check if generation reached the end
// Skip if we already sent an error
if !end_reached && !error {
let err = InferError::IncompleteGeneration;
let err = InferError::IncompleteGenerationStream;
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
tracing::error!("{err}");
yield Ok(Event::from(err));
@ -754,6 +782,7 @@ async fn completions(
.iter()
.map(|prompt| GenerateRequest {
inputs: prompt.to_string(),
add_special_tokens: true,
parameters: GenerateParameters {
best_of: None,
temperature,
@ -1158,14 +1187,16 @@ async fn chat_completions(
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
let max_new_tokens = max_tokens.or(Some(100));
let logprobs = logprobs.unwrap_or(false);
let tool_prompt = tool_prompt.unwrap_or_default();
let tool_prompt = tool_prompt
.filter(|s| !s.is_empty())
.unwrap_or_else(default_tool_prompt);
let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature {
Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other),
};
let (inputs, grammar, tool_grammar) = prepare_chat_input(
let (inputs, grammar, using_tools) = prepare_chat_input(
&infer,
response_format,
tools,
@ -1178,6 +1209,7 @@ async fn chat_completions(
// build the request passing some parameters
let generate_request = GenerateRequest {
inputs: inputs.to_string(),
add_special_tokens: false,
parameters: GenerateParameters {
best_of: None,
temperature,
@ -1221,7 +1253,7 @@ async fn chat_completions(
});
// replace the content with the tool calls if grammar is present
let (content, tool_calls) = if tool_grammar.is_some() {
let (content, tool_calls) = if using_tools {
(None, Some(vec![stream_token.token.text]))
} else {
let content = if !stream_token.token.special {
@ -1275,7 +1307,7 @@ async fn chat_completions(
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let (tool_calls, output) = if tool_grammar.is_some() {
let (tool_calls, output) = if using_tools {
let gen_text_value: Value =
serde_json::from_str(&generation.generated_text).map_err(|e| {
InferError::ToolError(format!(
@ -1377,13 +1409,14 @@ async fn vertex_compatibility(
));
}
// Process all instances
let predictions = req
.instances
.iter()
.map(|instance| {
let generate_request = GenerateRequest {
// Prepare futures for all instances
let mut futures = Vec::with_capacity(req.instances.len());
for instance in req.instances.iter() {
let generate_request = match instance {
VertexInstance::Generate(instance) => GenerateRequest {
inputs: instance.inputs.clone(),
add_special_tokens: true,
parameters: GenerateParameters {
do_sample: true,
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
@ -1392,14 +1425,98 @@ async fn vertex_compatibility(
decoder_input_details: true,
..Default::default()
},
},
VertexInstance::Chat(instance) => {
let ChatRequest {
model,
max_tokens,
messages,
seed,
stop,
stream,
tools,
tool_choice,
tool_prompt,
temperature,
response_format,
guideline,
presence_penalty,
frequency_penalty,
top_p,
top_logprobs,
..
} = instance.clone();
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
let max_new_tokens = max_tokens.or(Some(100));
let tool_prompt = tool_prompt
.filter(|s| !s.is_empty())
.unwrap_or_else(default_tool_prompt);
let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature {
Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other),
};
let (inputs, grammar, _using_tools) = match prepare_chat_input(
&infer,
response_format,
tools,
tool_choice,
&tool_prompt,
guideline,
messages,
) {
Ok(result) => result,
Err(e) => {
return Err((
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: format!("Failed to prepare chat input: {}", e),
error_type: "Input preparation error".to_string(),
}),
));
}
};
async {
GenerateRequest {
inputs: inputs.to_string(),
add_special_tokens: false,
parameters: GenerateParameters {
best_of: None,
temperature,
repetition_penalty,
frequency_penalty,
top_k: None,
top_p,
typical_p: None,
do_sample,
max_new_tokens,
return_full_text: None,
stop,
truncate: None,
watermark: false,
details: true,
decoder_input_details: !stream,
seed,
top_n_tokens: top_logprobs,
grammar,
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
},
}
}
};
let infer_clone = infer.clone();
let compute_type_clone = compute_type.clone();
let span_clone = span.clone();
futures.push(async move {
generate_internal(
Extension(infer.clone()),
compute_type.clone(),
Extension(infer_clone),
compute_type_clone,
Json(generate_request),
span.clone(),
span_clone,
)
.await
.map(|(_, Json(generation))| generation.generated_text)
@ -1412,11 +1529,13 @@ async fn vertex_compatibility(
}),
)
})
});
}
})
.collect::<FuturesUnordered<_>>()
.try_collect::<Vec<_>>()
.await?;
// execute all futures in parallel, collect results, returning early if any error occurs
let results = futures::future::join_all(futures).await;
let predictions: Result<Vec<_>, _> = results.into_iter().collect();
let predictions = predictions?;
let response = VertexResponse { predictions };
Ok((HeaderMap::new(), Json(response)).into_response())
@ -1499,6 +1618,7 @@ chat_completions,
completions,
tokenize,
metrics,
openai_get_model_info,
),
components(
schemas(
@ -1551,6 +1671,7 @@ ToolCall,
Function,
FunctionDefinition,
ToolChoice,
ModelInfo,
)
),
tags(
@ -1739,18 +1860,34 @@ pub async fn run(
});
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
let mut tokenizer = Tokenizer::from_file(filename).ok();
if let Some(tokenizer) = &mut tokenizer {
if let Some(class) = &tokenizer_config.tokenizer_class {
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
tokenizer.with_post_processor(post_processor);
}
}
}
}
tokenizer
use pyo3::prelude::*;
let convert = pyo3::Python::with_gil(|py| -> PyResult<()> {
let transformers = py.import_bound("transformers")?;
let auto = transformers.getattr("AutoTokenizer")?;
let from_pretrained = auto.getattr("from_pretrained")?;
let args = (tokenizer_name.to_string(),);
let kwargs = [(
"revision",
revision.clone().unwrap_or_else(|| "main".to_string()),
)]
.into_py_dict_bound(py);
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
let save = tokenizer.getattr("save_pretrained")?;
let args = ("out".to_string(),);
save.call1(args)?;
Ok(())
})
.inspect_err(|err| {
tracing::error!("Failed to import python tokenizer {err}");
});
let filename = if convert.is_ok() {
// If we have correctly loaded and resaved with transformers
// We might have modified the tokenizer.json according to transformers
"out/tokenizer.json".into()
} else {
filename
};
Tokenizer::from_file(filename).ok()
});
let config: Option<Config> = config_filename.and_then(|filename| {
@ -2244,7 +2381,8 @@ async fn start(
.route("/info", get(get_model_info))
.route("/health", get(health))
.route("/ping", get(health))
.route("/metrics", get(metrics));
.route("/metrics", get(metrics))
.route("/v1/models", get(openai_get_model_info));
// Conditional AWS Sagemaker route
let aws_sagemaker_route = if messages_api_enabled {
@ -2436,6 +2574,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
InferError::IncompleteGenerationStream => StatusCode::INTERNAL_SERVER_ERROR,
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
@ -2468,78 +2607,7 @@ pub enum WebServerError {
Axum(#[from] axum::BoxError),
}
/// Create a post_processor for the LlamaTokenizer
fn create_post_processor(
tokenizer: &Tokenizer,
tokenizer_config: &HubTokenizerConfig,
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);
let bos_token = tokenizer_config.bos_token.as_ref();
let eos_token = tokenizer_config.eos_token.as_ref();
if add_bos_token && bos_token.is_none() {
panic!("add_bos_token = true but bos_token is None");
}
if add_eos_token && eos_token.is_none() {
panic!("add_eos_token = true but eos_token is None");
}
let mut single = Vec::new();
let mut pair = Vec::new();
let mut special_tokens = Vec::new();
if add_bos_token {
if let Some(bos) = bos_token {
let bos_token_id = tokenizer
.token_to_id(bos.as_str())
.expect("Should have found the bos token id");
special_tokens.push((bos.as_str(), bos_token_id));
single.push(format!("{}:0", bos.as_str()));
pair.push(format!("{}:0", bos.as_str()));
}
}
single.push("$A:0".to_string());
pair.push("$A:0".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
let eos_token_id = tokenizer
.token_to_id(eos.as_str())
.expect("Should have found the eos token id");
special_tokens.push((eos.as_str(), eos_token_id));
single.push(format!("{}:0", eos.as_str()));
pair.push(format!("{}:0", eos.as_str()));
}
}
if add_bos_token {
if let Some(bos) = bos_token {
pair.push(format!("{}:1", bos.as_str()));
}
}
pair.push("$B:1".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
pair.push(format!("{}:1", eos.as_str()));
}
}
let post_processor = TemplateProcessing::builder()
.try_single(single)?
.try_pair(pair)?
.special_tokens(special_tokens)
.build()?;
Ok(post_processor)
}
type PreparedInput = (String, Option<GrammarType>, Option<Tools>);
type PreparedInput = (String, Option<GrammarType>, bool);
fn prepare_chat_input(
infer: &Infer,
@ -2556,19 +2624,139 @@ fn prepare_chat_input(
));
}
// when response_format is set, tools are not included when applying the chat template to generate inputs
if let Some(format) = response_format {
let inputs = infer.apply_chat_template(guideline, messages, None)?;
return Ok((inputs, Some(format), None));
return Ok((inputs, Some(format), false));
}
// if tools are set, apply the tool grammar and then the chat template
let tool_grammar: Option<Tools> = ToolGrammar::apply(tools, tool_choice)?;
let grammar = tool_grammar
// when no response_format is set and tools are included, apply the chat template with the tools
// to generate inputs
if let Some(tools) = tools {
let (updated_tools, tool_schema) = ToolGrammar::apply(tools, tool_choice)?;
let grammar = tool_schema
.as_ref()
.map(|t| GrammarType::Json(serde_json::json!(t)));
let tools_grammar_prompt = tool_grammar
.as_ref()
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into()));
let inputs = infer.apply_chat_template(guideline, messages, tools_grammar_prompt)?;
Ok((inputs, grammar, tool_grammar))
let inputs: String = infer.apply_chat_template(
guideline,
messages,
Some((updated_tools, tool_prompt.into())),
)?;
return Ok((inputs, grammar, tool_schema.is_some()));
}
// if no response_format or tools are set simply apply the chat template to generate inputs
let inputs = infer.apply_chat_template(guideline, messages, None)?;
Ok((inputs, None, false))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ChatTemplateVersions;
use crate::HubTokenizerConfig;
use crate::TokenizerConfigToken;
use crate::Tool;
use serde_json::json;
#[test]
fn test_prepare_chat_input() {
// Mock Backend to avoid network requests
struct MockBackend;
impl Backend for MockBackend {
fn schedule(
&self,
_request: crate::validation::ValidGenerateRequest,
) -> Result<
tokio_stream::wrappers::UnboundedReceiverStream<
Result<InferStreamResponse, InferError>,
>,
InferError,
> {
unimplemented!("Never called in this test");
}
fn health<'a, 'async_trait>(
&'a self,
_current_health: bool,
) -> core::pin::Pin<
Box<dyn core::future::Future<Output = bool> + core::marker::Send + 'async_trait>,
>
where
'a: 'async_trait,
Self: 'async_trait,
{
unimplemented!("Never called in this test");
}
}
let backend = MockBackend {};
let mut tokenizer_config = HubTokenizerConfig::default();
// mock tokenizer config values
tokenizer_config.bos_token = Some(TokenizerConfigToken::String("<s>".to_string()));
tokenizer_config.eos_token = Some(TokenizerConfigToken::String("</s>".to_string()));
tokenizer_config.chat_template = Some(
ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string())
);
let infer = Infer::new(
backend,
Validation::new(1, None, None, None, 1, 1, 1, 1, 1, false),
1,
tokenizer_config,
HubProcessorConfig::default(),
);
let response_format = None;
let tools = Some(vec![Tool {
r#type: "function".to_string(),
function: FunctionDefinition {
name: "get_current_weather".to_string(),
description: Some("Get the current weather".to_string()),
arguments: json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location."
}
},
"required": ["location", "format"]
}),
},
}]);
let tool_prompt = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.";
let guideline = None;
let messages = vec![Message {
name: None,
role: "user".to_string(),
content: MessageContent::SingleText(
"What is the weather like in New York?".to_string(),
),
}];
let result = prepare_chat_input(
&infer,
response_format,
tools,
ToolChoice(None),
tool_prompt,
guideline,
messages,
);
assert!(result.is_ok());
let (inputs, _grammar, using_tools) = result.unwrap();
assert_eq!(using_tools, true);
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
}
}

View File

@ -95,6 +95,7 @@ impl Validation {
pub async fn tokenize(
&self,
inputs: String,
add_special_tokens: bool,
truncate: Option<usize>,
) -> Result<Option<(tokenizers::Encoding, Vec<Chunk>)>, ValidationError> {
// If we have a fast tokenizer
@ -104,7 +105,11 @@ impl Validation {
// Send request to the background validation task
// Unwrap is safe here
sender
.send(((inputs, truncate), response_sender, Span::current()))
.send((
(inputs, add_special_tokens, truncate),
response_sender,
Span::current(),
))
.unwrap();
// Await on response channel
@ -121,11 +126,15 @@ impl Validation {
async fn validate_input(
&self,
inputs: String,
add_special_tokens: bool,
truncate: Option<usize>,
max_new_tokens: Option<u32>,
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
// If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
if let Some((encoding, inputs)) = self
.tokenize(inputs.clone(), add_special_tokens, truncate)
.await?
{
// Create response channel
let input_length = if let Some(truncate) = truncate {
std::cmp::min(encoding.len(), truncate)
@ -158,7 +167,8 @@ impl Validation {
));
}
let input_ids = encoding.get_ids()[..input_length].to_owned();
let ids = encoding.get_ids();
let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned();
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
Ok((inputs, Some(input_ids), input_length, max_new_tokens))
@ -324,7 +334,12 @@ impl Validation {
// Validate inputs
let (inputs, input_ids, input_length, max_new_tokens) = self
.validate_input(request.inputs, truncate, max_new_tokens)
.validate_input(
request.inputs,
request.add_special_tokens,
truncate,
max_new_tokens,
)
.await?;
// TODO: we should build the FSM here and pass the compiled FSM instead of the grammar
@ -401,6 +416,7 @@ impl Validation {
Ok(ValidGenerateRequest {
inputs,
input_ids: input_ids.map(Arc::new),
add_special_tokens: request.add_special_tokens,
decoder_input_details,
input_length: input_length as u32,
truncate: truncate.unwrap_or(self.max_input_length) as u32,
@ -449,12 +465,15 @@ fn tokenizer_worker(
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
) {
// Loop over requests
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
receiver.blocking_recv()
{
parent_span.in_scope(|| {
response_tx
.send(prepare_input(
inputs,
truncate,
add_special_tokens,
&tokenizer,
config.as_ref(),
preprocessor_config.as_ref(),
@ -591,6 +610,7 @@ fn image_tokens_fixup(config: &Config, text: String) -> String {
fn prepare_input(
inputs: String,
_truncate: Option<usize>,
add_special_tokens: bool,
tokenizer: &Tokenizer,
config: Option<&Config>,
preprocessor_config: Option<&HubPreprocessorConfig>,
@ -628,14 +648,14 @@ fn prepare_input(
// Get the number of tokens in the input
let encoding = tokenizer
.encode(tokenizer_query, true)
.encode(tokenizer_query, add_special_tokens)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
Ok((encoding, input_chunks))
}
type TokenizerRequest = (
(String, Option<usize>),
(String, bool, Option<usize>),
oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
Span,
);
@ -720,6 +740,7 @@ pub struct ValidGenerateRequest {
pub input_ids: Option<Arc<Vec<u32>>>,
pub input_length: u32,
pub truncate: u32,
pub add_special_tokens: bool,
pub decoder_input_details: bool,
pub parameters: ValidParameters,
pub stopping_parameters: ValidStoppingParameters,
@ -826,7 +847,7 @@ mod tests {
let max_new_tokens = 10;
match validation
.validate_input("Hello".to_string(), None, Some(max_new_tokens))
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
.await
{
// Err(ValidationError::MaxNewTokens(1, 10)) => (),
@ -861,7 +882,7 @@ mod tests {
let max_new_tokens = 10;
match validation
.validate_input("Hello".to_string(), None, Some(max_new_tokens))
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
.await
{
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
@ -895,6 +916,7 @@ mod tests {
match validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters {
best_of: Some(2),
do_sample: false,
@ -934,6 +956,7 @@ mod tests {
match validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters {
top_p: Some(1.0),
max_new_tokens: Some(5),
@ -949,6 +972,7 @@ mod tests {
match validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters {
top_p: Some(0.99),
max_new_tokens: Some(5),
@ -964,6 +988,7 @@ mod tests {
let valid_request = validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters {
top_p: None,
max_new_tokens: Some(5),
@ -1002,6 +1027,7 @@ mod tests {
match validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters {
top_n_tokens: Some(5),
max_new_tokens: Some(5),
@ -1017,6 +1043,7 @@ mod tests {
validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters {
top_n_tokens: Some(4),
max_new_tokens: Some(5),
@ -1029,6 +1056,7 @@ mod tests {
validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters {
top_n_tokens: Some(0),
max_new_tokens: Some(5),
@ -1041,6 +1069,7 @@ mod tests {
let valid_request = validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters {
top_n_tokens: None,
max_new_tokens: Some(5),
@ -1089,6 +1118,7 @@ mod tests {
let chunks = match validation
.tokenize(
format!("test![](data:image/gif;base64,{})", PIXEL_GIF),
true,
None,
)
.await
@ -1148,6 +1178,7 @@ mod tests {
"test![](data:image/gif;base64,{})![](data:image/gif;base64,{})",
PIXEL_GIF, PIXEL_GIF
),
true,
None,
)
.await

View File

@ -1,5 +1,5 @@
[toolchain]
# Released on: June 13, 2024
# https://releases.rs/docs/1.79.0/
channel = "1.79.0"
channel = "1.80.0"
components = ["rustfmt", "clippy"]

View File

@ -7,6 +7,7 @@ include Makefile-selective-scan
include Makefile-lorax-punica
include Makefile-fbgemm
include Makefile-exllamav2
include Makefile-flashinfer
unit-tests:
pytest -s -vv -m "not private" tests

View File

@ -1,7 +1,9 @@
fbgemm_commit := v0.8.0
build-fbgemm:
git clone https://github.com/pytorch/FBGEMM.git fbgemm && \
@if [ ! -d "fbgemm" ]; then \
git clone https://github.com/pytorch/FBGEMM.git fbgemm; \
fi
cd fbgemm && git fetch && git checkout $(fbgemm_commit) && \
git submodule update --init --recursive && \
cd fbgemm_gpu && \

View File

@ -0,0 +1,2 @@
install-flashinfer:
pip install flashinfer==0.1.6 -i https://flashinfer.ai/whl/cu124/torch2.4

5
server/poetry.lock generated
View File

@ -3237,11 +3237,6 @@ files = [
{file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"},
{file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"},
{file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"},
{file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"},
{file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"},
{file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"},
{file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"},
{file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"},
]
[package.dependencies]

View File

@ -1,7 +1,10 @@
import pytest
import os
from text_generation_server.pb import generate_pb2
os.environ["USE_PREFIX_CACHING"] = "1"
os.environ["ATTENTION"] = "flashinfer"
@pytest.fixture
def default_pb_parameters():

View File

@ -267,7 +267,7 @@ def test_batch_concatenate(
assert next_batch.max_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == next_batch_1.requests
assert next_batch.requests[1:] == list(next_batch_1.requests)
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers

View File

@ -262,7 +262,7 @@ def test_batch_concatenate(
assert next_batch.max_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == next_batch_1.requests
assert next_batch.requests[1:] == list(next_batch_1.requests)
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers

View File

@ -281,7 +281,7 @@ def test_batch_concatenate(
assert next_batch.max_decoder_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == next_batch_1.requests
assert next_batch.requests[1:] == list(next_batch_1.requests)
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers

View File

@ -1,6 +1,54 @@
import pytest
from unittest.mock import Mock
from text_generation_server.utils.adapter import get_attn_weights, get_mlp_weights
from text_generation_server.utils.adapter import (
get_attn_weights,
get_mlp_weights,
parse_lora_adapters,
AdapterInfo,
)
def test_parse_lora_adapters_empty():
assert parse_lora_adapters(None) == []
assert parse_lora_adapters("") == []
def test_parse_lora_adapters_single():
result = parse_lora_adapters("adapter1")
assert result == [AdapterInfo(id="adapter1", path=None, revision=None)]
def test_parse_lora_adapters_with_path():
result = parse_lora_adapters("adapter1=path/to/adapter1")
assert result == [
AdapterInfo(id="adapter1", path="path/to/adapter1", revision=None)
]
def test_parse_lora_adapters_with_path_and_revision():
result = parse_lora_adapters("adapter1=path/to/adapter1@main")
assert result == [
AdapterInfo(id="adapter1", path="path/to/adapter1", revision="main")
]
def test_parse_lora_adapters_multiple():
result = parse_lora_adapters(
"adapter1,adapter2=path/to/adapter2,adapter3=path/to/adapter3@dev"
)
assert result == [
AdapterInfo(id="adapter1", path=None, revision=None),
AdapterInfo(id="adapter2", path="path/to/adapter2", revision=None),
AdapterInfo(id="adapter3", path="path/to/adapter3", revision="dev"),
]
def test_parse_lora_adapters_invalid_format():
try:
parse_lora_adapters("adapter1,invalid=format=test,adapter3")
assert False, "Should have raised ValueError"
except ValueError as e:
assert str(e) == "Invalid LoRA adapter format: invalid=format=test"
def test_get_attn_weights():
@ -22,6 +70,10 @@ def test_get_attn_weights():
"model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value,
),
(2, "qkv_proj"): (
"model.layers.2.self_attn.qkv_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): (
"model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value,
@ -115,6 +167,10 @@ def test_get_attn_weights_llama_compatibility():
"model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value,
),
(2, "qkv_proj"): (
"model.layers.2.self_attn.qkv_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): (
"model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value,
@ -155,6 +211,10 @@ def test_get_attn_weights_gemma_compatibility():
"model.layers.2.self_attn.k_proj",
mock_layer.self_attn.query_key_value,
),
(2, "qkv_proj"): (
"model.layers.2.self_attn.qkv_proj",
mock_layer.self_attn.query_key_value,
),
(2, "v_proj"): (
"model.layers.2.self_attn.v_proj",
mock_layer.self_attn.query_key_value,

View File

@ -9,26 +9,46 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]
max_q: int
max_k: int
def __init__(self, input_lengths):
def __init__(
self,
input_lengths,
prefix_lengths,
cu_seqlen_q=None,
max_q=None,
max_k=None,
):
self.input_lengths = input_lengths
self.prefix_lengths = prefix_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
if cu_seqlen_q is None:
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
max_q = 1
else:
assert max_q is not None
assert max_k is not None
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:])
total = self.input_lengths + self.prefix_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
self.max_q = max_q
self.max_k = max_k
def clamp(self, max):
# Flash decoding doesn't need to clamp
@ -39,6 +59,11 @@ else:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: torch.Tensor
max_q: int
max_k: int
def clamp(self, max):
raise NotImplementedError("Not implemented seqlen for paged")
return Seqlen(torch.clamp(self.input_lengths, max=max))

View File

@ -222,18 +222,15 @@ if ATTENTION == "flashinfer":
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens,
max_s,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
assert window_size_left == -1, "Windowing is not supported with flash infer"
from text_generation_server.layers.attention.flashinfer import (
prefill_with_paged_kv_state,
)
@ -244,18 +241,17 @@ if ATTENTION == "flashinfer":
paged_kv_cache=(key_cache, value_cache),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
window_left=window_size_left,
)
elif V2:
def attention(
q,
k,
v,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens,
max_s,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
@ -266,17 +262,17 @@ elif V2:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
key_cache,
value_cache,
out,
cu_seqlens,
cu_seqlens,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
None,
block_tables,
None,
None,
max_s,
max_s,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,

View File

@ -8,11 +8,11 @@ SUPPORTS_WINDOWING = False
def attention(
q,
k,
v,
cu_seqlens,
max_s,
q: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
@ -22,14 +22,14 @@ def attention(
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
ipex.llm.functional.varlen_attention(
q,
k,
v,
q.contiguous() if q.device.type == "xpu" else q,
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_q,
0.0,
softmax_scale,
False,

View File

@ -45,12 +45,107 @@ class MLPSpeculatorLayerNorm(nn.Module):
return x
INV_SQRT2 = 2**-0.5
def simple_norm(x: torch.Tensor, eps=1e-06):
xf = x
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + eps)
x = xf.type_as(x)
return x * INV_SQRT2
class MLPSpeculatorModelTied(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.config = config
self.n_predict = get_speculate()
self.hidden_size = config.hidden_size
self.emb = TensorParallelEmbedding(f"{prefix}.emb.0", weights)
self.proj0 = FastLinear.load(
config,
prefix=f"{prefix}.proj.0",
weights=weights,
bias=False,
)
self.proj1 = FastLinear.load(
config,
prefix=f"{prefix}.proj.1",
weights=weights,
bias=False,
)
self.head = FastLinear.load(config, f"{prefix}.head.0", weights, bias=False)
self.ln = MLPSpeculatorLayerNorm(
prefix=f"{prefix}.ln.0",
config=config,
weights=weights,
)
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
self.activation = nn.GELU()
self.vsize = config.vocab_size
self.inner_dim = config.speculator_config["inner_dim"]
self.top_k_tokens_per_head = [1] * self.n_predict
self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(
self.inner_dim / 2
)
self.emb.weight *= self.emb_weight
def forward(
self,
hidden_states: torch.Tensor,
input_ids: torch.Tensor,
):
top_k_tokens_per_head = self.top_k_tokens_per_head
# k indicates # of candidates
# h indicates # of generated tokens
state = hidden_states
b = state.size(0)
ind = input_ids.unsqueeze(0)
all_probs = torch.empty(
b, self.n_predict, self.vsize, device=state.device
) # b k h v
assert (
len(top_k_tokens_per_head) == self.n_predict
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
for i in range(self.n_predict):
# Project and predict
z = self.emb(ind)
# z = z.mul(self.emb_weight) # b k d
if i == 0:
state = self.proj0(state) * self.state_weight + z
else:
state = self.proj1(state) * self.state_weight + z
state = self.activation(self.ln(state)) # b k d
probs = F.log_softmax(self.head(state), dim=-1) # b k v
_probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
# Update candidate set with new predictions
# Update distribution set with new logits
all_probs[:, i] = probs.exp()
# Update state, log_probs and ind for new predictions
state = state.unsqueeze(2).expand(
-1, -1, top_k_tokens_per_head[i], -1
) # b k k' d
state = state.reshape(-1, b, state.size(3)) # b kk' d
ind = preds.view(-1, b) # b kk'
speculative_logits = all_probs
return speculative_logits
class MLPSpeculatorModel(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.config = config
self.n_predict = get_speculate()
self.hidden_size = config.hidden_size
self.emb = nn.ModuleList(
[
TensorParallelEmbedding(f"{prefix}.emb.{i}", weights)
@ -84,13 +179,15 @@ class MLPSpeculatorModel(torch.nn.Module):
)
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
self.state_weight = 0.5 ** (0.5 / self.n_predict)
self.emb_weight = math.sqrt(1 - self.state_weight**2)
self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
self.activation = nn.GELU()
# TODO
self.vsize = config.vocab_size
self.inner_dim = config.speculator_config["inner_dim"]
self.top_k_tokens_per_head = [1] * self.n_predict
self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(
self.inner_dim / 2
)
self.emb.weight *= self.emb_weight
def forward(
self,
@ -113,7 +210,7 @@ class MLPSpeculatorModel(torch.nn.Module):
for i in range(self.n_predict):
# Project and predict
z = self.emb[i](ind)
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
# z = z.mul(self.emb_weight) # b k d
state = self.proj[i](state) * self.state_weight + z
state = self.activation(self.ln[i](state)) # b k d
probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
@ -136,10 +233,11 @@ class MLPSpeculatorModel(torch.nn.Module):
class MLPSpeculatorHead(nn.Module):
def __init__(self, lm_head, mlp_speculator):
def __init__(self, lm_head, mlp_speculator, scale_input: bool):
super().__init__()
self.lm_head = lm_head
self.mlp_speculator = mlp_speculator
self.scale_input = scale_input
def forward(
self, input: torch.Tensor
@ -150,6 +248,8 @@ class MLPSpeculatorHead(nn.Module):
return logits, None
input_ids = logits.argmax(dim=-1)
if self.scale_input:
input = simple_norm(input)
speculative_logits = self.mlp_speculator(input, input_ids)
return logits, speculative_logits
@ -171,6 +271,12 @@ class MLPSpeculatorHead(nn.Module):
)
routing[k] = filename
tie_weights = config.speculator_config.get("tie_weights", False)
if tie_weights:
mlp_speculator = MLPSpeculatorModelTied(config, "speculator", weights)
else:
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
# This is used in https://huggingface.co/ibm-fms/llama3-70b-accelerator
scale_input = config.speculator_config.get("scale_input", False)
lm_head = TensorParallelHead.load(config, prefix, weights)
return MLPSpeculatorHead(lm_head, mlp_speculator)
return MLPSpeculatorHead(lm_head, mlp_speculator, scale_input)

View File

@ -458,6 +458,11 @@ def get_model(
revision=mlp_revision,
filename=filename,
)
speculator_dir_path = Path(mlp_speculator_config).parent
# if these are downloaded, they get converted to safetensors
filenames.extend(
[p for p in os.listdir(speculator_dir_path) if p.endswith(extension)]
)
speculator = {
"path": Path(mlp_speculator_config).parent,
"model_paths": filenames,
@ -497,12 +502,11 @@ def get_model(
else -1
)
should_use_sliding_window = (
sliding_window is not None and sliding_window != -1 and SUPPORTS_WINDOWING
use_sliding_window = sliding_window is not None and sliding_window != -1
needs_sliding_window = (
max_input_tokens is not None and max_input_tokens > sliding_window
)
if should_use_sliding_window:
if max_input_tokens is not None and max_input_tokens > sliding_window:
if use_sliding_window and needs_sliding_window and not SUPPORTS_WINDOWING:
raise ValueError(
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
)
@ -1255,6 +1259,7 @@ def get_model_with_lora_adapters(
"gate_proj",
"up_proj",
"down_proj",
"qkv_proj",
]
for layer_name in adapter_layers:
@ -1282,7 +1287,7 @@ def get_model_with_lora_adapters(
if len(unused_weight_names) > 0:
logger.warning(
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
f"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}"
)
if adapter_tokenizer is not None:

View File

@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import (
@ -264,7 +265,7 @@ class FlashCohereAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
qkv = self.query_key_value(hidden_states)
@ -296,12 +297,10 @@ class FlashCohereAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
key,
value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1] if SYSTEM != "ipex" else value,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
@ -313,7 +312,7 @@ class FlashCohereAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
@ -388,7 +387,7 @@ class FlashCohereLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -402,7 +401,7 @@ class FlashCohereLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
@ -454,7 +453,7 @@ class FlashCohereModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: torch.Tensor,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
@ -477,7 +476,7 @@ class FlashCohereModel(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
@ -518,7 +517,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
@ -531,7 +530,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
if lm_head_indices is not None:

View File

@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
FastLinear,
@ -309,7 +310,7 @@ class DbrxAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
qkv = self.query_key_value(hidden_states)
@ -335,12 +336,10 @@ class DbrxAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
@ -352,7 +351,7 @@ class DbrxAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
@ -389,7 +388,7 @@ class DbrxNormAttentionNorm(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
normed_hidden_states, res = self.norm_1(hidden_states, residual)
@ -403,7 +402,7 @@ class DbrxNormAttentionNorm(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
@ -622,7 +621,7 @@ class DbrxLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
# Self Attention
@ -635,7 +634,7 @@ class DbrxLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
@ -679,7 +678,7 @@ class DbrxModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
@ -701,7 +700,7 @@ class DbrxModel(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
@ -734,7 +733,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
@ -747,7 +746,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
if lm_head_indices is not None:

View File

@ -29,8 +29,8 @@ from text_generation_server.layers.attention import (
attention,
paged_attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers.attention.common import Seqlen
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.import_utils import SYSTEM
@ -298,7 +298,7 @@ class DeepseekV2Attention(torch.nn.Module):
kv_cache: Tuple[torch.Tensor, torch.Tensor],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: Seqlen,
seqlen: Seqlen,
max_s: int,
):
if self.q_lora_rank is None:
@ -363,12 +363,10 @@ class DeepseekV2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
key,
value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1] if SYSTEM != "ipex" else value,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
@ -380,7 +378,7 @@ class DeepseekV2Attention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
@ -666,7 +664,7 @@ class DeepseekV2Layer(nn.Module):
kv_cache,
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: Seqlen,
seqlen: Seqlen,
max_s: int,
):
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -680,7 +678,7 @@ class DeepseekV2Layer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
@ -729,7 +727,7 @@ class DeepseekV2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
@ -751,7 +749,7 @@ class DeepseekV2Model(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
@ -781,7 +779,7 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
@ -794,7 +792,7 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
if lm_head_indices is not None:

View File

@ -25,11 +25,12 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -213,7 +214,7 @@ class FlashGemma2Attention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
qkv = self.query_key_value(hidden_states)
@ -236,12 +237,10 @@ class FlashGemma2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
causal=self.causal,
window_size_left=self.window_size,
@ -256,7 +255,7 @@ class FlashGemma2Attention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
softcap=self.softcap,
)
@ -343,7 +342,7 @@ class FlashGemma2Layer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -357,7 +356,7 @@ class FlashGemma2Layer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
@ -408,7 +407,7 @@ class FlashGemma2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
) -> torch.Tensor:
hidden_states = inputs_embeds
@ -430,7 +429,7 @@ class FlashGemma2Model(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
@ -477,7 +476,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
@ -491,7 +490,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
if lm_head_indices is not None:

View File

@ -25,11 +25,12 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -207,7 +208,7 @@ class FlashGemmaAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
qkv = self.query_key_value(hidden_states)
@ -230,12 +231,10 @@ class FlashGemmaAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
causal=self.causal,
)
@ -248,7 +247,7 @@ class FlashGemmaAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
@ -320,7 +319,7 @@ class FlashGemmaLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -334,7 +333,7 @@ class FlashGemmaLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
@ -382,7 +381,7 @@ class FlashGemmaModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
) -> torch.Tensor:
hidden_states = inputs_embeds
@ -404,7 +403,7 @@ class FlashGemmaModel(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
@ -449,7 +448,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
@ -463,7 +462,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
if lm_head_indices is not None:

View File

@ -24,11 +24,12 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -213,7 +214,7 @@ class FlashGPT2Attention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
query, key, value = self.query_key_value(hidden_states).split(
@ -230,12 +231,10 @@ class FlashGPT2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
key,
value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1] if SYSTEM != "ipex" else value,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
@ -247,7 +246,7 @@ class FlashGPT2Attention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
@ -316,7 +315,7 @@ class FlashGPT2Layer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
residual = hidden_states
@ -329,7 +328,7 @@ class FlashGPT2Layer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
@ -382,7 +381,7 @@ class FlashGPT2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
@ -398,7 +397,7 @@ class FlashGPT2Model(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
@ -435,7 +434,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
@ -451,7 +450,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,

View File

@ -24,11 +24,12 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -43,7 +44,6 @@ from text_generation_server.layers.rotary import (
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
from text_generation_server.utils.import_utils import SYSTEM
def load_attention(config, prefix: str, weights):
@ -167,7 +167,7 @@ class FlashGPTJAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
query, key, value = self.query_key_value(hidden_states).split(
@ -192,10 +192,10 @@ class FlashGPTJAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
key,
value,
cu_seqlen_prefill,
max_s,
kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1] if SYSTEM != "ipex" else value,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
@ -207,7 +207,7 @@ class FlashGPTJAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
@ -268,7 +268,7 @@ class FlashGPTJLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -281,7 +281,7 @@ class FlashGPTJLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
@ -328,7 +328,7 @@ class FlashGPTJModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
@ -351,7 +351,7 @@ class FlashGPTJModel(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
@ -382,7 +382,7 @@ class FlashGPTJForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
@ -395,7 +395,7 @@ class FlashGPTJForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices=prefill_cache_indices,
)

View File

@ -32,6 +32,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -65,15 +66,15 @@ def load_attention(config, prefix: str, weights, layer_id):
prefixes = None
if config.model_type == "phi3":
prefix = f"{prefix}.qkv_proj"
base_layer = TensorParallelColumnLinear.load_qkv(
config,
prefix=prefix,
prefix=f"{prefix}.qkv_proj",
weights=weights,
bias=bias,
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
)
prefixes = ["qkv_proj"]
elif config.model_type == "baichuan":
prefix = f"{prefix}.W_pack"
base_layer = TensorParallelColumnLinear.load_qkv(
@ -84,6 +85,7 @@ def load_attention(config, prefix: str, weights, layer_id):
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
)
prefixes = [prefix]
else:
prefixes = ["q_proj", "k_proj", "v_proj"]
sizes = [
@ -194,7 +196,7 @@ class FlashLlamaAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
adapter_data,
):
@ -218,12 +220,10 @@ class FlashLlamaAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
@ -235,7 +235,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
@ -375,7 +375,7 @@ class FlashLlamaLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
adapter_data,
):
@ -390,7 +390,7 @@ class FlashLlamaLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
adapter_data,
)
@ -479,7 +479,7 @@ class FlashLlamaModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
@ -504,7 +504,7 @@ class FlashLlamaModel(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
adapter_data,
)
@ -548,7 +548,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
@ -562,7 +562,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,

View File

@ -31,6 +31,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -185,7 +186,7 @@ class MistralAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
@ -217,12 +218,10 @@ class MistralAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
window_size_left=self.max_past,
)
@ -235,7 +234,7 @@ class MistralAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
@ -356,7 +355,7 @@ class MistralLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
@ -372,7 +371,7 @@ class MistralLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
@ -424,7 +423,7 @@ class MistralModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
@ -448,7 +447,7 @@ class MistralModel(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
adapter_data,
@ -499,7 +498,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
@ -512,7 +511,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
seqlen = seqlen.clamp(max=self.max_past_tensor)
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
@ -522,7 +521,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,

View File

@ -35,6 +35,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
FastLinear,
@ -243,7 +244,7 @@ class MixtralAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
):
@ -274,12 +275,10 @@ class MixtralAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
window_size_left=self.max_past,
)
@ -292,7 +291,7 @@ class MixtralAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
@ -498,7 +497,7 @@ class MixtralLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
):
@ -513,7 +512,7 @@ class MixtralLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
)
@ -568,7 +567,7 @@ class MixtralModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
@ -592,7 +591,7 @@ class MixtralModel(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
)
@ -627,7 +626,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
@ -640,7 +639,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model(
input_ids,
@ -649,7 +648,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,

View File

@ -26,11 +26,12 @@ from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -147,7 +148,7 @@ class FlashNeoxAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
qkv = self.query_key_value(hidden_states)
@ -171,12 +172,10 @@ class FlashNeoxAttention(torch.nn.Module):
# flash attention
attn_output = attention(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
kv_cache[0] if SYSTEM != "ipex" else qkv[:, 1],
kv_cache[1] if SYSTEM != "ipex" else qkv[:, 2],
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
@ -188,7 +187,7 @@ class FlashNeoxAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
@ -258,7 +257,7 @@ class FlashNeoXLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
if self.use_parallel_residual:
@ -272,7 +271,7 @@ class FlashNeoXLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
@ -296,7 +295,7 @@ class FlashNeoXLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
@ -350,7 +349,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_in(input_ids)
@ -372,7 +371,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
@ -404,7 +403,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
@ -417,7 +416,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
if lm_head_indices is not None:

View File

@ -19,6 +19,7 @@ from torch import nn
from typing import Optional, List, Tuple
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
@ -34,6 +35,11 @@ class PaliGemmaForConditionalGeneration(nn.Module):
config=config.vision_config,
weights=weights,
)
self.post_vision_tower_layernorm = nn.LayerNorm.load(
prefix="vision_tower.vision_model.post_layernorm",
weights=weights,
eps=config.vision_config.layer_norm_eps,
)
self.multi_modal_projector = TensorParallelColumnLinear.load(
config,
@ -65,7 +71,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
@ -84,7 +90,10 @@ class PaliGemmaForConditionalGeneration(nn.Module):
if pixel_values is not None:
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
image_outputs = self.vision_tower(pixel_values)
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
last_hidden_state = self.post_vision_tower_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multi_modal_projector(last_hidden_state)
# mask where image or padding tokens
mask = input_ids == self.config.image_token_index
@ -99,7 +108,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
seqlen=seqlen,
max_s=max_s,
)

View File

@ -10,6 +10,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -24,6 +25,7 @@ from text_generation_server.layers.layernorm import (
from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
from text_generation_server.utils.import_utils import SYSTEM
class PhiConfig(PretrainedConfig):
@ -159,7 +161,7 @@ class FlashPhiAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
# Compute query, key, value and split
@ -192,12 +194,10 @@ class FlashPhiAttention(torch.nn.Module):
if cu_seqlen_prefill is not None:
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
@ -209,7 +209,7 @@ class FlashPhiAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
@ -276,7 +276,7 @@ class FlashPhiLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -289,7 +289,7 @@ class FlashPhiLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
@ -341,7 +341,7 @@ class FlashPhiModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
@ -363,7 +363,7 @@ class FlashPhiModel(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
@ -396,7 +396,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
@ -409,7 +409,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
if lm_head_indices is not None:

View File

@ -9,6 +9,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -20,6 +21,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.utils.import_utils import SYSTEM
def load_attention(config, prefix, weights):
@ -104,7 +106,7 @@ class Qwen2Attention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
):
@ -135,12 +137,10 @@ class Qwen2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
window_size_left=self.max_past,
)
@ -153,7 +153,7 @@ class Qwen2Attention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
@ -225,7 +225,7 @@ class Qwen2Layer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
):
@ -240,7 +240,7 @@ class Qwen2Layer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
)
@ -296,7 +296,7 @@ class Qwen2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
@ -320,7 +320,7 @@ class Qwen2Model(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
prefill_cache_indices,
)
@ -361,7 +361,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
@ -374,7 +374,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model(
input_ids,
@ -383,7 +383,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,

View File

@ -5,7 +5,7 @@ import torch.distributed
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import (
SpeculativeHead,
TensorParallelColumnLinear,
@ -19,6 +19,7 @@ from text_generation_server.layers.attention import (
attention,
paged_attention,
reshape_and_cache,
Seqlen,
)
@ -181,7 +182,7 @@ class FlashRWAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
qkv = self.query_key_value(hidden_states)
@ -206,12 +207,10 @@ class FlashRWAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
@ -223,7 +222,7 @@ class FlashRWAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
@ -296,7 +295,7 @@ class FlashRWLargeAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
qkv = self.query_key_value(hidden_states)
@ -326,12 +325,10 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
kv_cache[0] if SYSTEM != "ipex" else kv[:, :, 0].contiguous(),
kv_cache[1] if SYSTEM != "ipex" else kv[:, :, 1].contiguous(),
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
@ -343,7 +340,7 @@ class FlashRWLargeAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
@ -429,7 +426,7 @@ class FlashRWLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
if self.parallel_attn:
@ -443,7 +440,7 @@ class FlashRWLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
@ -465,7 +462,7 @@ class FlashRWLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
@ -552,7 +549,7 @@ class FlashRWLargeLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
# Layer norm.
@ -567,7 +564,7 @@ class FlashRWLargeLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
@ -628,7 +625,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
@ -650,7 +647,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
@ -680,7 +677,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
@ -693,7 +690,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
if lm_head_indices is not None:

Some files were not shown because too many files have changed in this diff Show More