Merge branch 'main' into hot_fix_xpu

This commit is contained in:
Wang, Yi A 2024-08-06 21:57:25 -07:00
commit 30e70f2ceb
109 changed files with 5433 additions and 1838 deletions

View File

View File

View File

@ -2,3 +2,5 @@ aml
target target
server/transformers server/transformers
server/flash-attention server/flash-attention
cmake-build-debug/
cmake-build-release/

View File

@ -28,7 +28,7 @@ jobs:
- name: Install router - name: Install router
id: install-router id: install-router
run: cargo install --path router/ run: cargo install --path backends/v3/
- uses: actions/setup-node@v4 - uses: actions/setup-node@v4
with: with:
@ -41,5 +41,5 @@ jobs:
- name: Check that documentation is up-to-date - name: Check that documentation is up-to-date
run: | run: |
npm install -g swagger-cli npm install -g @redocly/cli
python update_doc.py --check python update_doc.py --check

View File

@ -10,6 +10,7 @@ on:
paths: paths:
- ".github/workflows/build.yaml" - ".github/workflows/build.yaml"
- "integration-tests/**" - "integration-tests/**"
- "backends/**"
- "server/**" - "server/**"
- "proto/**" - "proto/**"
- "router/**" - "router/**"

4
.gitignore vendored
View File

@ -3,6 +3,10 @@ target
router/tokenizer.json router/tokenizer.json
*__pycache__* *__pycache__*
backends/v3/src/client/pb
backends/client/src/v2/pb
backends/client/src/v3/pb
# ROCm auto-generated files # ROCm auto-generated files
*.hip *.hip
server/exllamav2_kernels/exllamav2_kernels/hip/ server/exllamav2_kernels/exllamav2_kernels/hip/

View File

@ -13,8 +13,8 @@ repos:
- repo: https://github.com/doublify/pre-commit-rust - repo: https://github.com/doublify/pre-commit-rust
rev: v1.0 rev: v1.0
hooks: hooks:
- id: fmt
- id: cargo-check - id: cargo-check
- id: fmt
- id: clippy - id: clippy
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0 rev: v0.3.0

79
.redocly.lint-ignore.yaml Normal file
View File

@ -0,0 +1,79 @@
# This file instructs Redocly's linter to ignore the rules contained for specific parts of your API.
# See https://redoc.ly/docs/cli/ for more information.
docs/openapi.json:
no-empty-servers:
- '#/openapi'
spec:
- >-
#/components/schemas/GenerateParameters/properties/best_of/exclusiveMinimum
- >-
#/components/schemas/GenerateParameters/properties/frequency_penalty/exclusiveMinimum
- '#/components/schemas/GenerateParameters/properties/grammar/nullable'
- >-
#/components/schemas/GenerateParameters/properties/repetition_penalty/exclusiveMinimum
- '#/components/schemas/GenerateParameters/properties/seed/exclusiveMinimum'
- >-
#/components/schemas/GenerateParameters/properties/temperature/exclusiveMinimum
- '#/components/schemas/GenerateParameters/properties/top_k/exclusiveMinimum'
- >-
#/components/schemas/GenerateParameters/properties/top_n_tokens/exclusiveMinimum
- '#/components/schemas/GenerateParameters/properties/top_p/exclusiveMinimum'
- >-
#/components/schemas/GenerateParameters/properties/typical_p/exclusiveMinimum
- '#/components/schemas/GenerateResponse/properties/details/nullable'
- '#/components/schemas/StreamResponse/properties/details/nullable'
- '#/components/schemas/ChatRequest/properties/response_format/nullable'
- '#/components/schemas/ChatRequest/properties/tool_choice/nullable'
- '#/components/schemas/ToolChoice/nullable'
- '#/components/schemas/ChatCompletionComplete/properties/logprobs/nullable'
- '#/components/schemas/ChatCompletionChoice/properties/logprobs/nullable'
no-invalid-media-type-examples:
- '#/paths/~1/post/responses/422/content/application~1json/example'
- '#/paths/~1/post/responses/424/content/application~1json/example'
- '#/paths/~1/post/responses/429/content/application~1json/example'
- '#/paths/~1/post/responses/500/content/application~1json/example'
- '#/paths/~1generate/post/responses/422/content/application~1json/example'
- '#/paths/~1generate/post/responses/424/content/application~1json/example'
- '#/paths/~1generate/post/responses/429/content/application~1json/example'
- '#/paths/~1generate/post/responses/500/content/application~1json/example'
- >-
#/paths/~1generate_stream/post/responses/422/content/text~1event-stream/example
- >-
#/paths/~1generate_stream/post/responses/424/content/text~1event-stream/example
- >-
#/paths/~1generate_stream/post/responses/429/content/text~1event-stream/example
- >-
#/paths/~1generate_stream/post/responses/500/content/text~1event-stream/example
- '#/paths/~1tokenize/post/responses/404/content/application~1json/example'
- >-
#/paths/~1v1~1chat~1completions/post/responses/422/content/application~1json/example
- >-
#/paths/~1v1~1chat~1completions/post/responses/424/content/application~1json/example
- >-
#/paths/~1v1~1chat~1completions/post/responses/429/content/application~1json/example
- >-
#/paths/~1v1~1chat~1completions/post/responses/500/content/application~1json/example
- >-
#/paths/~1v1~1completions/post/responses/422/content/application~1json/example
- >-
#/paths/~1v1~1completions/post/responses/424/content/application~1json/example
- >-
#/paths/~1v1~1completions/post/responses/429/content/application~1json/example
- >-
#/paths/~1v1~1completions/post/responses/500/content/application~1json/example
operation-4xx-response:
- '#/paths/~1health/get/responses'
- '#/paths/~1info/get/responses'
- '#/paths/~1metrics/get/responses'
no-unused-components:
- '#/components/schemas/Completion'
security-defined:
- '#/paths/~1/post'
- '#/paths/~1generate/post'
- '#/paths/~1generate_stream/post'
- '#/paths/~1health/get'
- '#/paths/~1info/get'
- '#/paths/~1metrics/get'
- '#/paths/~1tokenize/post'
- '#/paths/~1v1~1chat~1completions/post'
- '#/paths/~1v1~1completions/post'

734
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,10 +1,19 @@
[workspace] [workspace]
members = [ members = [
"benchmark", "benchmark",
"router", "backends/v3",
"router/client", "backends/grpc-metadata",
"router/grpc-metadata", "backends/trtllm",
"launcher" "backends/client",
"launcher"
]
default-members = [
"benchmark",
"backends/v3",
"backends/grpc-metadata",
# "backends/trtllm",
"backends/client",
"launcher"
] ]
resolver = "2" resolver = "2"
@ -18,6 +27,8 @@ homepage = "https://github.com/huggingface/text-generation-inference"
base64 = "0.22.0" base64 = "0.22.0"
tokenizers = { version = "0.19.1", features = ["http"] } tokenizers = { version = "0.19.1", features = ["http"] }
hf-hub = { version = "0.3.1", features = ["tokio"] } hf-hub = { version = "0.3.1", features = ["tokio"] }
metrics = { version = "0.23.0" }
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
[profile.release] [profile.release]
incremental = true incremental = true

View File

@ -11,6 +11,7 @@ COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto COPY proto proto
COPY benchmark benchmark COPY benchmark benchmark
COPY router router COPY router router
COPY backends backends
COPY launcher launcher COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json RUN cargo chef prepare --recipe-path recipe.json
@ -33,6 +34,7 @@ COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto COPY proto proto
COPY benchmark benchmark COPY benchmark benchmark
COPY router router COPY router router
COPY backends backends
COPY launcher launcher COPY launcher launcher
RUN cargo build --profile release-opt RUN cargo build --profile release-opt

23
Dockerfile.trtllm Normal file
View File

@ -0,0 +1,23 @@
# All the tooling for CUDA
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 AS cuda-builder
WORKDIR /usr/src/tgi/backends/trtllm
RUN apt update && apt install -y cmake git git-lfs gcc g++ ninja-build libopenmpi-dev python3-dev python3-pip wget
COPY . /usr/src/tgi
RUN chmod +x scripts/install_tensorrt.sh && scripts/install_tensorrt.sh
RUN cmake -G Ninja -B build -DTRT_LIB_DIR=/usr/local/tensorrt/lib -DTRT_INCLUDE_DIR=/usr/local/tensorrt/include .
RUN cmake --build build --parallel -t tgi_trtllm_backend_impl
# All the tooling for Rust
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
WORKDIR /usr/src
# Include CUDA related libraries and tools to the Rust based image
COPY --from=cuda-builder /usr/local/cuda /usr/local/cuda
COPY --from=cuda-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=cuda-builder /usr/src/tgi/backends/trtllm/build /usr/local/tgi/trtllm/build
ENV PATH=/usr/local/cuda/bin:$PATH
ENV LD_LIBRARY_PATH=/usr/local/tensorrt/lib:$LD_LIBRARY_PATH
RUN apt update && apt install -y cmake git gcc g++ ninja-build libopenmpi3

View File

@ -11,6 +11,7 @@ COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto COPY proto proto
COPY benchmark benchmark COPY benchmark benchmark
COPY router router COPY router router
COPY backends backends
COPY launcher launcher COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json RUN cargo chef prepare --recipe-path recipe.json
@ -33,6 +34,7 @@ COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto COPY proto proto
COPY benchmark benchmark COPY benchmark benchmark
COPY router router COPY router router
COPY backends backends
COPY launcher launcher COPY launcher launcher
RUN cargo build --profile release-opt RUN cargo build --profile release-opt

View File

@ -12,6 +12,7 @@ COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto COPY proto proto
COPY benchmark benchmark COPY benchmark benchmark
COPY router router COPY router router
COPY backends backends
COPY launcher launcher COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json RUN cargo chef prepare --recipe-path recipe.json
@ -34,6 +35,7 @@ COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto COPY proto proto
COPY benchmark benchmark COPY benchmark benchmark
COPY router router COPY router router
COPY backends backends
COPY launcher launcher COPY launcher launcher
RUN cargo build --profile release-opt RUN cargo build --profile release-opt

View File

@ -5,13 +5,13 @@ install-server-cpu:
cd server && make install-server cd server && make install-server
install-router: install-router:
cd router && cargo install --path . cargo install --path backends/v3/
install-launcher: install-launcher:
cd launcher && cargo install --path . cargo install --path launcher/
install-benchmark: install-benchmark:
cd benchmark && cargo install --path . cargo install --path benchmark/
install: install-server install-router install-launcher install: install-server install-router install-launcher

View File

@ -0,0 +1,63 @@
cmake_minimum_required(VERSION 3.20)
project(tgi-trtllm-backend VERSION 1.0.0)
set(CMAKE_CXX_STANDARD 20)
include(FetchContent)
include(ExternalProject)
option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located")
set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located")
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
find_package(CUDAToolkit 12.5 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
#### External dependencies ####
include(cmake/fmt.cmake)
include(cmake/json.cmake)
include(cmake/spdlog.cmake)
include(cmake/trtllm.cmake)
# Let's build TRTLLM as part of CMake
add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..")
# Tell CMake to need try to override the RPATH for executorWorker as it has not information on how to do so
set_target_properties(executorWorker PROPERTIES SKIP_BUILD_RPATH TRUE)
# TGI TRTLLM Backend definition
add_library(tgi_trtllm_backend_impl STATIC include/backend.h lib/backend.cpp include/hardware.h)
include_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
target_include_directories(tgi_trtllm_backend_impl PRIVATE
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<INSTALL_INTERFACE:include>
)
target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper CUDA::cudart CUDA::nvml)
target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt)
# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back
install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker)
install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB)
#### Unit Tests ####
if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})
message(STATUS "Building tests")
FetchContent_Declare(
Catch2
GIT_REPOSITORY https://github.com/catchorg/Catch2
GIT_TAG v3.6.0
)
FetchContent_MakeAvailable(Catch2)
# add_executable(tgi_trtllm_backend_tests tests/infer_test.cpp)
# target_link_libraries(tgi_trtllm_backend_tests PRIVATE tgi_trtllm_backend_impl Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt CUDA::cudart CUDA::nvml)
list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
include(CTest)
include(Catch)
# catch_discover_tests(tgi_trtllm_backend_tests)
endif ()

View File

@ -0,0 +1,26 @@
[package]
name = "text-generation-backends-trtllm"
version.workspace = true
edition.workspace = true
authors.workspace = true
homepage.workspace = true
[dependencies]
async-trait = "0.1"
async-stream = "0.3"
cxx = "1.0"
text-generation-router = { path = "../../router" }
tokenizers = { version = "0.19", features = ["hf-hub"] }
tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.15"
clap = { version = "4.5", features = ["derive"] }
thiserror = "1.0.62"
tracing = "0.1"
tracing-opentelemetry = "0.24"
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
log = { version = "0.4", features = [] }
[build-dependencies]
cmake = "0.1"
cxx-build = { version = "1.0", features = ["parallel"] }
pkg-config = "0.3"

100
backends/trtllm/Dockerfile Normal file
View File

@ -0,0 +1,100 @@
ARG CUDA_ARCH_LIST="75-real;80-real;86-real;89-real;90-real"
ARG OMPI_VERSION="4.1.6"
# Build dependencies resolver stage
FROM lukemathwalker/cargo-chef:latest AS chef
WORKDIR /usr/src/text-generation-inference
FROM chef AS planner
COPY . .
RUN cargo chef prepare --recipe-path recipe.json
# CUDA dependent dependencies resolver stage
FROM nvidia/cuda:12.5.1-cudnn-devel-ubuntu22.04 AS cuda-builder
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt update && apt install -y \
build-essential \
cmake \
curl \
gcc \
g++ \
git \
git-lfs \
libssl-dev \
ninja-build \
pkg-config \
python3 \
python3-setuptools \
tar \
wget
ENV TGI_INSTALL_PREFIX=/usr/local/tgi
ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt
# Install OpenMPI
FROM cuda-builder AS mpi-builder
ARG OMPI_VERSION
ENV OMPI_TARBALL_FILENAME="openmpi-$OMPI_VERSION.tar.bz2"
RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILENAME" -P /opt/src && \
mkdir /usr/src/mpi && \
tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \
cd /usr/src/mpi && \
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --without-slurm && \
make -j all && \
make install && \
rm -rf "/opt/src/$OMPI_TARBALL_FILENAME"
# Install TensorRT
FROM cuda-builder AS trt-builder
COPY backends/trtllm/scripts/install_tensorrt.sh /opt/install_tensorrt.sh
RUN chmod +x /opt/install_tensorrt.sh && \
/opt/install_tensorrt.sh
# Build Backend
FROM cuda-builder AS tgi-builder
WORKDIR /usr/src/text-generation-inference
# Install Rust
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \
chmod -R a+w /root/.rustup && \
chmod -R a+w /root/.cargo
ENV PATH="/root/.cargo/bin:$PATH"
RUN cargo install cargo-chef
# Cache dependencies
COPY --from=planner /usr/src/text-generation-inference/recipe.json .
RUN cargo chef cook --release --recipe-path recipe.json
# Build actual TGI
ARG CUDA_ARCH_LIST
ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt:$CMAKE_PREFIX_PATH"
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig:$PKG_CONFIG_PATH"
COPY . .
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \
CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release --bin text-generation-backends-trtllm
FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu22.04 AS runtime
WORKDIR /usr/local/tgi/bin
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
COPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
FROM runtime
LABEL co.huggingface.vendor="Hugging Face Inc."
LABEL org.opencontainers.image.authors="hardware@hf.co"
ENTRYPOINT ["./text-generation-launcher"]
CMD ["--executor-worker", "/usr/local/tgi/bin/executorWorker"]

46
backends/trtllm/README.md Normal file
View File

@ -0,0 +1,46 @@
# Text Generation Inference - TensorRT-LLM Backend Implementation
## Description
This folder provides the sources of the TensorRT-LLM backend implementation powered by TensorRT-LLM Executor new API
## Simplified Request Sequence
```mermaid
sequenceDiagram
actor User
participant TextGenerationInference.HttpServer
participant TextGenerationInference.TensorRtLlmBackend
participant TextGenerationInference.TensorRtLlmWorkerThread
participant TensorRtLlm.Executor
participant Nvidia.Gpu
User ->> TextGenerationInference.HttpServer: POST /generate
TextGenerationInference.HttpServer ->> TextGenerationInference.TensorRtLlmBackend: Validate and forward inputs & parameters
TextGenerationInference.TensorRtLlmBackend ->> TextGenerationInference.TensorRtLlmWorkerThread: Allocate a new context and spawn a new thread to handle the request
TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Submit the request to the In-Flight Batcher
activate Nvidia.Gpu
TensorRtLlm.Executor ->> Nvidia.Gpu: Add the request to the poll for execution
TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Response with an unique request identifier
rect rgb(10, 92, 54)
loop every 100us
rect rgb(15, 81, 50)
alt Acquire lock to query executor
TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Poll request number of new token(s) generated
else There are new generated tokens
TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Retrieve newly generated tokens
TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Return decoded token information and potential error (omitted)
rect rgb(11, 110, 79)
alt Generated token is final
TensorRtLlm.Executor ->> Nvidia.Gpu: Remove request from the scheduler and from the GPU
TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream the remaining decoded tokens and flush the connection
else Generated token is not final
TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream token back to the user as they get decoded
end
end
end
end
deactivate Nvidia.Gpu
end
end
```

150
backends/trtllm/build.rs Normal file
View File

