Merge branch 'main' into gpt_awq_4

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2024-09-13 04:45:19 -04:00
commit 10628e878a
35 changed files with 4446 additions and 1267 deletions

41
.github/workflows/nix_tests.yaml vendored Normal file
View File

@ -0,0 +1,41 @@
name: "Nix Tests"
on:
pull_request:
paths:
- ".github/workflows/nix_tests.yaml"
- "server/**"
- "proto/**"
- "router/**"
- "launcher/**"
- "Cargo.lock"
- "rust-toolchain.toml"
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
tests:
runs-on:
group: aws-highmemory-32-plus-priv
steps:
- uses: actions/checkout@v4
- uses: cachix/install-nix-action@v27
with:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
with:
name: text-generation-inference
# If you chose signing key for write access
authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
env:
USER: github_runner
- name: Build
run: nix develop .#test --command echo "Ok"
- name: Pre-commit tests.
run: nix develop .#test --command pre-commit run --all-files
- name: Python tests.
run: nix develop .#test --command python -m pytest server/tests/
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
- name: Rust tests.
run: nix develop .#test --command cargo test

View File

@ -17,19 +17,15 @@ concurrency:
jobs: jobs:
run_tests: run_tests:
runs-on: ubuntu-latest runs-on:
group: aws-highmemory-32-plus-priv
env:
SCCACHE_GHA_ENABLED: "on"
RUSTC_WRAPPER: /usr/local/bin/sccache
SCCACHE: 0.3.3
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v1 uses: actions/setup-python@v4
id: python
with: with:
python-version: 3.9 python-version: 3.11
- name: Install Rust - name: Install Rust
uses: actions-rs/toolchain@v1 uses: actions-rs/toolchain@v1
with: with:
@ -44,30 +40,9 @@ jobs:
run: | run: |
sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android
sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET
- name: Install sccache
run: |
curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache
chmod +x /usr/local/bin/sccache
- name: configure sccache
uses: actions/github-script@v6
with:
script: |
core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || '');
core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
core.exportVariable('SCCACHE_GHA_CACHE_TO', 'sccache-${{runner.os}}-${{github.ref_name}}');
core.exportVariable('SCCACHE_GHA_CACHE_FROM', 'sccache-${{runner.os}}-main,sccache-${{runner.os}}-');
- name: cargo registry cache
uses: actions/cache@v3
with:
key: cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}-${{ github.sha }}
restore-keys: |
cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}-
cargo-${{ runner.os }}-
path: |
~/.cargo/registry
~/.cargo/git
- name: Install - name: Install
run: | run: |
sudo apt install python3.11-dev -y
make install-cpu make install-cpu
- name: Run server tests - name: Run server tests
run: | run: |
@ -82,6 +57,3 @@ jobs:
- name: Run Rust tests - name: Run Rust tests
run: | run: |
cargo test cargo test
- name: sccache stats
run: |
/usr/local/bin/sccache --show-stats

120
Cargo.lock generated
View File

@ -2118,6 +2118,15 @@ version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
[[package]]
name = "memoffset"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "metrics" name = "metrics"
version = "0.23.0" version = "0.23.0"
@ -3112,6 +3121,69 @@ dependencies = [
"prost 0.12.6", "prost 0.12.6",
] ]
[[package]]
name = "pyo3"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "831e8e819a138c36e212f3af3fd9eeffed6bf1510a805af35b0edee5ffa59433"
dependencies = [
"cfg-if",
"indoc",
"libc",
"memoffset",
"once_cell",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"unindent",
]
[[package]]
name = "pyo3-build-config"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e8730e591b14492a8945cdff32f089250b05f5accecf74aeddf9e8272ce1fa8"
dependencies = [
"once_cell",
"target-lexicon",
]
[[package]]
name = "pyo3-ffi"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e97e919d2df92eb88ca80a037969f44e5e70356559654962cbb3316d00300c6"
dependencies = [
"libc",
"pyo3-build-config",
]
[[package]]
name = "pyo3-macros"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb57983022ad41f9e683a599f2fd13c3664d7063a3ac5714cae4b7bee7d3f206"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.76",
]
[[package]]
name = "pyo3-macros-backend"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec480c0c51ddec81019531705acac51bcdbeae563557c982aa8263bb96880372"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"pyo3-build-config",
"quote",
"syn 2.0.76",
]
[[package]] [[package]]
name = "qoi" name = "qoi"
version = "0.4.1" version = "0.4.1"
@ -4068,7 +4140,7 @@ dependencies = [
"pkg-config", "pkg-config",
"text-generation-router", "text-generation-router",
"thiserror", "thiserror",
"tokenizers", "tokenizers 0.19.1",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tracing", "tracing",
@ -4091,7 +4163,7 @@ dependencies = [
"tabled", "tabled",
"text-generation-client", "text-generation-client",
"thiserror", "thiserror",
"tokenizers", "tokenizers 0.20.0",
"tokio", "tokio",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
@ -4161,6 +4233,7 @@ dependencies = [
"once_cell", "once_cell",
"opentelemetry 0.20.0", "opentelemetry 0.20.0",
"opentelemetry-otlp", "opentelemetry-otlp",
"pyo3",
"rand", "rand",
"regex", "regex",
"reqwest", "reqwest",
@ -4168,7 +4241,7 @@ dependencies = [
"serde_json", "serde_json",
"sysinfo", "sysinfo",
"thiserror", "thiserror",
"tokenizers", "tokenizers 0.20.0",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tower-http", "tower-http",
@ -4219,7 +4292,7 @@ dependencies = [
"slotmap", "slotmap",
"text-generation-router", "text-generation-router",
"thiserror", "thiserror",
"tokenizers", "tokenizers 0.20.0",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tonic 0.10.2", "tonic 0.10.2",
@ -4374,6 +4447,39 @@ dependencies = [
"unicode_categories", "unicode_categories",
] ]
[[package]]
name = "tokenizers"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8a24d7f7d6be5b9d1377418b893ab1808af0074f5d1bb2c64784452ddd2aa70"
dependencies = [
"aho-corasick",
"derive_builder",
"esaxx-rs",
"getrandom",
"hf-hub",
"indicatif",
"itertools 0.12.1",
"lazy_static",
"log",
"macro_rules_attribute",
"monostate",
"onig",
"paste",
"rand",
"rayon",
"rayon-cond",
"regex",
"regex-syntax 0.8.4",
"serde",
"serde_json",
"spm_precompiled",
"thiserror",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",
]
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.39.3" version = "1.39.3"
@ -4839,6 +4945,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "unindent"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
[[package]] [[package]]
name = "untrusted" name = "untrusted"
version = "0.7.1" version = "0.7.1"

View File

@ -25,7 +25,7 @@ homepage = "https://github.com/huggingface/text-generation-inference"
[workspace.dependencies] [workspace.dependencies]
base64 = "0.22.0" base64 = "0.22.0"
tokenizers = { version = "0.19.1", features = ["http"] } tokenizers = { version = "0.20.0", features = ["http"] }
hf-hub = { version = "0.3.1", features = ["tokio"] } hf-hub = { version = "0.3.1", features = ["tokio"] }
metrics = { version = "0.23.0" } metrics = { version = "0.23.0" }
metrics-exporter-prometheus = { version = "0.15.1", features = [] } metrics-exporter-prometheus = { version = "0.15.1", features = [] }

View File

@ -13,10 +13,13 @@ COPY benchmark benchmark
COPY router router COPY router router
COPY backends backends COPY backends backends
COPY launcher launcher COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder FROM chef AS builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3.11-dev
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
@ -37,6 +40,7 @@ COPY router router
COPY backends backends COPY backends backends
COPY launcher launcher COPY launcher launcher
RUN cargo build --profile release-opt RUN cargo build --profile release-opt
RUN cargo build --profile release-opt
# Python builder # Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
@ -45,7 +49,7 @@ FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099 # NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
ARG PYTORCH_VERSION=2.4.0 ARG PYTORCH_VERSION=2.4.0
ARG PYTHON_VERSION=3.10 ARG PYTHON_VERSION=3.11
# Keep in sync with `server/pyproject.toml # Keep in sync with `server/pyproject.toml
ARG CUDA_VERSION=12.4 ARG CUDA_VERSION=12.4
ARG MAMBA_VERSION=24.3.0-0 ARG MAMBA_VERSION=24.3.0-0
@ -216,33 +220,33 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
COPY --from=pytorch-install /opt/conda /opt/conda COPY --from=pytorch-install /opt/conda /opt/conda
# Copy build artifacts from flash attention builder # Copy build artifacts from flash attention builder
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from flash attention v2 builder # Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /opt/conda/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so /opt/conda/lib/python3.10/site-packages COPY --from=flash-att-v2-builder /opt/conda/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from custom kernels builder # Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from exllama kernels builder # Copy build artifacts from exllama kernels builder
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from exllamav2 kernels builder # Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from awq kernels builder # Copy build artifacts from awq kernels builder
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from eetq kernels builder # Copy build artifacts from eetq kernels builder
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from lorax punica kernels builder # Copy build artifacts from lorax punica kernels builder
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from fbgemm builder # Copy build artifacts from fbgemm builder
COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.11/cmake-install /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from vllm builder # Copy build artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from mamba builder # Copy build artifacts from mamba builder
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /opt/conda/lib/python3.11/site-packages
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /opt/conda/lib/python3.11/site-packages
COPY --from=flashinfer-builder /opt/conda/lib/python3.10/site-packages/flashinfer/ /opt/conda/lib/python3.10/site-packages/flashinfer/ COPY --from=flashinfer-builder /opt/conda/lib/python3.11/site-packages/flashinfer/ /opt/conda/lib/python3.11/site-packages/flashinfer/
# Install flash-attention dependencies # Install flash-attention dependencies
RUN pip install einops --no-cache-dir RUN pip install einops --no-cache-dir
@ -257,7 +261,9 @@ RUN cd server && \
pip install ".[bnb, accelerate, marlin, quantize, peft, outlines]" --no-cache-dir && \ pip install ".[bnb, accelerate, marlin, quantize, peft, outlines]" --no-cache-dir && \
pip install nvidia-nccl-cu12==2.22.3 pip install nvidia-nccl-cu12==2.22.3
ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2 ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
# Required to find libpython within the rust binaries
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
# This is needed because exl2 tries to load flash-attn # This is needed because exl2 tries to load flash-attn
# And fails with our builds. # And fails with our builds.
ENV EXLLAMA_NO_FLASH_ATTN=1 ENV EXLLAMA_NO_FLASH_ATTN=1

View File

@ -17,6 +17,8 @@ RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder FROM chef AS builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3.11-dev
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
@ -64,14 +66,14 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
hipsolver-dev \ hipsolver-dev \
rccl-dev \ rccl-dev \
cmake \ cmake \
python3-dev && \ python3.11-dev && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
# Keep in sync with `server/pyproject.toml # Keep in sync with `server/pyproject.toml
ARG MAMBA_VERSION=23.1.0-1 ARG MAMBA_VERSION=23.1.0-1
ARG PYTORCH_VERSION='2.3.0' ARG PYTORCH_VERSION='2.3.0'
ARG ROCM_VERSION='6.0.2' ARG ROCM_VERSION='6.0.2'
ARG PYTHON_VERSION='3.10.10' ARG PYTHON_VERSION='3.11.10'
# Automatically set by buildx # Automatically set by buildx
ARG TARGETPLATFORM ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH ENV PATH /opt/conda/bin:$PATH
@ -89,10 +91,18 @@ RUN chmod +x ~/mambaforge.sh && \
mamba init && \ mamba init && \
rm ~/mambaforge.sh rm ~/mambaforge.sh
# RUN conda install intel::mkl-static intel::mkl-include
# Install pytorch
# On arm64 we exit with an error code
RUN case ${TARGETPLATFORM} in \
"linux/arm64") exit 1 ;; \
*) /opt/conda/bin/conda update -y conda && \
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
esac && \
/opt/conda/bin/conda clean -ya
# Install flash-attention, torch dependencies # Install flash-attention, torch dependencies
RUN pip install numpy einops ninja --no-cache-dir RUN pip install numpy einops ninja --no-cache-dir
RUN conda install intel::mkl-static intel::mkl-include
RUN pip uninstall -y triton && \ RUN pip uninstall -y triton && \
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \ git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
cd triton/python && \ cd triton/python && \
@ -172,19 +182,19 @@ ENV HF_HOME=/data \
PORT=80 PORT=80
# Copy builds artifacts from vllm builder # Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from flash attention v2 builder # Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from custom kernels builder # Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from exllama kernels builder # Copy build artifacts from exllama kernels builder
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from exllamav2 kernels builder # Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Install server # Install server
COPY proto proto COPY proto proto
@ -201,6 +211,7 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/l
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher # Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
# AWS Sagemaker compatible image # AWS Sagemaker compatible image
FROM base AS sagemaker FROM base AS sagemaker

View File

@ -18,6 +18,8 @@ RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder FROM chef AS builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3.11-dev
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
@ -42,9 +44,35 @@ RUN cargo build --profile release-opt
# Text Generation Inference base image for Intel # Text Generation Inference base image for Intel
FROM intel/intel-extension-for-pytorch:2.1.30-xpu AS xpu FROM intel/intel-extension-for-pytorch:2.3.110-xpu AS xpu
USER root USER root
ARG MAMBA_VERSION=23.1.0-1
ARG PYTHON_VERSION='3.11.10'
# Automatically set by buildx
ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba
# translating Docker's TARGETPLATFORM into mamba arches
RUN case ${TARGETPLATFORM} in \
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
*) MAMBA_ARCH=x86_64 ;; \
esac && \
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
RUN chmod +x ~/mambaforge.sh && \
bash ~/mambaforge.sh -b -p /opt/conda && \
rm ~/mambaforge.sh
RUN case ${TARGETPLATFORM} in \
"linux/arm64") exit 1 ;; \
*) /opt/conda/bin/conda update -y conda && \
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
esac && \
/opt/conda/bin/conda clean -ya
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it # libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \ RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
@ -54,7 +82,7 @@ RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dea
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build pciutils RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y intel-basekit xpu-smi cmake ninja-build pciutils
# Text Generation Inference base env # Text Generation Inference base env
ENV HF_HOME=/data \ ENV HF_HOME=/data \
@ -63,9 +91,7 @@ ENV HF_HOME=/data \
WORKDIR /usr/src WORKDIR /usr/src
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl RUN pip install torch==2.3.1+cxx11.abi torchvision==0.18.1+cxx11.abi torchaudio==2.3.1+cxx11.abi intel-extension-for-pytorch==2.3.110+xpu oneccl_bind_pt==2.3.100+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --no-cache-dir
RUN pip install https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed
# Install server # Install server
COPY proto proto COPY proto proto
@ -80,14 +106,12 @@ ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest
ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest
ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric
ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64: ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:/opt/conda/lib
ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV PATH=/opt/conda/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
ENV CCL_ZE_IPC_EXCHANGE=sockets ENV CCL_ZE_IPC_EXCHANGE=sockets
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
RUN pip uninstall -y intel-extension-for-pytorch && cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router # Install router
@ -123,7 +147,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
PORT=80 PORT=80
ARG MAMBA_VERSION=23.1.0-1 ARG MAMBA_VERSION=23.1.0-1
ARG PYTHON_VERSION='3.10.10' ARG PYTHON_VERSION='3.11.10'
# Automatically set by buildx # Automatically set by buildx
ARG TARGETPLATFORM ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH ENV PATH /opt/conda/bin:$PATH
@ -140,12 +164,19 @@ RUN chmod +x ~/mambaforge.sh && \
bash ~/mambaforge.sh -b -p /opt/conda && \ bash ~/mambaforge.sh -b -p /opt/conda && \
rm ~/mambaforge.sh rm ~/mambaforge.sh
RUN case ${TARGETPLATFORM} in \
"linux/arm64") exit 1 ;; \
*) /opt/conda/bin/conda update -y conda && \
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
esac && \
/opt/conda/bin/conda clean -ya
RUN conda install -c conda-forge gperftools mkl RUN conda install -c conda-forge gperftools mkl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.5.0.dev20240815%2Bcpu-cp310-cp310-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.5.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.20.0.dev20240815%2Bcpu-cp310-cp310-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.20.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240815%2Bcpu-cp310-cp310-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install triton numa RUN pip install triton py-libnuma
WORKDIR /usr/src WORKDIR /usr/src
@ -156,10 +187,11 @@ RUN cd torch-ccl && git submodule sync && git submodule update --init --recursiv
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so
ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch ENV CCL_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch
ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch ENV I_MPI_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/lib
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
# Install server # Install server
COPY proto proto COPY proto proto