@ -0,0 +1,150 @@
use cxx_build::CFG;
use pkg_config;
use std::env;
use std::env::consts::ARCH;
use std::path::{absolute, PathBuf};
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST");
const CUDA_REQUIRED_VERSION: &str = "12.5";
const MPI_REQUIRED_VERSION: &str = "4.1";
const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX");
const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR");
const NCCL_ROOT_DIR: Option<&str> = option_env!("NCCL_ROOT_DIR");
// Dependencies
const BACKEND_DEPS: [&str; 2] = ["tgi_trtllm_backend_impl", "tgi_trtllm_backend"];
const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"];
const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 5] = [
("dylib", "tensorrt_llm"),
("static", "tensorrt_llm_executor_static"),
("dylib", "tensorrt_llm_nvrtc_wrapper"),
("dylib", "nvinfer_plugin_tensorrt_llm"),
("dylib", "decoder_attention"),
];
macro_rules! probe {
($name: expr, $version: expr) => {
if let Err(_) = pkg_config::probe_library($name) {
pkg_config::probe_library(&format!("{}-{}", $name, $version))
.expect(&format!("Failed to locate {}", $name));
}
};
}
fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) {
// Build the backend implementation through CMake
let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi");
let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt");
let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("90-real"); // Hopper by default
let mut install_path = PathBuf::from(install_path);
if !install_path.is_absolute() {
install_path = absolute(out_dir).expect("cannot happen").join(install_path);
}
let _ = cmake::Config::new(".")
.uses_cxx11()
.generator("Ninja")
.profile(match is_debug {
true => "Debug",
false => "Release",
})
.env("OPT_LEVEL", opt_level)
.define("CMAKE_INSTALL_PREFIX", &install_path)
.define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc")
.define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list)
.define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path)
.build();
// Additional transitive CMake dependencies
let deps_folder = out_dir.join("build").join("_deps");
for dependency in ADDITIONAL_BACKEND_LINK_LIBRARIES {
let dep_name = match is_debug {
true => format!("{}d", dependency),
false => String::from(dependency),
};
let dep_path = deps_folder.join(format!("{}-build", dependency));
println!("cargo:rustc-link-search={}", dep_path.display());
println!("cargo:rustc-link-lib=static={}", dep_name);
}
// Emit linkage information from the artifacts we just built
let install_lib_path = install_path.join("lib");
println!(
r"cargo:warning=Adding link search path: {}",
install_lib_path.display()
);
println!(r"cargo:rustc-link-search={}", install_lib_path.display());
(PathBuf::from(install_path), deps_folder)
}
fn build_ffi_layer(deps_folder: &PathBuf) {
CFG.include_prefix = "backends/trtllm";
cxx_build::bridge("src/lib.rs")
.static_flag(true)
.include(deps_folder.join("fmt-src").join("include"))
.include(deps_folder.join("spdlog-src").join("include"))
.include(deps_folder.join("json-src").join("include"))
.include(deps_folder.join("trtllm-src").join("cpp").join("include"))
.include("/usr/local/cuda/include")
.include("/usr/local/tensorrt/include")
.file("src/ffi.cpp")
.std("c++20")
.compile("tgi_trtllm_backend");
println!("cargo:rerun-if-changed=CMakeLists.txt");
println!("cargo:rerun-if-changed=include/backend.h");
println!("cargo:rerun-if-changed=lib/backend.cpp");
println!("cargo:rerun-if-changed=include/ffi.h");
println!("cargo:rerun-if-changed=src/ffi.cpp");
}
fn main() {
// Misc variables
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let build_profile = env::var("PROFILE").unwrap();
let (is_debug, opt_level) = match build_profile.as_ref() {
"debug" => (true, "0"),
_ => (false, "3"),
};
// Build the backend
let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir);
// Build the FFI layer calling the backend above
build_ffi_layer(&deps_folder);
// Emit linkage search path
probe!("ompi", MPI_REQUIRED_VERSION);
// Probe CUDA & co. with pkg-config
CUDA_TRANSITIVE_DEPS.iter().for_each(|name| {
probe!(name, CUDA_REQUIRED_VERSION);
});
// NCCL is slightly trickier because it might not have a pkgconfig installed
let nccl_library_path_default = format!("/usr/local/{}-linux-gnu", ARCH);
let nccl_library_path = NCCL_ROOT_DIR.unwrap_or(&nccl_library_path_default);
println!(r"cargo:rustc-link-search=native={}", nccl_library_path);
println!("cargo:rustc-link-lib=dylib=nccl");
// TensorRT
let tensort_library_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt/lib");
println!(r"cargo:rustc-link-search=native={}", tensort_library_path);
println!("cargo:rustc-link-lib=dylib=nvinfer");
// TensorRT-LLM
TENSORRT_LLM_TRANSITIVE_DEPS
.iter()
.for_each(|(link_type, name)| {
println!("cargo:rustc-link-lib={}={}", link_type, name);
});
// Backend
BACKEND_DEPS.iter().for_each(|name| {
println!("cargo:rustc-link-lib=static={}", name);
});
}

View File

@ -0,0 +1,6 @@
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt
GIT_TAG 11.0.1
)
FetchContent_MakeAvailable(fmt)

View File

@ -0,0 +1,5 @@
fetchcontent_declare(
json
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz
)
fetchcontent_makeavailable(json)

View File

@ -0,0 +1,17 @@
set(SPDLOG_USE_FMT ON)
set(SPDLOG_BUILD_SHARED OFF)
set(SPDLOG_FMT_EXTERNAL ON)
# Define the level at which SPDLOG_ compilation level is defined
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG)
else ()
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO)
endif ()
fetchcontent_declare(
spdlog
GIT_REPOSITORY https://github.com/gabime/spdlog.git
GIT_TAG v1.14.1
)
fetchcontent_makeavailable(spdlog)

View File

@ -0,0 +1,42 @@
set(TRT_INCLUDE_DIR ${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
set(TRT_LIB_DIR ${TGI_TRTLLM_BACKEND_TRT_LIB_DIR})
set(USE_CXX11_ABI ON)
set(BUILD_PYT OFF)
set(BUILD_PYBIND OFF)
set(BUILD_MICRO_BENCHMARKS OFF)
set(BUILD_BENCHMARKS OFF)
set(BUILD_TESTS OFF)
set(CMAKE_CUDA_ARCHITECTURES ${TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST})
message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}")
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
set(FAST_BUILD ON)
set(NVTX_DISABLE OFF)
else ()
set(FAST_BUILD OFF)
set(FAST_MATH ON)
set(NVTX_DISABLE ON)
endif ()
fetchcontent_declare(
trtllm
GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git
GIT_TAG a681853d3803ee5893307e812530b5e7004bb6e1
GIT_SHALLOW FALSE
)
fetchcontent_makeavailable(trtllm)
message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}")
execute_process(COMMAND git lfs install WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/")
execute_process(COMMAND git lfs pull WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/")
# TRTLLM use a JIT based *precompiled* library to generate some specific kernels, we are generating the path to this one here
set(TRTLLM_NVRTC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_nvrtc_wrapper${CMAKE_SHARED_LIBRARY_SUFFIX}" CACHE INTERNAL "nvrtc wrapper library name")
set(TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_NVRTC_LIBRARY_NAME}"
CACHE INTERNAL "nvrtc wrapper library path")
# The same Executor Static library
set(TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_executor_static${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE INTERNAL "executor_static library name")
set(TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/executor/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME}" CACHE INTERNAL "executor_static library path")

View File

@ -0,0 +1,121 @@
//
// Created by Morgan Funtowicz on 6/30/24.
//
#ifndef TGI_TRTLLM_BACKEND_H
#define TGI_TRTLLM_BACKEND_H
#include <cmath>
#include <filesystem>
#include <span>
#include <vector>
#include <nlohmann/json.hpp>
#include <tensorrt_llm/runtime/common.h>
#include <tensorrt_llm/executor/executor.h>
#include <tensorrt_llm/plugins/api/tllmPlugin.h>
using json = nlohmann::json;
namespace tle = tensorrt_llm::executor;
namespace huggingface::tgi::backends {
using RequestId = tle::IdType;
using TokenId = tle::TokenIdType;
/**
* Initialize all the components required by TRTLLM.
* It is required to call this function before attempting to load any engine
*/
void InitializeBackend();
/**
*
* @param config TensorRT-LLM configuration object
* @param workerPath Path to the "executorWorker" provided by TensorRT-LLM when using orchestrator mode
* @return
*/
tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath);
/**
* Get the sampling configuration from the parameters provided by TGI
* @param topK
* @param topP
* @param temperature
* @param repetition_penalty
* @param frequency_penalty
* @param seed
* @return
*/
tle::SamplingConfig GetSamplingConfig(
uint32_t topK,
float_t topP,
float_t temperature,
float_t repetition_penalty,
float_t frequency_penalty,
uint64_t seed
);
/**
*
*/
class TensorRtLlmBackend {
private:
const json config;
tle::Executor executor;
public:
explicit TensorRtLlmBackend(
const std::filesystem::path &engineFolder,
const std::filesystem::path &executorWorker
);
/**
* Indicate if the backend is ready to accept incoming request
* @return true if ready, false otherwise
*/
[[nodiscard]] bool IsReady() const;
/**
* Query the executor for the number of token available for pulling
* @return
*/
[[nodiscard]] size_t NumResponsesReady() const;
/**
* Submit a new generation task to the executor
* @param tokens
* @param topK
* @param topP
* @param temperature
* @param repetition_penalty
* @param frequency_penalty
* @param seed
* @return Request id related to this generation for reference
*/
[[nodiscard]] RequestId Submit(
const std::vector<TokenId> &tokens,
int32_t topK,
float_t topP,
float_t temperature,
float_t repetition_penalty,
float_t frequency_penalty,
uint64_t seed
);
/**
*
* @param requestId The request id to poll the generation results
* @return
*/
std::vector<tle::Response> Poll(RequestId requestId);
/**
* Stop the underlying executor
*/
void Shutdown();
};
}
#endif //TGI_TRTLLM_BACKEND_H

View File

@ -0,0 +1,75 @@
//
// Created by mfuntowicz on 7/11/24.
//
#ifndef TGI_TRTLLM_BACKEND_FFI_H
#define TGI_TRTLLM_BACKEND_FFI_H
#include <cstddef>
#include "backend.h"
namespace huggingface::tgi::backends {
class TensorRtLlmBackendImpl;
}
#include "backends/trtllm/src/lib.rs.h"
namespace huggingface::tgi::backends {
// struct GenerationContext;
class TensorRtLlmBackendImpl : public TensorRtLlmBackend {
public:
/***
*
* @param engineFolder
* @param executorWorker
*/
TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker);
/***
*
* @return
*/
bool IsReady() const;
/***
*
* @param tokens
* @param topK
* @param topP
* @param temperature
* @param repetition_penalty
* @param frequency_penalty
* @param seed
* @return
*/
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
uint64_t
Submit(rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature,
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed);
/***
*
* @param requestId
* @param ctx
* @param callback
* @return
*/
size_t StreamTokens(
const RequestId requestId,
huggingface::tgi::backends::GenerationContext *ctx,
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
huggingface::tgi::backends::GenerationStep)> callback);
};
/***
*
* @param engineFolder
* @return
*/
std::unique_ptr<TensorRtLlmBackendImpl> CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker);
}
#endif //TGI_TRTLLM_BACKEND_FFI_H

View File

@ -0,0 +1,59 @@
//
// Created by mfuntowicz on 7/23/24.
//
#ifndef TGI_TRTLLM_BACKEND_HARDWARE_H
#define TGI_TRTLLM_BACKEND_HARDWARE_H
#include <cstdint>
#include <limits>
#include <fmt/base.h>
#include <spdlog/spdlog.h>
#include <nvml.h>
namespace huggingface::hardware::cuda {
#define AMPERE_SM_MAJOR 8
#define HOPPER_SM_MAJOR 8
/**
* Store information about the version of the CUDA Compute Capabilities detected on the device
*/
struct CudaComputeCapabilities {
int32_t major;
int32_t minor;
[[nodiscard]] constexpr bool isPostAmpere() const { return major >= AMPERE_SM_MAJOR; }
[[nodiscard]] constexpr bool isPostHopper() const { return major >= HOPPER_SM_MAJOR; }
};
CudaComputeCapabilities GetCudaComputeCapabilities() {
// Get the compute capabilities of the current hardware
nvmlDevice_t device;
CudaComputeCapabilities capabilities{0, 0};
if (nvmlDeviceGetHandleByIndex_v2(0, &device) == NVML_SUCCESS) {
SPDLOG_DEBUG("Successfully acquired nvmlDevice_t = 0");
if (nvmlDeviceGetCudaComputeCapability(device, &capabilities.major, &capabilities.minor) == NVML_SUCCESS) {
SPDLOG_INFO("Detected sm_{:d}{:d} compute capabilities", capabilities.major, capabilities.minor);
}
}
return capabilities;
}
/**
* Return the number of GPU detected. If no GPU is detected, return size_t::max()
* @return
*/
std::optional<size_t> GetNumDevices() {
uint32_t numGpus = 0;
if (nvmlDeviceGetCount_v2(&numGpus) == NVML_SUCCESS) {
return std::optional(numGpus);
} else {
return std::nullopt;
}
}
}
#endif //TGI_TRTLLM_BACKEND_HARDWARE_H

View File

@ -0,0 +1,146 @@
#include <fstream>
#include <fmt/ranges.h>
#include <spdlog/spdlog.h>
#include <nvml.h>
#include "backend.h"
#include "hardware.h"
void huggingface::tgi::backends::InitializeBackend() {
SPDLOG_INFO("Initializing Backend...");
nvmlInit_v2();
initTrtLlmPlugins();
const auto numGpus = huggingface::hardware::cuda::GetNumDevices();
if (numGpus.has_value()) {
SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value());
} else {
SPDLOG_WARN("Failed to detected Nvidia GPU(s) on the system");
}
}
[[nodiscard]]
tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
tle::ExecutorConfig execConfig(1);
// Retrieve the compute capabilities to enable some options at runtime
const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities();
// Single engine (TP = PP = 1) -> using leader mode (no MPI involved)
if (config["/pretrained_config/mapping/world_size"_json_pointer].get<uint8_t>() == 1) {
SPDLOG_INFO("Detected single engine deployment, using leader mode");
execConfig.setParallelConfig(tle::ParallelConfig(
tle::CommunicationType::kMPI,
tle::CommunicationMode::kLEADER,
std::nullopt,
std::nullopt,
std::nullopt
));
} else { // Multiple engines -> using orchestrator mode (MPI involved)
SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
execConfig.setParallelConfig(tle::ParallelConfig(
tle::CommunicationType::kMPI,
tle::CommunicationMode::kORCHESTRATOR,
std::nullopt,
std::nullopt,
tle::OrchestratorConfig(true, workerPath, nullptr, true)
));
}
// Define some configuration variables
execConfig.setKvCacheConfig(tle::KvCacheConfig(true));
execConfig.setEnableChunkedContext(computeCapabilities.isPostAmpere());
return execConfig;
}
tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
uint32_t topK,
float_t topP,
float_t temperature,
float_t repetition_penalty,
float_t frequency_penalty,
uint64_t seed) {
return tle::SamplingConfig(
1, // TGI only use a single beam
topK,
topP,
std::nullopt,
std::nullopt,
std::nullopt,
seed,
temperature,
temperature,
std::nullopt,
repetition_penalty,
std::nullopt,
frequency_penalty
);
}
huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
const std::filesystem::path &enginesFolder,
const std::filesystem::path &executorWorker
) :
config(json::parse(std::ifstream(enginesFolder / "config.json"))),
executor(
enginesFolder,
tensorrt_llm::executor::ModelType::kDECODER_ONLY,
GetExecutorConfig(config, executorWorker.string()
)) {
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
}
bool huggingface::tgi::backends::TensorRtLlmBackend::IsReady() const {
return executor.canEnqueueRequests();
}
[[nodiscard("Returned number of requests needs to be consumed")]]
size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const {
return executor.getNumResponsesReady();
}
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
const std::vector<tle::TokenIdType> &tokens,
const int32_t topK,
const float_t topP,
const float_t temperature,
const float_t repetition_penalty,
const float_t frequency_penalty,
const uint64_t seed
) {
#ifdef NDEBUG
SPDLOG_DEBUG(
FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"),
tokens.size(),
executor.getLatestIterationStats().back().numActiveRequests
);
#else
SPDLOG_DEBUG(
FMT_STRING("Submitting inference [{}] to the executor ({:d} already in-flight)"),
fmt::join(tokens, ", "),
executor.getLatestIterationStats().front().numActiveRequests
);
#endif
const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<size_t>();
const auto maxNewTokens = static_cast<int32_t>(std::max(1ul, maxNumTokens - tokens.size()));
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
const auto output = tle::OutputConfig(true, false, false, true, false);
return executor.enqueueRequest(
tle::Request{tokens, maxNewTokens, true, sampling, output});
}
[[nodiscard("Generated tokens result must be used")]]
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType requestId) {
SPDLOG_DEBUG(FMT_STRING("Polling status for request {:d}"), requestId);
return executor.awaitResponses(requestId);
}
void huggingface::tgi::backends::TensorRtLlmBackend::Shutdown() {
SPDLOG_INFO("Shutting down executor");
executor.shutdown();
}

View File