View File

@ -376,10 +376,9 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
// Send generation responses back to the infer task // Send generation responses back to the infer task
// If the receive an error from the Flume channel, it means that the client dropped the // If the receive an error from the Flume channel, it means that the client dropped the
// request and we need to stop generating hence why we unwrap_or(true) // request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| { let stopped = send_responses(generation, entry).inspect_err(|_err| {
tracing::error!("Entry response channel error."); tracing::error!("Entry response channel error.");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
err
}).unwrap_or(true); }).unwrap_or(true);
if stopped { if stopped {
entries.remove(&id).expect("ID not found in entries. This is a bug."); entries.remove(&id).expect("ID not found in entries. This is a bug.");

View File

@ -357,6 +357,7 @@ impl State {
let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) = let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) =
(block_allocation, &self.block_allocator) (block_allocation, &self.block_allocator)
{ {
tracing::debug!("Allocating {tokens} with {input_ids:?}");
match block_allocator.allocate(tokens, input_ids).await { match block_allocator.allocate(tokens, input_ids).await {
None => { None => {
// Entry is over budget // Entry is over budget

View File

@ -123,8 +123,6 @@ impl Allocator for RadixAllocator {
prefill_tokens: prefill_tokens.clone(), prefill_tokens: prefill_tokens.clone(),
}; };
tracing::debug!("Blocks {blocks:?}");
self.allocation_id += 1; self.allocation_id += 1;
self.allocations.insert(self.allocation_id, allocation); self.allocations.insert(self.allocation_id, allocation);

View File

@ -492,6 +492,24 @@
"type": "github" "type": "github"
} }
}, },
"flake-utils_7": {
"inputs": {
"systems": "systems_7"
},
"locked": {
"lastModified": 1710146030,
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"gitignore": { "gitignore": {
"inputs": { "inputs": {
"nixpkgs": [ "nixpkgs": [
@ -700,16 +718,16 @@
}, },
"nixpkgs_6": { "nixpkgs_6": {
"locked": { "locked": {
"lastModified": 1723912943, "lastModified": 1724915739,
"narHash": "sha256-39F9GzyhxYcY3wTeKuEFWRJWcrGBosO4nf4xzMTWZX8=", "narHash": "sha256-7PgRge4mn5akFvhPwefuaLQGbF5BnmxlwZJEf7CgbrE=",
"owner": "danieldk", "owner": "nixos",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "b82cdca86dbb30013b76c4b55d48806476820a5c", "rev": "85be051bb60943d3328d91aaf2598798f87e19af",
"type": "github" "type": "github"
}, },
"original": { "original": {
"owner": "danieldk", "owner": "nixos",
"ref": "cuda-12.4", "ref": "nixos-unstable-small",
"repo": "nixpkgs", "repo": "nixpkgs",
"type": "github" "type": "github"
} }
@ -835,11 +853,11 @@
] ]
}, },
"locked": { "locked": {
"lastModified": 1724638882, "lastModified": 1726021481,
"narHash": "sha256-ap2jIQi/FuUHR6HCht6ASWhoz8EiB99XmI8Esot38VE=", "narHash": "sha256-4J4E+Fh+77XIYnq2RVtg+ENWXpu6t74P0jKN/f2RQmI=",
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "19b70f147b9c67a759e35824b241f1ed92e46694", "rev": "1c2c120246c51a644c20ba2a36a33d3bd4860d70",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -938,17 +956,33 @@
"type": "github" "type": "github"
} }
}, },
"systems_7": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"tgi-nix": { "tgi-nix": {
"inputs": { "inputs": {
"flake-compat": "flake-compat_4", "flake-compat": "flake-compat_4",
"flake-utils": "flake-utils_7",
"nixpkgs": "nixpkgs_6" "nixpkgs": "nixpkgs_6"
}, },
"locked": { "locked": {
"lastModified": 1725011596, "lastModified": 1725950569,
"narHash": "sha256-zfq8lOXFgJnKxxsqSelHuKUvhxgH3cEmLoAgsOO62Cg=", "narHash": "sha256-nJHA1SvIQbXySpL2ueNbzQOhnkQASa5tOLz/kdW0PWA=",
"owner": "danieldk", "owner": "danieldk",
"repo": "tgi-nix", "repo": "tgi-nix",
"rev": "717c2b07e38538abf05237cca65b2d1363c2c9af", "rev": "d40f3c22e9bcc5e16c94d4605cf6a7d74dd07f46",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -46,12 +46,30 @@
launcher = cargoNix.workspaceMembers.text-generation-launcher.build.override { launcher = cargoNix.workspaceMembers.text-generation-launcher.build.override {
inherit crateOverrides; inherit crateOverrides;
}; };
router = cargoNix.workspaceMembers.text-generation-router-v3.build.override { router =
inherit crateOverrides; let
}; routerUnwrapped = cargoNix.workspaceMembers.text-generation-router-v3.build.override {
inherit crateOverrides;
};
packagePath =
with pkgs.python3.pkgs;
makePythonPath [
protobuf
sentencepiece
torch
transformers
];
in
pkgs.writeShellApplication {
name = "text-generation-router";
text = ''
PYTHONPATH="${packagePath}" ${routerUnwrapped}/bin/text-generation-router "$@"
'';
};
server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; }; server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; };
in in
{ {
formatter = pkgs.nixfmt-rfc-style;
devShells = with pkgs; rec { devShells = with pkgs; rec {
default = pure; default = pure;
@ -63,6 +81,29 @@
server server
]; ];
}; };
test = mkShell {
buildInputs =
[
# benchmark
# launcher
# router
server
openssl.dev
pkg-config
cargo
rustfmt
clippy
]
++ (with python3.pkgs; [
docker
pytest
pytest-asyncio
syrupy
pre-commit
ruff
]);
};
impure = mkShell { impure = mkShell {
buildInputs = buildInputs =
@ -82,6 +123,7 @@
docker docker
pip pip
ipdb ipdb
click
pyright pyright
pytest pytest
pytest-asyncio pytest-asyncio

View File

@ -19,6 +19,7 @@ from syrupy.extensions.json import JSONSnapshotExtension
from text_generation import AsyncClient from text_generation import AsyncClient
from text_generation.types import ( from text_generation.types import (
BestOfSequence, BestOfSequence,
Message,
ChatComplete, ChatComplete,
ChatCompletionChunk, ChatCompletionChunk,
ChatCompletionComplete, ChatCompletionComplete,
@ -97,25 +98,25 @@ class ResponseComparator(JSONSnapshotExtension):
) -> bool: ) -> bool:
def convert_data(data): def convert_data(data):
data = json.loads(data) data = json.loads(data)
if isinstance(data, Dict) and "choices" in data: return _convert_data(data)
choices = data["choices"]
if isinstance(choices, List) and len(choices) >= 1:
if "delta" in choices[0]:
return ChatCompletionChunk(**data)
if "text" in choices[0]:
return Completion(**data)
return ChatComplete(**data)
def _convert_data(data):
if isinstance(data, Dict): if isinstance(data, Dict):
return Response(**data) if "choices" in data:
data["choices"] = list(
sorted(data["choices"], key=lambda x: x["index"])
)
choices = data["choices"]
if isinstance(choices, List) and len(choices) >= 1:
if "delta" in choices[0]:
return ChatCompletionChunk(**data)
if "text" in choices[0]:
return Completion(**data)
return ChatComplete(**data)
else:
return Response(**data)
if isinstance(data, List): if isinstance(data, List):
if ( return [_convert_data(d) for d in data]
len(data) > 0
and "object" in data[0]
and data[0]["object"] == "text_completion"
):
return [Completion(**d) for d in data]
return [Response(**d) for d in data]
raise NotImplementedError raise NotImplementedError
def eq_token(token: Token, other: Token) -> bool: def eq_token(token: Token, other: Token) -> bool:
@ -571,3 +572,38 @@ def generate_load():
return await asyncio.gather(*futures) return await asyncio.gather(*futures)
return generate_load_inner return generate_load_inner
@pytest.fixture(scope="module")
def generate_multi():
async def generate_load_inner(
client: AsyncClient,
prompts: List[str],
max_new_tokens: int,
seed: Optional[int] = None,
) -> List[Response]:
import numpy as np
arange = np.arange(len(prompts))
perm = np.random.permutation(arange)
rperm = [-1] * len(perm)
for i, p in enumerate(perm):
rperm[p] = i
shuffled_prompts = [prompts[p] for p in perm]
futures = [
client.chat(
messages=[Message(role="user", content=prompt)],
max_tokens=max_new_tokens,
temperature=0,
seed=seed,
)
for prompt in shuffled_prompts
]
shuffled_responses = await asyncio.gather(*futures)
responses = [shuffled_responses[p] for p in rperm]
return responses
return generate_load_inner

View File

@ -1,38 +1,38 @@
{ {
"choices": [ "choices": [
{ {
"finish_reason": "stop", "finish_reason": "length",
"index": 0,
"logprobs": null,
"text": " A Beginners Guide\nDeep learning is a subset"
},
{
"finish_reason": "length",
"index": 1, "index": 1,
"logprobs": null, "logprobs": null,
"text": " PR for more information?" "text": " This is a question that has puzzled many people for"
}, },
{ {
"finish_reason": "length", "finish_reason": "length",
"index": 3, "index": 3,
"logprobs": null, "logprobs": null,
"text": "hd20220811-" "text": "usculas_minusculas(s):\n \"\"\"\n"
},
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": "le Business Incubator is providing a workspace"
}, },
{ {
"finish_reason": "length", "finish_reason": "length",
"index": 2, "index": 2,
"logprobs": null, "logprobs": null,
"text": " severely flawed and often has a substandard" "text": " Paris\nWhat is the capital of France?\nThe"
} }
], ],
"created": 1722014725, "created": 1725877154,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native", "system_fingerprint": "2.2.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 36, "completion_tokens": 40,
"prompt_tokens": 8, "prompt_tokens": 22,
"total_tokens": 44 "total_tokens": 62
} }
} }

View File

@ -5,12 +5,12 @@
"finish_reason": "", "finish_reason": "",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"text": "\n" "text": " A"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -20,12 +20,72 @@
"finish_reason": "", "finish_reason": "",
"index": 1, "index": 1,
"logprobs": null, "logprobs": null,
"text": "\n" "text": " This"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " Paris"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "us"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " Beginner"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " is"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -38,9 +98,9 @@
"text": "\n" "text": "\n"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -50,12 +110,12 @@
"finish_reason": "", "finish_reason": "",
"index": 3, "index": 3,
"logprobs": null, "logprobs": null,
"text": "hd" "text": "cul"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -65,12 +125,12 @@
"finish_reason": "", "finish_reason": "",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"text": "\n" "text": "s"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -80,12 +140,12 @@
"finish_reason": "", "finish_reason": "",
"index": 1, "index": 1,
"logprobs": null, "logprobs": null,
"text": "\n" "text": " a"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -95,12 +155,12 @@
"finish_reason": "", "finish_reason": "",
"index": 2, "index": 2,
"logprobs": null, "logprobs": null,
"text": "\n" "text": "What"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -110,12 +170,12 @@
"finish_reason": "", "finish_reason": "",
"index": 3, "index": 3,
"logprobs": null, "logprobs": null,
"text": "aho" "text": "as"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -125,12 +185,12 @@
"finish_reason": "", "finish_reason": "",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"text": "2" "text": " Guide"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -140,252 +200,12 @@
"finish_reason": "", "finish_reason": "",
"index": 1, "index": 1,
"logprobs": null, "logprobs": null,
"text": "2" "text": " question"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "2"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "ima"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "."
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "."
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "."
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "\n"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " Sarah"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " Yes"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " And"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "i"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "'"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": ","
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " what"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "'"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "s"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " Moh"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -398,9 +218,9 @@
"text": " is" "text": " is"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -410,12 +230,12 @@
"finish_reason": "", "finish_reason": "",
"index": 3, "index": 3,
"logprobs": null, "logprobs": null,
"text": "m" "text": "_minus"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -425,12 +245,12 @@
"finish_reason": "", "finish_reason": "",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"text": " Room" "text": "\n"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -440,12 +260,12 @@
"finish_reason": "", "finish_reason": "",
"index": 1, "index": 1,
"logprobs": null, "logprobs": null,
"text": "s" "text": " that"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -458,9 +278,9 @@
"text": " the" "text": " the"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -470,12 +290,12 @@
"finish_reason": "", "finish_reason": "",
"index": 3, "index": 3,
"logprobs": null, "logprobs": null,
"text": " tired" "text": "cul"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -485,12 +305,12 @@
"finish_reason": "", "finish_reason": "",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"text": ":" "text": "Deep"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -500,12 +320,12 @@
"finish_reason": "", "finish_reason": "",
"index": 1, "index": 1,
"logprobs": null, "logprobs": null,
"text": "'" "text": " has"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -518,9 +338,9 @@
"text": " capital" "text": " capital"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -530,12 +350,192 @@
"finish_reason": "", "finish_reason": "",
"index": 3, "index": 3,
"logprobs": null, "logprobs": null,
"text": "," "text": "as"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " learning"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " puzzled"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " of"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "(s"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " is"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " many"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " France"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "):\n"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " a"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " people"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "?\n"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": " "
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -545,12 +545,12 @@
"finish_reason": "length", "finish_reason": "length",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"text": " She" "text": " subset"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -560,12 +560,12 @@
"finish_reason": "length", "finish_reason": "length",
"index": 1, "index": 1,
"logprobs": null, "logprobs": null,
"text": " scale" "text": " for"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -575,12 +575,12 @@
"finish_reason": "length", "finish_reason": "length",
"index": 2, "index": 2,
"logprobs": null, "logprobs": null,
"text": " of" "text": "The"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
@ -590,12 +590,12 @@
"finish_reason": "length", "finish_reason": "length",
"index": 3, "index": 3,
"logprobs": null, "logprobs": null,
"text": " its" "text": " \"\"\"\n"
} }
], ],
"created": 1724833943, "created": 1725883643,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native" "system_fingerprint": "2.2.1-dev0-native"
} }