@ -0,0 +1,111 @@
#!/bin/bash
set -ex
TRT_VER="10.2.0.19"
CUDA_VER="12.5"
CUDNN_VER="9.2.1.18-1"
NCCL_VER="2.22.3-1+cuda12.5"
CUBLAS_VER="12.5.3.2-1"
NVRTC_VER="12.5.82-1"
for i in "$@"; do
case $i in
--TRT_VER=?*) TRT_VER="${i#*=}";;
--CUDA_VER=?*) CUDA_VER="${i#*=}";;
--CUDNN_VER=?*) CUDNN_VER="${i#*=}";;
--NCCL_VER=?*) NCCL_VER="${i#*=}";;
--CUBLAS_VER=?*) CUBLAS_VER="${i#*=}";;
*) ;;
esac
shift
done
NVCC_VERSION_OUTPUT=$(nvcc --version)
if [[ $(echo $NVCC_VERSION_OUTPUT | grep -oP "\d+\.\d+" | head -n 1) != ${CUDA_VER} ]]; then
echo "The version of pre-installed CUDA is not equal to ${CUDA_VER}."
exit 1
fi
install_ubuntu_requirements() {
apt-get update && apt-get install -y --no-install-recommends gnupg2 curl ca-certificates
ARCH=$(uname -m)
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi
curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/${ARCH}/cuda-keyring_1.0-1_all.deb
dpkg -i cuda-keyring_1.0-1_all.deb
apt-get update
if [[ $(apt list --installed | grep libcudnn9) ]]; then
apt-get remove --purge -y --allow-change-held-packages libcudnn9*
fi
if [[ $(apt list --installed | grep libnccl) ]]; then
apt-get remove --purge -y --allow-change-held-packages libnccl*
fi
if [[ $(apt list --installed | grep libcublas) ]]; then
apt-get remove --purge -y --allow-change-held-packages libcublas*
fi
if [[ $(apt list --installed | grep cuda-nvrtc-dev) ]]; then
apt-get remove --purge -y --allow-change-held-packages cuda-nvrtc-dev*
fi
CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
apt-get install -y --no-install-recommends libcudnn9-cuda-12=${CUDNN_VER} libcudnn9-dev-cuda-12=${CUDNN_VER}
apt-get install -y --no-install-recommends libnccl2=${NCCL_VER} libnccl-dev=${NCCL_VER}
apt-get install -y --no-install-recommends libcublas-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} libcublas-dev-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER}
# NVRTC static library doesn't exist in NGC PyTorch container.
NVRTC_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
apt-get install -y --no-install-recommends cuda-nvrtc-dev-${NVRTC_CUDA_VERSION}=${NVRTC_VER}
apt-get clean
rm -rf /var/lib/apt/lists/*
}
install_centos_requirements() {
CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
yum -y update
yum -y install epel-release
yum remove -y libnccl* && yum -y install libnccl-${NCCL_VER} libnccl-devel-${NCCL_VER}
yum remove -y libcublas* && yum -y install libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER} libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}
yum clean all
}
install_tensorrt() {
#PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')
#PARSED_PY_VERSION=$(echo "${PY_VERSION//./}")
TRT_CUDA_VERSION="12.5"
if [ -z "$RELEASE_URL_TRT" ];then
ARCH=${TRT_TARGETARCH}
if [ -z "$ARCH" ];then ARCH=$(uname -m);fi
if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi
if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-22.04" && OS="ubuntu-22.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi
RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.2.0/tars/TensorRT-${TRT_VER}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz
fi
wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar
tar -xf /tmp/TensorRT.tar -C /usr/local/
mv /usr/local/TensorRT-${TRT_VER} /usr/local/tensorrt
# pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl
rm -rf /tmp/TensorRT.tar
}
# Install base packages depending on the base OS
ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"')
case "$ID" in
debian)
install_ubuntu_requirements
install_tensorrt
;;
ubuntu)
install_ubuntu_requirements
install_tensorrt
;;
centos)
install_centos_requirements
install_tensorrt
;;
*)
echo "Unable to determine OS..."
exit 1
;;
esac

View File

@ -0,0 +1,329 @@
use std::future::Future;
use std::path::Path;
use std::pin::{pin, Pin};
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, OnceLock};
use std::task::{Context, Poll};
use std::time::Duration;
use async_trait::async_trait;
use cxx::UniquePtr;
use log::{error, warn};
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::sync::RwLock;
use tokio::time::{sleep, Instant};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::{Stream, StreamExt};
use tracing::{instrument, span, Level};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidationError::UnsupportedModality;
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError};
use text_generation_router::{FinishReason, Token};
use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
// Value used to poll the state of the generation stream
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
type InferResult<T> = Result<T, InferError>;
pub(crate) struct Generation {
executor: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
done: Arc<AtomicBool>,
}
/// Holds the user provided input to be executed along with a channel allowing
/// to bubble up all the generated tokens for that tokens the to end stream.
pub struct GenerationContext {
sender: UnboundedSender<InferResult<InferStreamResponse>>,
tokenizer: Arc<Tokenizer>,
tokens: Vec<u32>,
done: Arc<AtomicBool>,
queued: Instant,
start: Option<Instant>,
}
impl Stream for Generation {
type Item = usize;
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let interval = POLLING_INTERVAL_US.get_or_init(|| {
u64::from_str(option_env!("TRTLLM_BACKEND_POLLING_INTERVAL_US").unwrap_or("100"))
.expect("Invalid value provided for envvar POLLING_INTERVAL_US")
});
if !self.done.load(Ordering::Relaxed) {
let backend = pin!(self.executor.read());
let status = match backend.poll(ctx) {
Poll::Ready(executor_r) => {
let ready = executor_r.num_responses_ready();
if ready == 0 {
Poll::Pending
} else {
Poll::Ready(Some(ready))
}
}
Poll::Pending => Poll::Pending,
};
let waker = ctx.waker().clone();
tokio::spawn(async {
sleep(Duration::from_micros(*interval)).await;
waker.wake();
});
status
} else {
Poll::Ready(None) // end of stream
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(1, None)
}
}
unsafe impl Send for TensorRtLlmBackendImpl {}
unsafe impl Sync for TensorRtLlmBackendImpl {}
/// Implements the logic to execute generation with TensorRT-LLM executor API in background
pub struct TensorRtLlmBackend {
tokenizer: Arc<Tokenizer>,
// Backing the backend behind a RwLock to allow concurrent read access to retrieve
// the number of available tokens (read only) in the Generation stream
backend: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
}
impl TensorRtLlmBackend {
pub fn new<P: AsRef<Path> + Send + 'static, PP: AsRef<Path> + Send + 'static>(
tokenizer: Tokenizer,
engine_folder: P,
executor_worker_path: PP,
) -> Result<Self, TensorRtLlmBackendError> {
Ok(TensorRtLlmBackend {
tokenizer: Arc::new(tokenizer),
backend: Arc::new(RwLock::new(create_tensorrt_llm_backend(
engine_folder.as_ref().to_str().unwrap(),
executor_worker_path.as_ref().to_str().unwrap(),
))),
})
}
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
if request.top_n_tokens > 1 {
return Err(InferError::ValidationError(
ValidationError::TopNTokensDisabled,
));
}
// TODO: Is it really needed? How can it be validated before?
if request.parameters.grammar.is_some() {
return Err(InferError::ValidationError(ValidationError::Grammar));
}
match request.inputs.len() {
0 => Err(InferError::ValidationError(ValidationError::EmptyInput)),
2.. => Err(InferError::GenerationError(
"TensorRT-LLM backend don't support multi-chunk".into(),
)),
1 => match request.inputs.first().expect("Single item-chunk") {
Chunk::Text(text) => Ok(text),
Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))),
},
}
}
fn generate(
&self,
sender: UnboundedSender<InferResult<InferStreamResponse>>,
tokens: Vec<u32>,
top_k: u32,
top_p: f32,
temperature: f32,
repetition_penalty: f32,
frequency_penalty: f32,
seed: u64,
) {
let tokenizer = Arc::clone(&self.tokenizer);
let executor = Arc::clone(&self.backend);
// Let's push this in async context
tokio::spawn(async move {
// Define the generation state
let mut generation = Generation {
executor: executor.clone(),
done: Arc::new(AtomicBool::new(false)),
};
// Define the context over the generation
// TODO(asap): Do we really need so many shared-ownership?
let ctx = Box::new(GenerationContext {
sender: sender.clone(),
tokenizer,
tokens: vec![],
done: Arc::clone(&generation.done),
start: None,
queued: Instant::now(),
});
// We are leaking the context on-purpose to avoid the box being dropped while there are
// still computation ongoing
// TODO(asap): Can we achieve the same with an Arc<Box<T>> without the need to go unsafe?
let ctx_ = Box::leak(ctx);
// Submit the request to the batcher
let request_id = span!(Level::DEBUG, "submit")
.in_scope(|| async {
let mut handle = executor.write().await;
let request_id = handle.pin_mut().submit(
&tokens,
top_k as i32,
top_p,
temperature,
repetition_penalty,
frequency_penalty,
seed,
);
request_id
})
.await;
while let Some(_) = generation.next().await {
let mut executor_w = executor.write().await;
let executor = executor_w.pin_mut();
span!(Level::DEBUG, "decode")
.in_scope(|| async {
unsafe {
executor.stream_tokens(
request_id,
ctx_,
|ctx: *mut GenerationContext, step: GenerationStep| {
let inner_ctx = &mut *ctx;
// Update the timestamp at which the request started effectively
// Can be a bit off, would need to be before the callback, let's see
inner_ctx.start.get_or_insert(Instant::now());
inner_ctx.done.store(step.is_final, Ordering::Relaxed);
// Ensure we are not running into errors
let parcel = if !step.has_error {
// Insert the latest generated token to the tracker
inner_ctx.tokens.push(step.token_id);
// Decode the token
let text = inner_ctx
.tokenizer
.decode(&[step.token_id], true)
.expect("Failed to decode token");
let special = inner_ctx
.tokenizer
.get_added_vocabulary()
.is_special_token(&text);
// Create the structure holding the token
let token = Token {
id: step.token_id,
text,
logprob: step.log_prob,
special,
};
if step.is_final {
let generated_text = inner_ctx
.tokenizer
.decode(&inner_ctx.tokens, true)
.expect("Failed to decode generated_tokens");
Ok(InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text: GeneratedText {
text: generated_text,
generated_tokens: inner_ctx.tokens.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken,
seed: None,
},
start: inner_ctx.start.unwrap_or(Instant::now()),
queued: inner_ctx.queued,
})
} else {
Ok(InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
})
}
} else {
error!("Error caught while decoding: {}", &step.error_msg);
Err(InferError::GenerationError(step.error_msg))
};
// Send the parcel to the client
inner_ctx
.sender
.send(parcel)
.expect("Failed to sent msg through the channel");
},
);
}
})
.await;
}
// "Properly" free the shared context...
// TODO: clean that piece of sh** asap
unsafe {
let _ = Box::from_raw(ctx_);
}
});
}
}
#[async_trait]
impl Backend for TensorRtLlmBackend {
#[instrument(skip_all)]
fn schedule(
&self,
request: ValidGenerateRequest,
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
// Let's add a few more validation
let input = TensorRtLlmBackend::validate(&request)?;
// Channel to stream the generated token as they come from the worker thread back to the transport layer
let (sender, receiver) = unbounded_channel();
// Unpack parameters
let params = &request.parameters;
// Preprocess the inputs to send to TRTLLM backend
let encoding = self
.tokenizer
.encode(input.as_str(), true)
.map_err(|e| InferError::GenerationError(e.to_string()))?;
// Generate the response
self.generate(
sender,
Vec::from(encoding.get_ids()),
params.top_k,
params.top_p,
params.temperature,
params.repetition_penalty,
params.frequency_penalty,
params.seed,
);
Ok(UnboundedReceiverStream::new(receiver))
}
async fn health(&self, _current_health: bool) -> bool {
true
}
}

View File

@ -0,0 +1,15 @@
use thiserror::Error;
use text_generation_router::server;
#[derive(Debug, Error)]
pub enum TensorRtLlmBackendError {
#[error("Tokenizer error: {0}")]
Tokenizer(String),
#[error("Argument validation error: {0}")]
ArgumentValidation(String),
#[error("WebServer error: {0}")]
WebServer(#[from] server::WebServerError),
#[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error),
}

View File

@ -0,0 +1,84 @@
//
// Created by mfuntowicz on 6/30/24.
//
#pragma once
#include <cmath>
#include <exception>
#include <filesystem>
#include <limits>
#include <iterator>
#include <vector>
#include <spdlog/spdlog.h>
#include "backends/trtllm/include/ffi.h"
huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(
const std::string_view &engineFolder,
const std::string_view &executorWorker
) : TensorRtLlmBackend(engineFolder, executorWorker) {}
bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const {
return TensorRtLlmBackend::IsReady();
}
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, float_t repetition_penalty,
float_t frequency_penalty, uint64_t seed) {
// This will copy all the items from the initial slice
std::vector<int32_t> tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end()));
return TensorRtLlmBackend::Submit(
std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
}
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
const uint64_t requestId,
huggingface::tgi::backends::GenerationContext *ctx,
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
huggingface::tgi::backends::GenerationStep)> callback) {
size_t numTokens = 0;
for (const auto &item: Poll(requestId)) {
GenerationStep step;
if (!item.hasError()) {
SPDLOG_DEBUG("\tStreamTokens -> Decoding token...");
const auto decoded = item.getResult();
const auto token = decoded.outputTokenIds[0][0];
const auto isFinal = decoded.isFinal;
const auto logProb = decoded.logProbs.value()[0][0];
++numTokens;
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
step = huggingface::tgi::backends::GenerationStep{
static_cast<uint32_t>(token), logProb, isFinal, false, std::move(std::string())
};
SPDLOG_DEBUG("\tStreamTokens -> Post callback");
} else {
// TODO : Return rest::Result with error
const auto what = item.getErrorMsg();
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what);
step = huggingface::tgi::backends::GenerationStep{
std::numeric_limits<uint32_t>::max(), 0.0, true, true, std::move(what)
};
}
callback(std::move(ctx), std::move(step));
}
return numTokens;
}
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
// Unconditionally call this to initialize and discover TRTLLM plugins
InitializeBackend();
const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end());
const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end());
return std::make_unique<TensorRtLlmBackendImpl>(std::move(enginePath), std::move(executorPath));
}

View File

@ -0,0 +1,78 @@
pub use backend::{GenerationContext, TensorRtLlmBackend};
mod backend;
pub mod errors;
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
mod ffi {
/// Struct used as shared type between rust and C++ to represent the result
/// of a single decoding iteration
pub struct GenerationStep {
token_id: u32,
log_prob: f32,
is_final: bool,
has_error: bool,
error_msg: String,
}
extern "Rust" {
type GenerationContext;
}
unsafe extern "C++" {
include!("backends/trtllm/src/ffi.cpp");
/// Represent an instance of the underlying TensorRT-LLM backend
type TensorRtLlmBackendImpl;
/// Create an instance backed behind a std::unique_ptr to manage the lifespan of the backend
///
/// # Arguments
///
/// * `engine_folder`: Path to the folder containing all the TRTLLM engines
/// * `executor_worker`: Path to the TRTLLM executor worker
///
/// returns: <unknown>
///
/// # Examples
///
/// ```
///
/// ```
#[rust_name = "create_tensorrt_llm_backend"]
fn CreateTensorRtLlmBackend(
engine_folder: &str,
executor_worker: &str,
) -> UniquePtr<TensorRtLlmBackendImpl>;
// #[rust_name = "is_ready"]
// fn IsReady(self: &TensorRtLlmBackendImpl) -> bool;
#[rust_name = "num_responses_ready"]
fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
#[rust_name = "submit"]
fn Submit(
self: Pin<&mut TensorRtLlmBackendImpl>,
tokens: &[u32],
top_k: i32,
top_p: f32,
temperature: f32,
repetition_penalty: f32,
frequency_penalty: f32,
seed: u64,
) -> u64;
#[rust_name = "stream_tokens"]
unsafe fn StreamTokens(
self: Pin<&mut TensorRtLlmBackendImpl>,
request_id: u64,
ctx: *mut GenerationContext,
cb: unsafe fn(*mut GenerationContext, GenerationStep),
) -> usize;
// #[rust_name = "shutdown"]
// fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>);
}
}

166
backends/trtllm/src/main.rs Normal file
View File

@ -0,0 +1,166 @@
use std::collections::HashMap;
use std::path::PathBuf;
use clap::Parser;
use tokenizers::{FromPretrainedParameters, Tokenizer};
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
use text_generation_backends_trtllm::TensorRtLlmBackend;
use text_generation_router::server;
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)]
max_input_tokens: usize,
#[clap(default_value = "2048", long, env)]
max_total_tokens: usize,
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
#[clap(long, env)]
max_batch_total_tokens: Option<u32>,
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(long, env, required = true)]
tokenizer_name: String,
#[clap(long, env)]
tokenizer_config_path: Option<String>,
#[clap(long, env)]
revision: Option<String>,
#[clap(long, env)]
model_id: String,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: String,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
#[clap(long, env)]
auth_token: Option<String>,
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
executor_worker: PathBuf,
}
#[tokio::main]
async fn main() -> Result<(), TensorRtLlmBackendError> {
// Get args
let args = Args::parse();
// Pattern match configuration
let Args {
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
max_batch_prefill_tokens,
max_batch_total_tokens,
hostname,
port,
tokenizer_name,
tokenizer_config_path,
revision,
model_id,
validation_workers,
json_output,
otlp_endpoint,
otlp_service_name,
cors_allow_origin,
messages_api_enabled,
max_client_batch_size,
auth_token,
executor_worker,
} = args;
// Launch Tokio runtime
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
// Validate args
if max_input_tokens >= max_total_tokens {
return Err(TensorRtLlmBackendError::ArgumentValidation(
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
if validation_workers == 0 {
return Err(TensorRtLlmBackendError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
));
}
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
}
if !executor_worker.exists() {
return Err(TensorRtLlmBackendError::ArgumentValidation(format!(
"`executor_work` specified path doesn't exists: {}",
executor_worker.display()
)));
}
// Run server
let tokenizer = Tokenizer::from_pretrained(
tokenizer_name.clone(),
Some(FromPretrainedParameters {
revision: revision.clone().unwrap_or(String::from("main")),
user_agent: HashMap::new(),
auth_token,
}),
)
.map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?;
let backend = TensorRtLlmBackend::new(tokenizer, model_id, executor_worker)?;
server::run(
backend,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
validation_workers,
None,
tokenizer_name,
tokenizer_config_path,
revision,
hostname,
port,
cors_allow_origin,
false,
None,
None,
messages_api_enabled,
true,
max_client_batch_size,
)
.await?;
Ok(())
}

View File

@ -0,0 +1,14 @@
//
// Created by mfuntowicz on 7/2/24.
//
#include <catch2/catch_all.hpp>
#include <spdlog/spdlog.h>
#include "../include/backend.h"
TEST_CASE("Load TRTLLM Engine on the TGI Backend", "[trtllm][engine][load]") {
const auto engines = std::filesystem::path("/home/mfuntowicz/.cache/huggingface/assets/trtllm/0.11.0.dev2024062500/meta-llama--Meta-Llama-3-8B-Instruct/4090/engines/");
const auto executor = std::filesystem::path("/home/mfuntowicz/Workspace/text-generation-inference/backends/trtllm/cmake-build-debug/cmake-build-debug/_deps/trtllm-src/cpp/tensorrt_llm/executor_worker/executorWorker");
spdlog::info("Loading config from: {}", absolute(engines).string());
huggingface::tgi::backends::TensorRtLlmBackend backend(engines, executor);
}

66
backends/v3/Cargo.toml Normal file
View File

@ -0,0 +1,66 @@
[package]
name = "text-generation-router-v3"
description = "Text Generation Webserver"
version.workspace = true
edition.workspace = true
authors.workspace = true
homepage.workspace = true
[lib]
path = "src/lib.rs"
[[bin]]
name = "text-generation-router"
path = "src/main.rs"
[dependencies]
async-trait = "0.1.74"
async-stream = "0.3.5"
axum = { version = "0.7", features = ["json"] }
axum-tracing-opentelemetry = "0.16"
text-generation-router = { path = "../../router" }
clap = { version = "4.4.5", features = ["derive", "env"] }
grpc-metadata = { path = "../grpc-metadata" }
futures = "0.3.28"
hf-hub = { workspace = true }
jsonschema = { version = "0.17.1", features = ["draft202012"] }
metrics = { workspace = true }
metrics-exporter-prometheus = { workspace = true }
nohash-hasher = "0.2.0"
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.13.0"
rand = "0.8.5"
reqwest = { version = "0.11.20", features = [] }
serde = "1.0.188"
serde_json = "1.0.107"
thiserror = "1.0.48"
tokenizers = { workspace = true}
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.14"
tower-http = { version = "0.5.1", features = ["cors"] }
tracing = "0.1.37"
tracing-opentelemetry = "0.21.0"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
utoipa = { version = "4.2.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
minijinja = { version = "2.0.2" }
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
futures-util = "0.3.30"
regex = "1.10.3"
once_cell = "1.19.0"
image = "0.25.1"
base64 = { workspace = true }
prost = "^0.12"
tonic = "^0.10"
tower = "^0.4"
[build-dependencies]
tonic-build = "0.10.1"
prost-build = "0.12.1"
[features]
default = ["ngrok"]
ngrok = ["text-generation-router/ngrok"]
google = ["text-generation-router/google"]
kserve = ["text-generation-router/kserve"]

19
backends/v3/build.rs Normal file
View File

@ -0,0 +1,19 @@
use std::fs;
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("cargo:rerun-if-changed=../../proto/");
fs::create_dir_all("src/client/pb").unwrap_or(());
let mut config = prost_build::Config::new();
config.protoc_arg("--experimental_allow_proto3_optional");
tonic_build::configure()
.build_client(true)
.build_server(false)
.out_dir("src/client/pb")
.include_file("mod.rs")
.compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"])
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
Ok(())
}

508
backends/v3/src/backend.rs Normal file
View File

@ -0,0 +1,508 @@
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
/// Batching and inference logic
use crate::queue::{Entry, Queue};
use async_trait::async_trait;
use nohash_hasher::IntMap;
use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{FinishReason, PrefillToken, Token};
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{info_span, instrument, Instrument, Span};
pub struct BackendV3 {
/// Request queue
queue: Queue,
/// Notify batcher on queue appends
batching_task_notifier: Arc<Notify>,
/// Client clone, used for health checks to skip the queue
client: ShardedClient,
}
impl BackendV3 {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
client: ShardedClient,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
requires_padding: bool,
window_size: Option<u32>,
speculate: u32,
) -> Self {
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
} else {
false
};
let block_size = if flashdecoding { 256 } else { 16 };
let queue = Queue::new(
requires_padding,
block_size,
window_size,
speculate,
max_batch_total_tokens,
);
let batching_task_notifier = Arc::new(Notify::new());
// Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task(
client.clone(),
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
queue.clone(),
batching_task_notifier.clone(),
));
Self {
queue,
batching_task_notifier,
client,
}
}
}
#[async_trait]
impl Backend for BackendV3 {
#[instrument(skip_all)]
fn schedule(
&self,
request: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
// MPSC channel to communicate with the background batching task
let (response_tx, response_rx) = mpsc::unbounded_channel();
// Append the request to the queue
self.queue.append(Entry {
request,
response_tx,
span: Span::current(),
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
block_allocation: None,
});
// Notify the background task that we have a new entry in the queue that needs
// to be batched
self.batching_task_notifier.notify_one();
// Return stream
Ok(UnboundedReceiverStream::new(response_rx))
}
async fn health(&self, current_health: bool) -> bool {
if current_health {
// Generation is healthy, we only check that the shards can allocate on device
self.client.device_health().await
} else {
self.client.model_health().await
}
.is_ok()
}
}
/// Batching logic
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
#[allow(clippy::too_many_arguments)]
pub(crate) async fn batching_task(
mut client: ShardedClient,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
queue: Queue,
notifier: Arc<Notify>,
) {
// Infinite loop
loop {
// Wait for a notification from the Infer struct
notifier.notified().await;
// Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue
while let Some((mut entries, batch, span)) = queue
.next_batch(
None,
max_batch_size,
max_batch_prefill_tokens,
max_batch_total_tokens,
)
.await
{
let mut cached_batch = prefill(&mut client, batch, &mut entries)
.instrument(span)
.await;
let mut waiting_tokens = 1;
// We loop until we do not receive any cached batch from the inference server (== until
// all requests have met their stopping criteria)
while let Some(batch) = cached_batch {
// Get current batch info
let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens;
let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
None
} else {
// Minimum batch size
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
};
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
.await
{
// Tracking metrics
if min_size.is_some() {
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
.increment(1);
} else {
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
.increment(1);
}
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
// Add relationships
span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span);
// Update entry
entry.temp_span = Some(entry_waiting_span);
});
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
.instrument(span)
.await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
entries.extend(new_entries);
batches.push(new_cached_batch);
}
}
// Create span for this batch to add context to inference calls
let next_batch_size = entries.len();
let next_batch_span =
info_span!(parent: None, "batch", batch_size = next_batch_size);
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer");
// Add relationships
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.temp_span = Some(entry_batch_span);
});
cached_batch = decode(&mut client, batches, &mut entries)
.instrument(next_batch_span)
.await;
waiting_tokens += 1;
}
metrics::gauge!("tgi_batch_current_size").set(0.0);
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
}
}
}
#[instrument(skip_all)]
async fn prefill(
client: &mut ShardedClient,
batch: Batch,
entries: &mut IntMap<u64, Entry>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_id = batch.id;
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
match client.prefill(batch).await {
Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries);
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
None
}
}
}
#[instrument(skip_all)]
async fn decode(
client: &mut ShardedClient,
batches: Vec<CachedBatch>,
entries: &mut IntMap<u64, Entry>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
match client.decode(batches).await {
Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
.record(concat_duration.as_secs_f64());
}
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
for id in batch_ids {
let _ = client.clear_cache(Some(id)).await;
}
send_errors(err, entries);
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
None
}
}
}
/// Filter a `batch` and remove all requests not present in `entries`
#[instrument(skip_all)]
async fn filter_batch(
client: &mut ShardedClient,
next_batch: Option<CachedBatch>,
entries: &IntMap<u64, Entry>,
) -> Option<CachedBatch> {
let mut batch = next_batch?;
// No need to filter
if batch.size as usize == entries.len() {
return Some(batch);
}
let id = batch.id;
// Retain only requests that are still in entries
batch.request_ids.retain(|id| entries.contains_key(id));
if batch.request_ids.is_empty() {
// All requests have been filtered out
// Next batch is now empty
// Clear it from the Python shards cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client.clear_cache(Some(id)).await.unwrap();
None
} else {
// Filter Python shard cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client.filter_batch(id, batch.request_ids).await.unwrap()
}
}
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
/// and filter entries
#[instrument(skip_all)]
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
generations.into_iter().for_each(|generation| {
let id = generation.request_id;
// Get entry
// We can `expect` here as the request id should always be in the entries
let entry = entries
.get(&id)
.expect("ID not found in entries. This is a bug.");
// Create and enter a span to link this function back to the entry
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
// Send generation responses back to the infer task
// If the receive an error from the Flume channel, it means that the client dropped the
// request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| {
tracing::error!("Entry response channel error.");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
err
}).unwrap_or(true);
if stopped {
entries.remove(&id).expect("ID not found in entries. This is a bug.");
}
});
}
/// Send responses through the `entry` response channel
fn send_responses(
generation: Generation,
entry: &Entry,
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected
if entry.response_tx.is_closed() {
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
return Ok(true);
}
let mut stopped = false;
if let Some(prefill_tokens) = generation.prefill_tokens {
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
let prefill_tokens = prefill_tokens
.ids
.into_iter()
.zip(prefill_tokens.logprobs)
.zip(prefill_tokens.texts)
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect();
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
}
// Create last Token
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
let mut iterator = tokens_
.ids
.into_iter()
.zip(tokens_.logprobs)
.zip(tokens_.texts)
.zip(tokens_.is_special)
.enumerate()
.peekable();
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
let token = Token {
id,
text,
logprob,
special,
};
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
top_tokens_
.ids
.iter()
.zip(top_tokens_.logprobs.iter())
.zip(top_tokens_.texts.iter())
.zip(top_tokens_.is_special.iter())
.map(|(((&id, &logprob), text), &special)| Token {
id,
text: text.to_string(),
logprob,
special,
})
.collect()
} else {
vec![]
};
match (&generation.generated_text, iterator.peek()) {
(Some(generated_text), None) => {
// Generation has ended
stopped = true;
// Send message
entry.response_tx.send(Ok(InferStreamResponse::End {
token,
top_tokens,
generated_text: GeneratedText::from(generated_text.clone()),
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
}
_ => {
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
}
}
}
Ok(stopped)
}
/// Send errors to Infer for all `entries`
#[instrument(skip_all)]
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
entries.drain().for_each(|(_, entry)| {
// Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string());
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Err(err))
.unwrap_or(());
});
}
impl From<crate::client::GeneratedText> for GeneratedText {
fn from(value: crate::client::GeneratedText) -> Self {
let v3_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap();
let finish_reason = match v3_finish_reason {
crate::client::FinishReason::Length => FinishReason::Length,
crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
crate::client::FinishReason::StopSequence => FinishReason::StopSequence,
};
Self {
text: value.text,
generated_tokens: value.generated_tokens,
finish_reason,
seed: value.seed,
}
}
}

View File

@ -0,0 +1,284 @@
/// Single shard Client
use crate::client::{pb, Chunk};
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use grpc_metadata::InjectTelemetryContext;
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
use pb::generate::v3::*;
use std::cmp::min;
use std::time::Duration;
use tonic::transport::{Channel, Uri};
use tracing::instrument;
/// Text Generation Inference gRPC client
#[derive(Debug, Clone)]
pub struct Client {
stub: TextGenerationServiceClient<Channel>,
}
impl Client {
/// Returns a client connected to the given url
#[allow(dead_code)]
pub async fn connect(uri: Uri) -> Result<Self> {
let channel = Channel::builder(uri).connect().await?;
Ok(Self {
stub: TextGenerationServiceClient::new(channel),
})
}
/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result<Self> {
let channel = Channel::from_shared("http://[::]:50051".to_string())
.unwrap()
.connect_with_connector(tower::service_fn(move |_: Uri| {
tokio::net::UnixStream::connect(path.clone())
}))
.await?;
Ok(Self {
stub: TextGenerationServiceClient::new(channel),
})
}
/// Returns a list of uris or unix sockets of all shards
#[instrument(skip(self))]
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
let response = self.stub.service_discovery(request).await.map_err(|_| {
ClientError::Connection("Server does not support v3 interface".to_string())
})?;
let urls = response
.into_inner()
.urls
.into_iter()
// Remove unix socket prefix
.map(|url| match url.strip_prefix("unix://") {
None => url,
Some(stripped_url) => stripped_url.to_string(),
})
.collect();
Ok(urls)
}
/// Get model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result<InfoResponse> {
let request = tonic::Request::new(InfoRequest {}).inject_context();
let response = self.stub.info(request).await?.into_inner();
Ok(response)
}
/// Get model health
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> {
let request = tonic::Request::new(HealthRequest {}).inject_context();
let response = self.stub.health(request).await?.into_inner();
Ok(response)
}
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
self.stub.clear_cache(request).await?;
Ok(())
}
/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
let request = tonic::Request::new(FilterBatchRequest {
batch_id,
request_ids,
})
.inject_context();
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
Ok(filtered_batch.batch)
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip_all)]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let mut n_tokens = 0;
let mut requests = Vec::new();
// Create requests
while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut input_chunks = Vec::new();
input_chunks
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
if n_tokens == 0 {
input_chunks.push(
Chunk::Image(Image {
// Safe unwrap, because we control the data.
data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(),
mimetype: "image/jpeg;base64".to_string(),
})
.into(),
);
}
// Send stringly-typed inputs for compatibility for backends that haven't
// been updated to support chunks.
let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
if n_tokens == 0 {
// 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation.
inputs.push_str(&format!(
"![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})",
));
}
requests.push(Request {
id: 0,
inputs,
input_chunks: Some(Input {
chunks: input_chunks,
}),
// We truncate the input on the server side to be sure that it has the correct size
truncate,
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
slots: vec![],
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
top_k: 10,
top_p: 0.9,
typical_p: 0.9,
do_sample: false,
seed: 0,
repetition_penalty: 1.2,
frequency_penalty: 0.1,
watermark: true,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,
stop_sequences: vec![],
ignore_eos_token: true,
}),
prefill_logprobs: true,
top_n_tokens: 20,
adapter_id: None,
});
n_tokens += max_input_length;
// Check max_batch_size
if Some(requests.len()) == max_batch_size {
break;
}
}
let batch = Batch {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: max_input_length,
max_blocks: 0,
};
let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_input_length,
max_prefill_tokens,
max_total_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens)
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
pub async fn prefill(
&mut self,
batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((
response.generations,
response.batch,
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
))
}
/// Generate one token for each request in the given cached batches
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
pub async fn decode(
&mut self,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self.stub.decode(request).await?.into_inner();
Ok((
response.generations,
response.batch,
DecodeTimings::new(
response.concat_ns,
response.forward_ns,
response.decode_ns,
response.total_ns,
),
))
}
}
pub struct PrefillTimings {
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl PrefillTimings {
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
}
}
}
pub struct DecodeTimings {
pub concat: Option<Duration>,
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl DecodeTimings {
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
concat: concat_ns.map(Duration::from_nanos),
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
}
}
}

View File

@ -0,0 +1,76 @@
//! Text Generation gRPC client library
use async_trait::async_trait;
use thiserror::Error;
use tonic::transport;
use tonic::Status;
#[allow(clippy::derive_partial_eq_without_eq)]
mod pb;
mod grpc_client;
mod sharded_client;
pub use grpc_client::Client;
pub use pb::generate::v3::{
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
StoppingCriteriaParameters,
};
pub use sharded_client::ShardedClient;
#[async_trait]
pub trait Health {
/// Check if a generate server is healthy by asking it to allocate a tensor on device
async fn device_health(&self) -> Result<()>;
/// Check if a generate server is healthy by doing a forward pass.
/// EXPENSIVE
async fn model_health(&self) -> Result<()>;
}
#[derive(Debug)]
pub struct ShardInfo {
pub requires_padding: bool,
pub dtype: String,
pub device_type: String,
pub window_size: Option<u32>,
pub speculate: u32,
}
#[derive(Error, Debug, Clone)]
pub enum ClientError {
#[error("Could not connect to Text Generation server: {0}")]
Connection(String),
#[error("Server error: {0}")]
Generation(String),
#[error("Sharded results are empty")]
EmptyResults,
}
impl From<Status> for ClientError {
fn from(err: Status) -> Self {
let err = Self::Generation(err.message().to_string());
tracing::error!("{err}");
err
}
}
impl From<transport::Error> for ClientError {
fn from(err: transport::Error) -> Self {
let err = Self::Connection(err.to_string());
tracing::error!("{err}");
err
}
}
// Small convenience re-wrapping of `Chunk`.
impl From<Chunk> for InputChunk {
fn from(chunk: Chunk) -> Self {
InputChunk { chunk: Some(chunk) }
}
}
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
pub type Result<T> = std::result::Result<T, ClientError>;

View File

@ -0,0 +1,260 @@
use crate::client::{ClientError, Result};
/// Multi shard Client
use crate::client::{Health, ShardInfo};
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
use crate::client::{
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use crate::client::{Chunk, InfoResponse, Input};
use async_trait::async_trait;
use futures::future::join_all;
use tonic::transport::Uri;
use tracing::instrument;
#[derive(Debug, Clone)]
/// Text Generation Inference gRPC multi client
pub struct ShardedClient {
clients: Vec<Client>,
}
impl ShardedClient {
fn new(clients: Vec<Client>) -> Self {
Self { clients }
}
/// Create a new ShardedClient from a master client. The master client will communicate with
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
async fn from_master_client(mut master_client: Client) -> Result<Self> {
// Get all uris/unix sockets from the master client
let uris = master_client.service_discovery().await?;
let futures = uris.into_iter().map(Client::connect_uds);
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
Ok(Self::new(clients?))
}
/// Returns a client connected to the given uri
#[allow(dead_code)]
pub async fn connect(uri: Uri) -> Result<Self> {
let master_client = Client::connect(uri).await?;
Self::from_master_client(master_client).await
}
/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result<Self> {
let master_client = Client::connect_uds(path).await?;
Self::from_master_client(master_client).await
}
/// Get the model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result<ShardInfo> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.info())
.collect();
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
}
/// GRPC health check
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.health())
.collect();
join_all(futures).await.pop().unwrap()
}
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.clear_cache(batch_id))
.collect();
join_all(futures).await.into_iter().collect()
}
/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| {
Box::pin(client.warmup(
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_size,
))
})
.collect();
// Take the minimum value
let results = join_all(futures)
.await
.into_iter()
.collect::<Result<Vec<Option<u32>>>>()?;
Ok(results.into_iter().flatten().min())
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
pub async fn prefill(
&mut self,
batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
join_all(futures).await.into_iter().collect();
let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
}
/// Generate one token for each request in the given cached batches
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
pub async fn decode(
&mut self,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.decode(batches.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
join_all(futures).await.into_iter().collect();
let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
}
}
impl From<InfoResponse> for ShardInfo {
fn from(value: InfoResponse) -> Self {
Self {
requires_padding: value.requires_padding,
dtype: value.dtype,
device_type: value.device_type,
window_size: value.window_size,
speculate: value.speculate,
}
}
}
#[async_trait]
impl Health for ShardedClient {
async fn device_health(&self) -> Result<()> {
self.clone().health().await?;
Ok(())
}
async fn model_health(&self) -> Result<()> {
// Dummy batch of 1 token and 1 generated token
let liveness_request = Request {
id: u64::MAX,
inputs: "liveness".to_string(),
input_chunks: Some(Input {
chunks: vec![Chunk::Text("liveness".into()).into()],
}),
truncate: 10,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
stop_sequences: vec![],
ignore_eos_token: false,
}),
top_n_tokens: 0,
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
adapter_id: None,
};
let batch = Batch {
id: u64::MAX,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
max_blocks: 1,
};
self.clone().prefill(batch).await?;
Ok(())
}
}

142
backends/v3/src/lib.rs Normal file
View File

@ -0,0 +1,142 @@
mod backend;
mod block_allocator;
mod client;
mod queue;
use crate::client::{ClientError, ShardedClient};
pub(crate) use backend::BackendV3;
use serde::Serialize;
use thiserror::Error;
use utoipa::ToSchema;
#[derive(Clone, Debug, Serialize, ToSchema)]
pub struct BackendInfo {
/// Mandatory
#[schema(example = "cuda")]
pub model_device_type: String,
#[schema(example = "torch.float16")]
pub model_dtype: String,
/// Backend parameters
#[schema(example = "1")]
pub speculate: usize,
#[schema(example = "1.2")]
pub waiting_served_ratio: f32,
#[schema(example = "32000")]
pub max_batch_total_tokens: u32,
#[schema(example = "20")]
pub max_waiting_tokens: usize,
#[schema(nullable = true, example = "null")]
pub max_batch_size: Option<usize>,
}
#[allow(clippy::too_many_arguments)]
pub async fn connect_backend(
max_input_tokens: usize,
max_total_tokens: usize,
master_shard_uds_path: String,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
) -> Result<(BackendV3, BackendInfo), V3Error> {
// Helper function
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
match max_supported_batch_total_tokens {
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
tracing::warn!("Model does not support automatic max batch total tokens");
Ok(max_batch_total_tokens)
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
// Warn if user added his own max-batch-total-tokens as we will ignore it
if max_batch_total_tokens.is_some() {
tracing::warn!(
"`--max-batch-total-tokens` is deprecated for Flash \
Attention models."
);
tracing::warn!(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
}
if max_total_tokens as u32 > max_supported_batch_total_tokens {
return Err(V3Error::NotEnoughMemory(max_total_tokens));
}
Ok(max_supported_batch_total_tokens)
}
}
};
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
.map_err(V3Error::Connection)?;
// server is running on v3
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.map_err(V3Error::Cache)?;
// Get info from the shard
let shard_info = sharded_client.info().await.map_err(V3Error::Info)?;
// Warmup model
tracing::info!("Warming up model");
let max_batch_total_tokens = check_max_batch_total_tokens(
sharded_client
.warmup(
max_input_tokens as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_size,
)
.await
.map_err(V3Error::Warmup)?,
)?;
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
let backend_info = BackendInfo {
waiting_served_ratio,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
model_device_type: shard_info.device_type.clone(),
model_dtype: shard_info.dtype.clone(),
speculate: shard_info.speculate as usize,
};
let backend = BackendV3::new(
sharded_client,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
);
tracing::info!("Using backend V3");
Ok((backend, backend_info))
}
#[derive(Debug, Error)]
pub enum V3Error {
#[error("Unable to clear the Python model shards cache: {0}")]
Cache(ClientError),
#[error("Unable to connect to the Python model shards: {0}")]
Connection(ClientError),
#[error("Unable to get the Python model shards info: {0}")]
Info(ClientError),
#[error("Unable to warmup the Python model shards: {0}")]
Warmup(ClientError),
#[error("Not enough memory to handle `max_total_tokens={0}`")]
NotEnoughMemory(usize),
}

204
backends/v3/src/main.rs Normal file
View File

@ -0,0 +1,204 @@
use clap::{Parser, Subcommand};
use text_generation_router::{server, usage_stats};
use text_generation_router_v3::{connect_backend, V3Error};
use thiserror::Error;
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[command(subcommand)]
command: Option<Commands>,
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)]
max_input_tokens: usize,
#[clap(default_value = "2048", long, env)]
max_total_tokens: usize,
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
#[clap(long, env)]
max_batch_total_tokens: Option<u32>,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(long, env)]
max_batch_size: Option<usize>,
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
master_shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String,
#[clap(long, env)]
tokenizer_config_path: Option<String>,
#[clap(long, env)]
revision: Option<String>,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
api_key: Option<String>,
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: String,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env)]
ngrok: bool,
#[clap(long, env)]
ngrok_authtoken: Option<String>,
#[clap(long, env)]
ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
#[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel,
}
#[derive(Debug, Subcommand)]
enum Commands {
PrintSchema,
}
#[tokio::main]
async fn main() -> Result<(), RouterError> {
// Get args
let args = Args::parse();
// Pattern match configuration
let Args {
command,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
hostname,
port,
master_shard_uds_path,
tokenizer_name,
tokenizer_config_path,
revision,
validation_workers,
api_key,
json_output,
otlp_endpoint,
otlp_service_name,
cors_allow_origin,
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats,
} = args;
if let Some(Commands::PrintSchema) = command {
use utoipa::OpenApi;
let api_doc = text_generation_router::server::ApiDoc::openapi();
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
println!("{}", api_doc);
std::process::exit(0);
};
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
// Validate args
if max_input_tokens >= max_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
if validation_workers == 0 {
return Err(RouterError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
));
}
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
}
let (backend, _backend_info) = connect_backend(
max_input_tokens,
max_total_tokens,
master_shard_uds_path,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
)
.await?;
// Run server
server::run(
backend,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
validation_workers,
api_key,
tokenizer_name,
tokenizer_config_path,
revision,
hostname,
port,
cors_allow_origin,
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats,
)
.await?;
Ok(())
}
#[derive(Debug, Error)]
enum RouterError {
#[error("Argument validation error: {0}")]
ArgumentValidation(String),
#[error("Backend failed: {0}")]
Backend(#[from] V3Error),
#[error("WebServer error: {0}")]
WebServer(#[from] server::WebServerError),
#[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error),
}

View File

@ -1,17 +1,17 @@
use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator}; use crate::block_allocator::{BlockAllocation, BlockAllocator};
use crate::infer::InferError; use crate::client;
use crate::infer::InferStreamResponse; use crate::client::{
use crate::validation::{ Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
}; };
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::{max, min}; use std::cmp::{max, min};
use std::collections::VecDeque; use std::collections::VecDeque;
use text_generation_client::v3::{ use text_generation_router::infer::InferError;
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, use text_generation_router::infer::InferStreamResponse;
use text_generation_router::validation::{
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
ValidStoppingParameters,
}; };
use text_generation_client::ChunksToString;
use text_generation_client::Input;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant; use tokio::time::Instant;
use tracing::{info_span, instrument, Instrument, Span}; use tracing::{info_span, instrument, Instrument, Span};
@ -337,8 +337,22 @@ impl State {
batch_requests.push(Request { batch_requests.push(Request {
id, id,
prefill_logprobs: entry.request.decoder_input_details, prefill_logprobs: entry.request.decoder_input_details,
input_chunks: Some(Input { input_chunks: Some(client::Input {
chunks: entry.request.inputs.clone(), chunks: entry
.request
.inputs
.clone()
.into_iter()
.map(|c| client::InputChunk {
chunk: Some(match c {
Chunk::Text(text) => client::Chunk::Text(text),
Chunk::Image(image) => client::Chunk::Image(client::Image {
data: image.data,
mimetype: image.mimetype,
}),
}),
})
.collect(),
}), }),
inputs: entry.request.inputs.chunks_to_string(), inputs: entry.request.inputs.chunks_to_string(),
truncate: entry.request.truncate, truncate: entry.request.truncate,

View File

@ -21,7 +21,7 @@ float-ord = "0.3.2"
serde = {version = "1.0.188", features = ["derive"]} serde = {version = "1.0.188", features = ["derive"]}
serde_json = "1.0" serde_json = "1.0"
tabled = "0.14.0" tabled = "0.14.0"
text-generation-client = { path = "../router/client" } text-generation-client = { path = "../backends/client" }
thiserror = "1.0.48" thiserror = "1.0.48"
tokenizers = { workspace = true } tokenizers = { workspace = true }
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] } tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }

View File

@ -1580,16 +1580,11 @@
"type": "object", "type": "object",
"required": [ "required": [
"model_id", "model_id",
"model_dtype",
"model_device_type",
"max_concurrent_requests", "max_concurrent_requests",
"max_best_of", "max_best_of",
"max_stop_sequences", "max_stop_sequences",
"max_input_tokens", "max_input_tokens",
"max_total_tokens", "max_total_tokens",
"waiting_served_ratio",
"max_batch_total_tokens",
"max_waiting_tokens",
"validation_workers", "validation_workers",
"max_client_batch_size", "max_client_batch_size",
"router", "router",
@ -1601,18 +1596,6 @@
"example": "null", "example": "null",
"nullable": true "nullable": true
}, },
"max_batch_size": {
"type": "integer",
"example": "null",
"nullable": true,
"minimum": 0
},
"max_batch_total_tokens": {
"type": "integer",
"format": "int32",
"example": "32000",
"minimum": 0
},
"max_best_of": { "max_best_of": {
"type": "integer", "type": "integer",
"example": "2", "example": "2",
@ -1644,19 +1627,6 @@
"example": "2048", "example": "2048",
"minimum": 0 "minimum": 0
}, },
"max_waiting_tokens": {
"type": "integer",
"example": "20",
"minimum": 0
},
"model_device_type": {
"type": "string",
"example": "cuda"
},
"model_dtype": {
"type": "string",
"example": "torch.float16"
},
"model_id": { "model_id": {
"type": "string", "type": "string",
"description": "Model info", "description": "Model info",
@ -1690,11 +1660,6 @@
"version": { "version": {
"type": "string", "type": "string",
"example": "0.5.0" "example": "0.5.0"
},
"waiting_served_ratio": {
"type": "number",
"format": "float",
"example": "1.2"
} }
} }
}, },

View File

@ -431,20 +431,18 @@ Options:
[env: LORA_ADAPTERS=] [env: LORA_ADAPTERS=]
``` ```
## DISABLE_USAGE_STATS ## USAGE_STATS
```shell ```shell
--disable-usage-stats --usage-stats <USAGE_STATS>
Disable sending of all usage statistics Control if anonymous usage stats are collected. Options are "on", "off" and "no-stack" Defaul is on
[env: DISABLE_USAGE_STATS=] [env: USAGE_STATS=]
[default: on]
``` Possible values:
## DISABLE_CRASH_REPORTS - on: Default option, usage statistics are collected anonymously
```shell - off: Disables all collection of usage statistics
--disable-crash-reports - no-stack: Doesn't send the error stack trace or error type, but allows sending a crash event
Disable sending of crash reports, but allow anonymous usage statistics
[env: DISABLE_CRASH_REPORTS=]
``` ```
## HELP ## HELP

View File

@ -36,6 +36,18 @@ To use LoRA in TGI, when starting the server, you can specify the list of LoRA m
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
``` ```
additionally, you can specify the path to the LoRA models using the `LORA_ADAPTERS_PATH` environment variable. For example:
```bash
LORA_ADAPTERS=myadapter=/some/path/to/adapter,myadapter2=/another/path/to/adapter
```
note it's possible to mix adapter_ids with adapter_id=adapter_path e.g.
```bash
LORA_ADAPTERS=predibase/dbpedia,myadapter=/path/to/dir/
```
In the server logs, you will see the following message: In the server logs, you will see the following message:
```txt ```txt

View File

@ -70,4 +70,6 @@ As of release 2.1.2 this is an example of the data collected:
## How to opt-out ## How to opt-out
You can easily opt out by passing the `--disable-usage-stats` to the text-generation-launcher command. This will disable all usage statistics. You can also pass `--disable-crash-reports` which disables sending specific crash reports, but allows anonymous usage statistics. By passing the `--usage-stats` to the text-generation-launcher you can control how much usage statistics are being collected.
`--usage-stats=no-stack` will not emit the stack traces from errors and the error types, but will continue to send start and stop events
`--usage-stats=off` will completely disable everything

View File

@ -168,6 +168,33 @@ impl std::fmt::Display for RopeScaling {
} }
} }
#[derive(Clone, Copy, Debug, ValueEnum)]
pub enum UsageStatsLevel {
/// Default option, usage statistics are collected anonymously
On,
/// Disables all collection of usage statistics
Off,
/// Doesn't send the error stack trace or error type, but allows sending a crash event
NoStack,
}
impl std::fmt::Display for UsageStatsLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// To keep in track with `server`.
match self {
UsageStatsLevel::On => {
write!(f, "on")
}
UsageStatsLevel::Off => {
write!(f, "off")
}
UsageStatsLevel::NoStack => {
write!(f, "no-stack")
}
}
}
}
/// App Configuration /// App Configuration
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)] #[clap(author, version, about, long_about = None)]
@ -466,13 +493,11 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
lora_adapters: Option<String>, lora_adapters: Option<String>,
/// Disable sending of all usage statistics /// Control if anonymous usage stats are collected.
#[clap(default_value = "false", long, env)] /// Options are "on", "off" and "no-stack"
disable_usage_stats: bool, /// Defaul is on.
#[clap(default_value = "on", long, env)]
/// Disable sending of crash reports, but allow anonymous usage statistics usage_stats: UsageStatsLevel,
#[clap(default_value = "false", long, env)]
disable_crash_reports: bool,
} }
#[derive(Debug)] #[derive(Debug)]
@ -1218,12 +1243,8 @@ fn spawn_webserver(
]; ];
// Pass usage stats flags to router // Pass usage stats flags to router
if args.disable_usage_stats { router_args.push("--usage-stats".to_string());
router_args.push("--disable-usage-stats".to_string()); router_args.push(args.usage_stats.to_string());
}
if args.disable_crash_reports {
router_args.push("--disable-crash-reports".to_string());
}
// Grammar support // Grammar support
if args.disable_grammar_support { if args.disable_grammar_support {

View File

@ -7,25 +7,18 @@ edition.workspace = true
authors.workspace = true authors.workspace = true
homepage.workspace = true homepage.workspace = true
[lib]
path = "src/lib.rs"
[[bin]]
name = "text-generation-router"
path = "src/main.rs"
[dependencies] [dependencies]
async-trait = "0.1.74"
async-stream = "0.3.5" async-stream = "0.3.5"
axum = { version = "0.7", features = ["json"] } axum = { version = "0.7", features = ["json"] }
axum-tracing-opentelemetry = "0.16" axum-tracing-opentelemetry = "0.16"
text-generation-client = { path = "client" }
clap = { version = "4.4.5", features = ["derive", "env"] } clap = { version = "4.4.5", features = ["derive", "env"] }
futures = "0.3.28" futures = "0.3.28"
hf-hub = { workspace = true } hf-hub = { workspace = true }
itertools = "0.10" itertools = "0.10"
jsonschema = { version = "0.17.1", features = ["draft202012"] } jsonschema = { version = "0.17.1", features = ["draft202012"] }
metrics = "0.23.0" metrics = { workspace = true }
metrics-exporter-prometheus = { version = "0.15.1", features = [] } metrics-exporter-prometheus = { workspace = true }
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.13.0" opentelemetry-otlp = "0.13.0"
@ -55,6 +48,7 @@ base64 = { workspace = true }
sysinfo = "0.30.13" sysinfo = "0.30.13"
uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] } uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] }
csv = "1.3.0" csv = "1.3.0"
ureq = "=2.9"
[build-dependencies] [build-dependencies]

View File

@ -1 +0,0 @@
*

View File

@ -1 +0,0 @@
*

View File

@ -1,528 +1,85 @@
/// Batching and inference logic use crate::infer::InferError;
use crate::infer::v3::queue::{Entry, Queue}; use crate::{
use crate::infer::{ ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken,
GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler,
}; };
use crate::validation::ValidGenerateRequest; use minijinja::{Environment, ErrorKind, Template};
use crate::{FinishReason, PrefillToken, Token}; use minijinja_contrib::pycompat;
use nohash_hasher::IntMap;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient};
use text_generation_client::ClientError;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{info_span, instrument, Instrument, Span};
pub(crate) struct SchedulerV3 { /// Raise a exception (custom function) used in the chat templates
/// Request queue pub(crate) fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
queue: Queue, Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
/// Notify batcher on queue appends
batching_task_notifier: Arc<Notify>,
} }
impl SchedulerV3 { #[derive(Clone)]
#[allow(clippy::too_many_arguments)] pub(crate) struct ChatTemplate {
template: Template<'static, 'static>,
bos_token: Option<String>,
eos_token: Option<String>,
use_default_tool_template: bool,
}
impl ChatTemplate {
pub(crate) fn new( pub(crate) fn new(
client: ShardedClient, template: String,
waiting_served_ratio: f32, bos_token: Option<TokenizerConfigToken>,
max_batch_prefill_tokens: u32, eos_token: Option<TokenizerConfigToken>,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
requires_padding: bool,
window_size: Option<u32>,
speculate: u32,
generation_health: Arc<AtomicBool>,
) -> Self { ) -> Self {
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { let mut env = Box::new(Environment::new());
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") // enable things like .strip() or .capitalize()
} else { env.set_unknown_method_callback(pycompat::unknown_method_callback);
false let template_str = template.into_boxed_str();
}; env.add_function("raise_exception", raise_exception);
let block_size = if flashdecoding { 256 } else { 16 };
let queue = Queue::new(
requires_padding,
block_size,
window_size,
speculate,
max_batch_total_tokens,
);
let batching_task_notifier = Arc::new(Notify::new());
// Spawn batching background task that contains all the inference logic // check if contains the tools variable within the template
tokio::spawn(batching_task( let use_default_tool_template =
client, !template_str.as_ref().replace(' ', "").contains("{{tools}}");
waiting_served_ratio, // leaking env and template_str as read-only, static resources for performance.
max_batch_prefill_tokens, let template = Box::leak(env)
max_batch_total_tokens, .template_from_str(Box::leak(template_str))
max_waiting_tokens, .unwrap();
max_batch_size,
queue.clone(),
batching_task_notifier.clone(),
generation_health,
));
Self { Self {
queue, template,
batching_task_notifier, bos_token: bos_token.map(|token| token.as_str().to_string()),
eos_token: eos_token.map(|token| token.as_str().to_string()),
use_default_tool_template,
} }
} }
}
impl Scheduler for SchedulerV3 { pub(crate) fn apply(
#[instrument(skip_all)]
fn schedule(
&self, &self,
request: ValidGenerateRequest, mut messages: Vec<Message>,
permit: OwnedSemaphorePermit, grammar_with_prompt: Option<(GrammarType, String)>,
) -> Result<GenerateStreamResponse, InferError> { ) -> Result<String, InferError> {
// MPSC channel to communicate with the background batching task if self.use_default_tool_template {
let (response_tx, response_rx) = mpsc::unbounded_channel(); if let Some(last_message) = messages.last_mut() {
let input_length = request.input_length; if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
last_message.content.push(MessageChunk::Text {
// Append the request to the queue text: format!("\n---\n{}\n{}", tool_prompt, tools),
self.queue.append(Entry {
request,
response_tx,
span: Span::current(),
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
block_allocation: None,
});
// Notify the background task that we have a new entry in the queue that needs
// to be batched
self.batching_task_notifier.notify_one();
// Return stream
Ok((
permit,
input_length,
UnboundedReceiverStream::new(response_rx),
))
}
}
/// Batching logic
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
#[allow(clippy::too_many_arguments)]
pub(crate) async fn batching_task(
mut client: ShardedClient,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
queue: Queue,
notifier: Arc<Notify>,
generation_health: Arc<AtomicBool>,
) {
// Infinite loop
loop {
// Wait for a notification from the Infer struct
notifier.notified().await;
// Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue
while let Some((mut entries, batch, span)) = queue
.next_batch(
None,
max_batch_size,
max_batch_prefill_tokens,
max_batch_total_tokens,
)
.await
{
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
.instrument(span)
.await;
let mut waiting_tokens = 1;
// We loop until we do not receive any cached batch from the inference server (== until
// all requests have met their stopping criteria)
while let Some(batch) = cached_batch {
// Get current batch info
let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens;
let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
None
} else {
// Minimum batch size
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
};
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
.await
{
// Tracking metrics
if min_size.is_some() {
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
.increment(1);
} else {
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
.increment(1);
}
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
// Add relationships
span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span);
// Update entry
entry.temp_span = Some(entry_waiting_span);
}); });
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch =
prefill(&mut client, new_batch, &mut new_entries, &generation_health)
.instrument(span)
.await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
entries.extend(new_entries);
batches.push(new_cached_batch);
}
} }
// Create span for this batch to add context to inference calls
let next_batch_size = entries.len();
let next_batch_span =
info_span!(parent: None, "batch", batch_size = next_batch_size);
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer");
// Add relationships
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.temp_span = Some(entry_batch_span);
});
cached_batch = decode(&mut client, batches, &mut entries, &generation_health)
.instrument(next_batch_span)
.await;
waiting_tokens += 1;
}
metrics::gauge!("tgi_batch_current_size").set(0.0);
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
}
}
}
#[instrument(skip_all)]
async fn prefill(
client: &mut ShardedClient,
batch: Batch,
entries: &mut IntMap<u64, Entry>,
generation_health: &Arc<AtomicBool>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_id = batch.id;
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
match client.prefill(batch).await {
Ok((generations, next_batch, timings)) => {
// Update health
generation_health.store(true, Ordering::SeqCst);
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
// Update health
generation_health.store(false, Ordering::SeqCst);
let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries);
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
None
}
}
}
#[instrument(skip_all)]
async fn decode(
client: &mut ShardedClient,
batches: Vec<CachedBatch>,
entries: &mut IntMap<u64, Entry>,
generation_health: &Arc<AtomicBool>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
match client.decode(batches).await {
Ok((generations, next_batch, timings)) => {
// Update health
generation_health.store(true, Ordering::SeqCst);
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
.record(concat_duration.as_secs_f64());
}
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
generation_health.store(false, Ordering::SeqCst);
for id in batch_ids {
let _ = client.clear_cache(Some(id)).await;
}
send_errors(err, entries);
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
None
}
}
}
/// Filter a `batch` and remove all requests not present in `entries`
#[instrument(skip_all)]
async fn filter_batch(
client: &mut ShardedClient,
next_batch: Option<CachedBatch>,
entries: &IntMap<u64, Entry>,
) -> Option<CachedBatch> {
let mut batch = next_batch?;
// No need to filter
if batch.size as usize == entries.len() {
return Some(batch);
}
let id = batch.id;
// Retain only requests that are still in entries
batch.request_ids.retain(|id| entries.contains_key(id));
if batch.request_ids.is_empty() {
// All requests have been filtered out
// Next batch is now empty
// Clear it from the Python shards cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client.clear_cache(Some(id)).await.unwrap();
None
} else {
// Filter Python shard cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client.filter_batch(id, batch.request_ids).await.unwrap()
}
}
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
/// and filter entries
#[instrument(skip_all)]
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
generations.into_iter().for_each(|generation| {
let id = generation.request_id;
// Get entry
// We can `expect` here as the request id should always be in the entries
let entry = entries
.get(&id)
.expect("ID not found in entries. This is a bug.");
// Create and enter a span to link this function back to the entry
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
// Send generation responses back to the infer task
// If the receive an error from the Flume channel, it means that the client dropped the
// request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| {
tracing::error!("Entry response channel error.");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
err
}).unwrap_or(true);
if stopped {
entries.remove(&id).expect("ID not found in entries. This is a bug.");
}
});
}
/// Send responses through the `entry` response channel
fn send_responses(
generation: Generation,
entry: &Entry,
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected
if entry.response_tx.is_closed() {
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
return Ok(true);
}
let mut stopped = false;
if let Some(prefill_tokens) = generation.prefill_tokens {
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
let prefill_tokens = prefill_tokens
.ids
.into_iter()
.zip(prefill_tokens.logprobs)
.zip(prefill_tokens.texts)
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect();
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
}
// Create last Token
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
let mut iterator = tokens_
.ids
.into_iter()
.zip(tokens_.logprobs)
.zip(tokens_.texts)
.zip(tokens_.is_special)
.enumerate()
.peekable();
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
let token = Token {
id,
text,
logprob,
special,
};
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
top_tokens_
.ids
.iter()
.zip(top_tokens_.logprobs.iter())
.zip(top_tokens_.texts.iter())
.zip(top_tokens_.is_special.iter())
.map(|(((&id, &logprob), text), &special)| Token {
id,
text: text.to_string(),
logprob,
special,
})
.collect()
} else {
vec![]
};
match (&generation.generated_text, iterator.peek()) {
(Some(generated_text), None) => {
// Generation has ended
stopped = true;
// Send message
entry.response_tx.send(Ok(InferStreamResponse::End {
token,
top_tokens,
generated_text: GeneratedText::from(generated_text.clone()),
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
}
_ => {
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
} }
} }
}
Ok(stopped) let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
}
/// Send errors to Infer for all `entries` self.template
#[instrument(skip_all)] .render(ChatTemplateInputs {
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) { messages,
entries.drain().for_each(|(_, entry)| { bos_token: self.bos_token.as_deref(),
// Create and enter a span to link this function back to the entry eos_token: self.eos_token.as_deref(),
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); add_generation_prompt: true,
let err = InferError::GenerationError(error.to_string()); tools: None,
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1); tools_prompt: None,
tracing::error!("{err}"); })
.map_err(InferError::TemplateError)
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Err(err))
.unwrap_or(());
});
}
impl From<text_generation_client::v3::GeneratedText> for GeneratedText {
fn from(value: text_generation_client::v3::GeneratedText) -> Self {
let v3_finish_reason =
text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap();
let finish_reason = match v3_finish_reason {
text_generation_client::v3::FinishReason::Length => FinishReason::Length,
text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence,
};
Self {
text: value.text,
generated_tokens: value.generated_tokens,
finish_reason,
seed: value.seed,
}
} }
} }
// tests // tests
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::infer::raise_exception; use crate::infer::chat_template::raise_exception;
use crate::{ChatTemplateInputs, TextMessage}; use crate::{ChatTemplateInputs, TextMessage};
use minijinja::Environment; use minijinja::Environment;

View File

@ -1,34 +0,0 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use text_generation_client::Health;
#[derive(Clone)]
pub(crate) struct HealthCheck {
client: Arc<dyn Health + Send + Sync>,
generation_health: Arc<AtomicBool>,
}
impl HealthCheck {
pub(crate) fn new(
client: Arc<dyn Health + Send + Sync>,
generation_health: Arc<AtomicBool>,
) -> Self {
Self {
client,
generation_health,
}
}
pub(crate) async fn check(&mut self) -> bool {
let value = if self.generation_health.load(Ordering::SeqCst) {
// Generation is healthy, we only check that the shards can allocate on device
self.client.device_health().await
} else {
self.client.model_health().await
}
.is_ok();
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
value
}
}

View File

@ -1,23 +1,18 @@
mod health; // pub(crate) mod v2;
pub(crate) mod v2; mod chat_template;
pub(crate) mod v3; pub mod tool_grammar;
pub(crate) use health::HealthCheck;
use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
use crate::GrammarType;
use crate::{ use crate::{
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, ToolChoice, Message, PrefillToken, Token,
};
use crate::{
FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools,
}; };
use async_trait::async_trait;
use chat_template::ChatTemplate;
use futures::future::try_join_all; use futures::future::try_join_all;
use minijinja::{Environment, ErrorKind, Template}; use minijinja::ErrorKind;
use minijinja_contrib::pycompat; use std::sync::atomic::{AtomicBool, Ordering};
use serde_json::{json, Map, Value};
use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use thiserror::Error; use thiserror::Error;
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
@ -26,12 +21,14 @@ use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tracing::instrument; use tracing::instrument;
pub(crate) trait Scheduler { #[async_trait]
pub trait Backend {
fn schedule( fn schedule(
&self, &self,
request: ValidGenerateRequest, request: ValidGenerateRequest,
permit: OwnedSemaphorePermit, ) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError>;
) -> Result<GenerateStreamResponse, InferError>;
async fn health(&self, current_health: bool) -> bool;
} }
/// Inference struct /// Inference struct
@ -39,18 +36,20 @@ pub(crate) trait Scheduler {
pub struct Infer { pub struct Infer {
/// Validation /// Validation
validation: Validation, validation: Validation,
/// Request scheduler /// Request backend
scheduler: Arc<dyn Scheduler + Send + Sync>, backend: Arc<dyn Backend + Send + Sync>,
/// Chat template /// Chat template
chat_template: Option<ChatTemplate>, chat_template: Option<ChatTemplate>,
/// Inference limit /// Inference limit
limit_concurrent_requests: Arc<Semaphore>, limit_concurrent_requests: Arc<Semaphore>,
/// Backend health
backend_health: Arc<AtomicBool>,
} }
impl Infer { impl Infer {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub(crate) fn new( pub(crate) fn new(
scheduler: Arc<dyn Scheduler + Send + Sync>, backend: impl Backend + Send + Sync + 'static,
validation: Validation, validation: Validation,
max_concurrent_requests: usize, max_concurrent_requests: usize,
tokenizer_config: HubTokenizerConfig, tokenizer_config: HubTokenizerConfig,
@ -71,18 +70,22 @@ impl Infer {
// Inference limit with a semaphore // Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
// Backend health
let backend_health = Arc::new(AtomicBool::new(false));
Self { Self {
validation, validation,
scheduler, backend: Arc::new(backend),
chat_template, chat_template,
limit_concurrent_requests: semaphore, limit_concurrent_requests: semaphore,
backend_health,
} }
} }
/// Add a new request to the queue and return a stream of InferStreamResponse /// Add a new request to the queue and return a stream of InferStreamResponse
#[instrument(skip_all)] #[instrument(skip_all)]
pub(crate) async fn generate_stream( pub(crate) async fn generate_stream<'a>(
&self, &'a self,
request: GenerateRequest, request: GenerateRequest,
) -> Result<GenerateStreamResponse, InferError> { ) -> Result<GenerateStreamResponse, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore // Limit concurrent requests by acquiring a permit from the semaphore
@ -103,7 +106,10 @@ impl Infer {
err err
})?; })?;
self.scheduler.schedule(valid_request, permit) let input_length = valid_request.input_length;
let generation_stream = self.backend.schedule(valid_request)?;
Ok((permit, input_length, generation_stream))
} }
/// Tokenizer the input /// Tokenizer the input
@ -155,7 +161,7 @@ impl Infer {
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
// Create stream and keep semaphore permit as long as generate lives // Create stream and keep semaphore permit as long as generate lives
let (_permit, _input_length, mut stream) = self.generate_stream(request).await?; let (_permit, _input_length, stream) = self.generate_stream(request).await?;
// Return values // Return values
let mut result_prefill = Vec::new(); let mut result_prefill = Vec::new();
@ -165,6 +171,8 @@ impl Infer {
let mut result_start = None; let mut result_start = None;
let mut result_queued = None; let mut result_queued = None;
let mut stream = Box::pin(stream);
// Iterate on stream // Iterate on stream
while let Some(response) = stream.next().await { while let Some(response) = stream.next().await {
match response? { match response? {
@ -256,207 +264,15 @@ impl Infer {
let best_response = infer_responses.remove(max_index); let best_response = infer_responses.remove(max_index);
Ok((best_response, infer_responses)) Ok((best_response, infer_responses))
} }
}
/// Raise a exception (custom function) used in the chat templates #[instrument(skip(self))]
fn raise_exception(err_text: String) -> Result<String, minijinja::Error> { pub(crate) async fn health(&self) -> bool {
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text)) let health = self
} .backend
.health(self.backend_health.load(Ordering::SeqCst))
#[derive(Clone)] .await;
struct ChatTemplate { self.backend_health.store(health, Ordering::SeqCst);
template: Template<'static, 'static>, health
bos_token: Option<String>,
eos_token: Option<String>,
use_default_tool_template: bool,
}
impl ChatTemplate {
fn new(
template: String,
bos_token: Option<TokenizerConfigToken>,
eos_token: Option<TokenizerConfigToken>,
) -> Self {
let mut env = Box::new(Environment::new());
// enable things like .strip() or .capitalize()
env.set_unknown_method_callback(pycompat::unknown_method_callback);
let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception);
// check if contains the tools variable within the template
let use_default_tool_template =
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
// leaking env and template_str as read-only, static resources for performance.
let template = Box::leak(env)
.template_from_str(Box::leak(template_str))
.unwrap();
Self {
template,
bos_token: bos_token.map(|token| token.as_str().to_string()),
eos_token: eos_token.map(|token| token.as_str().to_string()),
use_default_tool_template,
}
}
fn apply(
&self,
mut messages: Vec<Message>,
grammar_with_prompt: Option<(GrammarType, String)>,
) -> Result<String, InferError> {
if self.use_default_tool_template {
if let Some(last_message) = messages.last_mut() {
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
last_message.content.push(MessageChunk::Text {
text: format!("\n---\n{}\n{}", tool_prompt, tools),
});
}
}
}
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
self.template
.render(ChatTemplateInputs {
messages,
bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(),
add_generation_prompt: true,
tools: None,
tools_prompt: None,
})
.map_err(InferError::TemplateError)
}
}
pub struct ToolGrammar {}
impl ToolGrammar {
// find a tool by name
fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
tools
.iter()
.find(|tool| tool.function.name == name)
.cloned()
.ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
}
pub fn apply(
tools: Option<Vec<Tool>>,
tool_choice: ToolChoice,
) -> Result<Option<Tools>, InferError> {
// if no tools are provided, we return None
let tools = match tools {
Some(tools) if !tools.is_empty() => tools,
_ => return Ok(None),
};
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
// if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
vec![Self::find_tool_by_name(&tools, &name)?]
}
ToolType::Function { function } => {
vec![Self::find_tool_by_name(&tools, &function.name)?]
}
ToolType::OneOf => tools,
ToolType::NoTool => return Ok(None),
};
// adds the error notification function for LLM feedback if required
let mut text_response_properties = Map::new();
text_response_properties.insert(
"error".to_string(),
serde_json::json!({
"type": "string",
"description": "The error or issue to notify"
}),
);
text_response_properties.insert(
"_name".to_string(),
serde_json::json!({
"type": "string",
"const": "notify_error"
}),
);
let functions: HashMap<String, serde_json::Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();
// Clone the existing parameters, which are expected to be a JSON object
let mut params = if let Value::Object(params) = &func.arguments {
params.clone()
} else {
Map::new()
};
// Insert the function's description at the top level, outside of properties
params.insert(
"description".to_string(),
Value::String(func.description.clone().unwrap_or_default()),
);
// Ensure 'properties' exists and is an object
let properties = params
.entry("properties".to_string())
.or_insert_with(|| json!({}))
.as_object_mut()
.unwrap();
// Insert the constant for the function name inside 'properties'
properties.insert(
"_name".to_string(),
json!({
"type": "string",
"const": func.name.clone(),
// "description": "The name of the function"
}),
);
// Check if 'required' exists, and it is an array. If not, create an empty array.
let required = params
.entry("required".to_string())
.or_insert_with(|| json!([]))
.as_array_mut()
.unwrap();
// Add 'name' to the 'required' array if it is not already present
if !required.iter().any(|r| r == "_name") {
required.push(json!("_name"));
}
(func.name, Value::Object(params))
})
.chain([(
"notify_error".to_string(),
serde_json::json!({
"properties": text_response_properties,
"required": ["error", "_name"],
"type": "object"
}),
)])
.collect();
let tools = Tools {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.chain(std::iter::once(FunctionRef {
ref_path: "#/$functions/notify_error".to_string(),
}))
.collect(),
},
};
Ok(Some(tools))
} }
} }
@ -468,15 +284,15 @@ pub(crate) type GenerateStreamResponse = (
); );
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct GeneratedText { pub struct GeneratedText {
pub(crate) text: String, pub text: String,
pub(crate) generated_tokens: u32, pub generated_tokens: u32,
pub(crate) finish_reason: FinishReason, pub finish_reason: FinishReason,
pub(crate) seed: Option<u64>, pub seed: Option<u64>,
} }
#[derive(Debug)] #[derive(Debug)]
pub(crate) enum InferStreamResponse { pub enum InferStreamResponse {
// Optional first message // Optional first message
Prefill(Vec<PrefillToken>), Prefill(Vec<PrefillToken>),
// Intermediate messages // Intermediate messages

View File

@ -0,0 +1,135 @@
use crate::infer::InferError;
use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools};
use serde_json::{json, Map, Value};
use std::collections::HashMap;
pub(crate) struct ToolGrammar {}
impl ToolGrammar {
// find a tool by name
fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
tools
.iter()
.find(|tool| tool.function.name == name)
.cloned()
.ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
}
pub fn apply(
tools: Option<Vec<Tool>>,
tool_choice: ToolChoice,
) -> Result<Option<Tools>, InferError> {
// if no tools are provided, we return None
let tools = match tools {
Some(tools) if !tools.is_empty() => tools,
_ => return Ok(None),
};
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
// if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
vec![Self::find_tool_by_name(&tools, &name)?]
}
ToolType::Function { function } => {
vec![Self::find_tool_by_name(&tools, &function.name)?]
}
ToolType::OneOf => tools,
ToolType::NoTool => return Ok(None),
};
// adds the error notification function for LLM feedback if required
let mut text_response_properties = Map::new();
text_response_properties.insert(
"error".to_string(),
serde_json::json!({
"type": "string",
"description": "The error or issue to notify"
}),
);
text_response_properties.insert(
"_name".to_string(),
serde_json::json!({
"type": "string",
"const": "notify_error"
}),
);
let functions: HashMap<String, serde_json::Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();
// Clone the existing parameters, which are expected to be a JSON object
let mut params = if let Value::Object(params) = &func.arguments {
params.clone()
} else {
Map::new()
};
// Insert the function's description at the top level, outside of properties
params.insert(
"description".to_string(),
Value::String(func.description.clone().unwrap_or_default()),
);
// Ensure 'properties' exists and is an object
let properties = params
.entry("properties".to_string())
.or_insert_with(|| json!({}))
.as_object_mut()
.unwrap();
// Insert the constant for the function name inside 'properties'
properties.insert(
"_name".to_string(),
json!({
"type": "string",
"const": func.name.clone(),
// "description": "The name of the function"
}),
);
// Check if 'required' exists, and it is an array. If not, create an empty array.
let required = params
.entry("required".to_string())
.or_insert_with(|| json!([]))
.as_array_mut()
.unwrap();
// Add 'name' to the 'required' array if it is not already present
if !required.iter().any(|r| r == "_name") {
required.push(json!("_name"));
}
(func.name, Value::Object(params))
})
.chain([(
"notify_error".to_string(),
serde_json::json!({
"properties": text_response_properties,
"required": ["error", "_name"],
"type": "object"
}),
)])
.collect();
let tools = Tools {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.chain(std::iter::once(FunctionRef {
ref_path: "#/$functions/notify_error".to_string(),
}))
.collect(),
},
};
Ok(Some(tools))
}
}

View File

@ -1,4 +1,4 @@
mod queue; mod queue;
mod scheduler; mod scheduler;
pub(crate) use scheduler::SchedulerV2; pub(crate) use scheduler::BackendV2;

View File

@ -1,7 +1,7 @@
/// Batching and inference logic /// Batching and inference logic
use crate::infer::v2::queue::{Entry, Queue}; use crate::infer::v2::queue::{Entry, Queue};
use crate::infer::{ use crate::infer::{
GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler, Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse,
}; };
use crate::validation::ValidGenerateRequest; use crate::validation::ValidGenerateRequest;
use crate::{FinishReason, PrefillToken, Token}; use crate::{FinishReason, PrefillToken, Token};
@ -18,14 +18,14 @@ use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{info_span, instrument, Instrument, Span}; use tracing::{info_span, instrument, Instrument, Span};
pub(crate) struct SchedulerV2 { pub(crate) struct BackendV2 {
/// Request queue /// Request queue
queue: Queue, queue: Queue,
/// Notify batcher on queue appends /// Notify batcher on queue appends
batching_task_notifier: Arc<Notify>, batching_task_notifier: Arc<Notify>,
} }
impl SchedulerV2 { impl BackendV2 {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub(crate) fn new( pub(crate) fn new(
client: ShardedClient, client: ShardedClient,
@ -69,7 +69,7 @@ impl SchedulerV2 {
} }
} }
impl Scheduler for SchedulerV2 { impl Backend for BackendV2 {
#[instrument(skip_all)] #[instrument(skip_all)]
fn schedule( fn schedule(
&self, &self,

View File

@ -1,5 +0,0 @@
mod block_allocator;
mod queue;
mod scheduler;
pub(crate) use scheduler::SchedulerV3;

View File

@ -1,11 +1,12 @@
/// Text Generation Inference Webserver /// Text Generation Inference Webserver
pub mod config; pub mod config;
mod infer; pub mod infer;
pub mod server; pub mod server;
mod validation; pub mod validation;
#[cfg(feature = "kserve")] #[cfg(feature = "kserve")]
mod kserve; mod kserve;
pub mod logging;
pub mod usage_stats; pub mod usage_stats;
@ -148,12 +149,13 @@ pub struct Info {
pub model_id: String, pub model_id: String,
#[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")] #[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
pub model_sha: Option<String>, pub model_sha: Option<String>,
#[schema(example = "torch.float16")] // #[schema(example = "torch.float16")]
pub model_dtype: String, // pub model_dtype: String,
#[schema(example = "cuda")] // #[schema(example = "cuda")]
pub model_device_type: String, // pub model_device_type: String,
#[schema(nullable = true, example = "text-generation")] #[schema(nullable = true, example = "text-generation")]
pub model_pipeline_tag: Option<String>, pub model_pipeline_tag: Option<String>,
/// Router Parameters /// Router Parameters
#[schema(example = "128")] #[schema(example = "128")]
pub max_concurrent_requests: usize, pub max_concurrent_requests: usize,
@ -165,18 +167,11 @@ pub struct Info {
pub max_input_tokens: usize, pub max_input_tokens: usize,
#[schema(example = "2048")] #[schema(example = "2048")]
pub max_total_tokens: usize, pub max_total_tokens: usize,
#[schema(example = "1.2")]
pub waiting_served_ratio: f32,
#[schema(example = "32000")]
pub max_batch_total_tokens: u32,
#[schema(example = "20")]
pub max_waiting_tokens: usize,
#[schema(nullable = true, example = "null")]
pub max_batch_size: Option<usize>,
#[schema(example = "2")] #[schema(example = "2")]
pub validation_workers: usize, pub validation_workers: usize,
#[schema(example = "32")] #[schema(example = "32")]
pub max_client_batch_size: usize, pub max_client_batch_size: usize,
/// Router Info /// Router Info
#[schema(example = "text-generation-router")] #[schema(example = "text-generation-router")]
pub router: &'static str, pub router: &'static str,
@ -624,7 +619,7 @@ impl ChatCompletion {
message, message,
logprobs: return_logprobs logprobs: return_logprobs
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))), .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
finish_reason: details.finish_reason.to_string(), finish_reason: details.finish_reason.format(true),
}], }],
usage: Usage { usage: Usage {
prompt_tokens: details.prefill.len() as u32, prompt_tokens: details.prefill.len() as u32,
@ -1068,23 +1063,23 @@ impl From<CompatGenerateRequest> for GenerateRequest {
#[derive(Debug, Serialize, ToSchema)] #[derive(Debug, Serialize, ToSchema)]
pub struct PrefillToken { pub struct PrefillToken {
#[schema(example = 0)] #[schema(example = 0)]
id: u32, pub id: u32,
#[schema(example = "test")] #[schema(example = "test")]
text: String, pub text: String,
#[schema(nullable = true, example = - 0.34)] #[schema(nullable = true, example = - 0.34)]
logprob: f32, pub logprob: f32,
} }
#[derive(Debug, Serialize, ToSchema, Clone)] #[derive(Debug, Serialize, ToSchema, Clone)]
pub struct Token { pub struct Token {
#[schema(example = 0)] #[schema(example = 0)]
id: u32, pub id: u32,
#[schema(example = "test")] #[schema(example = "test")]
text: String, pub text: String,
#[schema(nullable = true, example = - 0.34)] #[schema(nullable = true, example = - 0.34)]
logprob: f32, pub logprob: f32,
#[schema(example = "false")] #[schema(example = "false")]
special: bool, pub special: bool,
} }
#[derive(Debug, Serialize, ToSchema)] #[derive(Debug, Serialize, ToSchema)]
@ -1102,7 +1097,7 @@ pub struct SimpleToken {
#[derive(Debug, Serialize, ToSchema)] #[derive(Debug, Serialize, ToSchema)]
#[serde(rename_all(serialize = "snake_case"))] #[serde(rename_all(serialize = "snake_case"))]
#[schema(example = "Length")] #[schema(example = "Length")]
pub(crate) enum FinishReason { pub enum FinishReason {
#[schema(rename = "length")] #[schema(rename = "length")]
Length, Length,
#[serde(rename = "eos_token")] #[serde(rename = "eos_token")]
@ -1122,6 +1117,15 @@ impl std::fmt::Display for FinishReason {
} }
} }
impl FinishReason {
pub fn format(&self, use_stop: bool) -> String {
match self {
FinishReason::EndOfSequenceToken if use_stop => "stop".to_string(),
_ => self.to_string(),
}
}
}
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
pub(crate) struct BestOfSequence { pub(crate) struct BestOfSequence {
#[schema(example = "test")] #[schema(example = "test")]
@ -1162,6 +1166,12 @@ pub(crate) struct GenerateResponse {
pub details: Option<Details>, pub details: Option<Details>,
} }
#[derive(Serialize, ToSchema)]
pub(crate) struct ChatTokenizeResponse {
pub(crate) tokenize_response: TokenizeResponse,
pub(crate) templated_text: String,
}
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
#[serde(transparent)] #[serde(transparent)]
pub(crate) struct TokenizeResponse(Vec<SimpleToken>); pub(crate) struct TokenizeResponse(Vec<SimpleToken>);

81
router/src/logging.rs Normal file
View File

@ -0,0 +1,81 @@
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
/// - otlp_service_name service name to appear in APM
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
pub fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
let mut layers = Vec::new();
// STDOUT/STDERR layer
let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
let fmt_layer = tracing_subscriber::fmt::layer()
.with_file(true)
.with_ansi(ansi)
.with_line_number(true);
let fmt_layer = match json_output {
true => fmt_layer.json().flatten_event(true).boxed(),
false => fmt_layer.boxed(),
};
layers.push(fmt_layer);
// OpenTelemetry tracing layer
if let Some(otlp_endpoint) = otlp_endpoint {
global::set_text_map_propagator(TraceContextPropagator::new());
let tracer = opentelemetry_otlp::new_pipeline()
.tracing()
.with_exporter(
opentelemetry_otlp::new_exporter()
.tonic()
.with_endpoint(otlp_endpoint),
)
.with_trace_config(
trace::config()
.with_resource(Resource::new(vec![KeyValue::new(
"service.name",
otlp_service_name,
)]))
.with_sampler(Sampler::AlwaysOn),
)
.install_batch(opentelemetry::runtime::Tokio);
if let Ok(tracer) = tracer {
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
init_tracing_opentelemetry::init_propagator().unwrap();
};
}
// Filter events with LOG_LEVEL
let varname = "LOG_LEVEL";
let env_filter = if let Ok(log_level) = std::env::var(varname) {
// Override to avoid simple logs to be spammed with tokio level informations
let log_level = match &log_level[..] {
"warn" => "text_generation_launcher=warn,text_generation_router=warn",
"info" => "text_generation_launcher=info,text_generation_router=info",
"debug" => "text_generation_launcher=debug,text_generation_router=debug",
log_level => log_level,
};
EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.parse_lossy(log_level)
} else {
EnvFilter::new("info")
};
tracing_subscriber::registry()
.with(env_filter)
.with(layers)
.init();
}

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,5 @@
use crate::config::Config; use crate::config::Config;
use clap::ValueEnum;
use csv::ReaderBuilder; use csv::ReaderBuilder;
use reqwest::header::HeaderMap; use reqwest::header::HeaderMap;
use serde::Serialize; use serde::Serialize;
@ -13,6 +14,13 @@ use uuid::Uuid;
const TELEMETRY_URL: &str = "https://huggingface.co/api/telemetry/tgi"; const TELEMETRY_URL: &str = "https://huggingface.co/api/telemetry/tgi";
#[derive(Copy, Clone, Debug, Serialize, ValueEnum)]
pub enum UsageStatsLevel {
On,
NoStack,
Off,
}
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
pub struct UserAgent { pub struct UserAgent {
pub uid: String, pub uid: String,
@ -71,72 +79,69 @@ impl UsageStatsEvent {
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
pub struct Args { pub struct Args {
model_config: Option<Config>, model_config: Option<Config>,
tokenizer_config: Option<String>, tokenizer_class: Option<String>,
max_concurrent_requests: usize, max_concurrent_requests: usize,
max_best_of: usize, max_best_of: usize,
max_stop_sequences: usize, max_stop_sequences: usize,
max_top_n_tokens: u32, max_top_n_tokens: u32,
max_input_tokens: usize, max_input_tokens: usize,
max_total_tokens: usize, max_total_tokens: usize,
waiting_served_ratio: f32, // waiting_served_ratio: f32,
max_batch_prefill_tokens: u32, // max_batch_prefill_tokens: u32,
max_batch_total_tokens: Option<u32>, // max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize, // max_waiting_tokens: usize,
max_batch_size: Option<usize>, // max_batch_size: Option<usize>,
revision: Option<String>, revision: Option<String>,
validation_workers: usize, validation_workers: usize,
messages_api_enabled: bool, messages_api_enabled: bool,
disable_grammar_support: bool, disable_grammar_support: bool,
max_client_batch_size: usize, max_client_batch_size: usize,
disable_usage_stats: bool, usage_stats_level: UsageStatsLevel,
disable_crash_reports: bool,
} }
impl Args { impl Args {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn new( pub fn new(
model_config: Option<Config>, model_config: Option<Config>,
tokenizer_config: Option<String>, tokenizer_class: Option<String>,
max_concurrent_requests: usize, max_concurrent_requests: usize,
max_best_of: usize, max_best_of: usize,
max_stop_sequences: usize, max_stop_sequences: usize,
max_top_n_tokens: u32, max_top_n_tokens: u32,
max_input_tokens: usize, max_input_tokens: usize,
max_total_tokens: usize, max_total_tokens: usize,
waiting_served_ratio: f32, // waiting_served_ratio: f32,
max_batch_prefill_tokens: u32, // max_batch_prefill_tokens: u32,
max_batch_total_tokens: Option<u32>, // max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize, // max_waiting_tokens: usize,
max_batch_size: Option<usize>, // max_batch_size: Option<usize>,
revision: Option<String>, revision: Option<String>,
validation_workers: usize, validation_workers: usize,
messages_api_enabled: bool, messages_api_enabled: bool,
disable_grammar_support: bool, disable_grammar_support: bool,
max_client_batch_size: usize, max_client_batch_size: usize,
disable_usage_stats: bool, usage_stats_level: UsageStatsLevel,
disable_crash_reports: bool,
) -> Self { ) -> Self {
Self { Self {
model_config, model_config,
tokenizer_config, tokenizer_class,
max_concurrent_requests, max_concurrent_requests,
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens, max_top_n_tokens,
max_input_tokens, max_input_tokens,
max_total_tokens, max_total_tokens,
waiting_served_ratio, // waiting_served_ratio,
max_batch_prefill_tokens, // max_batch_prefill_tokens,
max_batch_total_tokens, // max_batch_total_tokens,
max_waiting_tokens, // max_waiting_tokens,
max_batch_size, // max_batch_size,
revision, revision,
validation_workers, validation_workers,
messages_api_enabled, messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
disable_usage_stats, usage_stats_level,
disable_crash_reports,
} }
} }
} }

View File

@ -5,13 +5,12 @@ use crate::{
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor, GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
}; };
use base64::{engine::general_purpose::STANDARD, Engine}; use base64::{engine::general_purpose::STANDARD, Engine};
use image::{io::Reader as ImageReader, ImageFormat}; use image::{ImageFormat, ImageReader};
use jsonschema::{Draft, JSONSchema}; use jsonschema::{Draft, JSONSchema};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use serde_json::Value; use serde_json::Value;
use std::io::Cursor; use std::io::Cursor;
use std::iter; use std::iter;
use text_generation_client::{Chunk, Image, InputChunk};
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokio::sync::mpsc; use tokio::sync::mpsc;
@ -96,7 +95,7 @@ impl Validation {
&self, &self,
inputs: String, inputs: String,
truncate: Option<usize>, truncate: Option<usize>,
) -> Result<Option<(tokenizers::Encoding, Vec<InputChunk>)>, ValidationError> { ) -> Result<Option<(tokenizers::Encoding, Vec<Chunk>)>, ValidationError> {
// If we have a fast tokenizer // If we have a fast tokenizer
if let Some(sender) = &self.sender { if let Some(sender) = &self.sender {
// Create response channel // Create response channel
@ -122,7 +121,7 @@ impl Validation {
inputs: String, inputs: String,
truncate: Option<usize>, truncate: Option<usize>,
max_new_tokens: Option<u32>, max_new_tokens: Option<u32>,
) -> Result<(Vec<InputChunk>, usize, u32), ValidationError> { ) -> Result<(Vec<Chunk>, usize, u32), ValidationError> {
// If we have a fast tokenizer // If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
// Create response channel // Create response channel
@ -181,11 +180,7 @@ impl Validation {
input_length = input_length.saturating_sub(max_new_tokens as usize); input_length = input_length.saturating_sub(max_new_tokens as usize);
} }
Ok(( Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens))
vec![Chunk::Text(inputs).into()],
input_length,
max_new_tokens,
))
} }
} }
@ -589,7 +584,7 @@ fn prepare_input(
tokenizer: &Tokenizer, tokenizer: &Tokenizer,
config: Option<&Config>, config: Option<&Config>,
preprocessor_config: Option<&HubPreprocessorConfig>, preprocessor_config: Option<&HubPreprocessorConfig>,
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> { ) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
use Config::*; use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config { let (tokenizer_query, input_chunks) = match config {
@ -601,16 +596,16 @@ fn prepare_input(
let chunk_start = chunk.start(); let chunk_start = chunk.start();
let chunk_end = chunk.end(); let chunk_end = chunk.end();
if chunk_start != start { if chunk_start != start {
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into()); input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
tokenizer_query.push_str(&inputs[start..chunk_start]); tokenizer_query.push_str(&inputs[start..chunk_start]);
} }
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); input_chunks.push(Chunk::Image(Image { data, mimetype }));
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width)); tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
start = chunk_end; start = chunk_end;
} }
if start != inputs.len() { if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into()); input_chunks.push(Chunk::Text(inputs[start..].to_string()));
tokenizer_query.push_str(&inputs[start..]); tokenizer_query.push_str(&inputs[start..]);
} }
@ -618,7 +613,7 @@ fn prepare_input(
(tokenizer_query, input_chunks) (tokenizer_query, input_chunks)
} }
_ => (inputs.clone(), vec![Chunk::Text(inputs).into()]), _ => (inputs.clone(), vec![Chunk::Text(inputs)]),
}; };
// Get the number of tokens in the input // Get the number of tokens in the input
@ -631,18 +626,51 @@ fn prepare_input(
type TokenizerRequest = ( type TokenizerRequest = (
(String, Option<usize>), (String, Option<usize>),
oneshot::Sender<Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError>>, oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
Span, Span,
); );
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Image {
pub data: Vec<u8>,
pub mimetype: String,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum Chunk {
Text(String),
Image(Image),
}
/// Convert input chunks to a stringly-typed input for backwards
/// compat for backends that haven't implemented chunked inputs.
pub trait ChunksToString {
/// Convert chunks to string.
fn chunks_to_string(&self) -> String;
}
impl ChunksToString for Vec<Chunk> {
fn chunks_to_string(&self) -> String {
let mut output = String::new();
self.iter().for_each(|c| match &c {
Chunk::Text(text) => output.push_str(text),
Chunk::Image(Image { data, mimetype }) => {
let encoded = STANDARD.encode(data);
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
}
});
output
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) enum ValidGrammar { pub enum ValidGrammar {
Json(String), Json(String),
Regex(String), Regex(String),
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct ValidParameters { pub struct ValidParameters {
/// / exponential scaling output probability distribution /// / exponential scaling output probability distribution
pub temperature: f32, pub temperature: f32,
/// / restricting to the k highest probability elements /// / restricting to the k highest probability elements
@ -666,7 +694,7 @@ pub(crate) struct ValidParameters {
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct ValidStoppingParameters { pub struct ValidStoppingParameters {
/// / Maximum number of generated tokens /// / Maximum number of generated tokens
pub max_new_tokens: u32, pub max_new_tokens: u32,
/// / Optional stopping sequences /// / Optional stopping sequences
@ -677,8 +705,8 @@ pub(crate) struct ValidStoppingParameters {
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct ValidGenerateRequest { pub struct ValidGenerateRequest {
pub inputs: Vec<InputChunk>, pub inputs: Vec<Chunk>,
pub input_length: u32, pub input_length: u32,
pub truncate: u32, pub truncate: u32,
pub decoder_input_details: bool, pub decoder_input_details: bool,
@ -750,6 +778,8 @@ pub enum ValidationError {
InvalidImageContent(String), InvalidImageContent(String),
#[error("Could not fetch image: {0}")] #[error("Could not fetch image: {0}")]
FailedFetchImage(#[from] reqwest::Error), FailedFetchImage(#[from] reqwest::Error),
#[error("{0} modality is not supported")]
UnsupportedModality(&'static str),
} }
#[cfg(test)] #[cfg(test)]

View File

@ -34,7 +34,6 @@ def reshape_and_cache(
def paged_attention( def paged_attention(
out: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
@ -85,7 +84,7 @@ def paged_attention(
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied. # This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
if softcap is None: if softcap is None:
softcap = 0.0 softcap = 0.0
out2 = flash_attn_2_cuda.varlen_fwd( out = flash_attn_2_cuda.varlen_fwd(
query, query,
key_cache, key_cache,
value_cache, value_cache,
@ -108,13 +107,15 @@ def paged_attention(
False, # return softmax False, # return softmax
None, # generator None, # generator
) )
return out2[0] return out[0]
else: else:
if softcap is not None: if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping") raise RuntimeError("Paged attention doesn't support softcapping")
input_lengths = seqlen.input_lengths input_lengths = seqlen.input_lengths
from vllm._C import ops from vllm._C import ops
out = torch.empty_like(query)
use_v1 = max_s <= 8192 and ( use_v1 = max_s <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512 max_num_partitions == 1 or num_seqs * num_heads > 512
) )
@ -171,6 +172,10 @@ def paged_attention(
try: try:
is_ampere_or_newer = major >= 8 and minor >= 0
if not is_ampere_or_newer:
raise ImportError("FlashAttention only supports Ampere GPUs or newer.")
import flash_attn_2_cuda import flash_attn_2_cuda
V2 = True V2 = True
@ -200,13 +205,13 @@ except ImportError:
SUPPORTS_WINDOWING = V2 SUPPORTS_WINDOWING = V2
if V2: if V2:
def attention( def attention(
q, q,
k, k,
v, v,
out,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
@ -214,6 +219,7 @@ if V2:
causal=True, causal=True,
softcap=0.0, softcap=0.0,
): ):
out = torch.empty_like(q)
if window_size_left <= 0 and window_size_left != -1: if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1") raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
@ -238,7 +244,7 @@ if V2:
softcap, softcap,
False, False,
None, None,
) )[0]
else: else:
@ -246,7 +252,6 @@ else:
q, q,
k, k,
v, v,
out,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
@ -286,7 +291,8 @@ else:
.reshape(original_shape[0], -1, original_shape[2]) .reshape(original_shape[0], -1, original_shape[2])
) )
return flash_attn_cuda.fwd( out = torch.empty_like(q)
flash_attn_cuda.fwd(
q, q,
k, k,
v, v,
@ -303,3 +309,4 @@ else:
0, 0,
None, None,
) )
return out

View File

@ -11,7 +11,6 @@ def attention(
q, q,
k, k,
v, v,
out,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
@ -19,6 +18,8 @@ def attention(
causal=True, causal=True,
softcap: Optional[float] = None, softcap: Optional[float] = None,
): ):
out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return ipex.llm.functional.varlen_attention( return ipex.llm.functional.varlen_attention(
q, q,
@ -51,7 +52,6 @@ def reshape_and_cache(
def paged_attention( def paged_attention(
out: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
@ -62,6 +62,7 @@ def paged_attention(
max_s: int, max_s: int,
softcap: Optional[float] = None, softcap: Optional[float] = None,
): ):
out = torch.empty_like(query)
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out, out,
query, query,

View File

@ -39,7 +39,6 @@ def reshape_and_cache(
def paged_attention( def paged_attention(
out: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
@ -72,6 +71,8 @@ def paged_attention(
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = input_lengths.input_lengths input_lengths = input_lengths.input_lengths
out = torch.empty_like(query)
# NOTE(woosuk): We use a simple heuristic to decide whether to use # NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use # PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of # V1 to avoid the overhead of reduction. Also, if the number of
@ -174,7 +175,6 @@ if ENGINE == "ck":
q, q,
k, k,
v, v,
out,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
@ -184,6 +184,8 @@ if ENGINE == "ck":
if window_size_left <= 0 and window_size_left != -1: if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1") raise ValueError("`window_size_left` must be > 0 or -1")
out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
q, q,
@ -209,13 +211,14 @@ elif ENGINE == "triton":
q, q,
k, k,
v, v,
out,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=True, causal=True,
): ):
out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
output, _ = triton_attention( output, _ = triton_attention(
q, q,

View File

@ -124,50 +124,7 @@ class GPTQWeightsLoader(WeightsLoader):
self.sym = sym self.sym = sym
def get_weights(self, weights: Weights, prefix: str): def get_weights(self, weights: Weights, prefix: str):
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
self._get_gptq_params(weights) self._get_gptq_params(weights)
if can_use_gptq_marlin(
bits=self.bits,
groupsize=self.groupsize,
quant_method=self.quant_method,
quantize=self.quantize,
sym=self.sym,
):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = weights.get_tensor(f"{prefix}.qweight")
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
if not self.sym:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_tensor(f"{prefix}.g_idx")
scales = weights.get_tensor(f"{prefix}.scales")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=False,
)
use_exllama = True use_exllama = True
if self.bits != 4: if self.bits != 4:
@ -248,11 +205,6 @@ class GPTQWeightsLoader(WeightsLoader):
prefix: str, prefix: str,
block_sizes: Union[int, List[int]], block_sizes: Union[int, List[int]],
): ):
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
try: try:
qweight = weights.get_packed_sharded( qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes f"{prefix}.qweight", dim=1, block_sizes=block_sizes
@ -267,36 +219,6 @@ class GPTQWeightsLoader(WeightsLoader):
scales = scales.to(dtype=weights.dtype) scales = scales.to(dtype=weights.dtype)
self._get_gptq_params(weights) self._get_gptq_params(weights)
if can_use_gptq_marlin(
bits=self.bits,
groupsize=self.groupsize,
quant_method=self.quant_method,
quantize=self.quantize,
sym=self.sym,
):
if not self.sym:
qzeros = weights.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_tensor(f"{prefix}.g_idx")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=False,
)
qzeros = weights.get_packed_sharded( qzeros = weights.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
@ -334,11 +256,6 @@ class GPTQWeightsLoader(WeightsLoader):
) )
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
try: try:
qweight = torch.cat( qweight = torch.cat(
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
@ -353,41 +270,6 @@ class GPTQWeightsLoader(WeightsLoader):
) )
self._get_gptq_params(weights) self._get_gptq_params(weights)
if can_use_gptq_marlin(
bits=self.bits,
groupsize=self.groupsize,
quant_method=self.quant_method,
quantize=self.quantize,
sym=self.sym,
):
if not self.sym:
qzeros = torch.cat(
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=False,
)
qzeros = torch.cat( qzeros = torch.cat(
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
@ -441,59 +323,7 @@ class GPTQWeightsLoader(WeightsLoader):
) )
def get_weights_row(self, weights: Weights, prefix: str): def get_weights_row(self, weights: Weights, prefix: str):
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
self._get_gptq_params(weights) self._get_gptq_params(weights)
if can_use_gptq_marlin(
bits=self.bits,
groupsize=self.groupsize,
quant_method=self.quant_method,
quantize=self.quantize,
sym=self.sym,
):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
if not self.sym:
if self.desc_act or self.groupsize == -1:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
else:
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
if self.desc_act or self.groupsize == -1:
scales = weights.get_tensor(f"{prefix}.scales")
else:
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
sharded_in_features = weights.process_group.size() > 1
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=sharded_in_features,
)
use_exllama = True use_exllama = True
if self.bits != 4: if self.bits != 4:

View File

@ -17,7 +17,7 @@ from loguru import logger
from typing import Optional from typing import Optional
from text_generation_server.layers.gptq.utils import torch_snr_error from text_generation_server.layers.gptq.utils import torch_snr_error
from text_generation_server.utils.weights import DefaultWeightsLoader from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
DEV = torch.device("cuda:0") DEV = torch.device("cuda:0")
@ -897,7 +897,7 @@ def quantize(
dtype=torch.float16, dtype=torch.float16,
process_group=process_group, process_group=process_group,
aliases={"embed_tokens.weight": ["lm_head.weight"]}, aliases={"embed_tokens.weight": ["lm_head.weight"]},
weights_loader=DefaultWeightsLoader(), weights_loader=DefaultWeightsLoader(UnquantizedWeight),
) )
hooks = [] hooks = []
for name, module in model.named_modules(): for name, module in model.named_modules():
@ -960,9 +960,6 @@ def quantize(
state_dict = model.state_dict() state_dict = model.state_dict()
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
state_dict["gptq_bits"] = torch.LongTensor([bits])
state_dict["gptq_groupsize"] = torch.LongTensor([groupsize])
state_dict["gptq_sym"] = torch.BoolTensor([sym])
max_shard_size = "10GB" max_shard_size = "10GB"
shards, index = shard_checkpoint( shards, index = shard_checkpoint(
@ -994,6 +991,15 @@ def quantize(
f"index located at {save_index_file}." f"index located at {save_index_file}."
) )
config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
config.quantization_config = {
"bits": bits,
"group_size": groupsize,
"damp_percent": percdamp,
"desc_act": act_order,
"static_groups": False,
"sym": sym,
"quant_method": "gptq",
}
config.save_pretrained(output_dir) config.save_pretrained(output_dir)
logger.info("Saved config") logger.info("Saved config")
logger.info("Saving tokenizer") logger.info("Saving tokenizer")

View File

@ -1,7 +1,6 @@
from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear
from text_generation_server.layers.marlin.gptq import ( from text_generation_server.layers.marlin.gptq import (
GPTQMarlinLinear, GPTQMarlinWeightsLoader,
GPTQMarlinWeight,
can_use_gptq_marlin, can_use_gptq_marlin,
repack_gptq_for_marlin, repack_gptq_for_marlin,
) )
@ -9,8 +8,7 @@ from text_generation_server.layers.marlin.marlin import MarlinWeightsLoader
__all__ = [ __all__ = [
"GPTQMarlinFP8Linear", "GPTQMarlinFP8Linear",
"GPTQMarlinLinear", "GPTQMarlinWeightsLoader",
"GPTQMarlinWeight",
"MarlinWeightsLoader", "MarlinWeightsLoader",
"can_use_gptq_marlin", "can_use_gptq_marlin",
"repack_gptq_for_marlin", "repack_gptq_for_marlin",

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import List, Optional, Union
import numpy import numpy
import torch import torch
@ -13,7 +13,7 @@ from text_generation_server.layers.marlin.util import (
) )
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try: try:
import marlin_kernels import marlin_kernels
@ -48,6 +48,204 @@ def can_use_gptq_marlin(
) )
class GPTQMarlinWeightsLoader(WeightsLoader):
"""
Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels.
"""
def __init__(
self,
*,
bits: int,
desc_act: bool,
groupsize: int,
quant_method: str,
quantize: str,
sym: bool,
):
self.bits = bits
self.desc_act = desc_act
self.groupsize = groupsize
self.quant_method = quant_method
self.quantize = quantize
self.sym = sym
def get_weights(self, weights: Weights, prefix: str):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = weights.get_tensor(f"{prefix}.qweight")
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
if not self.sym:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_tensor(f"{prefix}.g_idx")
scales = weights.get_tensor(f"{prefix}.scales")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=False,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
try:
qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
)
scales = weights.get_packed_sharded(
f"{prefix}.scales", dim=1, block_sizes=block_sizes
)
scales = scales.to(dtype=weights.dtype)
if not self.sym:
qzeros = weights.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_tensor(f"{prefix}.g_idx")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=False,
)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
try:
qweight = torch.cat(
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
)
scales = torch.cat(
[weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
if not self.sym:
qzeros = torch.cat(
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=False,
)
def get_weights_row(self, weights: Weights, prefix: str):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
if not self.sym:
if self.desc_act or self.groupsize == -1:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
else:
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
if self.desc_act or self.groupsize == -1:
scales = weights.get_tensor(f"{prefix}.scales")
else:
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
sharded_in_features = weights.process_group.size() > 1
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=sharded_in_features,
)
def _get_gptq_params(self, weights: Weights):
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
self.bits = weights.get_tensor("gptq_bits").item()
self.groupsize = weights.get_tensor("gptq_groupsize").item()
self.desc_act = False
# `server quantize` used asymmetric quantization unconditionally
# before the `gptq_sym` setting tensor was added.
self.sym = (
weights.get_tensor("gptq_sym").item()
if weights._has_tensor("gptq_sym")
else False
)
self.quant_method = "gptq"
@dataclass @dataclass
class GPTQMarlinWeight(Weight): class GPTQMarlinWeight(Weight):
""" """

View File

@ -484,6 +484,9 @@ def get_model(
) )
sliding_window = config_dict.get("sliding_window", -1) sliding_window = config_dict.get("sliding_window", -1)
if max_input_tokens is not None and max_input_tokens <= sliding_window:
sliding_window = -1
if ( if (
(sliding_window is not None and sliding_window != -1) (sliding_window is not None and sliding_window != -1)
and not SUPPORTS_WINDOWING and not SUPPORTS_WINDOWING

View File

@ -291,17 +291,13 @@ class FlashCohereAttention(torch.nn.Module):
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
key, key,
value, value,
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -309,7 +305,6 @@ class FlashCohereAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -330,17 +330,13 @@ class DbrxAttention(torch.nn.Module):
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -348,7 +344,6 @@ class DbrxAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -358,25 +358,20 @@ class DeepseekV2Attention(torch.nn.Module):
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# Output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
key, key,
value, value,
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -231,17 +231,13 @@ class FlashGemma2Attention(torch.nn.Module):
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -252,7 +248,6 @@ class FlashGemma2Attention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -225,17 +225,13 @@ class FlashGemmaAttention(torch.nn.Module):
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -244,7 +240,6 @@ class FlashGemmaAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -225,17 +225,13 @@ class FlashGPT2Attention(torch.nn.Module):
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
key, key,
value, value,
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -243,7 +239,6 @@ class FlashGPT2Attention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -213,17 +213,13 @@ class FlashLlamaAttention(torch.nn.Module):
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -231,7 +227,6 @@ class FlashLlamaAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -212,17 +212,13 @@ class MistralAttention(torch.nn.Module):
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
) )
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -231,7 +227,6 @@ class MistralAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -269,17 +269,13 @@ class MixtralAttention(torch.nn.Module):
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
) )
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -288,7 +284,6 @@ class MixtralAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -158,17 +158,13 @@ class FlashNeoxAttention(torch.nn.Module):
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(qkv[:, 0])
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
qkv[:, 0], qkv[:, 0],
qkv[:, 1], qkv[:, 1],
qkv[:, 2], qkv[:, 2],
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -176,7 +172,6 @@ class FlashNeoxAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
qkv[:, 0], qkv[:, 0],
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -188,16 +188,12 @@ class FlashPhiAttention(torch.nn.Module):
# Reshape key and value and cache # Reshape key and value and cache
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -205,7 +201,6 @@ class FlashPhiAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -130,17 +130,13 @@ class Qwen2Attention(torch.nn.Module):
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
) )
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -149,7 +145,6 @@ class Qwen2Attention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -201,17 +201,13 @@ class FlashRWAttention(torch.nn.Module):
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -219,7 +215,6 @@ class FlashRWAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
@ -324,17 +319,13 @@ class FlashRWLargeAttention(torch.nn.Module):
slots, slots,
) )
# output
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1), torch.select(kv, dim=2, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -342,7 +333,6 @@ class FlashRWLargeAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
@ -392,8 +382,13 @@ class FlashRWLayer(nn.Module):
prefix = f"{prefix}.h.{layer_id}" prefix = f"{prefix}.h.{layer_id}"
# NOTE: Falcon 180B uses the ln_attn prefix
ln_prefix = "input_layernorm"
if config.num_hidden_layers == 80:
ln_prefix = "ln_attn"
self.input_layernorm = FastLayerNorm.load( self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm", prefix=f"{prefix}.{ln_prefix}",
weights=weights, weights=weights,
eps=config.layer_norm_epsilon, eps=config.layer_norm_epsilon,
) )
@ -483,7 +478,13 @@ class FlashRWLayer(nn.Module):
class FlashRWLayerNorm(nn.Module): class FlashRWLayerNorm(nn.Module):
def __init__(self, config, prefix: str, weights): def __init__(self, config, prefix: str, weights):
super().__init__() super().__init__()
self.num_ln = config.num_ln_in_parallel_attn # Falcon2 includes the number of layer norms in the config
# in the case no number of layer norms is provided, we default to 1
self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1)
# Falcon 180B uses the ln_attn prefix and has 2 layer norms
if config.num_hidden_layers == 80:
self.num_ln = 2
if self.num_ln == 1: if self.num_ln == 1:
self.input_ln = FastLayerNorm.load( self.input_ln = FastLayerNorm.load(

Some files were not shown because too many files have changed in this diff Show More