View File

@ -4,17 +4,17 @@
"finish_reason": "length", "finish_reason": "length",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"text": " PR for flake8" "text": " A Beginners Guide\nDeep learning is a subset"
} }
], ],
"created": 1713284454, "created": 1725876621,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native", "system_fingerprint": "2.2.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 5, "completion_tokens": 10,
"prompt_tokens": 6, "prompt_tokens": 6,
"total_tokens": 11 "total_tokens": 16
} }
} }

View File

@ -11,7 +11,7 @@ from text_generation.types import (
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_llama_completion_handle(launcher): def flash_llama_completion_handle(launcher):
with launcher( with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", "meta-llama/Meta-Llama-3.1-8B-Instruct",
) as handle: ) as handle:
yield handle yield handle
@ -34,16 +34,19 @@ def test_flash_llama_completion_single_prompt(
f"{flash_llama_completion.base_url}/v1/completions", f"{flash_llama_completion.base_url}/v1/completions",
json={ json={
"model": "tgi", "model": "tgi",
"prompt": "Say this is a test", "prompt": "What is Deep Learning?",
"max_tokens": 5, "max_tokens": 10,
"seed": 0, "temperature": 0.0,
}, },
headers=flash_llama_completion.headers, headers=flash_llama_completion.headers,
stream=False, stream=False,
) )
response = response.json() response = response.json()
assert len(response["choices"]) == 1 assert len(response["choices"]) == 1
assert (
response["choices"][0]["text"]
== " A Beginners Guide\nDeep learning is a subset"
)
assert response == response_snapshot assert response == response_snapshot
@ -53,9 +56,15 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
f"{flash_llama_completion.base_url}/v1/completions", f"{flash_llama_completion.base_url}/v1/completions",
json={ json={
"model": "tgi", "model": "tgi",
"prompt": ["Say", "this", "is", "a"], "prompt": [
"What is Deep Learning?",
"Is water wet?",
"What is the capital of France?",
"def mai",
],
"max_tokens": 10, "max_tokens": 10,
"seed": 0, "seed": 0,
"temperature": 0.0,
}, },
headers=flash_llama_completion.headers, headers=flash_llama_completion.headers,
stream=False, stream=False,
@ -63,9 +72,16 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
response = response.json() response = response.json()
assert len(response["choices"]) == 4 assert len(response["choices"]) == 4
all_indexes = [choice["index"] for choice in response["choices"]] all_indexes = [(choice["index"], choice["text"]) for choice in response["choices"]]
all_indexes.sort() all_indexes.sort()
assert all_indexes == [0, 1, 2, 3] all_indices, all_strings = zip(*all_indexes)
assert list(all_indices) == [0, 1, 2, 3]
assert list(all_strings) == [
" A Beginners Guide\nDeep learning is a subset",
" This is a question that has puzzled many people for",
" Paris\nWhat is the capital of France?\nThe",
'usculas_minusculas(s):\n """\n',
]
assert response == response_snapshot assert response == response_snapshot
@ -77,19 +93,21 @@ async def test_flash_llama_completion_many_prompts_stream(
request = { request = {
"model": "tgi", "model": "tgi",
"prompt": [ "prompt": [
"What color is the sky?", "What is Deep Learning?",
"Is water wet?", "Is water wet?",
"What is the capital of France?", "What is the capital of France?",
"def mai", "def mai",
], ],
"max_tokens": 10, "max_tokens": 10,
"seed": 0, "seed": 0,
"temperature": 0.0,
"stream": True, "stream": True,
} }
url = f"{flash_llama_completion.base_url}/v1/completions" url = f"{flash_llama_completion.base_url}/v1/completions"
chunks = [] chunks = []
strings = [""] * 4
async with ClientSession(headers=flash_llama_completion.headers) as session: async with ClientSession(headers=flash_llama_completion.headers) as session:
async with session.post(url, json=request) as response: async with session.post(url, json=request) as response:
# iterate over the stream # iterate over the stream
@ -108,7 +126,15 @@ async def test_flash_llama_completion_many_prompts_stream(
for c in chunk: for c in chunk:
chunks.append(Completion(**c)) chunks.append(Completion(**c))
assert "choices" in c assert "choices" in c
assert 0 <= c["choices"][0]["index"] <= 4 index = c["choices"][0]["index"]
assert 0 <= index <= 4
strings[index] += c["choices"][0]["text"]
assert response.status == 200 assert response.status == 200
assert list(strings) == [
" A Beginners Guide\nDeep learning is a subset",
" This is a question that has puzzled many people for",
" Paris\nWhat is the capital of France?\nThe",
'usculas_minusculas(s):\n """\n',
]
assert chunks == response_snapshot assert chunks == response_snapshot

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@ -6,9 +6,10 @@ authors = ["Nicolas Patry <nicolas@huggingface.co>"]
[tool.poetry.dependencies] [tool.poetry.dependencies]
pydantic = "> 2, < 3" pydantic = "> 2, < 3"
python = ">=3.9,<3.13" python = ">=3.10,<3.13"
syrupy = "^4.7.1" syrupy = "^4.7.1"
text-generation = "^0.6.0" text-generation = "^0.6.0"
pytest = "^7.4.0" pytest = "^7.4.0"
pytest-asyncio = "^0.21.1" pytest-asyncio = "^0.21.1"
docker = "^6.1.3" docker = "^7"
numpy = "^1.20"

View File

@ -1,34 +1,35 @@
aiohttp==3.8.5 ; python_version >= "3.9" and python_version < "3.13" aiohappyeyeballs==2.4.0 ; python_version >= "3.10" and python_version < "3.13"
aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13" aiohttp==3.10.5 ; python_version >= "3.10" and python_version < "3.13"
annotated-types==0.6.0 ; python_version >= "3.9" and python_version < "3.13" aiosignal==1.3.1 ; python_version >= "3.10" and python_version < "3.13"
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.13" annotated-types==0.7.0 ; python_version >= "3.10" and python_version < "3.13"
attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13" async-timeout==4.0.3 ; python_version >= "3.10" and python_version < "3.11"
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13" attrs==24.2.0 ; python_version >= "3.10" and python_version < "3.13"
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13" certifi==2024.8.30 ; python_version >= "3.10" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") charset-normalizer==3.3.2 ; python_version >= "3.10" and python_version < "3.13"
docker==6.1.3 ; python_version >= "3.9" and python_version < "3.13" colorama==0.4.6 ; python_version >= "3.10" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
exceptiongroup==1.1.3 ; python_version >= "3.9" and python_version < "3.11" docker==7.1.0 ; python_version >= "3.10" and python_version < "3.13"
filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13" exceptiongroup==1.2.2 ; python_version >= "3.10" and python_version < "3.11"
frozenlist==1.4.0 ; python_version >= "3.9" and python_version < "3.13" filelock==3.16.0 ; python_version >= "3.10" and python_version < "3.13"
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "3.13" frozenlist==1.4.1 ; python_version >= "3.10" and python_version < "3.13"
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.9.0 ; python_version >= "3.10" and python_version < "3.13"
idna==3.4 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.24.6 ; python_version >= "3.10" and python_version < "3.13"
iniconfig==2.0.0 ; python_version >= "3.9" and python_version < "3.13" idna==3.8 ; python_version >= "3.10" and python_version < "3.13"
multidict==6.0.4 ; python_version >= "3.9" and python_version < "3.13" iniconfig==2.0.0 ; python_version >= "3.10" and python_version < "3.13"
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13" multidict==6.1.0 ; python_version >= "3.10" and python_version < "3.13"
pluggy==1.3.0 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.10" and python_version < "3.13"
pydantic-core==2.16.3 ; python_version >= "3.9" and python_version < "3.13" packaging==24.1 ; python_version >= "3.10" and python_version < "3.13"
pydantic==2.6.4 ; python_version >= "3.9" and python_version < "3.13" pluggy==1.5.0 ; python_version >= "3.10" and python_version < "3.13"
pytest-asyncio==0.21.1 ; python_version >= "3.9" and python_version < "3.13" pydantic-core==2.23.3 ; python_version >= "3.10" and python_version < "3.13"
pytest==7.4.0 ; python_version >= "3.9" and python_version < "3.13" pydantic==2.9.1 ; python_version >= "3.10" and python_version < "3.13"
pywin32==306 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" pytest-asyncio==0.21.2 ; python_version >= "3.10" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" pytest==7.4.4 ; python_version >= "3.10" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" pywin32==306 ; python_version >= "3.10" and python_version < "3.13" and sys_platform == "win32"
syrupy==4.7.1 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.2 ; python_version >= "3.10" and python_version < "3.13"
text-generation==0.6.1 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.10" and python_version < "3.13"
tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11" syrupy==4.7.1 ; python_version >= "3.10" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" text-generation==0.6.1 ; python_version >= "3.10" and python_version < "3.13"
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13" tomli==2.0.1 ; python_version >= "3.10" and python_version < "3.11"
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.5 ; python_version >= "3.10" and python_version < "3.13"
websocket-client==1.6.2 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.12.2 ; python_version >= "3.10" and python_version < "3.13"
yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.2 ; python_version >= "3.10" and python_version < "3.13"
yarl==1.11.1 ; python_version >= "3.10" and python_version < "3.13"

View File

@ -1843,9 +1843,8 @@ fn main() -> Result<(), LauncherError> {
shutdown.clone(), shutdown.clone(),
&shutdown_receiver, &shutdown_receiver,
) )
.map_err(|err| { .inspect_err(|_| {
shutdown_shards(shutdown.clone(), &shutdown_receiver); shutdown_shards(shutdown.clone(), &shutdown_receiver);
err
})?; })?;
// Default exit code // Default exit code

View File

@ -28,6 +28,9 @@ defaultCrateOverrides
]; ];
}; };
}; };
pyo3-build-config = attrs: {
buildInputs = [ python3 ];
};
text-generation-benchmark = attrs: { text-generation-benchmark = attrs: {
src = filter { src = filter {
root = ../benchmark; root = ../benchmark;

View File

@ -61,6 +61,7 @@ uuid = { version = "1.9.1", default-features = false, features = [
] } ] }
csv = "1.3.0" csv = "1.3.0"
ureq = "=2.9" ureq = "=2.9"
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
[build-dependencies] [build-dependencies]

View File

@ -336,6 +336,8 @@ pub enum InferError {
ValidationError(#[from] ValidationError), ValidationError(#[from] ValidationError),
#[error("Incomplete generation")] #[error("Incomplete generation")]
IncompleteGeneration, IncompleteGeneration,
#[error("Incomplete generation stream")]
IncompleteGenerationStream,
#[error("Template error: {0}")] #[error("Template error: {0}")]
TemplateError(#[from] minijinja::Error), TemplateError(#[from] minijinja::Error),
#[error("Missing template vatiable: {0}")] #[error("Missing template vatiable: {0}")]
@ -351,6 +353,7 @@ impl InferError {
InferError::Overloaded(_) => "overloaded", InferError::Overloaded(_) => "overloaded",
InferError::ValidationError(_) => "validation", InferError::ValidationError(_) => "validation",
InferError::IncompleteGeneration => "incomplete_generation", InferError::IncompleteGeneration => "incomplete_generation",
InferError::IncompleteGenerationStream => "incomplete_generation_stream",
InferError::TemplateError(_) => "template_error", InferError::TemplateError(_) => "template_error",
InferError::MissingTemplateVariable(_) => "missing_template_variable", InferError::MissingTemplateVariable(_) => "missing_template_variable",
InferError::ToolError(_) => "tool_error", InferError::ToolError(_) => "tool_error",

View File

@ -41,6 +41,7 @@ use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Cache, Repo, RepoType}; use hf_hub::{Cache, Repo, RepoType};
use http::header::AUTHORIZATION; use http::header::AUTHORIZATION;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use pyo3::types::IntoPyDict;
use serde_json::Value; use serde_json::Value;
use std::convert::Infallible; use std::convert::Infallible;
use std::fs::File; use std::fs::File;
@ -48,7 +49,6 @@ use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use thiserror::Error; use thiserror::Error;
use tokenizers::processors::template::TemplateProcessing;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::select; use tokio::select;
use tokio::signal; use tokio::signal;
@ -318,7 +318,10 @@ pub(crate) async fn generate_internal(
metrics::counter!("tgi_request_count").increment(1); metrics::counter!("tgi_request_count").increment(1);
// Do not long ultra long inputs, like image payloads. // Do not long ultra long inputs, like image payloads.
tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]); tracing::debug!(
"Input: {}",
&req.inputs.chars().take(1000).collect::<String>()
);
let compute_characters = req.inputs.chars().count(); let compute_characters = req.inputs.chars().count();
let mut add_prompt = None; let mut add_prompt = None;
@ -674,7 +677,7 @@ async fn generate_stream_internal(
// Check if generation reached the end // Check if generation reached the end
// Skip if we already sent an error // Skip if we already sent an error
if !end_reached && !error { if !end_reached && !error {
let err = InferError::IncompleteGeneration; let err = InferError::IncompleteGenerationStream;
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1); metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)); yield Ok(Event::from(err));
@ -1857,18 +1860,34 @@ pub async fn run(
}); });
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| { let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
let mut tokenizer = Tokenizer::from_file(filename).ok(); use pyo3::prelude::*;
if let Some(tokenizer) = &mut tokenizer { let convert = pyo3::Python::with_gil(|py| -> PyResult<()> {
if let Some(class) = &tokenizer_config.tokenizer_class { let transformers = py.import_bound("transformers")?;
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{ let auto = transformers.getattr("AutoTokenizer")?;
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) { let from_pretrained = auto.getattr("from_pretrained")?;
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205"); let args = (tokenizer_name.to_string(),);
tokenizer.with_post_processor(post_processor); let kwargs = [(
} "revision",
} revision.clone().unwrap_or_else(|| "main".to_string()),
} )]
} .into_py_dict_bound(py);
tokenizer let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
let save = tokenizer.getattr("save_pretrained")?;
let args = ("out".to_string(),);
save.call1(args)?;
Ok(())
})
.inspect_err(|err| {
tracing::error!("Failed to import python tokenizer {err}");
});
let filename = if convert.is_ok() {
// If we have correctly loaded and resaved with transformers
// We might have modified the tokenizer.json according to transformers
"out/tokenizer.json".into()
} else {
filename
};
Tokenizer::from_file(filename).ok()
}); });
let config: Option<Config> = config_filename.and_then(|filename| { let config: Option<Config> = config_filename.and_then(|filename| {
@ -2555,6 +2574,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS, InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
InferError::IncompleteGenerationStream => StatusCode::INTERNAL_SERVER_ERROR,
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
@ -2587,77 +2607,6 @@ pub enum WebServerError {
Axum(#[from] axum::BoxError), Axum(#[from] axum::BoxError),
} }
/// Create a post_processor for the LlamaTokenizer
fn create_post_processor(
tokenizer: &Tokenizer,
tokenizer_config: &HubTokenizerConfig,
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);
let bos_token = tokenizer_config.bos_token.as_ref();
let eos_token = tokenizer_config.eos_token.as_ref();
if add_bos_token && bos_token.is_none() {
panic!("add_bos_token = true but bos_token is None");
}
if add_eos_token && eos_token.is_none() {
panic!("add_eos_token = true but eos_token is None");
}
let mut single = Vec::new();
let mut pair = Vec::new();
let mut special_tokens = Vec::new();
if add_bos_token {
if let Some(bos) = bos_token {
let bos_token_id = tokenizer
.token_to_id(bos.as_str())
.expect("Should have found the bos token id");
special_tokens.push((bos.as_str(), bos_token_id));
single.push(format!("{}:0", bos.as_str()));
pair.push(format!("{}:0", bos.as_str()));
}
}
single.push("$A:0".to_string());
pair.push("$A:0".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
let eos_token_id = tokenizer
.token_to_id(eos.as_str())
.expect("Should have found the eos token id");
special_tokens.push((eos.as_str(), eos_token_id));
single.push(format!("{}:0", eos.as_str()));
pair.push(format!("{}:0", eos.as_str()));
}
}
if add_bos_token {
if let Some(bos) = bos_token {
pair.push(format!("{}:1", bos.as_str()));
}
}
pair.push("$B:1".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
pair.push(format!("{}:1", eos.as_str()));
}
}
let post_processor = TemplateProcessing::builder()
.try_single(single)?
.try_pair(pair)?
.special_tokens(special_tokens)
.build()?;
Ok(post_processor)
}
type PreparedInput = (String, Option<GrammarType>, bool); type PreparedInput = (String, Option<GrammarType>, bool);
fn prepare_chat_input( fn prepare_chat_input(

View File

@ -1,2 +1,2 @@
install-flashinfer: install-flashinfer:
pip install flashinfer==0.1.5 -i https://flashinfer.ai/whl/cu124/torch2.4 pip install flashinfer==0.1.6 -i https://flashinfer.ai/whl/cu124/torch2.4

View File

@ -267,7 +267,7 @@ def test_batch_concatenate(
assert next_batch.max_input_length == 3 assert next_batch.max_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0] assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == next_batch_1.requests assert next_batch.requests[1:] == list(next_batch_1.requests)
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0] assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers

View File

@ -262,7 +262,7 @@ def test_batch_concatenate(
assert next_batch.max_input_length == 3 assert next_batch.max_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0] assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == next_batch_1.requests assert next_batch.requests[1:] == list(next_batch_1.requests)
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0] assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers

View File

@ -281,7 +281,7 @@ def test_batch_concatenate(
assert next_batch.max_decoder_input_length == 3 assert next_batch.max_decoder_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0] assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == next_batch_1.requests assert next_batch.requests[1:] == list(next_batch_1.requests)
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0] assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers

View File

@ -22,9 +22,9 @@ def attention(
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
ipex.llm.functional.varlen_attention( ipex.llm.functional.varlen_attention(
q, q.contiguous() if q.device.type == "xpu" else q,
key_cache, key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
value_cache, value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
out, out,
seqlen.cu_seqlen_q, seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q, seqlen.cu_seqlen_q,

View File

@ -82,7 +82,7 @@ def init_cpu_threads_env(rank_id: int, world_size: int):
import numa import numa
import psutil import psutil
nodes = numa.get_max_node() + 1 nodes = numa.info.get_max_node() + 1
rank_per_node = math.ceil(world_size / nodes) rank_per_node = math.ceil(world_size / nodes)
num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes) num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes)
node_id = int(rank_id / rank_per_node) node_id = int(rank_id / rank_per_node)
@ -91,18 +91,22 @@ def init_cpu_threads_env(rank_id: int, world_size: int):
num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1) num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1)
else: else:
num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS")) num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS"))
if len(numa.get_membind()) == nodes: if len(numa.memory.get_membind_nodes()) == nodes:
numa.set_membind([node_id]) numa.memory.set_membind_nodes((node_id))
torch.set_num_threads(num_cpus_per_rank) torch.set_num_threads(num_cpus_per_rank)
if len(numa.get_affinity(0)) == psutil.cpu_count(logical=True): if len(numa.schedule.get_affinitive_cpus(0)) == psutil.cpu_count(logical=True):
cpu_start = num_cpus_per_rank * rank_offset_per_node cpu_start = num_cpus_per_rank * rank_offset_per_node
numa.set_affinity( numa.schedule.run_on_cpus(
0, 0,
list(numa.node_to_cpus(node_id))[ *(
cpu_start : cpu_start + num_cpus_per_rank numa.info.node_to_cpus(node_id)[
], cpu_start : cpu_start + num_cpus_per_rank
]
),
) )
logger.info(f"affinity={numa.get_affinity(0)}, membind = {numa.get_membind()}") logger.info(
f"affinity={numa.schedule.get_affinitive_cpus(0)}, membind = {numa.memory.get_membind_nodes()}"
)
@dataclass @dataclass
@ -272,6 +276,8 @@ class FlashCausalLMBatch(Batch):
assert prefix_len > 0 assert prefix_len > 0
prefix_len -= 1 prefix_len -= 1
# Commented as it's costly.
# log_master(logger.debug, "Tokenized input ids {tokenized_input}")
prefix_ids.append(tokenized_input[:prefix_len]) prefix_ids.append(tokenized_input[:prefix_len])
tokenized_input = tokenized_input[prefix_len:] tokenized_input = tokenized_input[prefix_len:]
@ -515,6 +521,7 @@ class FlashCausalLMBatch(Batch):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
assert len(pb.requests) > 0
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer) batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
@ -640,6 +647,7 @@ class FlashCausalLMBatch(Batch):
adapter_segments = torch.tensor( adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device adapter_segments, dtype=torch.int32, device=device
) )
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
return type(self)( return type(self)(
batch_id=self.batch_id, batch_id=self.batch_id,
@ -834,6 +842,8 @@ class FlashCausalLMBatch(Batch):
start_slots = torch.concat(start_slots) start_slots = torch.concat(start_slots)
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, next_token_chooser_parameters,
dtype=batches[0].next_token_chooser.dtype, dtype=batches[0].next_token_chooser.dtype,
@ -1152,27 +1162,6 @@ class FlashCausalLM(Model):
input_lengths=input_lengths, input_lengths=input_lengths,
prefix_lens=prefix_lengths, prefix_lens=prefix_lengths,
) )
self.cuda_graphs[bs] = {
"input_ids": input_ids,
"position_ids": position_ids,
"kv_cache": self.kv_cache,
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths_tensor,
"prefix_lengths": prefix_lengths_tensor,
}
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import ( from text_generation_server.layers.attention.flashinfer import (
create_decode_state_cuda_graphs, create_decode_state_cuda_graphs,
) )
@ -1189,21 +1178,38 @@ class FlashCausalLM(Model):
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
) )
self.cuda_graphs[bs]["state"] = state
else: else:
state = None state = None
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs] = {
"input_ids": input_ids,
"position_ids": position_ids,
"kv_cache": self.kv_cache,
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths_tensor,
"prefix_lengths": prefix_lengths_tensor,
"state": state,
"graph": graph,
}
torch.cuda.synchronize() torch.cuda.synchronize()
# Run once outside to warmup # Run once outside to warmup
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor, input_lengths_tensor=input_lengths_tensor,
state=state, state=state,
prefix_lens=prefix_lengths,
prefix_lens_tensor=prefix_lengths_tensor, prefix_lens_tensor=prefix_lengths_tensor,
): ):
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
self.model.forward( self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -1216,6 +1222,7 @@ class FlashCausalLM(Model):
prefill_cache_indices=None, prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
) )
del seqlen
torch.cuda.synchronize() torch.cuda.synchronize()
@ -1481,9 +1488,7 @@ class FlashCausalLM(Model):
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths=batch.input_lengths, input_lengths_tensor=input_lengths,
input_lengths_tensor=input_lengths + prefix_lens_tensor,
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor, prefix_lens_tensor=prefix_lens_tensor,
): ):
max_k = (input_lengths + prefix_lens_tensor).max().item() max_k = (input_lengths + prefix_lens_tensor).max().item()
@ -1521,26 +1526,28 @@ class FlashCausalLM(Model):
input_lengths=batch.input_lengths, input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens, prefix_lens=batch.prefix_lens,
) )
# assert block_tables.shape[0] >= slots.shape[0]
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else: else:
cuda_graph["block_tables"][ cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1] : block_tables.shape[0], : block_tables.shape[1]
] = block_tables ] = block_tables
cuda_graph["slots"].fill_(-1)
# XXX: This is working only because block 0 is reserved for the healthcheck
# so it doesn't matter if we override it with bogus values.
cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
input_lengths + prefix_lens_tensor cuda_graph["prefix_lengths"].zero_()
) cuda_graph["prefix_lengths"][: prefix_lens_tensor.shape[0]] = prefix_lens_tensor
with self._forward_context( with self._forward_context(
block_tables=cuda_graph["block_tables"], block_tables=cuda_graph["block_tables"],
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
input_lengths=batch.input_lengths,
input_lengths_tensor=cuda_graph["input_lengths"], input_lengths_tensor=cuda_graph["input_lengths"],
prefix_lens=batch.prefix_lens, prefix_lens_tensor=cuda_graph["prefix_lengths"],
prefix_lens_tensor=prefix_lens_tensor, state=cuda_graph["state"],
state=cuda_graph.get("state"),
): ):
# Replay the graph # Replay the graph
cuda_graph["graph"].replay() cuda_graph["graph"].replay()
@ -1769,7 +1776,7 @@ class FlashCausalLM(Model):
left = 0 left = 0
if n_accepted_ids > 1: if n_accepted_ids > 1:
log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}") log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")
current_stopped = False current_stopped = False
for j in range(index, index + n_accepted_ids): for j in range(index, index + n_accepted_ids):
@ -1924,9 +1931,7 @@ class FlashCausalLM(Model):
*, *,
block_tables: torch.Tensor, block_tables: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
input_lengths: List[int],
input_lengths_tensor: torch.Tensor, input_lengths_tensor: torch.Tensor,
prefix_lens: List[int],
prefix_lens_tensor: torch.Tensor, prefix_lens_tensor: torch.Tensor,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> ContextManager: ) -> ContextManager:
@ -1952,7 +1957,7 @@ class FlashCausalLM(Model):
# ), # ),
block_tables=block_tables, block_tables=block_tables,
cu_seqlens=cu_seqlen_prefill, cu_seqlens=cu_seqlen_prefill,
input_lengths=input_lengths_tensor, input_lengths=input_lengths_tensor + prefix_lens_tensor,
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
head_size=self.head_size, head_size=self.head_size,
@ -1962,7 +1967,7 @@ class FlashCausalLM(Model):
assert input_lengths_tensor is not None assert input_lengths_tensor is not None
return use_decode_state( return use_decode_state(
state=state if state is not None else self.decode_state, state=state if state is not None else self.decode_state,
input_lengths=input_lengths_tensor, input_lengths=input_lengths_tensor + prefix_lens_tensor,
block_tables=block_tables, block_tables=block_tables,
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,

View File

@ -367,9 +367,7 @@ class VlmCausalLM(FlashCausalLM):
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths=batch.input_lengths,
input_lengths_tensor=input_lengths, input_lengths_tensor=input_lengths,
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor, prefix_lens_tensor=prefix_lens_tensor,
): ):
max_k = (input_lengths + prefix_lens_tensor).max().item() max_k = (input_lengths + prefix_lens_tensor).max().item()

View File

@ -77,12 +77,12 @@ def load_and_merge_adapters(
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
if len(adapter_parameters.adapter_info) == 1: if len(adapter_parameters.adapter_info) == 1:
adapter_info = next(iter(adapter_parameters.adapter_info)) adapter = next(iter(adapter_parameters.adapter_info))
return load_module_map( return load_module_map(
model_id, model_id,
adapter_info.revision, adapter.revision,
adapter_info.id, adapter.id,
adapter_info.path, adapter.path,
weight_names, weight_names,
trust_remote_code, trust_remote_code,
) )
@ -90,7 +90,6 @@ def load_and_merge_adapters(
adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index) adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)
return _load_and_merge( return _load_and_merge(
model_id, model_id,
adapter_params.revision,
adapter_params, adapter_params,
weight_names, weight_names,
trust_remote_code, trust_remote_code,
@ -109,7 +108,6 @@ class AdapterParametersContainer:
@lru_cache(maxsize=32) @lru_cache(maxsize=32)
def _load_and_merge( def _load_and_merge(
model_id: str, model_id: str,
revision: str,
adapter_params: AdapterParametersContainer, adapter_params: AdapterParametersContainer,
weight_names: Tuple[str], weight_names: Tuple[str],
trust_remote_code: bool = False, trust_remote_code: bool = False,
@ -126,6 +124,7 @@ def _load_and_merge(
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = ( module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
load_module_map( load_module_map(
model_id, model_id,
adapter.revision,
adapter.id, adapter.id,
adapter.path, adapter.path,
weight_names, weight_names,