mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 12:54:52 +00:00
Merge branch 'main' into hot_fix_xpu
This commit is contained in:
commit
30e70f2ceb
0
.devcontainer/Dockerfile.trtllm
Normal file
0
.devcontainer/Dockerfile.trtllm
Normal file
0
.devcontainer/devcontainer.json
Normal file
0
.devcontainer/devcontainer.json
Normal file
@ -2,3 +2,5 @@ aml
|
|||||||
target
|
target
|
||||||
server/transformers
|
server/transformers
|
||||||
server/flash-attention
|
server/flash-attention
|
||||||
|
cmake-build-debug/
|
||||||
|
cmake-build-release/
|
||||||
|
4
.github/workflows/autodocs.yaml
vendored
4
.github/workflows/autodocs.yaml
vendored
@ -28,7 +28,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install router
|
- name: Install router
|
||||||
id: install-router
|
id: install-router
|
||||||
run: cargo install --path router/
|
run: cargo install --path backends/v3/
|
||||||
|
|
||||||
- uses: actions/setup-node@v4
|
- uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
@ -41,5 +41,5 @@ jobs:
|
|||||||
|
|
||||||
- name: Check that documentation is up-to-date
|
- name: Check that documentation is up-to-date
|
||||||
run: |
|
run: |
|
||||||
npm install -g swagger-cli
|
npm install -g @redocly/cli
|
||||||
python update_doc.py --check
|
python update_doc.py --check
|
||||||
|
1
.github/workflows/ci_build.yaml
vendored
1
.github/workflows/ci_build.yaml
vendored
@ -10,6 +10,7 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- ".github/workflows/build.yaml"
|
- ".github/workflows/build.yaml"
|
||||||
- "integration-tests/**"
|
- "integration-tests/**"
|
||||||
|
- "backends/**"
|
||||||
- "server/**"
|
- "server/**"
|
||||||
- "proto/**"
|
- "proto/**"
|
||||||
- "router/**"
|
- "router/**"
|
||||||
|
4
.gitignore
vendored
4
.gitignore
vendored
@ -3,6 +3,10 @@ target
|
|||||||
router/tokenizer.json
|
router/tokenizer.json
|
||||||
*__pycache__*
|
*__pycache__*
|
||||||
|
|
||||||
|
backends/v3/src/client/pb
|
||||||
|
backends/client/src/v2/pb
|
||||||
|
backends/client/src/v3/pb
|
||||||
|
|
||||||
# ROCm auto-generated files
|
# ROCm auto-generated files
|
||||||
*.hip
|
*.hip
|
||||||
server/exllamav2_kernels/exllamav2_kernels/hip/
|
server/exllamav2_kernels/exllamav2_kernels/hip/
|
||||||
|
@ -13,8 +13,8 @@ repos:
|
|||||||
- repo: https://github.com/doublify/pre-commit-rust
|
- repo: https://github.com/doublify/pre-commit-rust
|
||||||
rev: v1.0
|
rev: v1.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: fmt
|
|
||||||
- id: cargo-check
|
- id: cargo-check
|
||||||
|
- id: fmt
|
||||||
- id: clippy
|
- id: clippy
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.3.0
|
rev: v0.3.0
|
||||||
|
79
.redocly.lint-ignore.yaml
Normal file
79
.redocly.lint-ignore.yaml
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
# This file instructs Redocly's linter to ignore the rules contained for specific parts of your API.
|
||||||
|
# See https://redoc.ly/docs/cli/ for more information.
|
||||||
|
docs/openapi.json:
|
||||||
|
no-empty-servers:
|
||||||
|
- '#/openapi'
|
||||||
|
spec:
|
||||||
|
- >-
|
||||||
|
#/components/schemas/GenerateParameters/properties/best_of/exclusiveMinimum
|
||||||
|
- >-
|
||||||
|
#/components/schemas/GenerateParameters/properties/frequency_penalty/exclusiveMinimum
|
||||||
|
- '#/components/schemas/GenerateParameters/properties/grammar/nullable'
|
||||||
|
- >-
|
||||||
|
#/components/schemas/GenerateParameters/properties/repetition_penalty/exclusiveMinimum
|
||||||
|
- '#/components/schemas/GenerateParameters/properties/seed/exclusiveMinimum'
|
||||||
|
- >-
|
||||||
|
#/components/schemas/GenerateParameters/properties/temperature/exclusiveMinimum
|
||||||
|
- '#/components/schemas/GenerateParameters/properties/top_k/exclusiveMinimum'
|
||||||
|
- >-
|
||||||
|
#/components/schemas/GenerateParameters/properties/top_n_tokens/exclusiveMinimum
|
||||||
|
- '#/components/schemas/GenerateParameters/properties/top_p/exclusiveMinimum'
|
||||||
|
- >-
|
||||||
|
#/components/schemas/GenerateParameters/properties/typical_p/exclusiveMinimum
|
||||||
|
- '#/components/schemas/GenerateResponse/properties/details/nullable'
|
||||||
|
- '#/components/schemas/StreamResponse/properties/details/nullable'
|
||||||
|
- '#/components/schemas/ChatRequest/properties/response_format/nullable'
|
||||||
|
- '#/components/schemas/ChatRequest/properties/tool_choice/nullable'
|
||||||
|
- '#/components/schemas/ToolChoice/nullable'
|
||||||
|
- '#/components/schemas/ChatCompletionComplete/properties/logprobs/nullable'
|
||||||
|
- '#/components/schemas/ChatCompletionChoice/properties/logprobs/nullable'
|
||||||
|
no-invalid-media-type-examples:
|
||||||
|
- '#/paths/~1/post/responses/422/content/application~1json/example'
|
||||||
|
- '#/paths/~1/post/responses/424/content/application~1json/example'
|
||||||
|
- '#/paths/~1/post/responses/429/content/application~1json/example'
|
||||||
|
- '#/paths/~1/post/responses/500/content/application~1json/example'
|
||||||
|
- '#/paths/~1generate/post/responses/422/content/application~1json/example'
|
||||||
|
- '#/paths/~1generate/post/responses/424/content/application~1json/example'
|
||||||
|
- '#/paths/~1generate/post/responses/429/content/application~1json/example'
|
||||||
|
- '#/paths/~1generate/post/responses/500/content/application~1json/example'
|
||||||
|
- >-
|
||||||
|
#/paths/~1generate_stream/post/responses/422/content/text~1event-stream/example
|
||||||
|
- >-
|
||||||
|
#/paths/~1generate_stream/post/responses/424/content/text~1event-stream/example
|
||||||
|
- >-
|
||||||
|
#/paths/~1generate_stream/post/responses/429/content/text~1event-stream/example
|
||||||
|
- >-
|
||||||
|
#/paths/~1generate_stream/post/responses/500/content/text~1event-stream/example
|
||||||
|
- '#/paths/~1tokenize/post/responses/404/content/application~1json/example'
|
||||||
|
- >-
|
||||||
|
#/paths/~1v1~1chat~1completions/post/responses/422/content/application~1json/example
|
||||||
|
- >-
|
||||||
|
#/paths/~1v1~1chat~1completions/post/responses/424/content/application~1json/example
|
||||||
|
- >-
|
||||||
|
#/paths/~1v1~1chat~1completions/post/responses/429/content/application~1json/example
|
||||||
|
- >-
|
||||||
|
#/paths/~1v1~1chat~1completions/post/responses/500/content/application~1json/example
|
||||||
|
- >-
|
||||||
|
#/paths/~1v1~1completions/post/responses/422/content/application~1json/example
|
||||||
|
- >-
|
||||||
|
#/paths/~1v1~1completions/post/responses/424/content/application~1json/example
|
||||||
|
- >-
|
||||||
|
#/paths/~1v1~1completions/post/responses/429/content/application~1json/example
|
||||||
|
- >-
|
||||||
|
#/paths/~1v1~1completions/post/responses/500/content/application~1json/example
|
||||||
|
operation-4xx-response:
|
||||||
|
- '#/paths/~1health/get/responses'
|
||||||
|
- '#/paths/~1info/get/responses'
|
||||||
|
- '#/paths/~1metrics/get/responses'
|
||||||
|
no-unused-components:
|
||||||
|
- '#/components/schemas/Completion'
|
||||||
|
security-defined:
|
||||||
|
- '#/paths/~1/post'
|
||||||
|
- '#/paths/~1generate/post'
|
||||||
|
- '#/paths/~1generate_stream/post'
|
||||||
|
- '#/paths/~1health/get'
|
||||||
|
- '#/paths/~1info/get'
|
||||||
|
- '#/paths/~1metrics/get'
|
||||||
|
- '#/paths/~1tokenize/post'
|
||||||
|
- '#/paths/~1v1~1chat~1completions/post'
|
||||||
|
- '#/paths/~1v1~1completions/post'
|
734
Cargo.lock
generated
734
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
21
Cargo.toml
21
Cargo.toml
@ -1,10 +1,19 @@
|
|||||||
[workspace]
|
[workspace]
|
||||||
members = [
|
members = [
|
||||||
"benchmark",
|
"benchmark",
|
||||||
"router",
|
"backends/v3",
|
||||||
"router/client",
|
"backends/grpc-metadata",
|
||||||
"router/grpc-metadata",
|
"backends/trtllm",
|
||||||
"launcher"
|
"backends/client",
|
||||||
|
"launcher"
|
||||||
|
]
|
||||||
|
default-members = [
|
||||||
|
"benchmark",
|
||||||
|
"backends/v3",
|
||||||
|
"backends/grpc-metadata",
|
||||||
|
# "backends/trtllm",
|
||||||
|
"backends/client",
|
||||||
|
"launcher"
|
||||||
]
|
]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
@ -18,6 +27,8 @@ homepage = "https://github.com/huggingface/text-generation-inference"
|
|||||||
base64 = "0.22.0"
|
base64 = "0.22.0"
|
||||||
tokenizers = { version = "0.19.1", features = ["http"] }
|
tokenizers = { version = "0.19.1", features = ["http"] }
|
||||||
hf-hub = { version = "0.3.1", features = ["tokio"] }
|
hf-hub = { version = "0.3.1", features = ["tokio"] }
|
||||||
|
metrics = { version = "0.23.0" }
|
||||||
|
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
incremental = true
|
incremental = true
|
||||||
|
@ -11,6 +11,7 @@ COPY rust-toolchain.toml rust-toolchain.toml
|
|||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY benchmark benchmark
|
COPY benchmark benchmark
|
||||||
COPY router router
|
COPY router router
|
||||||
|
COPY backends backends
|
||||||
COPY launcher launcher
|
COPY launcher launcher
|
||||||
RUN cargo chef prepare --recipe-path recipe.json
|
RUN cargo chef prepare --recipe-path recipe.json
|
||||||
|
|
||||||
@ -33,6 +34,7 @@ COPY rust-toolchain.toml rust-toolchain.toml
|
|||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY benchmark benchmark
|
COPY benchmark benchmark
|
||||||
COPY router router
|
COPY router router
|
||||||
|
COPY backends backends
|
||||||
COPY launcher launcher
|
COPY launcher launcher
|
||||||
RUN cargo build --profile release-opt
|
RUN cargo build --profile release-opt
|
||||||
|
|
||||||
|
23
Dockerfile.trtllm
Normal file
23
Dockerfile.trtllm
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# All the tooling for CUDA
|
||||||
|
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 AS cuda-builder
|
||||||
|
|
||||||
|
WORKDIR /usr/src/tgi/backends/trtllm
|
||||||
|
RUN apt update && apt install -y cmake git git-lfs gcc g++ ninja-build libopenmpi-dev python3-dev python3-pip wget
|
||||||
|
|
||||||
|
COPY . /usr/src/tgi
|
||||||
|
RUN chmod +x scripts/install_tensorrt.sh && scripts/install_tensorrt.sh
|
||||||
|
RUN cmake -G Ninja -B build -DTRT_LIB_DIR=/usr/local/tensorrt/lib -DTRT_INCLUDE_DIR=/usr/local/tensorrt/include .
|
||||||
|
RUN cmake --build build --parallel -t tgi_trtllm_backend_impl
|
||||||
|
|
||||||
|
# All the tooling for Rust
|
||||||
|
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
# Include CUDA related libraries and tools to the Rust based image
|
||||||
|
COPY --from=cuda-builder /usr/local/cuda /usr/local/cuda
|
||||||
|
COPY --from=cuda-builder /usr/local/tensorrt /usr/local/tensorrt
|
||||||
|
COPY --from=cuda-builder /usr/src/tgi/backends/trtllm/build /usr/local/tgi/trtllm/build
|
||||||
|
ENV PATH=/usr/local/cuda/bin:$PATH
|
||||||
|
ENV LD_LIBRARY_PATH=/usr/local/tensorrt/lib:$LD_LIBRARY_PATH
|
||||||
|
|
||||||
|
RUN apt update && apt install -y cmake git gcc g++ ninja-build libopenmpi3
|
@ -11,6 +11,7 @@ COPY rust-toolchain.toml rust-toolchain.toml
|
|||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY benchmark benchmark
|
COPY benchmark benchmark
|
||||||
COPY router router
|
COPY router router
|
||||||
|
COPY backends backends
|
||||||
COPY launcher launcher
|
COPY launcher launcher
|
||||||
RUN cargo chef prepare --recipe-path recipe.json
|
RUN cargo chef prepare --recipe-path recipe.json
|
||||||
|
|
||||||
@ -33,6 +34,7 @@ COPY rust-toolchain.toml rust-toolchain.toml
|
|||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY benchmark benchmark
|
COPY benchmark benchmark
|
||||||
COPY router router
|
COPY router router
|
||||||
|
COPY backends backends
|
||||||
COPY launcher launcher
|
COPY launcher launcher
|
||||||
RUN cargo build --profile release-opt
|
RUN cargo build --profile release-opt
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ COPY rust-toolchain.toml rust-toolchain.toml
|
|||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY benchmark benchmark
|
COPY benchmark benchmark
|
||||||
COPY router router
|
COPY router router
|
||||||
|
COPY backends backends
|
||||||
COPY launcher launcher
|
COPY launcher launcher
|
||||||
RUN cargo chef prepare --recipe-path recipe.json
|
RUN cargo chef prepare --recipe-path recipe.json
|
||||||
|
|
||||||
@ -34,6 +35,7 @@ COPY rust-toolchain.toml rust-toolchain.toml
|
|||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY benchmark benchmark
|
COPY benchmark benchmark
|
||||||
COPY router router
|
COPY router router
|
||||||
|
COPY backends backends
|
||||||
COPY launcher launcher
|
COPY launcher launcher
|
||||||
RUN cargo build --profile release-opt
|
RUN cargo build --profile release-opt
|
||||||
|
|
||||||
|
6
Makefile
6
Makefile
@ -5,13 +5,13 @@ install-server-cpu:
|
|||||||
cd server && make install-server
|
cd server && make install-server
|
||||||
|
|
||||||
install-router:
|
install-router:
|
||||||
cd router && cargo install --path .
|
cargo install --path backends/v3/
|
||||||
|
|
||||||
install-launcher:
|
install-launcher:
|
||||||
cd launcher && cargo install --path .
|
cargo install --path launcher/
|
||||||
|
|
||||||
install-benchmark:
|
install-benchmark:
|
||||||
cd benchmark && cargo install --path .
|
cargo install --path benchmark/
|
||||||
|
|
||||||
install: install-server install-router install-launcher
|
install: install-server install-router install-launcher
|
||||||
|
|
||||||
|
63
backends/trtllm/CMakeLists.txt
Normal file
63
backends/trtllm/CMakeLists.txt
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.20)
|
||||||
|
|
||||||
|
project(tgi-trtllm-backend VERSION 1.0.0)
|
||||||
|
set(CMAKE_CXX_STANDARD 20)
|
||||||
|
|
||||||
|
include(FetchContent)
|
||||||
|
include(ExternalProject)
|
||||||
|
|
||||||
|
option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
|
||||||
|
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
|
||||||
|
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
|
||||||
|
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located")
|
||||||
|
set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located")
|
||||||
|
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")
|
||||||
|
|
||||||
|
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
|
||||||
|
find_package(CUDAToolkit 12.5 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
|
||||||
|
|
||||||
|
#### External dependencies ####
|
||||||
|
include(cmake/fmt.cmake)
|
||||||
|
include(cmake/json.cmake)
|
||||||
|
include(cmake/spdlog.cmake)
|
||||||
|
include(cmake/trtllm.cmake)
|
||||||
|
|
||||||
|
# Let's build TRTLLM as part of CMake
|
||||||
|
add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..")
|
||||||
|
|
||||||
|
# Tell CMake to need try to override the RPATH for executorWorker as it has not information on how to do so
|
||||||
|
set_target_properties(executorWorker PROPERTIES SKIP_BUILD_RPATH TRUE)
|
||||||
|
|
||||||
|
# TGI TRTLLM Backend definition
|
||||||
|
add_library(tgi_trtllm_backend_impl STATIC include/backend.h lib/backend.cpp include/hardware.h)
|
||||||
|
include_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
|
||||||
|
target_include_directories(tgi_trtllm_backend_impl PRIVATE
|
||||||
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||||
|
$<INSTALL_INTERFACE:include>
|
||||||
|
)
|
||||||
|
target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
|
||||||
|
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper CUDA::cudart CUDA::nvml)
|
||||||
|
target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt)
|
||||||
|
|
||||||
|
# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back
|
||||||
|
install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker)
|
||||||
|
install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB)
|
||||||
|
|
||||||
|
#### Unit Tests ####
|
||||||
|
if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})
|
||||||
|
message(STATUS "Building tests")
|
||||||
|
FetchContent_Declare(
|
||||||
|
Catch2
|
||||||
|
GIT_REPOSITORY https://github.com/catchorg/Catch2
|
||||||
|
GIT_TAG v3.6.0
|
||||||
|
)
|
||||||
|
FetchContent_MakeAvailable(Catch2)
|
||||||
|
|
||||||
|
# add_executable(tgi_trtllm_backend_tests tests/infer_test.cpp)
|
||||||
|
# target_link_libraries(tgi_trtllm_backend_tests PRIVATE tgi_trtllm_backend_impl Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt CUDA::cudart CUDA::nvml)
|
||||||
|
|
||||||
|
list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
|
||||||
|
include(CTest)
|
||||||
|
include(Catch)
|
||||||
|
# catch_discover_tests(tgi_trtllm_backend_tests)
|
||||||
|
endif ()
|
26
backends/trtllm/Cargo.toml
Normal file
26
backends/trtllm/Cargo.toml
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
[package]
|
||||||
|
name = "text-generation-backends-trtllm"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
authors.workspace = true
|
||||||
|
homepage.workspace = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
async-trait = "0.1"
|
||||||
|
async-stream = "0.3"
|
||||||
|
cxx = "1.0"
|
||||||
|
text-generation-router = { path = "../../router" }
|
||||||
|
tokenizers = { version = "0.19", features = ["hf-hub"] }
|
||||||
|
tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
|
tokio-stream = "0.1.15"
|
||||||
|
clap = { version = "4.5", features = ["derive"] }
|
||||||
|
thiserror = "1.0.62"
|
||||||
|
tracing = "0.1"
|
||||||
|
tracing-opentelemetry = "0.24"
|
||||||
|
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
|
||||||
|
log = { version = "0.4", features = [] }
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
cmake = "0.1"
|
||||||
|
cxx-build = { version = "1.0", features = ["parallel"] }
|
||||||
|
pkg-config = "0.3"
|
100
backends/trtllm/Dockerfile
Normal file
100
backends/trtllm/Dockerfile
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
ARG CUDA_ARCH_LIST="75-real;80-real;86-real;89-real;90-real"
|
||||||
|
ARG OMPI_VERSION="4.1.6"
|
||||||
|
|
||||||
|
# Build dependencies resolver stage
|
||||||
|
FROM lukemathwalker/cargo-chef:latest AS chef
|
||||||
|
WORKDIR /usr/src/text-generation-inference
|
||||||
|
|
||||||
|
FROM chef AS planner
|
||||||
|
COPY . .
|
||||||
|
RUN cargo chef prepare --recipe-path recipe.json
|
||||||
|
|
||||||
|
# CUDA dependent dependencies resolver stage
|
||||||
|
FROM nvidia/cuda:12.5.1-cudnn-devel-ubuntu22.04 AS cuda-builder
|
||||||
|
|
||||||
|
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||||
|
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||||
|
apt update && apt install -y \
|
||||||
|
build-essential \
|
||||||
|
cmake \
|
||||||
|
curl \
|
||||||
|
gcc \
|
||||||
|
g++ \
|
||||||
|
git \
|
||||||
|
git-lfs \
|
||||||
|
libssl-dev \
|
||||||
|
ninja-build \
|
||||||
|
pkg-config \
|
||||||
|
python3 \
|
||||||
|
python3-setuptools \
|
||||||
|
tar \
|
||||||
|
wget
|
||||||
|
|
||||||
|
ENV TGI_INSTALL_PREFIX=/usr/local/tgi
|
||||||
|
ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt
|
||||||
|
|
||||||
|
# Install OpenMPI
|
||||||
|
FROM cuda-builder AS mpi-builder
|
||||||
|
ARG OMPI_VERSION
|
||||||
|
|
||||||
|
ENV OMPI_TARBALL_FILENAME="openmpi-$OMPI_VERSION.tar.bz2"
|
||||||
|
RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILENAME" -P /opt/src && \
|
||||||
|
mkdir /usr/src/mpi && \
|
||||||
|
tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \
|
||||||
|
cd /usr/src/mpi && \
|
||||||
|
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --without-slurm && \
|
||||||
|
make -j all && \
|
||||||
|
make install && \
|
||||||
|
rm -rf "/opt/src/$OMPI_TARBALL_FILENAME"
|
||||||
|
|
||||||
|
# Install TensorRT
|
||||||
|
FROM cuda-builder AS trt-builder
|
||||||
|
COPY backends/trtllm/scripts/install_tensorrt.sh /opt/install_tensorrt.sh
|
||||||
|
RUN chmod +x /opt/install_tensorrt.sh && \
|
||||||
|
/opt/install_tensorrt.sh
|
||||||
|
|
||||||
|
# Build Backend
|
||||||
|
FROM cuda-builder AS tgi-builder
|
||||||
|
WORKDIR /usr/src/text-generation-inference
|
||||||
|
|
||||||
|
# Install Rust
|
||||||
|
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \
|
||||||
|
chmod -R a+w /root/.rustup && \
|
||||||
|
chmod -R a+w /root/.cargo
|
||||||
|
|
||||||
|
ENV PATH="/root/.cargo/bin:$PATH"
|
||||||
|
RUN cargo install cargo-chef
|
||||||
|
|
||||||
|
# Cache dependencies
|
||||||
|
COPY --from=planner /usr/src/text-generation-inference/recipe.json .
|
||||||
|
RUN cargo chef cook --release --recipe-path recipe.json
|
||||||
|
|
||||||
|
# Build actual TGI
|
||||||
|
ARG CUDA_ARCH_LIST
|
||||||
|
ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt:$CMAKE_PREFIX_PATH"
|
||||||
|
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
|
||||||
|
ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig:$PKG_CONFIG_PATH"
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
||||||
|
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
|
||||||
|
RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \
|
||||||
|
CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release --bin text-generation-backends-trtllm
|
||||||
|
|
||||||
|
FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu22.04 AS runtime
|
||||||
|
WORKDIR /usr/local/tgi/bin
|
||||||
|
|
||||||
|
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
|
||||||
|
|
||||||
|
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
|
||||||
|
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
||||||
|
COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
|
||||||
|
COPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
|
||||||
|
|
||||||
|
FROM runtime
|
||||||
|
|
||||||
|
LABEL co.huggingface.vendor="Hugging Face Inc."
|
||||||
|
LABEL org.opencontainers.image.authors="hardware@hf.co"
|
||||||
|
|
||||||
|
ENTRYPOINT ["./text-generation-launcher"]
|
||||||
|
CMD ["--executor-worker", "/usr/local/tgi/bin/executorWorker"]
|
46
backends/trtllm/README.md
Normal file
46
backends/trtllm/README.md
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
# Text Generation Inference - TensorRT-LLM Backend Implementation
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
This folder provides the sources of the TensorRT-LLM backend implementation powered by TensorRT-LLM Executor new API
|
||||||
|
|
||||||
|
## Simplified Request Sequence
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
sequenceDiagram
|
||||||
|
actor User
|
||||||
|
participant TextGenerationInference.HttpServer
|
||||||
|
participant TextGenerationInference.TensorRtLlmBackend
|
||||||
|
participant TextGenerationInference.TensorRtLlmWorkerThread
|
||||||
|
participant TensorRtLlm.Executor
|
||||||
|
participant Nvidia.Gpu
|
||||||
|
User ->> TextGenerationInference.HttpServer: POST /generate
|
||||||
|
TextGenerationInference.HttpServer ->> TextGenerationInference.TensorRtLlmBackend: Validate and forward inputs & parameters
|
||||||
|
TextGenerationInference.TensorRtLlmBackend ->> TextGenerationInference.TensorRtLlmWorkerThread: Allocate a new context and spawn a new thread to handle the request
|
||||||
|
TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Submit the request to the In-Flight Batcher
|
||||||
|
activate Nvidia.Gpu
|
||||||
|
TensorRtLlm.Executor ->> Nvidia.Gpu: Add the request to the poll for execution
|
||||||
|
TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Response with an unique request identifier
|
||||||
|
rect rgb(10, 92, 54)
|
||||||
|
loop every 100us
|
||||||
|
rect rgb(15, 81, 50)
|
||||||
|
alt Acquire lock to query executor
|
||||||
|
TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Poll request number of new token(s) generated
|
||||||
|
else There are new generated tokens
|
||||||
|
TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Retrieve newly generated tokens
|
||||||
|
TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Return decoded token information and potential error (omitted)
|
||||||
|
rect rgb(11, 110, 79)
|
||||||
|
alt Generated token is final
|
||||||
|
TensorRtLlm.Executor ->> Nvidia.Gpu: Remove request from the scheduler and from the GPU
|
||||||
|
TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream the remaining decoded tokens and flush the connection
|
||||||
|
else Generated token is not final
|
||||||
|
TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream token back to the user as they get decoded
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
deactivate Nvidia.Gpu
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
```
|
150
backends/trtllm/build.rs
Normal file
150
backends/trtllm/build.rs
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
use cxx_build::CFG;
|
||||||
|
use pkg_config;
|
||||||
|
use std::env;
|
||||||
|
use std::env::consts::ARCH;
|
||||||
|
use std::path::{absolute, PathBuf};
|
||||||
|
|
||||||
|
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
|
||||||
|
const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST");
|
||||||
|
const CUDA_REQUIRED_VERSION: &str = "12.5";
|
||||||
|
const MPI_REQUIRED_VERSION: &str = "4.1";
|
||||||
|
const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX");
|
||||||
|
const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR");
|
||||||
|
const NCCL_ROOT_DIR: Option<&str> = option_env!("NCCL_ROOT_DIR");
|
||||||
|
|
||||||
|
// Dependencies
|
||||||
|
const BACKEND_DEPS: [&str; 2] = ["tgi_trtllm_backend_impl", "tgi_trtllm_backend"];
|
||||||
|
const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"];
|
||||||
|
const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 5] = [
|
||||||
|
("dylib", "tensorrt_llm"),
|
||||||
|
("static", "tensorrt_llm_executor_static"),
|
||||||
|
("dylib", "tensorrt_llm_nvrtc_wrapper"),
|
||||||
|
("dylib", "nvinfer_plugin_tensorrt_llm"),
|
||||||
|
("dylib", "decoder_attention"),
|
||||||
|
];
|
||||||
|
|
||||||
|
macro_rules! probe {
|
||||||
|
($name: expr, $version: expr) => {
|
||||||
|
if let Err(_) = pkg_config::probe_library($name) {
|
||||||
|
pkg_config::probe_library(&format!("{}-{}", $name, $version))
|
||||||
|
.expect(&format!("Failed to locate {}", $name));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) {
|
||||||
|
// Build the backend implementation through CMake
|
||||||
|
let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi");
|
||||||
|
let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt");
|
||||||
|
let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("90-real"); // Hopper by default
|
||||||
|
|
||||||
|
let mut install_path = PathBuf::from(install_path);
|
||||||
|
if !install_path.is_absolute() {
|
||||||
|
install_path = absolute(out_dir).expect("cannot happen").join(install_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = cmake::Config::new(".")
|
||||||
|
.uses_cxx11()
|
||||||
|
.generator("Ninja")
|
||||||
|
.profile(match is_debug {
|
||||||
|
true => "Debug",
|
||||||
|
false => "Release",
|
||||||
|
})
|
||||||
|
.env("OPT_LEVEL", opt_level)
|
||||||
|
.define("CMAKE_INSTALL_PREFIX", &install_path)
|
||||||
|
.define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc")
|
||||||
|
.define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list)
|
||||||
|
.define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// Additional transitive CMake dependencies
|
||||||
|
let deps_folder = out_dir.join("build").join("_deps");
|
||||||
|
for dependency in ADDITIONAL_BACKEND_LINK_LIBRARIES {
|
||||||
|
let dep_name = match is_debug {
|
||||||
|
true => format!("{}d", dependency),
|
||||||
|
false => String::from(dependency),
|
||||||
|
};
|
||||||
|
let dep_path = deps_folder.join(format!("{}-build", dependency));
|
||||||
|
println!("cargo:rustc-link-search={}", dep_path.display());
|
||||||
|
println!("cargo:rustc-link-lib=static={}", dep_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit linkage information from the artifacts we just built
|
||||||
|
let install_lib_path = install_path.join("lib");
|
||||||
|
|
||||||
|
println!(
|
||||||
|
r"cargo:warning=Adding link search path: {}",
|
||||||
|
install_lib_path.display()
|
||||||
|
);
|
||||||
|
println!(r"cargo:rustc-link-search={}", install_lib_path.display());
|
||||||
|
|
||||||
|
(PathBuf::from(install_path), deps_folder)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_ffi_layer(deps_folder: &PathBuf) {
|
||||||
|
CFG.include_prefix = "backends/trtllm";
|
||||||
|
cxx_build::bridge("src/lib.rs")
|
||||||
|
.static_flag(true)
|
||||||
|
.include(deps_folder.join("fmt-src").join("include"))
|
||||||
|
.include(deps_folder.join("spdlog-src").join("include"))
|
||||||
|
.include(deps_folder.join("json-src").join("include"))
|
||||||
|
.include(deps_folder.join("trtllm-src").join("cpp").join("include"))
|
||||||
|
.include("/usr/local/cuda/include")
|
||||||
|
.include("/usr/local/tensorrt/include")
|
||||||
|
.file("src/ffi.cpp")
|
||||||
|
.std("c++20")
|
||||||
|
.compile("tgi_trtllm_backend");
|
||||||
|
|
||||||
|
println!("cargo:rerun-if-changed=CMakeLists.txt");
|
||||||
|
println!("cargo:rerun-if-changed=include/backend.h");
|
||||||
|
println!("cargo:rerun-if-changed=lib/backend.cpp");
|
||||||
|
println!("cargo:rerun-if-changed=include/ffi.h");
|
||||||
|
println!("cargo:rerun-if-changed=src/ffi.cpp");
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
// Misc variables
|
||||||
|
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
|
||||||
|
let build_profile = env::var("PROFILE").unwrap();
|
||||||
|
let (is_debug, opt_level) = match build_profile.as_ref() {
|
||||||
|
"debug" => (true, "0"),
|
||||||
|
_ => (false, "3"),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Build the backend
|
||||||
|
let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir);
|
||||||
|
|
||||||
|
// Build the FFI layer calling the backend above
|
||||||
|
build_ffi_layer(&deps_folder);
|
||||||
|
|
||||||
|
// Emit linkage search path
|
||||||
|
probe!("ompi", MPI_REQUIRED_VERSION);
|
||||||
|
|
||||||
|
// Probe CUDA & co. with pkg-config
|
||||||
|
CUDA_TRANSITIVE_DEPS.iter().for_each(|name| {
|
||||||
|
probe!(name, CUDA_REQUIRED_VERSION);
|
||||||
|
});
|
||||||
|
|
||||||
|
// NCCL is slightly trickier because it might not have a pkgconfig installed
|
||||||
|
let nccl_library_path_default = format!("/usr/local/{}-linux-gnu", ARCH);
|
||||||
|
let nccl_library_path = NCCL_ROOT_DIR.unwrap_or(&nccl_library_path_default);
|
||||||
|
println!(r"cargo:rustc-link-search=native={}", nccl_library_path);
|
||||||
|
println!("cargo:rustc-link-lib=dylib=nccl");
|
||||||
|
|
||||||
|
// TensorRT
|
||||||
|
let tensort_library_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt/lib");
|
||||||
|
println!(r"cargo:rustc-link-search=native={}", tensort_library_path);
|
||||||
|
println!("cargo:rustc-link-lib=dylib=nvinfer");
|
||||||
|
|
||||||
|
// TensorRT-LLM
|
||||||
|
TENSORRT_LLM_TRANSITIVE_DEPS
|
||||||
|
.iter()
|
||||||
|
.for_each(|(link_type, name)| {
|
||||||
|
println!("cargo:rustc-link-lib={}={}", link_type, name);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Backend
|
||||||
|
BACKEND_DEPS.iter().for_each(|name| {
|
||||||
|
println!("cargo:rustc-link-lib=static={}", name);
|
||||||
|
});
|
||||||
|
}
|
6
backends/trtllm/cmake/fmt.cmake
Normal file
6
backends/trtllm/cmake/fmt.cmake
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
FetchContent_Declare(
|
||||||
|
fmt
|
||||||
|
GIT_REPOSITORY https://github.com/fmtlib/fmt
|
||||||
|
GIT_TAG 11.0.1
|
||||||
|
)
|
||||||
|
FetchContent_MakeAvailable(fmt)
|
5
backends/trtllm/cmake/json.cmake
Normal file
5
backends/trtllm/cmake/json.cmake
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
fetchcontent_declare(
|
||||||
|
json
|
||||||
|
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz
|
||||||
|
)
|
||||||
|
fetchcontent_makeavailable(json)
|
17
backends/trtllm/cmake/spdlog.cmake
Normal file
17
backends/trtllm/cmake/spdlog.cmake
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
set(SPDLOG_USE_FMT ON)
|
||||||
|
set(SPDLOG_BUILD_SHARED OFF)
|
||||||
|
set(SPDLOG_FMT_EXTERNAL ON)
|
||||||
|
|
||||||
|
# Define the level at which SPDLOG_ compilation level is defined
|
||||||
|
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
||||||
|
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG)
|
||||||
|
else ()
|
||||||
|
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO)
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
fetchcontent_declare(
|
||||||
|
spdlog
|
||||||
|
GIT_REPOSITORY https://github.com/gabime/spdlog.git
|
||||||
|
GIT_TAG v1.14.1
|
||||||
|
)
|
||||||
|
fetchcontent_makeavailable(spdlog)
|
42
backends/trtllm/cmake/trtllm.cmake
Normal file
42
backends/trtllm/cmake/trtllm.cmake
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
set(TRT_INCLUDE_DIR ${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
|
||||||
|
set(TRT_LIB_DIR ${TGI_TRTLLM_BACKEND_TRT_LIB_DIR})
|
||||||
|
|
||||||
|
set(USE_CXX11_ABI ON)
|
||||||
|
set(BUILD_PYT OFF)
|
||||||
|
set(BUILD_PYBIND OFF)
|
||||||
|
set(BUILD_MICRO_BENCHMARKS OFF)
|
||||||
|
set(BUILD_BENCHMARKS OFF)
|
||||||
|
set(BUILD_TESTS OFF)
|
||||||
|
set(CMAKE_CUDA_ARCHITECTURES ${TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST})
|
||||||
|
|
||||||
|
message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
||||||
|
|
||||||
|
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
||||||
|
set(FAST_BUILD ON)
|
||||||
|
set(NVTX_DISABLE OFF)
|
||||||
|
else ()
|
||||||
|
set(FAST_BUILD OFF)
|
||||||
|
set(FAST_MATH ON)
|
||||||
|
set(NVTX_DISABLE ON)
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
fetchcontent_declare(
|
||||||
|
trtllm
|
||||||
|
GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git
|
||||||
|
GIT_TAG a681853d3803ee5893307e812530b5e7004bb6e1
|
||||||
|
GIT_SHALLOW FALSE
|
||||||
|
)
|
||||||
|
fetchcontent_makeavailable(trtllm)
|
||||||
|
|
||||||
|
message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}")
|
||||||
|
execute_process(COMMAND git lfs install WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/")
|
||||||
|
execute_process(COMMAND git lfs pull WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/")
|
||||||
|
|
||||||
|
# TRTLLM use a JIT based *precompiled* library to generate some specific kernels, we are generating the path to this one here
|
||||||
|
set(TRTLLM_NVRTC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_nvrtc_wrapper${CMAKE_SHARED_LIBRARY_SUFFIX}" CACHE INTERNAL "nvrtc wrapper library name")
|
||||||
|
set(TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_NVRTC_LIBRARY_NAME}"
|
||||||
|
CACHE INTERNAL "nvrtc wrapper library path")
|
||||||
|
|
||||||
|
# The same Executor Static library
|
||||||
|
set(TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_executor_static${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE INTERNAL "executor_static library name")
|
||||||
|
set(TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/executor/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME}" CACHE INTERNAL "executor_static library path")
|
0
backends/trtllm/cmake/utils/detect_cuda_arch.cu
Normal file
0
backends/trtllm/cmake/utils/detect_cuda_arch.cu
Normal file
121
backends/trtllm/include/backend.h
Normal file
121
backends/trtllm/include/backend.h
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
//
|
||||||
|
// Created by Morgan Funtowicz on 6/30/24.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef TGI_TRTLLM_BACKEND_H
|
||||||
|
#define TGI_TRTLLM_BACKEND_H
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <filesystem>
|
||||||
|
#include <span>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
|
#include <tensorrt_llm/runtime/common.h>
|
||||||
|
#include <tensorrt_llm/executor/executor.h>
|
||||||
|
#include <tensorrt_llm/plugins/api/tllmPlugin.h>
|
||||||
|
|
||||||
|
using json = nlohmann::json;
|
||||||
|
namespace tle = tensorrt_llm::executor;
|
||||||
|
|
||||||
|
namespace huggingface::tgi::backends {
|
||||||
|
using RequestId = tle::IdType;
|
||||||
|
using TokenId = tle::TokenIdType;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialize all the components required by TRTLLM.
|
||||||
|
* It is required to call this function before attempting to load any engine
|
||||||
|
*/
|
||||||
|
void InitializeBackend();
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param config TensorRT-LLM configuration object
|
||||||
|
* @param workerPath Path to the "executorWorker" provided by TensorRT-LLM when using orchestrator mode
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the sampling configuration from the parameters provided by TGI
|
||||||
|
* @param topK
|
||||||
|
* @param topP
|
||||||
|
* @param temperature
|
||||||
|
* @param repetition_penalty
|
||||||
|
* @param frequency_penalty
|
||||||
|
* @param seed
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
tle::SamplingConfig GetSamplingConfig(
|
||||||
|
uint32_t topK,
|
||||||
|
float_t topP,
|
||||||
|
float_t temperature,
|
||||||
|
float_t repetition_penalty,
|
||||||
|
float_t frequency_penalty,
|
||||||
|
uint64_t seed
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
class TensorRtLlmBackend {
|
||||||
|
private:
|
||||||
|
const json config;
|
||||||
|
tle::Executor executor;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit TensorRtLlmBackend(
|
||||||
|
const std::filesystem::path &engineFolder,
|
||||||
|
const std::filesystem::path &executorWorker
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Indicate if the backend is ready to accept incoming request
|
||||||
|
* @return true if ready, false otherwise
|
||||||
|
*/
|
||||||
|
[[nodiscard]] bool IsReady() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Query the executor for the number of token available for pulling
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
[[nodiscard]] size_t NumResponsesReady() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Submit a new generation task to the executor
|
||||||
|
* @param tokens
|
||||||
|
* @param topK
|
||||||
|
* @param topP
|
||||||
|
* @param temperature
|
||||||
|
* @param repetition_penalty
|
||||||
|
* @param frequency_penalty
|
||||||
|
* @param seed
|
||||||
|
* @return Request id related to this generation for reference
|
||||||
|
*/
|
||||||
|
[[nodiscard]] RequestId Submit(
|
||||||
|
const std::vector<TokenId> &tokens,
|
||||||
|
int32_t topK,
|
||||||
|
float_t topP,
|
||||||
|
float_t temperature,
|
||||||
|
float_t repetition_penalty,
|
||||||
|
float_t frequency_penalty,
|
||||||
|
uint64_t seed
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param requestId The request id to poll the generation results
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
std::vector<tle::Response> Poll(RequestId requestId);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stop the underlying executor
|
||||||
|
*/
|
||||||
|
void Shutdown();
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#endif //TGI_TRTLLM_BACKEND_H
|
75
backends/trtllm/include/ffi.h
Normal file
75
backends/trtllm/include/ffi.h
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
//
|
||||||
|
// Created by mfuntowicz on 7/11/24.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef TGI_TRTLLM_BACKEND_FFI_H
|
||||||
|
#define TGI_TRTLLM_BACKEND_FFI_H
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include "backend.h"
|
||||||
|
|
||||||
|
namespace huggingface::tgi::backends {
|
||||||
|
class TensorRtLlmBackendImpl;
|
||||||
|
}
|
||||||
|
|
||||||
|
#include "backends/trtllm/src/lib.rs.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace huggingface::tgi::backends {
|
||||||
|
|
||||||
|
// struct GenerationContext;
|
||||||
|
|
||||||
|
class TensorRtLlmBackendImpl : public TensorRtLlmBackend {
|
||||||
|
public:
|
||||||
|
/***
|
||||||
|
*
|
||||||
|
* @param engineFolder
|
||||||
|
* @param executorWorker
|
||||||
|
*/
|
||||||
|
TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker);
|
||||||
|
|
||||||
|
/***
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
bool IsReady() const;
|
||||||
|
|
||||||
|
/***
|
||||||
|
*
|
||||||
|
* @param tokens
|
||||||
|
* @param topK
|
||||||
|
* @param topP
|
||||||
|
* @param temperature
|
||||||
|
* @param repetition_penalty
|
||||||
|
* @param frequency_penalty
|
||||||
|
* @param seed
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
|
||||||
|
uint64_t
|
||||||
|
Submit(rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature,
|
||||||
|
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed);
|
||||||
|
|
||||||
|
/***
|
||||||
|
*
|
||||||
|
* @param requestId
|
||||||
|
* @param ctx
|
||||||
|
* @param callback
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
size_t StreamTokens(
|
||||||
|
const RequestId requestId,
|
||||||
|
huggingface::tgi::backends::GenerationContext *ctx,
|
||||||
|
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
|
||||||
|
huggingface::tgi::backends::GenerationStep)> callback);
|
||||||
|
};
|
||||||
|
|
||||||
|
/***
|
||||||
|
*
|
||||||
|
* @param engineFolder
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
std::unique_ptr<TensorRtLlmBackendImpl> CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif //TGI_TRTLLM_BACKEND_FFI_H
|
59
backends/trtllm/include/hardware.h
Normal file
59
backends/trtllm/include/hardware.h
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
//
|
||||||
|
// Created by mfuntowicz on 7/23/24.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef TGI_TRTLLM_BACKEND_HARDWARE_H
|
||||||
|
#define TGI_TRTLLM_BACKEND_HARDWARE_H
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <limits>
|
||||||
|
#include <fmt/base.h>
|
||||||
|
#include <spdlog/spdlog.h>
|
||||||
|
#include <nvml.h>
|
||||||
|
|
||||||
|
namespace huggingface::hardware::cuda {
|
||||||
|
|
||||||
|
#define AMPERE_SM_MAJOR 8
|
||||||
|
#define HOPPER_SM_MAJOR 8
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Store information about the version of the CUDA Compute Capabilities detected on the device
|
||||||
|
*/
|
||||||
|
struct CudaComputeCapabilities {
|
||||||
|
int32_t major;
|
||||||
|
int32_t minor;
|
||||||
|
|
||||||
|
[[nodiscard]] constexpr bool isPostAmpere() const { return major >= AMPERE_SM_MAJOR; }
|
||||||
|
|
||||||
|
[[nodiscard]] constexpr bool isPostHopper() const { return major >= HOPPER_SM_MAJOR; }
|
||||||
|
};
|
||||||
|
|
||||||
|
CudaComputeCapabilities GetCudaComputeCapabilities() {
|
||||||
|
// Get the compute capabilities of the current hardware
|
||||||
|
nvmlDevice_t device;
|
||||||
|
CudaComputeCapabilities capabilities{0, 0};
|
||||||
|
if (nvmlDeviceGetHandleByIndex_v2(0, &device) == NVML_SUCCESS) {
|
||||||
|
SPDLOG_DEBUG("Successfully acquired nvmlDevice_t = 0");
|
||||||
|
if (nvmlDeviceGetCudaComputeCapability(device, &capabilities.major, &capabilities.minor) == NVML_SUCCESS) {
|
||||||
|
SPDLOG_INFO("Detected sm_{:d}{:d} compute capabilities", capabilities.major, capabilities.minor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return capabilities;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return the number of GPU detected. If no GPU is detected, return size_t::max()
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
std::optional<size_t> GetNumDevices() {
|
||||||
|
uint32_t numGpus = 0;
|
||||||
|
if (nvmlDeviceGetCount_v2(&numGpus) == NVML_SUCCESS) {
|
||||||
|
return std::optional(numGpus);
|
||||||
|
} else {
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif //TGI_TRTLLM_BACKEND_HARDWARE_H
|
146
backends/trtllm/lib/backend.cpp
Normal file
146
backends/trtllm/lib/backend.cpp
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
#include <fmt/ranges.h>
|
||||||
|
#include <spdlog/spdlog.h>
|
||||||
|
#include <nvml.h>
|
||||||
|
|
||||||
|
#include "backend.h"
|
||||||
|
#include "hardware.h"
|
||||||
|
|
||||||
|
void huggingface::tgi::backends::InitializeBackend() {
|
||||||
|
SPDLOG_INFO("Initializing Backend...");
|
||||||
|
nvmlInit_v2();
|
||||||
|
initTrtLlmPlugins();
|
||||||
|
|
||||||
|
const auto numGpus = huggingface::hardware::cuda::GetNumDevices();
|
||||||
|
if (numGpus.has_value()) {
|
||||||
|
SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value());
|
||||||
|
} else {
|
||||||
|
SPDLOG_WARN("Failed to detected Nvidia GPU(s) on the system");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]]
|
||||||
|
tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
|
||||||
|
tle::ExecutorConfig execConfig(1);
|
||||||
|
|
||||||
|
// Retrieve the compute capabilities to enable some options at runtime
|
||||||
|
const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities();
|
||||||
|
|
||||||
|
// Single engine (TP = PP = 1) -> using leader mode (no MPI involved)
|
||||||
|
if (config["/pretrained_config/mapping/world_size"_json_pointer].get<uint8_t>() == 1) {
|
||||||
|
SPDLOG_INFO("Detected single engine deployment, using leader mode");
|
||||||
|
execConfig.setParallelConfig(tle::ParallelConfig(
|
||||||
|
tle::CommunicationType::kMPI,
|
||||||
|
tle::CommunicationMode::kLEADER,
|
||||||
|
std::nullopt,
|
||||||
|
std::nullopt,
|
||||||
|
std::nullopt
|
||||||
|
));
|
||||||
|
} else { // Multiple engines -> using orchestrator mode (MPI involved)
|
||||||
|
SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
|
||||||
|
execConfig.setParallelConfig(tle::ParallelConfig(
|
||||||
|
tle::CommunicationType::kMPI,
|
||||||
|
tle::CommunicationMode::kORCHESTRATOR,
|
||||||
|
std::nullopt,
|
||||||
|
std::nullopt,
|
||||||
|
tle::OrchestratorConfig(true, workerPath, nullptr, true)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define some configuration variables
|
||||||
|
execConfig.setKvCacheConfig(tle::KvCacheConfig(true));
|
||||||
|
execConfig.setEnableChunkedContext(computeCapabilities.isPostAmpere());
|
||||||
|
return execConfig;
|
||||||
|
}
|
||||||
|
|
||||||
|
tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
|
||||||
|
uint32_t topK,
|
||||||
|
float_t topP,
|
||||||
|
float_t temperature,
|
||||||
|
float_t repetition_penalty,
|
||||||
|
float_t frequency_penalty,
|
||||||
|
uint64_t seed) {
|
||||||
|
return tle::SamplingConfig(
|
||||||
|
1, // TGI only use a single beam
|
||||||
|
topK,
|
||||||
|
topP,
|
||||||
|
std::nullopt,
|
||||||
|
std::nullopt,
|
||||||
|
std::nullopt,
|
||||||
|
seed,
|
||||||
|
temperature,
|
||||||
|
temperature,
|
||||||
|
std::nullopt,
|
||||||
|
repetition_penalty,
|
||||||
|
std::nullopt,
|
||||||
|
frequency_penalty
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
|
||||||
|
const std::filesystem::path &enginesFolder,
|
||||||
|
const std::filesystem::path &executorWorker
|
||||||
|
) :
|
||||||
|
config(json::parse(std::ifstream(enginesFolder / "config.json"))),
|
||||||
|
executor(
|
||||||
|
enginesFolder,
|
||||||
|
tensorrt_llm::executor::ModelType::kDECODER_ONLY,
|
||||||
|
GetExecutorConfig(config, executorWorker.string()
|
||||||
|
)) {
|
||||||
|
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
|
||||||
|
}
|
||||||
|
|
||||||
|
bool huggingface::tgi::backends::TensorRtLlmBackend::IsReady() const {
|
||||||
|
return executor.canEnqueueRequests();
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard("Returned number of requests needs to be consumed")]]
|
||||||
|
size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const {
|
||||||
|
return executor.getNumResponsesReady();
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
|
||||||
|
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
||||||
|
const std::vector<tle::TokenIdType> &tokens,
|
||||||
|
const int32_t topK,
|
||||||
|
const float_t topP,
|
||||||
|
const float_t temperature,
|
||||||
|
const float_t repetition_penalty,
|
||||||
|
const float_t frequency_penalty,
|
||||||
|
const uint64_t seed
|
||||||
|
) {
|
||||||
|
#ifdef NDEBUG
|
||||||
|
SPDLOG_DEBUG(
|
||||||
|
FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"),
|
||||||
|
tokens.size(),
|
||||||
|
executor.getLatestIterationStats().back().numActiveRequests
|
||||||
|
);
|
||||||
|
#else
|
||||||
|
SPDLOG_DEBUG(
|
||||||
|
FMT_STRING("Submitting inference [{}] to the executor ({:d} already in-flight)"),
|
||||||
|
fmt::join(tokens, ", "),
|
||||||
|
executor.getLatestIterationStats().front().numActiveRequests
|
||||||
|
);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<size_t>();
|
||||||
|
const auto maxNewTokens = static_cast<int32_t>(std::max(1ul, maxNumTokens - tokens.size()));
|
||||||
|
|
||||||
|
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
||||||
|
const auto output = tle::OutputConfig(true, false, false, true, false);
|
||||||
|
return executor.enqueueRequest(
|
||||||
|
tle::Request{tokens, maxNewTokens, true, sampling, output});
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard("Generated tokens result must be used")]]
|
||||||
|
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType requestId) {
|
||||||
|
SPDLOG_DEBUG(FMT_STRING("Polling status for request {:d}"), requestId);
|
||||||
|
return executor.awaitResponses(requestId);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void huggingface::tgi::backends::TensorRtLlmBackend::Shutdown() {
|
||||||
|
SPDLOG_INFO("Shutting down executor");
|
||||||
|
executor.shutdown();
|
||||||
|
}
|
111
backends/trtllm/scripts/install_tensorrt.sh
Executable file
111
backends/trtllm/scripts/install_tensorrt.sh
Executable file
@ -0,0 +1,111 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
TRT_VER="10.2.0.19"
|
||||||
|
CUDA_VER="12.5"
|
||||||
|
CUDNN_VER="9.2.1.18-1"
|
||||||
|
NCCL_VER="2.22.3-1+cuda12.5"
|
||||||
|
CUBLAS_VER="12.5.3.2-1"
|
||||||
|
NVRTC_VER="12.5.82-1"
|
||||||
|
|
||||||
|
for i in "$@"; do
|
||||||
|
case $i in
|
||||||
|
--TRT_VER=?*) TRT_VER="${i#*=}";;
|
||||||
|
--CUDA_VER=?*) CUDA_VER="${i#*=}";;
|
||||||
|
--CUDNN_VER=?*) CUDNN_VER="${i#*=}";;
|
||||||
|
--NCCL_VER=?*) NCCL_VER="${i#*=}";;
|
||||||
|
--CUBLAS_VER=?*) CUBLAS_VER="${i#*=}";;
|
||||||
|
*) ;;
|
||||||
|
esac
|
||||||
|
shift
|
||||||
|
done
|
||||||
|
|
||||||
|
NVCC_VERSION_OUTPUT=$(nvcc --version)
|
||||||
|
if [[ $(echo $NVCC_VERSION_OUTPUT | grep -oP "\d+\.\d+" | head -n 1) != ${CUDA_VER} ]]; then
|
||||||
|
echo "The version of pre-installed CUDA is not equal to ${CUDA_VER}."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
install_ubuntu_requirements() {
|
||||||
|
apt-get update && apt-get install -y --no-install-recommends gnupg2 curl ca-certificates
|
||||||
|
ARCH=$(uname -m)
|
||||||
|
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
|
||||||
|
if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi
|
||||||
|
curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/${ARCH}/cuda-keyring_1.0-1_all.deb
|
||||||
|
dpkg -i cuda-keyring_1.0-1_all.deb
|
||||||
|
|
||||||
|
apt-get update
|
||||||
|
if [[ $(apt list --installed | grep libcudnn9) ]]; then
|
||||||
|
apt-get remove --purge -y --allow-change-held-packages libcudnn9*
|
||||||
|
fi
|
||||||
|
if [[ $(apt list --installed | grep libnccl) ]]; then
|
||||||
|
apt-get remove --purge -y --allow-change-held-packages libnccl*
|
||||||
|
fi
|
||||||
|
if [[ $(apt list --installed | grep libcublas) ]]; then
|
||||||
|
apt-get remove --purge -y --allow-change-held-packages libcublas*
|
||||||
|
fi
|
||||||
|
if [[ $(apt list --installed | grep cuda-nvrtc-dev) ]]; then
|
||||||
|
apt-get remove --purge -y --allow-change-held-packages cuda-nvrtc-dev*
|
||||||
|
fi
|
||||||
|
CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
|
||||||
|
apt-get install -y --no-install-recommends libcudnn9-cuda-12=${CUDNN_VER} libcudnn9-dev-cuda-12=${CUDNN_VER}
|
||||||
|
apt-get install -y --no-install-recommends libnccl2=${NCCL_VER} libnccl-dev=${NCCL_VER}
|
||||||
|
apt-get install -y --no-install-recommends libcublas-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} libcublas-dev-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER}
|
||||||
|
# NVRTC static library doesn't exist in NGC PyTorch container.
|
||||||
|
NVRTC_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
|
||||||
|
apt-get install -y --no-install-recommends cuda-nvrtc-dev-${NVRTC_CUDA_VERSION}=${NVRTC_VER}
|
||||||
|
apt-get clean
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
}
|
||||||
|
|
||||||
|
install_centos_requirements() {
|
||||||
|
CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
|
||||||
|
yum -y update
|
||||||
|
yum -y install epel-release
|
||||||
|
yum remove -y libnccl* && yum -y install libnccl-${NCCL_VER} libnccl-devel-${NCCL_VER}
|
||||||
|
yum remove -y libcublas* && yum -y install libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER} libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}
|
||||||
|
yum clean all
|
||||||
|
}
|
||||||
|
|
||||||
|
install_tensorrt() {
|
||||||
|
#PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')
|
||||||
|
#PARSED_PY_VERSION=$(echo "${PY_VERSION//./}")
|
||||||
|
TRT_CUDA_VERSION="12.5"
|
||||||
|
|
||||||
|
if [ -z "$RELEASE_URL_TRT" ];then
|
||||||
|
ARCH=${TRT_TARGETARCH}
|
||||||
|
if [ -z "$ARCH" ];then ARCH=$(uname -m);fi
|
||||||
|
if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi
|
||||||
|
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
|
||||||
|
if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi
|
||||||
|
if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-22.04" && OS="ubuntu-22.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi
|
||||||
|
RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.2.0/tars/TensorRT-${TRT_VER}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz
|
||||||
|
fi
|
||||||
|
wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar
|
||||||
|
tar -xf /tmp/TensorRT.tar -C /usr/local/
|
||||||
|
mv /usr/local/TensorRT-${TRT_VER} /usr/local/tensorrt
|
||||||
|
# pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl
|
||||||
|
rm -rf /tmp/TensorRT.tar
|
||||||
|
}
|
||||||
|
|
||||||
|
# Install base packages depending on the base OS
|
||||||
|
ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"')
|
||||||
|
case "$ID" in
|
||||||
|
debian)
|
||||||
|
install_ubuntu_requirements
|
||||||
|
install_tensorrt
|
||||||
|
;;
|
||||||
|
ubuntu)
|
||||||
|
install_ubuntu_requirements
|
||||||
|
install_tensorrt
|
||||||
|
;;
|
||||||
|
centos)
|
||||||
|
install_centos_requirements
|
||||||
|
install_tensorrt
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unable to determine OS..."
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
329
backends/trtllm/src/backend.rs
Normal file
329
backends/trtllm/src/backend.rs
Normal file
@ -0,0 +1,329 @@
|
|||||||
|
use std::future::Future;
|
||||||
|
use std::path::Path;
|
||||||
|
use std::pin::{pin, Pin};
|
||||||
|
use std::str::FromStr;
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::sync::{Arc, OnceLock};
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use cxx::UniquePtr;
|
||||||
|
use log::{error, warn};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
use tokio::time::{sleep, Instant};
|
||||||
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
|
use tokio_stream::{Stream, StreamExt};
|
||||||
|
use tracing::{instrument, span, Level};
|
||||||
|
|
||||||
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
|
use text_generation_router::validation::ValidationError::UnsupportedModality;
|
||||||
|
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError};
|
||||||
|
use text_generation_router::{FinishReason, Token};
|
||||||
|
|
||||||
|
use crate::errors::TensorRtLlmBackendError;
|
||||||
|
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
|
||||||
|
|
||||||
|
// Value used to poll the state of the generation stream
|
||||||
|
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
|
||||||
|
|
||||||
|
type InferResult<T> = Result<T, InferError>;
|
||||||
|
|
||||||
|
pub(crate) struct Generation {
|
||||||
|
executor: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
|
||||||
|
done: Arc<AtomicBool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Holds the user provided input to be executed along with a channel allowing
|
||||||
|
/// to bubble up all the generated tokens for that tokens the to end stream.
|
||||||
|
pub struct GenerationContext {
|
||||||
|
sender: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||||
|
tokenizer: Arc<Tokenizer>,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
done: Arc<AtomicBool>,
|
||||||
|
queued: Instant,
|
||||||
|
start: Option<Instant>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Stream for Generation {
|
||||||
|
type Item = usize;
|
||||||
|
|
||||||
|
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
|
let interval = POLLING_INTERVAL_US.get_or_init(|| {
|
||||||
|
u64::from_str(option_env!("TRTLLM_BACKEND_POLLING_INTERVAL_US").unwrap_or("100"))
|
||||||
|
.expect("Invalid value provided for envvar POLLING_INTERVAL_US")
|
||||||
|
});
|
||||||
|
|
||||||
|
if !self.done.load(Ordering::Relaxed) {
|
||||||
|
let backend = pin!(self.executor.read());
|
||||||
|
let status = match backend.poll(ctx) {
|
||||||
|
Poll::Ready(executor_r) => {
|
||||||
|
let ready = executor_r.num_responses_ready();
|
||||||
|
if ready == 0 {
|
||||||
|
Poll::Pending
|
||||||
|
} else {
|
||||||
|
Poll::Ready(Some(ready))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Poll::Pending => Poll::Pending,
|
||||||
|
};
|
||||||
|
|
||||||
|
let waker = ctx.waker().clone();
|
||||||
|
tokio::spawn(async {
|
||||||
|
sleep(Duration::from_micros(*interval)).await;
|
||||||
|
waker.wake();
|
||||||
|
});
|
||||||
|
|
||||||
|
status
|
||||||
|
} else {
|
||||||
|
Poll::Ready(None) // end of stream
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||||
|
(1, None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe impl Send for TensorRtLlmBackendImpl {}
|
||||||
|
unsafe impl Sync for TensorRtLlmBackendImpl {}
|
||||||
|
|
||||||
|
/// Implements the logic to execute generation with TensorRT-LLM executor API in background
|
||||||
|
pub struct TensorRtLlmBackend {
|
||||||
|
tokenizer: Arc<Tokenizer>,
|
||||||
|
|
||||||
|
// Backing the backend behind a RwLock to allow concurrent read access to retrieve
|
||||||
|
// the number of available tokens (read only) in the Generation stream
|
||||||
|
backend: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TensorRtLlmBackend {
|
||||||
|
pub fn new<P: AsRef<Path> + Send + 'static, PP: AsRef<Path> + Send + 'static>(
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
engine_folder: P,
|
||||||
|
executor_worker_path: PP,
|
||||||
|
) -> Result<Self, TensorRtLlmBackendError> {
|
||||||
|
Ok(TensorRtLlmBackend {
|
||||||
|
tokenizer: Arc::new(tokenizer),
|
||||||
|
backend: Arc::new(RwLock::new(create_tensorrt_llm_backend(
|
||||||
|
engine_folder.as_ref().to_str().unwrap(),
|
||||||
|
executor_worker_path.as_ref().to_str().unwrap(),
|
||||||
|
))),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
|
||||||
|
if request.top_n_tokens > 1 {
|
||||||
|
return Err(InferError::ValidationError(
|
||||||
|
ValidationError::TopNTokensDisabled,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Is it really needed? How can it be validated before?
|
||||||
|
if request.parameters.grammar.is_some() {
|
||||||
|
return Err(InferError::ValidationError(ValidationError::Grammar));
|
||||||
|
}
|
||||||
|
|
||||||
|
match request.inputs.len() {
|
||||||
|
0 => Err(InferError::ValidationError(ValidationError::EmptyInput)),
|
||||||
|
2.. => Err(InferError::GenerationError(
|
||||||
|
"TensorRT-LLM backend don't support multi-chunk".into(),
|
||||||
|
)),
|
||||||
|
1 => match request.inputs.first().expect("Single item-chunk") {
|
||||||
|
Chunk::Text(text) => Ok(text),
|
||||||
|
Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn generate(
|
||||||
|
&self,
|
||||||
|
sender: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
top_k: u32,
|
||||||
|
top_p: f32,
|
||||||
|
temperature: f32,
|
||||||
|
repetition_penalty: f32,
|
||||||
|
frequency_penalty: f32,
|
||||||
|
seed: u64,
|
||||||
|
) {
|
||||||
|
let tokenizer = Arc::clone(&self.tokenizer);
|
||||||
|
let executor = Arc::clone(&self.backend);
|
||||||
|
|
||||||
|
// Let's push this in async context
|
||||||
|
tokio::spawn(async move {
|
||||||
|
// Define the generation state
|
||||||
|
let mut generation = Generation {
|
||||||
|
executor: executor.clone(),
|
||||||
|
done: Arc::new(AtomicBool::new(false)),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Define the context over the generation
|
||||||
|
// TODO(asap): Do we really need so many shared-ownership?
|
||||||
|
let ctx = Box::new(GenerationContext {
|
||||||
|
sender: sender.clone(),
|
||||||
|
tokenizer,
|
||||||
|
tokens: vec![],
|
||||||
|
done: Arc::clone(&generation.done),
|
||||||
|
start: None,
|
||||||
|
queued: Instant::now(),
|
||||||
|
});
|
||||||
|
|
||||||
|
// We are leaking the context on-purpose to avoid the box being dropped while there are
|
||||||
|
// still computation ongoing
|
||||||
|
// TODO(asap): Can we achieve the same with an Arc<Box<T>> without the need to go unsafe?
|
||||||
|
let ctx_ = Box::leak(ctx);
|
||||||
|
|
||||||
|
// Submit the request to the batcher
|
||||||
|
let request_id = span!(Level::DEBUG, "submit")
|
||||||
|
.in_scope(|| async {
|
||||||
|
let mut handle = executor.write().await;
|
||||||
|
let request_id = handle.pin_mut().submit(
|
||||||
|
&tokens,
|
||||||
|
top_k as i32,
|
||||||
|
top_p,
|
||||||
|
temperature,
|
||||||
|
repetition_penalty,
|
||||||
|
frequency_penalty,
|
||||||
|
seed,
|
||||||
|
);
|
||||||
|
|
||||||
|
request_id
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
while let Some(_) = generation.next().await {
|
||||||
|
let mut executor_w = executor.write().await;
|
||||||
|
let executor = executor_w.pin_mut();
|
||||||
|
|
||||||
|
span!(Level::DEBUG, "decode")
|
||||||
|
.in_scope(|| async {
|
||||||
|
unsafe {
|
||||||
|
executor.stream_tokens(
|
||||||
|
request_id,
|
||||||
|
ctx_,
|
||||||
|
|ctx: *mut GenerationContext, step: GenerationStep| {
|
||||||
|
let inner_ctx = &mut *ctx;
|
||||||
|
|
||||||
|
// Update the timestamp at which the request started effectively
|
||||||
|
// Can be a bit off, would need to be before the callback, let's see
|
||||||
|
inner_ctx.start.get_or_insert(Instant::now());
|
||||||
|
inner_ctx.done.store(step.is_final, Ordering::Relaxed);
|
||||||
|
|
||||||
|
// Ensure we are not running into errors
|
||||||
|
let parcel = if !step.has_error {
|
||||||
|
// Insert the latest generated token to the tracker
|
||||||
|
inner_ctx.tokens.push(step.token_id);
|
||||||
|
|
||||||
|
// Decode the token
|
||||||
|
let text = inner_ctx
|
||||||
|
.tokenizer
|
||||||
|
.decode(&[step.token_id], true)
|
||||||
|
.expect("Failed to decode token");
|
||||||
|
|
||||||
|
let special = inner_ctx
|
||||||
|
.tokenizer
|
||||||
|
.get_added_vocabulary()
|
||||||
|
.is_special_token(&text);
|
||||||
|
|
||||||
|
// Create the structure holding the token
|
||||||
|
let token = Token {
|
||||||
|
id: step.token_id,
|
||||||
|
text,
|
||||||
|
logprob: step.log_prob,
|
||||||
|
special,
|
||||||
|
};
|
||||||
|
|
||||||
|
if step.is_final {
|
||||||
|
let generated_text = inner_ctx
|
||||||
|
.tokenizer
|
||||||
|
.decode(&inner_ctx.tokens, true)
|
||||||
|
.expect("Failed to decode generated_tokens");
|
||||||
|
|
||||||
|
Ok(InferStreamResponse::End {
|
||||||
|
token,
|
||||||
|
top_tokens: vec![],
|
||||||
|
generated_text: GeneratedText {
|
||||||
|
text: generated_text,
|
||||||
|
generated_tokens: inner_ctx.tokens.len() as u32,
|
||||||
|
finish_reason: FinishReason::EndOfSequenceToken,
|
||||||
|
seed: None,
|
||||||
|
},
|
||||||
|
start: inner_ctx.start.unwrap_or(Instant::now()),
|
||||||
|
queued: inner_ctx.queued,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Ok(InferStreamResponse::Intermediate {
|
||||||
|
token,
|
||||||
|
top_tokens: vec![],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
error!("Error caught while decoding: {}", &step.error_msg);
|
||||||
|
Err(InferError::GenerationError(step.error_msg))
|
||||||
|
};
|
||||||
|
|
||||||
|
// Send the parcel to the client
|
||||||
|
inner_ctx
|
||||||
|
.sender
|
||||||
|
.send(parcel)
|
||||||
|
.expect("Failed to sent msg through the channel");
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// "Properly" free the shared context...
|
||||||
|
// TODO: clean that piece of sh** asap
|
||||||
|
unsafe {
|
||||||
|
let _ = Box::from_raw(ctx_);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Backend for TensorRtLlmBackend {
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
fn schedule(
|
||||||
|
&self,
|
||||||
|
request: ValidGenerateRequest,
|
||||||
|
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
|
||||||
|
// Let's add a few more validation
|
||||||
|
let input = TensorRtLlmBackend::validate(&request)?;
|
||||||
|
|
||||||
|
// Channel to stream the generated token as they come from the worker thread back to the transport layer
|
||||||
|
let (sender, receiver) = unbounded_channel();
|
||||||
|
|
||||||
|
// Unpack parameters
|
||||||
|
let params = &request.parameters;
|
||||||
|
|
||||||
|
// Preprocess the inputs to send to TRTLLM backend
|
||||||
|
let encoding = self
|
||||||
|
.tokenizer
|
||||||
|
.encode(input.as_str(), true)
|
||||||
|
.map_err(|e| InferError::GenerationError(e.to_string()))?;
|
||||||
|
|
||||||
|
// Generate the response
|
||||||
|
self.generate(
|
||||||
|
sender,
|
||||||
|
Vec::from(encoding.get_ids()),
|
||||||
|
params.top_k,
|
||||||
|
params.top_p,
|
||||||
|
params.temperature,
|
||||||
|
params.repetition_penalty,
|
||||||
|
params.frequency_penalty,
|
||||||
|
params.seed,
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(UnboundedReceiverStream::new(receiver))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health(&self, _current_health: bool) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
15
backends/trtllm/src/errors.rs
Normal file
15
backends/trtllm/src/errors.rs
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
use text_generation_router::server;
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum TensorRtLlmBackendError {
|
||||||
|
#[error("Tokenizer error: {0}")]
|
||||||
|
Tokenizer(String),
|
||||||
|
#[error("Argument validation error: {0}")]
|
||||||
|
ArgumentValidation(String),
|
||||||
|
#[error("WebServer error: {0}")]
|
||||||
|
WebServer(#[from] server::WebServerError),
|
||||||
|
#[error("Tokio runtime failed to start: {0}")]
|
||||||
|
Tokio(#[from] std::io::Error),
|
||||||
|
}
|
84
backends/trtllm/src/ffi.cpp
Normal file
84
backends/trtllm/src/ffi.cpp
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
//
|
||||||
|
// Created by mfuntowicz on 6/30/24.
|
||||||
|
//
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <exception>
|
||||||
|
#include <filesystem>
|
||||||
|
#include <limits>
|
||||||
|
#include <iterator>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <spdlog/spdlog.h>
|
||||||
|
#include "backends/trtllm/include/ffi.h"
|
||||||
|
|
||||||
|
|
||||||
|
huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(
|
||||||
|
const std::string_view &engineFolder,
|
||||||
|
const std::string_view &executorWorker
|
||||||
|
) : TensorRtLlmBackend(engineFolder, executorWorker) {}
|
||||||
|
|
||||||
|
|
||||||
|
bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const {
|
||||||
|
return TensorRtLlmBackend::IsReady();
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
|
||||||
|
rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, float_t repetition_penalty,
|
||||||
|
float_t frequency_penalty, uint64_t seed) {
|
||||||
|
|
||||||
|
// This will copy all the items from the initial slice
|
||||||
|
std::vector<int32_t> tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end()));
|
||||||
|
return TensorRtLlmBackend::Submit(
|
||||||
|
std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
|
||||||
|
const uint64_t requestId,
|
||||||
|
huggingface::tgi::backends::GenerationContext *ctx,
|
||||||
|
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
|
||||||
|
huggingface::tgi::backends::GenerationStep)> callback) {
|
||||||
|
|
||||||
|
size_t numTokens = 0;
|
||||||
|
for (const auto &item: Poll(requestId)) {
|
||||||
|
GenerationStep step;
|
||||||
|
if (!item.hasError()) {
|
||||||
|
SPDLOG_DEBUG("\tStreamTokens -> Decoding token...");
|
||||||
|
const auto decoded = item.getResult();
|
||||||
|
|
||||||
|
const auto token = decoded.outputTokenIds[0][0];
|
||||||
|
const auto isFinal = decoded.isFinal;
|
||||||
|
const auto logProb = decoded.logProbs.value()[0][0];
|
||||||
|
|
||||||
|
++numTokens;
|
||||||
|
|
||||||
|
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
|
||||||
|
step = huggingface::tgi::backends::GenerationStep{
|
||||||
|
static_cast<uint32_t>(token), logProb, isFinal, false, std::move(std::string())
|
||||||
|
};
|
||||||
|
SPDLOG_DEBUG("\tStreamTokens -> Post callback");
|
||||||
|
} else {
|
||||||
|
// TODO : Return rest::Result with error
|
||||||
|
const auto what = item.getErrorMsg();
|
||||||
|
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what);
|
||||||
|
step = huggingface::tgi::backends::GenerationStep{
|
||||||
|
std::numeric_limits<uint32_t>::max(), 0.0, true, true, std::move(what)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
callback(std::move(ctx), std::move(step));
|
||||||
|
}
|
||||||
|
|
||||||
|
return numTokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
|
||||||
|
huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
|
||||||
|
// Unconditionally call this to initialize and discover TRTLLM plugins
|
||||||
|
InitializeBackend();
|
||||||
|
|
||||||
|
const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end());
|
||||||
|
const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end());
|
||||||
|
return std::make_unique<TensorRtLlmBackendImpl>(std::move(enginePath), std::move(executorPath));
|
||||||
|
}
|
78
backends/trtllm/src/lib.rs
Normal file
78
backends/trtllm/src/lib.rs
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
pub use backend::{GenerationContext, TensorRtLlmBackend};
|
||||||
|
|
||||||
|
mod backend;
|
||||||
|
pub mod errors;
|
||||||
|
|
||||||
|
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
|
||||||
|
mod ffi {
|
||||||
|
|
||||||
|
/// Struct used as shared type between rust and C++ to represent the result
|
||||||
|
/// of a single decoding iteration
|
||||||
|
pub struct GenerationStep {
|
||||||
|
token_id: u32,
|
||||||
|
log_prob: f32,
|
||||||
|
is_final: bool,
|
||||||
|
has_error: bool,
|
||||||
|
error_msg: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "Rust" {
|
||||||
|
type GenerationContext;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe extern "C++" {
|
||||||
|
include!("backends/trtllm/src/ffi.cpp");
|
||||||
|
|
||||||
|
/// Represent an instance of the underlying TensorRT-LLM backend
|
||||||
|
type TensorRtLlmBackendImpl;
|
||||||
|
|
||||||
|
/// Create an instance backed behind a std::unique_ptr to manage the lifespan of the backend
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `engine_folder`: Path to the folder containing all the TRTLLM engines
|
||||||
|
/// * `executor_worker`: Path to the TRTLLM executor worker
|
||||||
|
///
|
||||||
|
/// returns: <unknown>
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
#[rust_name = "create_tensorrt_llm_backend"]
|
||||||
|
fn CreateTensorRtLlmBackend(
|
||||||
|
engine_folder: &str,
|
||||||
|
executor_worker: &str,
|
||||||
|
) -> UniquePtr<TensorRtLlmBackendImpl>;
|
||||||
|
|
||||||
|
// #[rust_name = "is_ready"]
|
||||||
|
// fn IsReady(self: &TensorRtLlmBackendImpl) -> bool;
|
||||||
|
|
||||||
|
#[rust_name = "num_responses_ready"]
|
||||||
|
fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
|
||||||
|
|
||||||
|
#[rust_name = "submit"]
|
||||||
|
fn Submit(
|
||||||
|
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||||
|
tokens: &[u32],
|
||||||
|
top_k: i32,
|
||||||
|
top_p: f32,
|
||||||
|
temperature: f32,
|
||||||
|
repetition_penalty: f32,
|
||||||
|
frequency_penalty: f32,
|
||||||
|
seed: u64,
|
||||||
|
) -> u64;
|
||||||
|
|
||||||
|
#[rust_name = "stream_tokens"]
|
||||||
|
unsafe fn StreamTokens(
|
||||||
|
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||||
|
request_id: u64,
|
||||||
|
ctx: *mut GenerationContext,
|
||||||
|
cb: unsafe fn(*mut GenerationContext, GenerationStep),
|
||||||
|
) -> usize;
|
||||||
|
|
||||||
|
// #[rust_name = "shutdown"]
|
||||||
|
// fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>);
|
||||||
|
}
|
||||||
|
}
|
166
backends/trtllm/src/main.rs
Normal file
166
backends/trtllm/src/main.rs
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use clap::Parser;
|
||||||
|
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||||
|
|
||||||
|
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
|
||||||
|
use text_generation_backends_trtllm::TensorRtLlmBackend;
|
||||||
|
use text_generation_router::server;
|
||||||
|
|
||||||
|
/// App Configuration
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[clap(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
#[clap(default_value = "128", long, env)]
|
||||||
|
max_concurrent_requests: usize,
|
||||||
|
#[clap(default_value = "2", long, env)]
|
||||||
|
max_best_of: usize,
|
||||||
|
#[clap(default_value = "4", long, env)]
|
||||||
|
max_stop_sequences: usize,
|
||||||
|
#[clap(default_value = "5", long, env)]
|
||||||
|
max_top_n_tokens: u32,
|
||||||
|
#[clap(default_value = "1024", long, env)]
|
||||||
|
max_input_tokens: usize,
|
||||||
|
#[clap(default_value = "2048", long, env)]
|
||||||
|
max_total_tokens: usize,
|
||||||
|
#[clap(default_value = "4096", long, env)]
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
#[clap(long, env)]
|
||||||
|
max_batch_total_tokens: Option<u32>,
|
||||||
|
#[clap(default_value = "0.0.0.0", long, env)]
|
||||||
|
hostname: String,
|
||||||
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
|
port: u16,
|
||||||
|
#[clap(long, env, required = true)]
|
||||||
|
tokenizer_name: String,
|
||||||
|
#[clap(long, env)]
|
||||||
|
tokenizer_config_path: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
revision: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
model_id: String,
|
||||||
|
#[clap(default_value = "2", long, env)]
|
||||||
|
validation_workers: usize,
|
||||||
|
#[clap(long, env)]
|
||||||
|
json_output: bool,
|
||||||
|
#[clap(long, env)]
|
||||||
|
otlp_endpoint: Option<String>,
|
||||||
|
#[clap(default_value = "text-generation-inference.router", long, env)]
|
||||||
|
otlp_service_name: String,
|
||||||
|
#[clap(long, env)]
|
||||||
|
cors_allow_origin: Option<Vec<String>>,
|
||||||
|
#[clap(long, env, default_value_t = false)]
|
||||||
|
messages_api_enabled: bool,
|
||||||
|
#[clap(default_value = "4", long, env)]
|
||||||
|
max_client_batch_size: usize,
|
||||||
|
#[clap(long, env)]
|
||||||
|
auth_token: Option<String>,
|
||||||
|
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
|
||||||
|
executor_worker: PathBuf,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), TensorRtLlmBackendError> {
|
||||||
|
// Get args
|
||||||
|
let args = Args::parse();
|
||||||
|
// Pattern match configuration
|
||||||
|
let Args {
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
hostname,
|
||||||
|
port,
|
||||||
|
tokenizer_name,
|
||||||
|
tokenizer_config_path,
|
||||||
|
revision,
|
||||||
|
model_id,
|
||||||
|
validation_workers,
|
||||||
|
json_output,
|
||||||
|
otlp_endpoint,
|
||||||
|
otlp_service_name,
|
||||||
|
cors_allow_origin,
|
||||||
|
messages_api_enabled,
|
||||||
|
max_client_batch_size,
|
||||||
|
auth_token,
|
||||||
|
executor_worker,
|
||||||
|
} = args;
|
||||||
|
|
||||||
|
// Launch Tokio runtime
|
||||||
|
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||||
|
|
||||||
|
// Validate args
|
||||||
|
if max_input_tokens >= max_total_tokens {
|
||||||
|
return Err(TensorRtLlmBackendError::ArgumentValidation(
|
||||||
|
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||||
|
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||||
|
}
|
||||||
|
|
||||||
|
if validation_workers == 0 {
|
||||||
|
return Err(TensorRtLlmBackendError::ArgumentValidation(
|
||||||
|
"`validation_workers` must be > 0".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
||||||
|
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||||
|
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||||
|
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !executor_worker.exists() {
|
||||||
|
return Err(TensorRtLlmBackendError::ArgumentValidation(format!(
|
||||||
|
"`executor_work` specified path doesn't exists: {}",
|
||||||
|
executor_worker.display()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run server
|
||||||
|
let tokenizer = Tokenizer::from_pretrained(
|
||||||
|
tokenizer_name.clone(),
|
||||||
|
Some(FromPretrainedParameters {
|
||||||
|
revision: revision.clone().unwrap_or(String::from("main")),
|
||||||
|
user_agent: HashMap::new(),
|
||||||
|
auth_token,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?;
|
||||||
|
|
||||||
|
let backend = TensorRtLlmBackend::new(tokenizer, model_id, executor_worker)?;
|
||||||
|
server::run(
|
||||||
|
backend,
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
validation_workers,
|
||||||
|
None,
|
||||||
|
tokenizer_name,
|
||||||
|
tokenizer_config_path,
|
||||||
|
revision,
|
||||||
|
hostname,
|
||||||
|
port,
|
||||||
|
cors_allow_origin,
|
||||||
|
false,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
messages_api_enabled,
|
||||||
|
true,
|
||||||
|
max_client_batch_size,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
14
backends/trtllm/tests/infer_test.cpp
Normal file
14
backends/trtllm/tests/infer_test.cpp
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
//
|
||||||
|
// Created by mfuntowicz on 7/2/24.
|
||||||
|
//
|
||||||
|
#include <catch2/catch_all.hpp>
|
||||||
|
#include <spdlog/spdlog.h>
|
||||||
|
#include "../include/backend.h"
|
||||||
|
|
||||||
|
TEST_CASE("Load TRTLLM Engine on the TGI Backend", "[trtllm][engine][load]") {
|
||||||
|
const auto engines = std::filesystem::path("/home/mfuntowicz/.cache/huggingface/assets/trtllm/0.11.0.dev2024062500/meta-llama--Meta-Llama-3-8B-Instruct/4090/engines/");
|
||||||
|
const auto executor = std::filesystem::path("/home/mfuntowicz/Workspace/text-generation-inference/backends/trtllm/cmake-build-debug/cmake-build-debug/_deps/trtllm-src/cpp/tensorrt_llm/executor_worker/executorWorker");
|
||||||
|
|
||||||
|
spdlog::info("Loading config from: {}", absolute(engines).string());
|
||||||
|
huggingface::tgi::backends::TensorRtLlmBackend backend(engines, executor);
|
||||||
|
}
|
66
backends/v3/Cargo.toml
Normal file
66
backends/v3/Cargo.toml
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
[package]
|
||||||
|
name = "text-generation-router-v3"
|
||||||
|
description = "Text Generation Webserver"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
authors.workspace = true
|
||||||
|
homepage.workspace = true
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
path = "src/lib.rs"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "text-generation-router"
|
||||||
|
path = "src/main.rs"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
async-trait = "0.1.74"
|
||||||
|
async-stream = "0.3.5"
|
||||||
|
axum = { version = "0.7", features = ["json"] }
|
||||||
|
axum-tracing-opentelemetry = "0.16"
|
||||||
|
text-generation-router = { path = "../../router" }
|
||||||
|
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||||
|
grpc-metadata = { path = "../grpc-metadata" }
|
||||||
|
futures = "0.3.28"
|
||||||
|
hf-hub = { workspace = true }
|
||||||
|
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
||||||
|
metrics = { workspace = true }
|
||||||
|
metrics-exporter-prometheus = { workspace = true }
|
||||||
|
nohash-hasher = "0.2.0"
|
||||||
|
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
||||||
|
opentelemetry-otlp = "0.13.0"
|
||||||
|
rand = "0.8.5"
|
||||||
|
reqwest = { version = "0.11.20", features = [] }
|
||||||
|
serde = "1.0.188"
|
||||||
|
serde_json = "1.0.107"
|
||||||
|
thiserror = "1.0.48"
|
||||||
|
tokenizers = { workspace = true}
|
||||||
|
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
|
tokio-stream = "0.1.14"
|
||||||
|
tower-http = { version = "0.5.1", features = ["cors"] }
|
||||||
|
tracing = "0.1.37"
|
||||||
|
tracing-opentelemetry = "0.21.0"
|
||||||
|
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||||
|
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||||
|
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||||
|
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
||||||
|
minijinja = { version = "2.0.2" }
|
||||||
|
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
||||||
|
futures-util = "0.3.30"
|
||||||
|
regex = "1.10.3"
|
||||||
|
once_cell = "1.19.0"
|
||||||
|
image = "0.25.1"
|
||||||
|
base64 = { workspace = true }
|
||||||
|
prost = "^0.12"
|
||||||
|
tonic = "^0.10"
|
||||||
|
tower = "^0.4"
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
tonic-build = "0.10.1"
|
||||||
|
prost-build = "0.12.1"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["ngrok"]
|
||||||
|
ngrok = ["text-generation-router/ngrok"]
|
||||||
|
google = ["text-generation-router/google"]
|
||||||
|
kserve = ["text-generation-router/kserve"]
|
19
backends/v3/build.rs
Normal file
19
backends/v3/build.rs
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
use std::fs;
|
||||||
|
|
||||||
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
println!("cargo:rerun-if-changed=../../proto/");
|
||||||
|
|
||||||
|
fs::create_dir_all("src/client/pb").unwrap_or(());
|
||||||
|
let mut config = prost_build::Config::new();
|
||||||
|
config.protoc_arg("--experimental_allow_proto3_optional");
|
||||||
|
|
||||||
|
tonic_build::configure()
|
||||||
|
.build_client(true)
|
||||||
|
.build_server(false)
|
||||||
|
.out_dir("src/client/pb")
|
||||||
|
.include_file("mod.rs")
|
||||||
|
.compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"])
|
||||||
|
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
508
backends/v3/src/backend.rs
Normal file
508
backends/v3/src/backend.rs
Normal file
@ -0,0 +1,508 @@
|
|||||||
|
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
|
||||||
|
/// Batching and inference logic
|
||||||
|
use crate::queue::{Entry, Queue};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use nohash_hasher::IntMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
|
use text_generation_router::validation::ValidGenerateRequest;
|
||||||
|
use text_generation_router::{FinishReason, PrefillToken, Token};
|
||||||
|
use tokio::sync::mpsc::error::SendError;
|
||||||
|
use tokio::sync::{mpsc, Notify};
|
||||||
|
use tokio::time::Instant;
|
||||||
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
|
use tracing::{info_span, instrument, Instrument, Span};
|
||||||
|
|
||||||
|
pub struct BackendV3 {
|
||||||
|
/// Request queue
|
||||||
|
queue: Queue,
|
||||||
|
/// Notify batcher on queue appends
|
||||||
|
batching_task_notifier: Arc<Notify>,
|
||||||
|
/// Client clone, used for health checks to skip the queue
|
||||||
|
client: ShardedClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BackendV3 {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub(crate) fn new(
|
||||||
|
client: ShardedClient,
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
requires_padding: bool,
|
||||||
|
window_size: Option<u32>,
|
||||||
|
speculate: u32,
|
||||||
|
) -> Self {
|
||||||
|
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
|
||||||
|
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
};
|
||||||
|
let block_size = if flashdecoding { 256 } else { 16 };
|
||||||
|
|
||||||
|
let queue = Queue::new(
|
||||||
|
requires_padding,
|
||||||
|
block_size,
|
||||||
|
window_size,
|
||||||
|
speculate,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
);
|
||||||
|
let batching_task_notifier = Arc::new(Notify::new());
|
||||||
|
|
||||||
|
// Spawn batching background task that contains all the inference logic
|
||||||
|
tokio::spawn(batching_task(
|
||||||
|
client.clone(),
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
queue.clone(),
|
||||||
|
batching_task_notifier.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
Self {
|
||||||
|
queue,
|
||||||
|
batching_task_notifier,
|
||||||
|
client,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Backend for BackendV3 {
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
fn schedule(
|
||||||
|
&self,
|
||||||
|
request: ValidGenerateRequest,
|
||||||
|
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||||
|
// MPSC channel to communicate with the background batching task
|
||||||
|
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
|
// Append the request to the queue
|
||||||
|
self.queue.append(Entry {
|
||||||
|
request,
|
||||||
|
response_tx,
|
||||||
|
span: Span::current(),
|
||||||
|
temp_span: None,
|
||||||
|
queue_time: Instant::now(),
|
||||||
|
batch_time: None,
|
||||||
|
block_allocation: None,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Notify the background task that we have a new entry in the queue that needs
|
||||||
|
// to be batched
|
||||||
|
self.batching_task_notifier.notify_one();
|
||||||
|
|
||||||
|
// Return stream
|
||||||
|
Ok(UnboundedReceiverStream::new(response_rx))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health(&self, current_health: bool) -> bool {
|
||||||
|
if current_health {
|
||||||
|
// Generation is healthy, we only check that the shards can allocate on device
|
||||||
|
self.client.device_health().await
|
||||||
|
} else {
|
||||||
|
self.client.model_health().await
|
||||||
|
}
|
||||||
|
.is_ok()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Batching logic
|
||||||
|
/// Will be launched in a background Tokio task
|
||||||
|
///
|
||||||
|
/// Batches requests and sends them to the inference server
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub(crate) async fn batching_task(
|
||||||
|
mut client: ShardedClient,
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
queue: Queue,
|
||||||
|
notifier: Arc<Notify>,
|
||||||
|
) {
|
||||||
|
// Infinite loop
|
||||||
|
loop {
|
||||||
|
// Wait for a notification from the Infer struct
|
||||||
|
notifier.notified().await;
|
||||||
|
|
||||||
|
// Get the next batch from the queue
|
||||||
|
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||||
|
// waiting in the queue
|
||||||
|
while let Some((mut entries, batch, span)) = queue
|
||||||
|
.next_batch(
|
||||||
|
None,
|
||||||
|
max_batch_size,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
||||||
|
.instrument(span)
|
||||||
|
.await;
|
||||||
|
let mut waiting_tokens = 1;
|
||||||
|
|
||||||
|
// We loop until we do not receive any cached batch from the inference server (== until
|
||||||
|
// all requests have met their stopping criteria)
|
||||||
|
while let Some(batch) = cached_batch {
|
||||||
|
// Get current batch info
|
||||||
|
let batch_size = batch.size;
|
||||||
|
let batch_max_tokens = batch.max_tokens;
|
||||||
|
let mut batches = vec![batch];
|
||||||
|
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||||
|
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||||
|
|
||||||
|
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||||
|
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||||
|
// to add a new batch even though its size might be small
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
// Minimum batch size
|
||||||
|
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||||
|
};
|
||||||
|
|
||||||
|
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||||
|
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
|
||||||
|
|
||||||
|
// Try to get a new batch
|
||||||
|
if let Some((mut new_entries, new_batch, span)) = queue
|
||||||
|
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
// Tracking metrics
|
||||||
|
if min_size.is_some() {
|
||||||
|
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
||||||
|
.increment(1);
|
||||||
|
} else {
|
||||||
|
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
||||||
|
.increment(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
entries.iter_mut().for_each(|(_, entry)| {
|
||||||
|
// Create a new span to add the info that this entry is waiting
|
||||||
|
// because a new batch is being computed
|
||||||
|
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
|
||||||
|
// Add relationships
|
||||||
|
span.follows_from(&entry_waiting_span);
|
||||||
|
entry_waiting_span.follows_from(&span);
|
||||||
|
// Update entry
|
||||||
|
entry.temp_span = Some(entry_waiting_span);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
|
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
||||||
|
.instrument(span)
|
||||||
|
.await;
|
||||||
|
// Reset waiting counter
|
||||||
|
waiting_tokens = 1;
|
||||||
|
// Extend current batch with the new batch
|
||||||
|
if let Some(new_cached_batch) = new_cached_batch {
|
||||||
|
entries.extend(new_entries);
|
||||||
|
batches.push(new_cached_batch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create span for this batch to add context to inference calls
|
||||||
|
let next_batch_size = entries.len();
|
||||||
|
let next_batch_span =
|
||||||
|
info_span!(parent: None, "batch", batch_size = next_batch_size);
|
||||||
|
entries.iter_mut().for_each(|(_, entry)| {
|
||||||
|
// Create a new span to link the batch back to this entry
|
||||||
|
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||||
|
// Add relationships
|
||||||
|
next_batch_span.follows_from(&entry_batch_span);
|
||||||
|
entry_batch_span.follows_from(&next_batch_span);
|
||||||
|
// Update entry
|
||||||
|
entry.temp_span = Some(entry_batch_span);
|
||||||
|
});
|
||||||
|
|
||||||
|
cached_batch = decode(&mut client, batches, &mut entries)
|
||||||
|
.instrument(next_batch_span)
|
||||||
|
.await;
|
||||||
|
waiting_tokens += 1;
|
||||||
|
}
|
||||||
|
metrics::gauge!("tgi_batch_current_size").set(0.0);
|
||||||
|
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn prefill(
|
||||||
|
client: &mut ShardedClient,
|
||||||
|
batch: Batch,
|
||||||
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
) -> Option<CachedBatch> {
|
||||||
|
let start_time = Instant::now();
|
||||||
|
let batch_id = batch.id;
|
||||||
|
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
||||||
|
|
||||||
|
match client.prefill(batch).await {
|
||||||
|
Ok((generations, next_batch, timings)) => {
|
||||||
|
let start_filtering_time = Instant::now();
|
||||||
|
// Send generated tokens and filter stopped entries
|
||||||
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
|
// Filter next batch and remove requests that were stopped
|
||||||
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
|
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
|
||||||
|
.record(timings.forward.as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||||
|
.record(timings.decode.as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
|
||||||
|
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
|
||||||
|
.record(start_time.elapsed().as_secs_f64());
|
||||||
|
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
|
||||||
|
next_batch
|
||||||
|
}
|
||||||
|
// If we have an error, we discard the whole batch
|
||||||
|
Err(err) => {
|
||||||
|
let _ = client.clear_cache(Some(batch_id)).await;
|
||||||
|
send_errors(err, entries);
|
||||||
|
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn decode(
|
||||||
|
client: &mut ShardedClient,
|
||||||
|
batches: Vec<CachedBatch>,
|
||||||
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
) -> Option<CachedBatch> {
|
||||||
|
let start_time = Instant::now();
|
||||||
|
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||||
|
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
||||||
|
|
||||||
|
match client.decode(batches).await {
|
||||||
|
Ok((generations, next_batch, timings)) => {
|
||||||
|
let start_filtering_time = Instant::now();
|
||||||
|
// Send generated tokens and filter stopped entries
|
||||||
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
|
// Filter next batch and remove requests that were stopped
|
||||||
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
|
if let Some(concat_duration) = timings.concat {
|
||||||
|
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
||||||
|
.record(concat_duration.as_secs_f64());
|
||||||
|
}
|
||||||
|
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
|
||||||
|
.record(timings.forward.as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
|
||||||
|
.record(timings.decode.as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
|
||||||
|
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
|
||||||
|
.record(start_time.elapsed().as_secs_f64());
|
||||||
|
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
|
||||||
|
next_batch
|
||||||
|
}
|
||||||
|
// If we have an error, we discard the whole batch
|
||||||
|
Err(err) => {
|
||||||
|
for id in batch_ids {
|
||||||
|
let _ = client.clear_cache(Some(id)).await;
|
||||||
|
}
|
||||||
|
send_errors(err, entries);
|
||||||
|
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter a `batch` and remove all requests not present in `entries`
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn filter_batch(
|
||||||
|
client: &mut ShardedClient,
|
||||||
|
next_batch: Option<CachedBatch>,
|
||||||
|
entries: &IntMap<u64, Entry>,
|
||||||
|
) -> Option<CachedBatch> {
|
||||||
|
let mut batch = next_batch?;
|
||||||
|
|
||||||
|
// No need to filter
|
||||||
|
if batch.size as usize == entries.len() {
|
||||||
|
return Some(batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
let id = batch.id;
|
||||||
|
|
||||||
|
// Retain only requests that are still in entries
|
||||||
|
batch.request_ids.retain(|id| entries.contains_key(id));
|
||||||
|
|
||||||
|
if batch.request_ids.is_empty() {
|
||||||
|
// All requests have been filtered out
|
||||||
|
// Next batch is now empty
|
||||||
|
// Clear it from the Python shards cache
|
||||||
|
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||||
|
client.clear_cache(Some(id)).await.unwrap();
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
// Filter Python shard cache
|
||||||
|
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||||
|
client.filter_batch(id, batch.request_ids).await.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||||
|
/// and filter entries
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
||||||
|
generations.into_iter().for_each(|generation| {
|
||||||
|
let id = generation.request_id;
|
||||||
|
// Get entry
|
||||||
|
// We can `expect` here as the request id should always be in the entries
|
||||||
|
let entry = entries
|
||||||
|
.get(&id)
|
||||||
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
|
||||||
|
// Create and enter a span to link this function back to the entry
|
||||||
|
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||||
|
// Send generation responses back to the infer task
|
||||||
|
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||||
|
// request and we need to stop generating hence why we unwrap_or(true)
|
||||||
|
let stopped = send_responses(generation, entry).map_err(|err| {
|
||||||
|
tracing::error!("Entry response channel error.");
|
||||||
|
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||||
|
err
|
||||||
|
}).unwrap_or(true);
|
||||||
|
if stopped {
|
||||||
|
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send responses through the `entry` response channel
|
||||||
|
fn send_responses(
|
||||||
|
generation: Generation,
|
||||||
|
entry: &Entry,
|
||||||
|
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||||
|
// Return directly if the channel is disconnected
|
||||||
|
if entry.response_tx.is_closed() {
|
||||||
|
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||||
|
return Ok(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut stopped = false;
|
||||||
|
|
||||||
|
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||||
|
// Create Token objects
|
||||||
|
// We do that here instead of in the Python code as Rust for loops are faster
|
||||||
|
let prefill_tokens = prefill_tokens
|
||||||
|
.ids
|
||||||
|
.into_iter()
|
||||||
|
.zip(prefill_tokens.logprobs)
|
||||||
|
.zip(prefill_tokens.texts)
|
||||||
|
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Send message
|
||||||
|
entry
|
||||||
|
.response_tx
|
||||||
|
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create last Token
|
||||||
|
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||||
|
let n = tokens_.ids.len();
|
||||||
|
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
|
||||||
|
let mut iterator = tokens_
|
||||||
|
.ids
|
||||||
|
.into_iter()
|
||||||
|
.zip(tokens_.logprobs)
|
||||||
|
.zip(tokens_.texts)
|
||||||
|
.zip(tokens_.is_special)
|
||||||
|
.enumerate()
|
||||||
|
.peekable();
|
||||||
|
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
|
||||||
|
let token = Token {
|
||||||
|
id,
|
||||||
|
text,
|
||||||
|
logprob,
|
||||||
|
special,
|
||||||
|
};
|
||||||
|
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
|
||||||
|
top_tokens_
|
||||||
|
.ids
|
||||||
|
.iter()
|
||||||
|
.zip(top_tokens_.logprobs.iter())
|
||||||
|
.zip(top_tokens_.texts.iter())
|
||||||
|
.zip(top_tokens_.is_special.iter())
|
||||||
|
.map(|(((&id, &logprob), text), &special)| Token {
|
||||||
|
id,
|
||||||
|
text: text.to_string(),
|
||||||
|
logprob,
|
||||||
|
special,
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
vec![]
|
||||||
|
};
|
||||||
|
match (&generation.generated_text, iterator.peek()) {
|
||||||
|
(Some(generated_text), None) => {
|
||||||
|
// Generation has ended
|
||||||
|
stopped = true;
|
||||||
|
// Send message
|
||||||
|
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||||
|
token,
|
||||||
|
top_tokens,
|
||||||
|
generated_text: GeneratedText::from(generated_text.clone()),
|
||||||
|
queued: entry.queue_time,
|
||||||
|
start: entry.batch_time.unwrap(),
|
||||||
|
}))?;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// Send message
|
||||||
|
entry
|
||||||
|
.response_tx
|
||||||
|
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(stopped)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send errors to Infer for all `entries`
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||||
|
entries.drain().for_each(|(_, entry)| {
|
||||||
|
// Create and enter a span to link this function back to the entry
|
||||||
|
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||||
|
let err = InferError::GenerationError(error.to_string());
|
||||||
|
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
|
||||||
|
tracing::error!("{err}");
|
||||||
|
|
||||||
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
|
entry
|
||||||
|
.response_tx
|
||||||
|
.send(Err(err))
|
||||||
|
.unwrap_or(());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<crate::client::GeneratedText> for GeneratedText {
|
||||||
|
fn from(value: crate::client::GeneratedText) -> Self {
|
||||||
|
let v3_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap();
|
||||||
|
let finish_reason = match v3_finish_reason {
|
||||||
|
crate::client::FinishReason::Length => FinishReason::Length,
|
||||||
|
crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||||
|
crate::client::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||||
|
};
|
||||||
|
|
||||||
|
Self {
|
||||||
|
text: value.text,
|
||||||
|
generated_tokens: value.generated_tokens,
|
||||||
|
finish_reason,
|
||||||
|
seed: value.seed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
284
backends/v3/src/client/grpc_client.rs
Normal file
284
backends/v3/src/client/grpc_client.rs
Normal file
@ -0,0 +1,284 @@
|
|||||||
|
/// Single shard Client
|
||||||
|
use crate::client::{pb, Chunk};
|
||||||
|
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
|
||||||
|
use base64::engine::general_purpose::STANDARD;
|
||||||
|
use base64::Engine;
|
||||||
|
use grpc_metadata::InjectTelemetryContext;
|
||||||
|
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
|
||||||
|
use pb::generate::v3::*;
|
||||||
|
use std::cmp::min;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tonic::transport::{Channel, Uri};
|
||||||
|
use tracing::instrument;
|
||||||
|
|
||||||
|
/// Text Generation Inference gRPC client
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Client {
|
||||||
|
stub: TextGenerationServiceClient<Channel>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Client {
|
||||||
|
/// Returns a client connected to the given url
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||||
|
let channel = Channel::builder(uri).connect().await?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
stub: TextGenerationServiceClient::new(channel),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a client connected to the given unix socket
|
||||||
|
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||||
|
let channel = Channel::from_shared("http://[::]:50051".to_string())
|
||||||
|
.unwrap()
|
||||||
|
.connect_with_connector(tower::service_fn(move |_: Uri| {
|
||||||
|
tokio::net::UnixStream::connect(path.clone())
|
||||||
|
}))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
stub: TextGenerationServiceClient::new(channel),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a list of uris or unix sockets of all shards
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
|
||||||
|
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
|
||||||
|
let response = self.stub.service_discovery(request).await.map_err(|_| {
|
||||||
|
ClientError::Connection("Server does not support v3 interface".to_string())
|
||||||
|
})?;
|
||||||
|
let urls = response
|
||||||
|
.into_inner()
|
||||||
|
.urls
|
||||||
|
.into_iter()
|
||||||
|
// Remove unix socket prefix
|
||||||
|
.map(|url| match url.strip_prefix("unix://") {
|
||||||
|
None => url,
|
||||||
|
Some(stripped_url) => stripped_url.to_string(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Ok(urls)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get model info
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn info(&mut self) -> Result<InfoResponse> {
|
||||||
|
let request = tonic::Request::new(InfoRequest {}).inject_context();
|
||||||
|
let response = self.stub.info(request).await?.into_inner();
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get model health
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||||
|
let request = tonic::Request::new(HealthRequest {}).inject_context();
|
||||||
|
let response = self.stub.health(request).await?.into_inner();
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clear the past generations cache
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||||
|
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
|
||||||
|
self.stub.clear_cache(request).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter a cached batch
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn filter_batch(
|
||||||
|
&mut self,
|
||||||
|
batch_id: u64,
|
||||||
|
request_ids: Vec<u64>,
|
||||||
|
) -> Result<Option<CachedBatch>> {
|
||||||
|
let request = tonic::Request::new(FilterBatchRequest {
|
||||||
|
batch_id,
|
||||||
|
request_ids,
|
||||||
|
})
|
||||||
|
.inject_context();
|
||||||
|
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
||||||
|
Ok(filtered_batch.batch)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Warmup on a max size batch
|
||||||
|
///
|
||||||
|
/// Returns the maximum amount of tokens supported by the hardware
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
pub async fn warmup(
|
||||||
|
&mut self,
|
||||||
|
max_input_length: u32,
|
||||||
|
max_prefill_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
) -> Result<Option<u32>> {
|
||||||
|
let mut n_tokens = 0;
|
||||||
|
let mut requests = Vec::new();
|
||||||
|
// Create requests
|
||||||
|
while n_tokens < max_prefill_tokens {
|
||||||
|
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||||
|
|
||||||
|
let mut input_chunks = Vec::new();
|
||||||
|
input_chunks
|
||||||
|
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
|
||||||
|
if n_tokens == 0 {
|
||||||
|
input_chunks.push(
|
||||||
|
Chunk::Image(Image {
|
||||||
|
// Safe unwrap, because we control the data.
|
||||||
|
data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(),
|
||||||
|
mimetype: "image/jpeg;base64".to_string(),
|
||||||
|
})
|
||||||
|
.into(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send stringly-typed inputs for compatibility for backends that haven't
|
||||||
|
// been updated to support chunks.
|
||||||
|
|
||||||
|
let mut inputs = String::new();
|
||||||
|
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||||
|
if n_tokens == 0 {
|
||||||
|
// 1 request is enough to test vision heads.
|
||||||
|
// Sending images on other queries messes up easily with truncation.
|
||||||
|
inputs.push_str(&format!(
|
||||||
|
"",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
requests.push(Request {
|
||||||
|
id: 0,
|
||||||
|
inputs,
|
||||||
|
input_chunks: Some(Input {
|
||||||
|
chunks: input_chunks,
|
||||||
|
}),
|
||||||
|
// We truncate the input on the server side to be sure that it has the correct size
|
||||||
|
truncate,
|
||||||
|
// Blocks and slots will be set on the server side if we use paged attention
|
||||||
|
blocks: vec![],
|
||||||
|
slots: vec![],
|
||||||
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
|
parameters: Some(NextTokenChooserParameters {
|
||||||
|
temperature: 0.9,
|
||||||
|
top_k: 10,
|
||||||
|
top_p: 0.9,
|
||||||
|
typical_p: 0.9,
|
||||||
|
do_sample: false,
|
||||||
|
seed: 0,
|
||||||
|
repetition_penalty: 1.2,
|
||||||
|
frequency_penalty: 0.1,
|
||||||
|
watermark: true,
|
||||||
|
grammar: String::new(),
|
||||||
|
grammar_type: GrammarType::None as i32,
|
||||||
|
}),
|
||||||
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
|
max_new_tokens: max_total_tokens - truncate,
|
||||||
|
stop_sequences: vec![],
|
||||||
|
ignore_eos_token: true,
|
||||||
|
}),
|
||||||
|
prefill_logprobs: true,
|
||||||
|
top_n_tokens: 20,
|
||||||
|
adapter_id: None,
|
||||||
|
});
|
||||||
|
n_tokens += max_input_length;
|
||||||
|
|
||||||
|
// Check max_batch_size
|
||||||
|
if Some(requests.len()) == max_batch_size {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let batch = Batch {
|
||||||
|
id: 0,
|
||||||
|
size: requests.len() as u32,
|
||||||
|
requests,
|
||||||
|
max_tokens: max_input_length,
|
||||||
|
max_blocks: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let request = tonic::Request::new(WarmupRequest {
|
||||||
|
batch: Some(batch),
|
||||||
|
max_input_length,
|
||||||
|
max_prefill_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
})
|
||||||
|
.inject_context();
|
||||||
|
let response = self.stub.warmup(request).await?.into_inner();
|
||||||
|
Ok(response.max_supported_total_tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given batch
|
||||||
|
///
|
||||||
|
/// Returns Generation for each request in batch
|
||||||
|
/// and the next cached batch
|
||||||
|
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
||||||
|
pub async fn prefill(
|
||||||
|
&mut self,
|
||||||
|
batch: Batch,
|
||||||
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
|
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||||
|
let response = self.stub.prefill(request).await?.into_inner();
|
||||||
|
Ok((
|
||||||
|
response.generations,
|
||||||
|
response.batch,
|
||||||
|
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given cached batches
|
||||||
|
///
|
||||||
|
/// Returns Generation for each request in batches
|
||||||
|
/// and the next cached batch
|
||||||
|
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
||||||
|
pub async fn decode(
|
||||||
|
&mut self,
|
||||||
|
batches: Vec<CachedBatch>,
|
||||||
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||||
|
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||||
|
let response = self.stub.decode(request).await?.into_inner();
|
||||||
|
Ok((
|
||||||
|
response.generations,
|
||||||
|
response.batch,
|
||||||
|
DecodeTimings::new(
|
||||||
|
response.concat_ns,
|
||||||
|
response.forward_ns,
|
||||||
|
response.decode_ns,
|
||||||
|
response.total_ns,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PrefillTimings {
|
||||||
|
pub forward: Duration,
|
||||||
|
pub decode: Duration,
|
||||||
|
pub total: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PrefillTimings {
|
||||||
|
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||||
|
Self {
|
||||||
|
forward: Duration::from_nanos(forward_ns),
|
||||||
|
decode: Duration::from_nanos(decode_ns),
|
||||||
|
total: Duration::from_nanos(total_ns),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct DecodeTimings {
|
||||||
|
pub concat: Option<Duration>,
|
||||||
|
pub forward: Duration,
|
||||||
|
pub decode: Duration,
|
||||||
|
pub total: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DecodeTimings {
|
||||||
|
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||||
|
Self {
|
||||||
|
concat: concat_ns.map(Duration::from_nanos),
|
||||||
|
forward: Duration::from_nanos(forward_ns),
|
||||||
|
decode: Duration::from_nanos(decode_ns),
|
||||||
|
total: Duration::from_nanos(total_ns),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
76
backends/v3/src/client/mod.rs
Normal file
76
backends/v3/src/client/mod.rs
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
//! Text Generation gRPC client library
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use thiserror::Error;
|
||||||
|
use tonic::transport;
|
||||||
|
use tonic::Status;
|
||||||
|
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
mod pb;
|
||||||
|
|
||||||
|
mod grpc_client;
|
||||||
|
mod sharded_client;
|
||||||
|
|
||||||
|
pub use grpc_client::Client;
|
||||||
|
pub use pb::generate::v3::{
|
||||||
|
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||||
|
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
||||||
|
StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
|
pub use sharded_client::ShardedClient;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Health {
|
||||||
|
/// Check if a generate server is healthy by asking it to allocate a tensor on device
|
||||||
|
async fn device_health(&self) -> Result<()>;
|
||||||
|
|
||||||
|
/// Check if a generate server is healthy by doing a forward pass.
|
||||||
|
/// EXPENSIVE
|
||||||
|
async fn model_health(&self) -> Result<()>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ShardInfo {
|
||||||
|
pub requires_padding: bool,
|
||||||
|
pub dtype: String,
|
||||||
|
pub device_type: String,
|
||||||
|
pub window_size: Option<u32>,
|
||||||
|
pub speculate: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Error, Debug, Clone)]
|
||||||
|
pub enum ClientError {
|
||||||
|
#[error("Could not connect to Text Generation server: {0}")]
|
||||||
|
Connection(String),
|
||||||
|
#[error("Server error: {0}")]
|
||||||
|
Generation(String),
|
||||||
|
#[error("Sharded results are empty")]
|
||||||
|
EmptyResults,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Status> for ClientError {
|
||||||
|
fn from(err: Status) -> Self {
|
||||||
|
let err = Self::Generation(err.message().to_string());
|
||||||
|
tracing::error!("{err}");
|
||||||
|
err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<transport::Error> for ClientError {
|
||||||
|
fn from(err: transport::Error) -> Self {
|
||||||
|
let err = Self::Connection(err.to_string());
|
||||||
|
tracing::error!("{err}");
|
||||||
|
err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Small convenience re-wrapping of `Chunk`.
|
||||||
|
impl From<Chunk> for InputChunk {
|
||||||
|
fn from(chunk: Chunk) -> Self {
|
||||||
|
InputChunk { chunk: Some(chunk) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
|
||||||
|
|
||||||
|
pub type Result<T> = std::result::Result<T, ClientError>;
|
260
backends/v3/src/client/sharded_client.rs
Normal file
260
backends/v3/src/client/sharded_client.rs
Normal file
@ -0,0 +1,260 @@
|
|||||||
|
use crate::client::{ClientError, Result};
|
||||||
|
/// Multi shard Client
|
||||||
|
use crate::client::{Health, ShardInfo};
|
||||||
|
|
||||||
|
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
|
||||||
|
use crate::client::{
|
||||||
|
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
||||||
|
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
|
use crate::client::{Chunk, InfoResponse, Input};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::future::join_all;
|
||||||
|
use tonic::transport::Uri;
|
||||||
|
use tracing::instrument;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
/// Text Generation Inference gRPC multi client
|
||||||
|
pub struct ShardedClient {
|
||||||
|
clients: Vec<Client>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShardedClient {
|
||||||
|
fn new(clients: Vec<Client>) -> Self {
|
||||||
|
Self { clients }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new ShardedClient from a master client. The master client will communicate with
|
||||||
|
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
|
||||||
|
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
||||||
|
// Get all uris/unix sockets from the master client
|
||||||
|
let uris = master_client.service_discovery().await?;
|
||||||
|
let futures = uris.into_iter().map(Client::connect_uds);
|
||||||
|
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
|
||||||
|
Ok(Self::new(clients?))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a client connected to the given uri
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||||
|
let master_client = Client::connect(uri).await?;
|
||||||
|
Self::from_master_client(master_client).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a client connected to the given unix socket
|
||||||
|
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||||
|
let master_client = Client::connect_uds(path).await?;
|
||||||
|
Self::from_master_client(master_client).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the model info
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn info(&mut self) -> Result<ShardInfo> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| client.info())
|
||||||
|
.collect();
|
||||||
|
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GRPC health check
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| client.health())
|
||||||
|
.collect();
|
||||||
|
join_all(futures).await.pop().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clear the past generations cache
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| client.clear_cache(batch_id))
|
||||||
|
.collect();
|
||||||
|
join_all(futures).await.into_iter().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter a cached batch
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn filter_batch(
|
||||||
|
&mut self,
|
||||||
|
batch_id: u64,
|
||||||
|
request_ids: Vec<u64>,
|
||||||
|
) -> Result<Option<CachedBatch>> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
|
||||||
|
.collect();
|
||||||
|
// all shards return the same message
|
||||||
|
join_all(futures).await.pop().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Warmup on a max size batch
|
||||||
|
///
|
||||||
|
/// Returns the maximum amount of tokens supported by the hardware
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn warmup(
|
||||||
|
&mut self,
|
||||||
|
max_input_length: u32,
|
||||||
|
max_prefill_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
) -> Result<Option<u32>> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| {
|
||||||
|
Box::pin(client.warmup(
|
||||||
|
max_input_length,
|
||||||
|
max_prefill_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
))
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
// Take the minimum value
|
||||||
|
let results = join_all(futures)
|
||||||
|
.await
|
||||||
|
.into_iter()
|
||||||
|
.collect::<Result<Vec<Option<u32>>>>()?;
|
||||||
|
Ok(results.into_iter().flatten().min())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given batch
|
||||||
|
///
|
||||||
|
/// Returns Generation for each request in batch
|
||||||
|
/// and the next cached batch
|
||||||
|
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
|
||||||
|
pub async fn prefill(
|
||||||
|
&mut self,
|
||||||
|
batch: Batch,
|
||||||
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||||
|
.collect();
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||||
|
join_all(futures).await.into_iter().collect();
|
||||||
|
let mut results = results?;
|
||||||
|
|
||||||
|
let (mut generations, next_batch, mut timings) =
|
||||||
|
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||||
|
|
||||||
|
// Merge generations from different model shards
|
||||||
|
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||||
|
generations.append(&mut shard_generations);
|
||||||
|
// Return the timings of the slowest shard
|
||||||
|
if shard_timings.total > timings.total {
|
||||||
|
timings = shard_timings;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((generations, next_batch, timings))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given cached batches
|
||||||
|
///
|
||||||
|
/// Returns Generation for each request in batches
|
||||||
|
/// and the next cached batch
|
||||||
|
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
|
||||||
|
pub async fn decode(
|
||||||
|
&mut self,
|
||||||
|
batches: Vec<CachedBatch>,
|
||||||
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||||
|
.collect();
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
|
||||||
|
join_all(futures).await.into_iter().collect();
|
||||||
|
let mut results = results?;
|
||||||
|
|
||||||
|
let (mut generations, next_batch, mut timings) =
|
||||||
|
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||||
|
|
||||||
|
// Merge generations from different model shards
|
||||||
|
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||||
|
generations.append(&mut shard_generations);
|
||||||
|
// Return the timings of the slowest shard
|
||||||
|
if shard_timings.total > timings.total {
|
||||||
|
timings = shard_timings;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((generations, next_batch, timings))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<InfoResponse> for ShardInfo {
|
||||||
|
fn from(value: InfoResponse) -> Self {
|
||||||
|
Self {
|
||||||
|
requires_padding: value.requires_padding,
|
||||||
|
dtype: value.dtype,
|
||||||
|
device_type: value.device_type,
|
||||||
|
window_size: value.window_size,
|
||||||
|
speculate: value.speculate,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Health for ShardedClient {
|
||||||
|
async fn device_health(&self) -> Result<()> {
|
||||||
|
self.clone().health().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn model_health(&self) -> Result<()> {
|
||||||
|
// Dummy batch of 1 token and 1 generated token
|
||||||
|
let liveness_request = Request {
|
||||||
|
id: u64::MAX,
|
||||||
|
inputs: "liveness".to_string(),
|
||||||
|
input_chunks: Some(Input {
|
||||||
|
chunks: vec![Chunk::Text("liveness".into()).into()],
|
||||||
|
}),
|
||||||
|
truncate: 10,
|
||||||
|
prefill_logprobs: false,
|
||||||
|
parameters: Some(NextTokenChooserParameters {
|
||||||
|
temperature: 1.0,
|
||||||
|
top_k: 0,
|
||||||
|
top_p: 1.0,
|
||||||
|
typical_p: 1.0,
|
||||||
|
do_sample: false,
|
||||||
|
seed: 0,
|
||||||
|
repetition_penalty: 1.0,
|
||||||
|
frequency_penalty: 0.0,
|
||||||
|
watermark: false,
|
||||||
|
grammar: String::new(),
|
||||||
|
grammar_type: GrammarType::None as i32,
|
||||||
|
}),
|
||||||
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
|
max_new_tokens: 1,
|
||||||
|
stop_sequences: vec![],
|
||||||
|
ignore_eos_token: false,
|
||||||
|
}),
|
||||||
|
top_n_tokens: 0,
|
||||||
|
// Block 0 is reserved for health checks
|
||||||
|
blocks: vec![0],
|
||||||
|
slots: (0..16).collect(),
|
||||||
|
adapter_id: None,
|
||||||
|
};
|
||||||
|
let batch = Batch {
|
||||||
|
id: u64::MAX,
|
||||||
|
requests: vec![liveness_request],
|
||||||
|
size: 1,
|
||||||
|
max_tokens: 2,
|
||||||
|
max_blocks: 1,
|
||||||
|
};
|
||||||
|
self.clone().prefill(batch).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
142
backends/v3/src/lib.rs
Normal file
142
backends/v3/src/lib.rs
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
mod backend;
|
||||||
|
mod block_allocator;
|
||||||
|
mod client;
|
||||||
|
mod queue;
|
||||||
|
|
||||||
|
use crate::client::{ClientError, ShardedClient};
|
||||||
|
pub(crate) use backend::BackendV3;
|
||||||
|
use serde::Serialize;
|
||||||
|
use thiserror::Error;
|
||||||
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||||
|
pub struct BackendInfo {
|
||||||
|
/// Mandatory
|
||||||
|
#[schema(example = "cuda")]
|
||||||
|
pub model_device_type: String,
|
||||||
|
#[schema(example = "torch.float16")]
|
||||||
|
pub model_dtype: String,
|
||||||
|
|
||||||
|
/// Backend parameters
|
||||||
|
#[schema(example = "1")]
|
||||||
|
pub speculate: usize,
|
||||||
|
#[schema(example = "1.2")]
|
||||||
|
pub waiting_served_ratio: f32,
|
||||||
|
#[schema(example = "32000")]
|
||||||
|
pub max_batch_total_tokens: u32,
|
||||||
|
#[schema(example = "20")]
|
||||||
|
pub max_waiting_tokens: usize,
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub max_batch_size: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub async fn connect_backend(
|
||||||
|
max_input_tokens: usize,
|
||||||
|
max_total_tokens: usize,
|
||||||
|
master_shard_uds_path: String,
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
max_batch_total_tokens: Option<u32>,
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
) -> Result<(BackendV3, BackendInfo), V3Error> {
|
||||||
|
// Helper function
|
||||||
|
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
||||||
|
match max_supported_batch_total_tokens {
|
||||||
|
// Older models do not support automatic max-batch-total-tokens
|
||||||
|
None => {
|
||||||
|
let max_batch_total_tokens = max_batch_total_tokens
|
||||||
|
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
|
||||||
|
tracing::warn!("Model does not support automatic max batch total tokens");
|
||||||
|
Ok(max_batch_total_tokens)
|
||||||
|
}
|
||||||
|
// Flash attention models return their max supported total tokens
|
||||||
|
Some(max_supported_batch_total_tokens) => {
|
||||||
|
// Warn if user added his own max-batch-total-tokens as we will ignore it
|
||||||
|
if max_batch_total_tokens.is_some() {
|
||||||
|
tracing::warn!(
|
||||||
|
"`--max-batch-total-tokens` is deprecated for Flash \
|
||||||
|
Attention models."
|
||||||
|
);
|
||||||
|
tracing::warn!(
|
||||||
|
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
||||||
|
return Err(V3Error::NotEnoughMemory(max_total_tokens));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(max_supported_batch_total_tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||||
|
.await
|
||||||
|
.map_err(V3Error::Connection)?;
|
||||||
|
|
||||||
|
// server is running on v3
|
||||||
|
// Clear the cache; useful if the webserver rebooted
|
||||||
|
sharded_client
|
||||||
|
.clear_cache(None)
|
||||||
|
.await
|
||||||
|
.map_err(V3Error::Cache)?;
|
||||||
|
// Get info from the shard
|
||||||
|
let shard_info = sharded_client.info().await.map_err(V3Error::Info)?;
|
||||||
|
|
||||||
|
// Warmup model
|
||||||
|
tracing::info!("Warming up model");
|
||||||
|
let max_batch_total_tokens = check_max_batch_total_tokens(
|
||||||
|
sharded_client
|
||||||
|
.warmup(
|
||||||
|
max_input_tokens as u32,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_total_tokens as u32,
|
||||||
|
max_batch_size,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(V3Error::Warmup)?,
|
||||||
|
)?;
|
||||||
|
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||||
|
|
||||||
|
let backend_info = BackendInfo {
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
model_device_type: shard_info.device_type.clone(),
|
||||||
|
model_dtype: shard_info.dtype.clone(),
|
||||||
|
speculate: shard_info.speculate as usize,
|
||||||
|
};
|
||||||
|
|
||||||
|
let backend = BackendV3::new(
|
||||||
|
sharded_client,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
shard_info.requires_padding,
|
||||||
|
shard_info.window_size,
|
||||||
|
shard_info.speculate,
|
||||||
|
);
|
||||||
|
|
||||||
|
tracing::info!("Using backend V3");
|
||||||
|
|
||||||
|
Ok((backend, backend_info))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum V3Error {
|
||||||
|
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||||
|
Cache(ClientError),
|
||||||
|
#[error("Unable to connect to the Python model shards: {0}")]
|
||||||
|
Connection(ClientError),
|
||||||
|
#[error("Unable to get the Python model shards info: {0}")]
|
||||||
|
Info(ClientError),
|
||||||
|
#[error("Unable to warmup the Python model shards: {0}")]
|
||||||
|
Warmup(ClientError),
|
||||||
|
#[error("Not enough memory to handle `max_total_tokens={0}`")]
|
||||||
|
NotEnoughMemory(usize),
|
||||||
|
}
|
204
backends/v3/src/main.rs
Normal file
204
backends/v3/src/main.rs
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
use clap::{Parser, Subcommand};
|
||||||
|
use text_generation_router::{server, usage_stats};
|
||||||
|
use text_generation_router_v3::{connect_backend, V3Error};
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
/// App Configuration
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[clap(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
#[command(subcommand)]
|
||||||
|
command: Option<Commands>,
|
||||||
|
|
||||||
|
#[clap(default_value = "128", long, env)]
|
||||||
|
max_concurrent_requests: usize,
|
||||||
|
#[clap(default_value = "2", long, env)]
|
||||||
|
max_best_of: usize,
|
||||||
|
#[clap(default_value = "4", long, env)]
|
||||||
|
max_stop_sequences: usize,
|
||||||
|
#[clap(default_value = "5", long, env)]
|
||||||
|
max_top_n_tokens: u32,
|
||||||
|
#[clap(default_value = "1024", long, env)]
|
||||||
|
max_input_tokens: usize,
|
||||||
|
#[clap(default_value = "2048", long, env)]
|
||||||
|
max_total_tokens: usize,
|
||||||
|
#[clap(default_value = "1.2", long, env)]
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
#[clap(default_value = "4096", long, env)]
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
#[clap(long, env)]
|
||||||
|
max_batch_total_tokens: Option<u32>,
|
||||||
|
#[clap(default_value = "20", long, env)]
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
#[clap(long, env)]
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
#[clap(default_value = "0.0.0.0", long, env)]
|
||||||
|
hostname: String,
|
||||||
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
|
port: u16,
|
||||||
|
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||||
|
master_shard_uds_path: String,
|
||||||
|
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||||
|
tokenizer_name: String,
|
||||||
|
#[clap(long, env)]
|
||||||
|
tokenizer_config_path: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
revision: Option<String>,
|
||||||
|
#[clap(default_value = "2", long, env)]
|
||||||
|
validation_workers: usize,
|
||||||
|
#[clap(long, env)]
|
||||||
|
api_key: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
json_output: bool,
|
||||||
|
#[clap(long, env)]
|
||||||
|
otlp_endpoint: Option<String>,
|
||||||
|
#[clap(default_value = "text-generation-inference.router", long, env)]
|
||||||
|
otlp_service_name: String,
|
||||||
|
#[clap(long, env)]
|
||||||
|
cors_allow_origin: Option<Vec<String>>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok: bool,
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok_authtoken: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok_edge: Option<String>,
|
||||||
|
#[clap(long, env, default_value_t = false)]
|
||||||
|
messages_api_enabled: bool,
|
||||||
|
#[clap(long, env, default_value_t = false)]
|
||||||
|
disable_grammar_support: bool,
|
||||||
|
#[clap(default_value = "4", long, env)]
|
||||||
|
max_client_batch_size: usize,
|
||||||
|
#[clap(default_value = "on", long, env)]
|
||||||
|
usage_stats: usage_stats::UsageStatsLevel,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Subcommand)]
|
||||||
|
enum Commands {
|
||||||
|
PrintSchema,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), RouterError> {
|
||||||
|
// Get args
|
||||||
|
let args = Args::parse();
|
||||||
|
// Pattern match configuration
|
||||||
|
let Args {
|
||||||
|
command,
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
hostname,
|
||||||
|
port,
|
||||||
|
master_shard_uds_path,
|
||||||
|
tokenizer_name,
|
||||||
|
tokenizer_config_path,
|
||||||
|
revision,
|
||||||
|
validation_workers,
|
||||||
|
api_key,
|
||||||
|
json_output,
|
||||||
|
otlp_endpoint,
|
||||||
|
otlp_service_name,
|
||||||
|
cors_allow_origin,
|
||||||
|
ngrok,
|
||||||
|
ngrok_authtoken,
|
||||||
|
ngrok_edge,
|
||||||
|
messages_api_enabled,
|
||||||
|
disable_grammar_support,
|
||||||
|
max_client_batch_size,
|
||||||
|
usage_stats,
|
||||||
|
} = args;
|
||||||
|
|
||||||
|
if let Some(Commands::PrintSchema) = command {
|
||||||
|
use utoipa::OpenApi;
|
||||||
|
let api_doc = text_generation_router::server::ApiDoc::openapi();
|
||||||
|
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
|
||||||
|
println!("{}", api_doc);
|
||||||
|
std::process::exit(0);
|
||||||
|
};
|
||||||
|
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||||
|
|
||||||
|
// Validate args
|
||||||
|
if max_input_tokens >= max_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||||
|
}
|
||||||
|
|
||||||
|
if validation_workers == 0 {
|
||||||
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
"`validation_workers` must be > 0".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
||||||
|
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let (backend, _backend_info) = connect_backend(
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
master_shard_uds_path,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Run server
|
||||||
|
server::run(
|
||||||
|
backend,
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
validation_workers,
|
||||||
|
api_key,
|
||||||
|
tokenizer_name,
|
||||||
|
tokenizer_config_path,
|
||||||
|
revision,
|
||||||
|
hostname,
|
||||||
|
port,
|
||||||
|
cors_allow_origin,
|
||||||
|
ngrok,
|
||||||
|
ngrok_authtoken,
|
||||||
|
ngrok_edge,
|
||||||
|
messages_api_enabled,
|
||||||
|
disable_grammar_support,
|
||||||
|
max_client_batch_size,
|
||||||
|
usage_stats,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
enum RouterError {
|
||||||
|
#[error("Argument validation error: {0}")]
|
||||||
|
ArgumentValidation(String),
|
||||||
|
#[error("Backend failed: {0}")]
|
||||||
|
Backend(#[from] V3Error),
|
||||||
|
#[error("WebServer error: {0}")]
|
||||||
|
WebServer(#[from] server::WebServerError),
|
||||||
|
#[error("Tokio runtime failed to start: {0}")]
|
||||||
|
Tokio(#[from] std::io::Error),
|
||||||
|
}
|
@ -1,17 +1,17 @@
|
|||||||
use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator};
|
use crate::block_allocator::{BlockAllocation, BlockAllocator};
|
||||||
use crate::infer::InferError;
|
use crate::client;
|
||||||
use crate::infer::InferStreamResponse;
|
use crate::client::{
|
||||||
use crate::validation::{
|
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
|
||||||
};
|
};
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use std::cmp::{max, min};
|
use std::cmp::{max, min};
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use text_generation_client::v3::{
|
use text_generation_router::infer::InferError;
|
||||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
use text_generation_router::infer::InferStreamResponse;
|
||||||
|
use text_generation_router::validation::{
|
||||||
|
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
|
||||||
|
ValidStoppingParameters,
|
||||||
};
|
};
|
||||||
use text_generation_client::ChunksToString;
|
|
||||||
use text_generation_client::Input;
|
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::{info_span, instrument, Instrument, Span};
|
use tracing::{info_span, instrument, Instrument, Span};
|
||||||
@ -337,8 +337,22 @@ impl State {
|
|||||||
batch_requests.push(Request {
|
batch_requests.push(Request {
|
||||||
id,
|
id,
|
||||||
prefill_logprobs: entry.request.decoder_input_details,
|
prefill_logprobs: entry.request.decoder_input_details,
|
||||||
input_chunks: Some(Input {
|
input_chunks: Some(client::Input {
|
||||||
chunks: entry.request.inputs.clone(),
|
chunks: entry
|
||||||
|
.request
|
||||||
|
.inputs
|
||||||
|
.clone()
|
||||||
|
.into_iter()
|
||||||
|
.map(|c| client::InputChunk {
|
||||||
|
chunk: Some(match c {
|
||||||
|
Chunk::Text(text) => client::Chunk::Text(text),
|
||||||
|
Chunk::Image(image) => client::Chunk::Image(client::Image {
|
||||||
|
data: image.data,
|
||||||
|
mimetype: image.mimetype,
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
}),
|
}),
|
||||||
inputs: entry.request.inputs.chunks_to_string(),
|
inputs: entry.request.inputs.chunks_to_string(),
|
||||||
truncate: entry.request.truncate,
|
truncate: entry.request.truncate,
|
@ -21,7 +21,7 @@ float-ord = "0.3.2"
|
|||||||
serde = {version = "1.0.188", features = ["derive"]}
|
serde = {version = "1.0.188", features = ["derive"]}
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
tabled = "0.14.0"
|
tabled = "0.14.0"
|
||||||
text-generation-client = { path = "../router/client" }
|
text-generation-client = { path = "../backends/client" }
|
||||||
thiserror = "1.0.48"
|
thiserror = "1.0.48"
|
||||||
tokenizers = { workspace = true }
|
tokenizers = { workspace = true }
|
||||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }
|
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }
|
||||||
|
@ -1580,16 +1580,11 @@
|
|||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
"model_id",
|
"model_id",
|
||||||
"model_dtype",
|
|
||||||
"model_device_type",
|
|
||||||
"max_concurrent_requests",
|
"max_concurrent_requests",
|
||||||
"max_best_of",
|
"max_best_of",
|
||||||
"max_stop_sequences",
|
"max_stop_sequences",
|
||||||
"max_input_tokens",
|
"max_input_tokens",
|
||||||
"max_total_tokens",
|
"max_total_tokens",
|
||||||
"waiting_served_ratio",
|
|
||||||
"max_batch_total_tokens",
|
|
||||||
"max_waiting_tokens",
|
|
||||||
"validation_workers",
|
"validation_workers",
|
||||||
"max_client_batch_size",
|
"max_client_batch_size",
|
||||||
"router",
|
"router",
|
||||||
@ -1601,18 +1596,6 @@
|
|||||||
"example": "null",
|
"example": "null",
|
||||||
"nullable": true
|
"nullable": true
|
||||||
},
|
},
|
||||||
"max_batch_size": {
|
|
||||||
"type": "integer",
|
|
||||||
"example": "null",
|
|
||||||
"nullable": true,
|
|
||||||
"minimum": 0
|
|
||||||
},
|
|
||||||
"max_batch_total_tokens": {
|
|
||||||
"type": "integer",
|
|
||||||
"format": "int32",
|
|
||||||
"example": "32000",
|
|
||||||
"minimum": 0
|
|
||||||
},
|
|
||||||
"max_best_of": {
|
"max_best_of": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"example": "2",
|
"example": "2",
|
||||||
@ -1644,19 +1627,6 @@
|
|||||||
"example": "2048",
|
"example": "2048",
|
||||||
"minimum": 0
|
"minimum": 0
|
||||||
},
|
},
|
||||||
"max_waiting_tokens": {
|
|
||||||
"type": "integer",
|
|
||||||
"example": "20",
|
|
||||||
"minimum": 0
|
|
||||||
},
|
|
||||||
"model_device_type": {
|
|
||||||
"type": "string",
|
|
||||||
"example": "cuda"
|
|
||||||
},
|
|
||||||
"model_dtype": {
|
|
||||||
"type": "string",
|
|
||||||
"example": "torch.float16"
|
|
||||||
},
|
|
||||||
"model_id": {
|
"model_id": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Model info",
|
"description": "Model info",
|
||||||
@ -1690,11 +1660,6 @@
|
|||||||
"version": {
|
"version": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"example": "0.5.0"
|
"example": "0.5.0"
|
||||||
},
|
|
||||||
"waiting_served_ratio": {
|
|
||||||
"type": "number",
|
|
||||||
"format": "float",
|
|
||||||
"example": "1.2"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -431,20 +431,18 @@ Options:
|
|||||||
[env: LORA_ADAPTERS=]
|
[env: LORA_ADAPTERS=]
|
||||||
|
|
||||||
```
|
```
|
||||||
## DISABLE_USAGE_STATS
|
## USAGE_STATS
|
||||||
```shell
|
```shell
|
||||||
--disable-usage-stats
|
--usage-stats <USAGE_STATS>
|
||||||
Disable sending of all usage statistics
|
Control if anonymous usage stats are collected. Options are "on", "off" and "no-stack" Defaul is on
|
||||||
|
|
||||||
[env: DISABLE_USAGE_STATS=]
|
[env: USAGE_STATS=]
|
||||||
|
[default: on]
|
||||||
|
|
||||||
```
|
Possible values:
|
||||||
## DISABLE_CRASH_REPORTS
|
- on: Default option, usage statistics are collected anonymously
|
||||||
```shell
|
- off: Disables all collection of usage statistics
|
||||||
--disable-crash-reports
|
- no-stack: Doesn't send the error stack trace or error type, but allows sending a crash event
|
||||||
Disable sending of crash reports, but allow anonymous usage statistics
|
|
||||||
|
|
||||||
[env: DISABLE_CRASH_REPORTS=]
|
|
||||||
|
|
||||||
```
|
```
|
||||||
## HELP
|
## HELP
|
||||||
|
@ -36,6 +36,18 @@ To use LoRA in TGI, when starting the server, you can specify the list of LoRA m
|
|||||||
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
|
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
|
||||||
```
|
```
|
||||||
|
|
||||||
|
additionally, you can specify the path to the LoRA models using the `LORA_ADAPTERS_PATH` environment variable. For example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LORA_ADAPTERS=myadapter=/some/path/to/adapter,myadapter2=/another/path/to/adapter
|
||||||
|
```
|
||||||
|
|
||||||
|
note it's possible to mix adapter_ids with adapter_id=adapter_path e.g.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LORA_ADAPTERS=predibase/dbpedia,myadapter=/path/to/dir/
|
||||||
|
```
|
||||||
|
|
||||||
In the server logs, you will see the following message:
|
In the server logs, you will see the following message:
|
||||||
|
|
||||||
```txt
|
```txt
|
||||||
|
@ -70,4 +70,6 @@ As of release 2.1.2 this is an example of the data collected:
|
|||||||
|
|
||||||
## How to opt-out
|
## How to opt-out
|
||||||
|
|
||||||
You can easily opt out by passing the `--disable-usage-stats` to the text-generation-launcher command. This will disable all usage statistics. You can also pass `--disable-crash-reports` which disables sending specific crash reports, but allows anonymous usage statistics.
|
By passing the `--usage-stats` to the text-generation-launcher you can control how much usage statistics are being collected.
|
||||||
|
`--usage-stats=no-stack` will not emit the stack traces from errors and the error types, but will continue to send start and stop events
|
||||||
|
`--usage-stats=off` will completely disable everything
|
||||||
|
@ -168,6 +168,33 @@ impl std::fmt::Display for RopeScaling {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
pub enum UsageStatsLevel {
|
||||||
|
/// Default option, usage statistics are collected anonymously
|
||||||
|
On,
|
||||||
|
/// Disables all collection of usage statistics
|
||||||
|
Off,
|
||||||
|
/// Doesn't send the error stack trace or error type, but allows sending a crash event
|
||||||
|
NoStack,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for UsageStatsLevel {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
// To keep in track with `server`.
|
||||||
|
match self {
|
||||||
|
UsageStatsLevel::On => {
|
||||||
|
write!(f, "on")
|
||||||
|
}
|
||||||
|
UsageStatsLevel::Off => {
|
||||||
|
write!(f, "off")
|
||||||
|
}
|
||||||
|
UsageStatsLevel::NoStack => {
|
||||||
|
write!(f, "no-stack")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[clap(author, version, about, long_about = None)]
|
#[clap(author, version, about, long_about = None)]
|
||||||
@ -466,13 +493,11 @@ struct Args {
|
|||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
lora_adapters: Option<String>,
|
lora_adapters: Option<String>,
|
||||||
|
|
||||||
/// Disable sending of all usage statistics
|
/// Control if anonymous usage stats are collected.
|
||||||
#[clap(default_value = "false", long, env)]
|
/// Options are "on", "off" and "no-stack"
|
||||||
disable_usage_stats: bool,
|
/// Defaul is on.
|
||||||
|
#[clap(default_value = "on", long, env)]
|
||||||
/// Disable sending of crash reports, but allow anonymous usage statistics
|
usage_stats: UsageStatsLevel,
|
||||||
#[clap(default_value = "false", long, env)]
|
|
||||||
disable_crash_reports: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -1218,12 +1243,8 @@ fn spawn_webserver(
|
|||||||
];
|
];
|
||||||
|
|
||||||
// Pass usage stats flags to router
|
// Pass usage stats flags to router
|
||||||
if args.disable_usage_stats {
|
router_args.push("--usage-stats".to_string());
|
||||||
router_args.push("--disable-usage-stats".to_string());
|
router_args.push(args.usage_stats.to_string());
|
||||||
}
|
|
||||||
if args.disable_crash_reports {
|
|
||||||
router_args.push("--disable-crash-reports".to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Grammar support
|
// Grammar support
|
||||||
if args.disable_grammar_support {
|
if args.disable_grammar_support {
|
||||||
|
@ -7,25 +7,18 @@ edition.workspace = true
|
|||||||
authors.workspace = true
|
authors.workspace = true
|
||||||
homepage.workspace = true
|
homepage.workspace = true
|
||||||
|
|
||||||
[lib]
|
|
||||||
path = "src/lib.rs"
|
|
||||||
|
|
||||||
[[bin]]
|
|
||||||
name = "text-generation-router"
|
|
||||||
path = "src/main.rs"
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
async-trait = "0.1.74"
|
||||||
async-stream = "0.3.5"
|
async-stream = "0.3.5"
|
||||||
axum = { version = "0.7", features = ["json"] }
|
axum = { version = "0.7", features = ["json"] }
|
||||||
axum-tracing-opentelemetry = "0.16"
|
axum-tracing-opentelemetry = "0.16"
|
||||||
text-generation-client = { path = "client" }
|
|
||||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||||
futures = "0.3.28"
|
futures = "0.3.28"
|
||||||
hf-hub = { workspace = true }
|
hf-hub = { workspace = true }
|
||||||
itertools = "0.10"
|
itertools = "0.10"
|
||||||
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
||||||
metrics = "0.23.0"
|
metrics = { workspace = true }
|
||||||
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
metrics-exporter-prometheus = { workspace = true }
|
||||||
nohash-hasher = "0.2.0"
|
nohash-hasher = "0.2.0"
|
||||||
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
||||||
opentelemetry-otlp = "0.13.0"
|
opentelemetry-otlp = "0.13.0"
|
||||||
@ -55,6 +48,7 @@ base64 = { workspace = true }
|
|||||||
sysinfo = "0.30.13"
|
sysinfo = "0.30.13"
|
||||||
uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] }
|
uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] }
|
||||||
csv = "1.3.0"
|
csv = "1.3.0"
|
||||||
|
ureq = "=2.9"
|
||||||
|
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
1
router/client/src/v2/pb/.gitignore
vendored
1
router/client/src/v2/pb/.gitignore
vendored
@ -1 +0,0 @@
|
|||||||
*
|
|
1
router/client/src/v3/pb/.gitignore
vendored
1
router/client/src/v3/pb/.gitignore
vendored
@ -1 +0,0 @@
|
|||||||
*
|
|
@ -1,528 +1,85 @@
|
|||||||
/// Batching and inference logic
|
use crate::infer::InferError;
|
||||||
use crate::infer::v3::queue::{Entry, Queue};
|
use crate::{
|
||||||
use crate::infer::{
|
ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken,
|
||||||
GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler,
|
|
||||||
};
|
};
|
||||||
use crate::validation::ValidGenerateRequest;
|
use minijinja::{Environment, ErrorKind, Template};
|
||||||
use crate::{FinishReason, PrefillToken, Token};
|
use minijinja_contrib::pycompat;
|
||||||
use nohash_hasher::IntMap;
|
|
||||||
use std::sync::{
|
|
||||||
atomic::{AtomicBool, Ordering},
|
|
||||||
Arc,
|
|
||||||
};
|
|
||||||
use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient};
|
|
||||||
use text_generation_client::ClientError;
|
|
||||||
use tokio::sync::mpsc::error::SendError;
|
|
||||||
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit};
|
|
||||||
use tokio::time::Instant;
|
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
|
||||||
use tracing::{info_span, instrument, Instrument, Span};
|
|
||||||
|
|
||||||
pub(crate) struct SchedulerV3 {
|
/// Raise a exception (custom function) used in the chat templates
|
||||||
/// Request queue
|
pub(crate) fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
|
||||||
queue: Queue,
|
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
|
||||||
/// Notify batcher on queue appends
|
|
||||||
batching_task_notifier: Arc<Notify>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SchedulerV3 {
|
#[derive(Clone)]
|
||||||
#[allow(clippy::too_many_arguments)]
|
pub(crate) struct ChatTemplate {
|
||||||
|
template: Template<'static, 'static>,
|
||||||
|
bos_token: Option<String>,
|
||||||
|
eos_token: Option<String>,
|
||||||
|
use_default_tool_template: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatTemplate {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
client: ShardedClient,
|
template: String,
|
||||||
waiting_served_ratio: f32,
|
bos_token: Option<TokenizerConfigToken>,
|
||||||
max_batch_prefill_tokens: u32,
|
eos_token: Option<TokenizerConfigToken>,
|
||||||
max_batch_total_tokens: u32,
|
|
||||||
max_waiting_tokens: usize,
|
|
||||||
max_batch_size: Option<usize>,
|
|
||||||
requires_padding: bool,
|
|
||||||
window_size: Option<u32>,
|
|
||||||
speculate: u32,
|
|
||||||
generation_health: Arc<AtomicBool>,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
|
let mut env = Box::new(Environment::new());
|
||||||
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
|
// enable things like .strip() or .capitalize()
|
||||||
} else {
|
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||||
false
|
let template_str = template.into_boxed_str();
|
||||||
};
|
env.add_function("raise_exception", raise_exception);
|
||||||
let block_size = if flashdecoding { 256 } else { 16 };
|
|
||||||
let queue = Queue::new(
|
|
||||||
requires_padding,
|
|
||||||
block_size,
|
|
||||||
window_size,
|
|
||||||
speculate,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
);
|
|
||||||
let batching_task_notifier = Arc::new(Notify::new());
|
|
||||||
|
|
||||||
// Spawn batching background task that contains all the inference logic
|
// check if contains the tools variable within the template
|
||||||
tokio::spawn(batching_task(
|
let use_default_tool_template =
|
||||||
client,
|
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
|
||||||
waiting_served_ratio,
|
// leaking env and template_str as read-only, static resources for performance.
|
||||||
max_batch_prefill_tokens,
|
let template = Box::leak(env)
|
||||||
max_batch_total_tokens,
|
.template_from_str(Box::leak(template_str))
|
||||||
max_waiting_tokens,
|
.unwrap();
|
||||||
max_batch_size,
|
|
||||||
queue.clone(),
|
|
||||||
batching_task_notifier.clone(),
|
|
||||||
generation_health,
|
|
||||||
));
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
queue,
|
template,
|
||||||
batching_task_notifier,
|
bos_token: bos_token.map(|token| token.as_str().to_string()),
|
||||||
|
eos_token: eos_token.map(|token| token.as_str().to_string()),
|
||||||
|
use_default_tool_template,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl Scheduler for SchedulerV3 {
|
pub(crate) fn apply(
|
||||||
#[instrument(skip_all)]
|
|
||||||
fn schedule(
|
|
||||||
&self,
|
&self,
|
||||||
request: ValidGenerateRequest,
|
mut messages: Vec<Message>,
|
||||||
permit: OwnedSemaphorePermit,
|
grammar_with_prompt: Option<(GrammarType, String)>,
|
||||||
) -> Result<GenerateStreamResponse, InferError> {
|
) -> Result<String, InferError> {
|
||||||
// MPSC channel to communicate with the background batching task
|
if self.use_default_tool_template {
|
||||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
if let Some(last_message) = messages.last_mut() {
|
||||||
let input_length = request.input_length;
|
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
||||||
|
last_message.content.push(MessageChunk::Text {
|
||||||
// Append the request to the queue
|
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
||||||
self.queue.append(Entry {
|
|
||||||
request,
|
|
||||||
response_tx,
|
|
||||||
span: Span::current(),
|
|
||||||
temp_span: None,
|
|
||||||
queue_time: Instant::now(),
|
|
||||||
batch_time: None,
|
|
||||||
block_allocation: None,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Notify the background task that we have a new entry in the queue that needs
|
|
||||||
// to be batched
|
|
||||||
self.batching_task_notifier.notify_one();
|
|
||||||
|
|
||||||
// Return stream
|
|
||||||
Ok((
|
|
||||||
permit,
|
|
||||||
input_length,
|
|
||||||
UnboundedReceiverStream::new(response_rx),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Batching logic
|
|
||||||
/// Will be launched in a background Tokio task
|
|
||||||
///
|
|
||||||
/// Batches requests and sends them to the inference server
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub(crate) async fn batching_task(
|
|
||||||
mut client: ShardedClient,
|
|
||||||
waiting_served_ratio: f32,
|
|
||||||
max_batch_prefill_tokens: u32,
|
|
||||||
max_batch_total_tokens: u32,
|
|
||||||
max_waiting_tokens: usize,
|
|
||||||
max_batch_size: Option<usize>,
|
|
||||||
queue: Queue,
|
|
||||||
notifier: Arc<Notify>,
|
|
||||||
generation_health: Arc<AtomicBool>,
|
|
||||||
) {
|
|
||||||
// Infinite loop
|
|
||||||
loop {
|
|
||||||
// Wait for a notification from the Infer struct
|
|
||||||
notifier.notified().await;
|
|
||||||
|
|
||||||
// Get the next batch from the queue
|
|
||||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
|
||||||
// waiting in the queue
|
|
||||||
while let Some((mut entries, batch, span)) = queue
|
|
||||||
.next_batch(
|
|
||||||
None,
|
|
||||||
max_batch_size,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
|
|
||||||
.instrument(span)
|
|
||||||
.await;
|
|
||||||
let mut waiting_tokens = 1;
|
|
||||||
|
|
||||||
// We loop until we do not receive any cached batch from the inference server (== until
|
|
||||||
// all requests have met their stopping criteria)
|
|
||||||
while let Some(batch) = cached_batch {
|
|
||||||
// Get current batch info
|
|
||||||
let batch_size = batch.size;
|
|
||||||
let batch_max_tokens = batch.max_tokens;
|
|
||||||
let mut batches = vec![batch];
|
|
||||||
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
|
||||||
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
|
||||||
|
|
||||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
|
||||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
|
||||||
// to add a new batch even though its size might be small
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
// Minimum batch size
|
|
||||||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
|
||||||
};
|
|
||||||
|
|
||||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
|
||||||
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
|
|
||||||
|
|
||||||
// Try to get a new batch
|
|
||||||
if let Some((mut new_entries, new_batch, span)) = queue
|
|
||||||
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
// Tracking metrics
|
|
||||||
if min_size.is_some() {
|
|
||||||
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
|
||||||
.increment(1);
|
|
||||||
} else {
|
|
||||||
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
|
||||||
.increment(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
entries.iter_mut().for_each(|(_, entry)| {
|
|
||||||
// Create a new span to add the info that this entry is waiting
|
|
||||||
// because a new batch is being computed
|
|
||||||
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
|
|
||||||
// Add relationships
|
|
||||||
span.follows_from(&entry_waiting_span);
|
|
||||||
entry_waiting_span.follows_from(&span);
|
|
||||||
// Update entry
|
|
||||||
entry.temp_span = Some(entry_waiting_span);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Generate one token for this new batch to have the attention past in cache
|
|
||||||
let new_cached_batch =
|
|
||||||
prefill(&mut client, new_batch, &mut new_entries, &generation_health)
|
|
||||||
.instrument(span)
|
|
||||||
.await;
|
|
||||||
// Reset waiting counter
|
|
||||||
waiting_tokens = 1;
|
|
||||||
// Extend current batch with the new batch
|
|
||||||
if let Some(new_cached_batch) = new_cached_batch {
|
|
||||||
entries.extend(new_entries);
|
|
||||||
batches.push(new_cached_batch);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create span for this batch to add context to inference calls
|
|
||||||
let next_batch_size = entries.len();
|
|
||||||
let next_batch_span =
|
|
||||||
info_span!(parent: None, "batch", batch_size = next_batch_size);
|
|
||||||
entries.iter_mut().for_each(|(_, entry)| {
|
|
||||||
// Create a new span to link the batch back to this entry
|
|
||||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
|
||||||
// Add relationships
|
|
||||||
next_batch_span.follows_from(&entry_batch_span);
|
|
||||||
entry_batch_span.follows_from(&next_batch_span);
|
|
||||||
// Update entry
|
|
||||||
entry.temp_span = Some(entry_batch_span);
|
|
||||||
});
|
|
||||||
|
|
||||||
cached_batch = decode(&mut client, batches, &mut entries, &generation_health)
|
|
||||||
.instrument(next_batch_span)
|
|
||||||
.await;
|
|
||||||
waiting_tokens += 1;
|
|
||||||
}
|
|
||||||
metrics::gauge!("tgi_batch_current_size").set(0.0);
|
|
||||||
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn prefill(
|
|
||||||
client: &mut ShardedClient,
|
|
||||||
batch: Batch,
|
|
||||||
entries: &mut IntMap<u64, Entry>,
|
|
||||||
generation_health: &Arc<AtomicBool>,
|
|
||||||
) -> Option<CachedBatch> {
|
|
||||||
let start_time = Instant::now();
|
|
||||||
let batch_id = batch.id;
|
|
||||||
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
|
||||||
|
|
||||||
match client.prefill(batch).await {
|
|
||||||
Ok((generations, next_batch, timings)) => {
|
|
||||||
// Update health
|
|
||||||
generation_health.store(true, Ordering::SeqCst);
|
|
||||||
|
|
||||||
let start_filtering_time = Instant::now();
|
|
||||||
// Send generated tokens and filter stopped entries
|
|
||||||
filter_send_generations(generations, entries);
|
|
||||||
|
|
||||||
// Filter next batch and remove requests that were stopped
|
|
||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
|
||||||
|
|
||||||
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
|
|
||||||
.record(timings.forward.as_secs_f64());
|
|
||||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
|
||||||
.record(timings.decode.as_secs_f64());
|
|
||||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
|
|
||||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
|
||||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
|
|
||||||
.record(start_time.elapsed().as_secs_f64());
|
|
||||||
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
|
|
||||||
next_batch
|
|
||||||
}
|
|
||||||
// If we have an error, we discard the whole batch
|
|
||||||
Err(err) => {
|
|
||||||
// Update health
|
|
||||||
generation_health.store(false, Ordering::SeqCst);
|
|
||||||
let _ = client.clear_cache(Some(batch_id)).await;
|
|
||||||
send_errors(err, entries);
|
|
||||||
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn decode(
|
|
||||||
client: &mut ShardedClient,
|
|
||||||
batches: Vec<CachedBatch>,
|
|
||||||
entries: &mut IntMap<u64, Entry>,
|
|
||||||
generation_health: &Arc<AtomicBool>,
|
|
||||||
) -> Option<CachedBatch> {
|
|
||||||
let start_time = Instant::now();
|
|
||||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
|
||||||
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
|
||||||
|
|
||||||
match client.decode(batches).await {
|
|
||||||
Ok((generations, next_batch, timings)) => {
|
|
||||||
// Update health
|
|
||||||
generation_health.store(true, Ordering::SeqCst);
|
|
||||||
|
|
||||||
let start_filtering_time = Instant::now();
|
|
||||||
// Send generated tokens and filter stopped entries
|
|
||||||
filter_send_generations(generations, entries);
|
|
||||||
|
|
||||||
// Filter next batch and remove requests that were stopped
|
|
||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
|
||||||
|
|
||||||
if let Some(concat_duration) = timings.concat {
|
|
||||||
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
|
||||||
.record(concat_duration.as_secs_f64());
|
|
||||||
}
|
|
||||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
|
|
||||||
.record(timings.forward.as_secs_f64());
|
|
||||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
|
|
||||||
.record(timings.decode.as_secs_f64());
|
|
||||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
|
|
||||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
|
||||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
|
|
||||||
.record(start_time.elapsed().as_secs_f64());
|
|
||||||
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
|
|
||||||
next_batch
|
|
||||||
}
|
|
||||||
// If we have an error, we discard the whole batch
|
|
||||||
Err(err) => {
|
|
||||||
generation_health.store(false, Ordering::SeqCst);
|
|
||||||
for id in batch_ids {
|
|
||||||
let _ = client.clear_cache(Some(id)).await;
|
|
||||||
}
|
|
||||||
send_errors(err, entries);
|
|
||||||
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Filter a `batch` and remove all requests not present in `entries`
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
async fn filter_batch(
|
|
||||||
client: &mut ShardedClient,
|
|
||||||
next_batch: Option<CachedBatch>,
|
|
||||||
entries: &IntMap<u64, Entry>,
|
|
||||||
) -> Option<CachedBatch> {
|
|
||||||
let mut batch = next_batch?;
|
|
||||||
|
|
||||||
// No need to filter
|
|
||||||
if batch.size as usize == entries.len() {
|
|
||||||
return Some(batch);
|
|
||||||
}
|
|
||||||
|
|
||||||
let id = batch.id;
|
|
||||||
|
|
||||||
// Retain only requests that are still in entries
|
|
||||||
batch.request_ids.retain(|id| entries.contains_key(id));
|
|
||||||
|
|
||||||
if batch.request_ids.is_empty() {
|
|
||||||
// All requests have been filtered out
|
|
||||||
// Next batch is now empty
|
|
||||||
// Clear it from the Python shards cache
|
|
||||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
|
||||||
client.clear_cache(Some(id)).await.unwrap();
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
// Filter Python shard cache
|
|
||||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
|
||||||
client.filter_batch(id, batch.request_ids).await.unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
|
||||||
/// and filter entries
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
|
||||||
generations.into_iter().for_each(|generation| {
|
|
||||||
let id = generation.request_id;
|
|
||||||
// Get entry
|
|
||||||
// We can `expect` here as the request id should always be in the entries
|
|
||||||
let entry = entries
|
|
||||||
.get(&id)
|
|
||||||
.expect("ID not found in entries. This is a bug.");
|
|
||||||
|
|
||||||
// Create and enter a span to link this function back to the entry
|
|
||||||
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
|
||||||
// Send generation responses back to the infer task
|
|
||||||
// If the receive an error from the Flume channel, it means that the client dropped the
|
|
||||||
// request and we need to stop generating hence why we unwrap_or(true)
|
|
||||||
let stopped = send_responses(generation, entry).map_err(|err| {
|
|
||||||
tracing::error!("Entry response channel error.");
|
|
||||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
|
||||||
err
|
|
||||||
}).unwrap_or(true);
|
|
||||||
if stopped {
|
|
||||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Send responses through the `entry` response channel
|
|
||||||
fn send_responses(
|
|
||||||
generation: Generation,
|
|
||||||
entry: &Entry,
|
|
||||||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
|
||||||
// Return directly if the channel is disconnected
|
|
||||||
if entry.response_tx.is_closed() {
|
|
||||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
|
||||||
return Ok(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut stopped = false;
|
|
||||||
|
|
||||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
|
||||||
// Create Token objects
|
|
||||||
// We do that here instead of in the Python code as Rust for loops are faster
|
|
||||||
let prefill_tokens = prefill_tokens
|
|
||||||
.ids
|
|
||||||
.into_iter()
|
|
||||||
.zip(prefill_tokens.logprobs)
|
|
||||||
.zip(prefill_tokens.texts)
|
|
||||||
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
// Send message
|
|
||||||
entry
|
|
||||||
.response_tx
|
|
||||||
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create last Token
|
|
||||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
|
||||||
let n = tokens_.ids.len();
|
|
||||||
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
|
|
||||||
let mut iterator = tokens_
|
|
||||||
.ids
|
|
||||||
.into_iter()
|
|
||||||
.zip(tokens_.logprobs)
|
|
||||||
.zip(tokens_.texts)
|
|
||||||
.zip(tokens_.is_special)
|
|
||||||
.enumerate()
|
|
||||||
.peekable();
|
|
||||||
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
|
|
||||||
let token = Token {
|
|
||||||
id,
|
|
||||||
text,
|
|
||||||
logprob,
|
|
||||||
special,
|
|
||||||
};
|
|
||||||
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
|
|
||||||
top_tokens_
|
|
||||||
.ids
|
|
||||||
.iter()
|
|
||||||
.zip(top_tokens_.logprobs.iter())
|
|
||||||
.zip(top_tokens_.texts.iter())
|
|
||||||
.zip(top_tokens_.is_special.iter())
|
|
||||||
.map(|(((&id, &logprob), text), &special)| Token {
|
|
||||||
id,
|
|
||||||
text: text.to_string(),
|
|
||||||
logprob,
|
|
||||||
special,
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
} else {
|
|
||||||
vec![]
|
|
||||||
};
|
|
||||||
match (&generation.generated_text, iterator.peek()) {
|
|
||||||
(Some(generated_text), None) => {
|
|
||||||
// Generation has ended
|
|
||||||
stopped = true;
|
|
||||||
// Send message
|
|
||||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
|
||||||
token,
|
|
||||||
top_tokens,
|
|
||||||
generated_text: GeneratedText::from(generated_text.clone()),
|
|
||||||
queued: entry.queue_time,
|
|
||||||
start: entry.batch_time.unwrap(),
|
|
||||||
}))?;
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
// Send message
|
|
||||||
entry
|
|
||||||
.response_tx
|
|
||||||
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
Ok(stopped)
|
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||||
}
|
|
||||||
|
|
||||||
/// Send errors to Infer for all `entries`
|
self.template
|
||||||
#[instrument(skip_all)]
|
.render(ChatTemplateInputs {
|
||||||
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
messages,
|
||||||
entries.drain().for_each(|(_, entry)| {
|
bos_token: self.bos_token.as_deref(),
|
||||||
// Create and enter a span to link this function back to the entry
|
eos_token: self.eos_token.as_deref(),
|
||||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
add_generation_prompt: true,
|
||||||
let err = InferError::GenerationError(error.to_string());
|
tools: None,
|
||||||
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
|
tools_prompt: None,
|
||||||
tracing::error!("{err}");
|
})
|
||||||
|
.map_err(InferError::TemplateError)
|
||||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
|
||||||
entry
|
|
||||||
.response_tx
|
|
||||||
.send(Err(err))
|
|
||||||
.unwrap_or(());
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<text_generation_client::v3::GeneratedText> for GeneratedText {
|
|
||||||
fn from(value: text_generation_client::v3::GeneratedText) -> Self {
|
|
||||||
let v3_finish_reason =
|
|
||||||
text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap();
|
|
||||||
let finish_reason = match v3_finish_reason {
|
|
||||||
text_generation_client::v3::FinishReason::Length => FinishReason::Length,
|
|
||||||
text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
|
||||||
text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence,
|
|
||||||
};
|
|
||||||
|
|
||||||
Self {
|
|
||||||
text: value.text,
|
|
||||||
generated_tokens: value.generated_tokens,
|
|
||||||
finish_reason,
|
|
||||||
seed: value.seed,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// tests
|
// tests
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::infer::raise_exception;
|
use crate::infer::chat_template::raise_exception;
|
||||||
use crate::{ChatTemplateInputs, TextMessage};
|
use crate::{ChatTemplateInputs, TextMessage};
|
||||||
use minijinja::Environment;
|
use minijinja::Environment;
|
||||||
|
|
@ -1,34 +0,0 @@
|
|||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
|
||||||
use std::sync::Arc;
|
|
||||||
use text_generation_client::Health;
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub(crate) struct HealthCheck {
|
|
||||||
client: Arc<dyn Health + Send + Sync>,
|
|
||||||
generation_health: Arc<AtomicBool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl HealthCheck {
|
|
||||||
pub(crate) fn new(
|
|
||||||
client: Arc<dyn Health + Send + Sync>,
|
|
||||||
generation_health: Arc<AtomicBool>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
client,
|
|
||||||
generation_health,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) async fn check(&mut self) -> bool {
|
|
||||||
let value = if self.generation_health.load(Ordering::SeqCst) {
|
|
||||||
// Generation is healthy, we only check that the shards can allocate on device
|
|
||||||
self.client.device_health().await
|
|
||||||
} else {
|
|
||||||
self.client.model_health().await
|
|
||||||
}
|
|
||||||
.is_ok();
|
|
||||||
// Update generation health
|
|
||||||
self.generation_health.store(value, Ordering::SeqCst);
|
|
||||||
value
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,23 +1,18 @@
|
|||||||
mod health;
|
// pub(crate) mod v2;
|
||||||
pub(crate) mod v2;
|
mod chat_template;
|
||||||
pub(crate) mod v3;
|
pub mod tool_grammar;
|
||||||
|
|
||||||
pub(crate) use health::HealthCheck;
|
|
||||||
|
|
||||||
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
||||||
|
use crate::GrammarType;
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
|
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
||||||
HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, ToolChoice,
|
Message, PrefillToken, Token,
|
||||||
};
|
|
||||||
use crate::{
|
|
||||||
FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools,
|
|
||||||
};
|
};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use chat_template::ChatTemplate;
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
use minijinja::{Environment, ErrorKind, Template};
|
use minijinja::ErrorKind;
|
||||||
use minijinja_contrib::pycompat;
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
|
||||||
use serde_json::{json, Map, Value};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||||
@ -26,12 +21,14 @@ use tokio_stream::wrappers::UnboundedReceiverStream;
|
|||||||
use tokio_stream::StreamExt;
|
use tokio_stream::StreamExt;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
pub(crate) trait Scheduler {
|
#[async_trait]
|
||||||
|
pub trait Backend {
|
||||||
fn schedule(
|
fn schedule(
|
||||||
&self,
|
&self,
|
||||||
request: ValidGenerateRequest,
|
request: ValidGenerateRequest,
|
||||||
permit: OwnedSemaphorePermit,
|
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError>;
|
||||||
) -> Result<GenerateStreamResponse, InferError>;
|
|
||||||
|
async fn health(&self, current_health: bool) -> bool;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Inference struct
|
/// Inference struct
|
||||||
@ -39,18 +36,20 @@ pub(crate) trait Scheduler {
|
|||||||
pub struct Infer {
|
pub struct Infer {
|
||||||
/// Validation
|
/// Validation
|
||||||
validation: Validation,
|
validation: Validation,
|
||||||
/// Request scheduler
|
/// Request backend
|
||||||
scheduler: Arc<dyn Scheduler + Send + Sync>,
|
backend: Arc<dyn Backend + Send + Sync>,
|
||||||
/// Chat template
|
/// Chat template
|
||||||
chat_template: Option<ChatTemplate>,
|
chat_template: Option<ChatTemplate>,
|
||||||
/// Inference limit
|
/// Inference limit
|
||||||
limit_concurrent_requests: Arc<Semaphore>,
|
limit_concurrent_requests: Arc<Semaphore>,
|
||||||
|
/// Backend health
|
||||||
|
backend_health: Arc<AtomicBool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Infer {
|
impl Infer {
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
scheduler: Arc<dyn Scheduler + Send + Sync>,
|
backend: impl Backend + Send + Sync + 'static,
|
||||||
validation: Validation,
|
validation: Validation,
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
tokenizer_config: HubTokenizerConfig,
|
tokenizer_config: HubTokenizerConfig,
|
||||||
@ -71,18 +70,22 @@ impl Infer {
|
|||||||
// Inference limit with a semaphore
|
// Inference limit with a semaphore
|
||||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||||
|
|
||||||
|
// Backend health
|
||||||
|
let backend_health = Arc::new(AtomicBool::new(false));
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
validation,
|
validation,
|
||||||
scheduler,
|
backend: Arc::new(backend),
|
||||||
chat_template,
|
chat_template,
|
||||||
limit_concurrent_requests: semaphore,
|
limit_concurrent_requests: semaphore,
|
||||||
|
backend_health,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a new request to the queue and return a stream of InferStreamResponse
|
/// Add a new request to the queue and return a stream of InferStreamResponse
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub(crate) async fn generate_stream(
|
pub(crate) async fn generate_stream<'a>(
|
||||||
&self,
|
&'a self,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> Result<GenerateStreamResponse, InferError> {
|
) -> Result<GenerateStreamResponse, InferError> {
|
||||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||||
@ -103,7 +106,10 @@ impl Infer {
|
|||||||
err
|
err
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
self.scheduler.schedule(valid_request, permit)
|
let input_length = valid_request.input_length;
|
||||||
|
let generation_stream = self.backend.schedule(valid_request)?;
|
||||||
|
|
||||||
|
Ok((permit, input_length, generation_stream))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Tokenizer the input
|
/// Tokenizer the input
|
||||||
@ -155,7 +161,7 @@ impl Infer {
|
|||||||
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
||||||
|
|
||||||
// Create stream and keep semaphore permit as long as generate lives
|
// Create stream and keep semaphore permit as long as generate lives
|
||||||
let (_permit, _input_length, mut stream) = self.generate_stream(request).await?;
|
let (_permit, _input_length, stream) = self.generate_stream(request).await?;
|
||||||
|
|
||||||
// Return values
|
// Return values
|
||||||
let mut result_prefill = Vec::new();
|
let mut result_prefill = Vec::new();
|
||||||
@ -165,6 +171,8 @@ impl Infer {
|
|||||||
let mut result_start = None;
|
let mut result_start = None;
|
||||||
let mut result_queued = None;
|
let mut result_queued = None;
|
||||||
|
|
||||||
|
let mut stream = Box::pin(stream);
|
||||||
|
|
||||||
// Iterate on stream
|
// Iterate on stream
|
||||||
while let Some(response) = stream.next().await {
|
while let Some(response) = stream.next().await {
|
||||||
match response? {
|
match response? {
|
||||||
@ -256,207 +264,15 @@ impl Infer {
|
|||||||
let best_response = infer_responses.remove(max_index);
|
let best_response = infer_responses.remove(max_index);
|
||||||
Ok((best_response, infer_responses))
|
Ok((best_response, infer_responses))
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/// Raise a exception (custom function) used in the chat templates
|
#[instrument(skip(self))]
|
||||||
fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
|
pub(crate) async fn health(&self) -> bool {
|
||||||
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
|
let health = self
|
||||||
}
|
.backend
|
||||||
|
.health(self.backend_health.load(Ordering::SeqCst))
|
||||||
#[derive(Clone)]
|
.await;
|
||||||
struct ChatTemplate {
|
self.backend_health.store(health, Ordering::SeqCst);
|
||||||
template: Template<'static, 'static>,
|
health
|
||||||
bos_token: Option<String>,
|
|
||||||
eos_token: Option<String>,
|
|
||||||
use_default_tool_template: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ChatTemplate {
|
|
||||||
fn new(
|
|
||||||
template: String,
|
|
||||||
bos_token: Option<TokenizerConfigToken>,
|
|
||||||
eos_token: Option<TokenizerConfigToken>,
|
|
||||||
) -> Self {
|
|
||||||
let mut env = Box::new(Environment::new());
|
|
||||||
// enable things like .strip() or .capitalize()
|
|
||||||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
|
||||||
let template_str = template.into_boxed_str();
|
|
||||||
env.add_function("raise_exception", raise_exception);
|
|
||||||
|
|
||||||
// check if contains the tools variable within the template
|
|
||||||
let use_default_tool_template =
|
|
||||||
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
|
|
||||||
// leaking env and template_str as read-only, static resources for performance.
|
|
||||||
let template = Box::leak(env)
|
|
||||||
.template_from_str(Box::leak(template_str))
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
Self {
|
|
||||||
template,
|
|
||||||
bos_token: bos_token.map(|token| token.as_str().to_string()),
|
|
||||||
eos_token: eos_token.map(|token| token.as_str().to_string()),
|
|
||||||
use_default_tool_template,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn apply(
|
|
||||||
&self,
|
|
||||||
mut messages: Vec<Message>,
|
|
||||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
|
||||||
) -> Result<String, InferError> {
|
|
||||||
if self.use_default_tool_template {
|
|
||||||
if let Some(last_message) = messages.last_mut() {
|
|
||||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
|
||||||
last_message.content.push(MessageChunk::Text {
|
|
||||||
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
|
||||||
|
|
||||||
self.template
|
|
||||||
.render(ChatTemplateInputs {
|
|
||||||
messages,
|
|
||||||
bos_token: self.bos_token.as_deref(),
|
|
||||||
eos_token: self.eos_token.as_deref(),
|
|
||||||
add_generation_prompt: true,
|
|
||||||
tools: None,
|
|
||||||
tools_prompt: None,
|
|
||||||
})
|
|
||||||
.map_err(InferError::TemplateError)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct ToolGrammar {}
|
|
||||||
|
|
||||||
impl ToolGrammar {
|
|
||||||
// find a tool by name
|
|
||||||
fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
|
|
||||||
tools
|
|
||||||
.iter()
|
|
||||||
.find(|tool| tool.function.name == name)
|
|
||||||
.cloned()
|
|
||||||
.ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn apply(
|
|
||||||
tools: Option<Vec<Tool>>,
|
|
||||||
tool_choice: ToolChoice,
|
|
||||||
) -> Result<Option<Tools>, InferError> {
|
|
||||||
// if no tools are provided, we return None
|
|
||||||
let tools = match tools {
|
|
||||||
Some(tools) if !tools.is_empty() => tools,
|
|
||||||
_ => return Ok(None),
|
|
||||||
};
|
|
||||||
|
|
||||||
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
|
|
||||||
|
|
||||||
// if tools are provided and no tool_choice we default to the OneOf
|
|
||||||
let tools_to_use = match tool_choice {
|
|
||||||
ToolType::FunctionName(name) => {
|
|
||||||
vec![Self::find_tool_by_name(&tools, &name)?]
|
|
||||||
}
|
|
||||||
ToolType::Function { function } => {
|
|
||||||
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
|
||||||
}
|
|
||||||
ToolType::OneOf => tools,
|
|
||||||
ToolType::NoTool => return Ok(None),
|
|
||||||
};
|
|
||||||
|
|
||||||
// adds the error notification function for LLM feedback if required
|
|
||||||
let mut text_response_properties = Map::new();
|
|
||||||
text_response_properties.insert(
|
|
||||||
"error".to_string(),
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "string",
|
|
||||||
"description": "The error or issue to notify"
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
text_response_properties.insert(
|
|
||||||
"_name".to_string(),
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "string",
|
|
||||||
"const": "notify_error"
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
|
|
||||||
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
|
||||||
.iter()
|
|
||||||
.map(|tool| {
|
|
||||||
let func = tool.function.clone();
|
|
||||||
|
|
||||||
// Clone the existing parameters, which are expected to be a JSON object
|
|
||||||
let mut params = if let Value::Object(params) = &func.arguments {
|
|
||||||
params.clone()
|
|
||||||
} else {
|
|
||||||
Map::new()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Insert the function's description at the top level, outside of properties
|
|
||||||
params.insert(
|
|
||||||
"description".to_string(),
|
|
||||||
Value::String(func.description.clone().unwrap_or_default()),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Ensure 'properties' exists and is an object
|
|
||||||
let properties = params
|
|
||||||
.entry("properties".to_string())
|
|
||||||
.or_insert_with(|| json!({}))
|
|
||||||
.as_object_mut()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Insert the constant for the function name inside 'properties'
|
|
||||||
properties.insert(
|
|
||||||
"_name".to_string(),
|
|
||||||
json!({
|
|
||||||
"type": "string",
|
|
||||||
"const": func.name.clone(),
|
|
||||||
// "description": "The name of the function"
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Check if 'required' exists, and it is an array. If not, create an empty array.
|
|
||||||
let required = params
|
|
||||||
.entry("required".to_string())
|
|
||||||
.or_insert_with(|| json!([]))
|
|
||||||
.as_array_mut()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Add 'name' to the 'required' array if it is not already present
|
|
||||||
if !required.iter().any(|r| r == "_name") {
|
|
||||||
required.push(json!("_name"));
|
|
||||||
}
|
|
||||||
|
|
||||||
(func.name, Value::Object(params))
|
|
||||||
})
|
|
||||||
.chain([(
|
|
||||||
"notify_error".to_string(),
|
|
||||||
serde_json::json!({
|
|
||||||
"properties": text_response_properties,
|
|
||||||
"required": ["error", "_name"],
|
|
||||||
"type": "object"
|
|
||||||
}),
|
|
||||||
)])
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let tools = Tools {
|
|
||||||
functions_map: FunctionsMap { functions },
|
|
||||||
properties: Properties {
|
|
||||||
function: tools_to_use
|
|
||||||
.iter()
|
|
||||||
.map(|tool| FunctionRef {
|
|
||||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
|
||||||
})
|
|
||||||
.chain(std::iter::once(FunctionRef {
|
|
||||||
ref_path: "#/$functions/notify_error".to_string(),
|
|
||||||
}))
|
|
||||||
.collect(),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Some(tools))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -468,15 +284,15 @@ pub(crate) type GenerateStreamResponse = (
|
|||||||
);
|
);
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct GeneratedText {
|
pub struct GeneratedText {
|
||||||
pub(crate) text: String,
|
pub text: String,
|
||||||
pub(crate) generated_tokens: u32,
|
pub generated_tokens: u32,
|
||||||
pub(crate) finish_reason: FinishReason,
|
pub finish_reason: FinishReason,
|
||||||
pub(crate) seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) enum InferStreamResponse {
|
pub enum InferStreamResponse {
|
||||||
// Optional first message
|
// Optional first message
|
||||||
Prefill(Vec<PrefillToken>),
|
Prefill(Vec<PrefillToken>),
|
||||||
// Intermediate messages
|
// Intermediate messages
|
||||||
|
135
router/src/infer/tool_grammar.rs
Normal file
135
router/src/infer/tool_grammar.rs
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
use crate::infer::InferError;
|
||||||
|
use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools};
|
||||||
|
use serde_json::{json, Map, Value};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
pub(crate) struct ToolGrammar {}
|
||||||
|
|
||||||
|
impl ToolGrammar {
|
||||||
|
// find a tool by name
|
||||||
|
fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
|
||||||
|
tools
|
||||||
|
.iter()
|
||||||
|
.find(|tool| tool.function.name == name)
|
||||||
|
.cloned()
|
||||||
|
.ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn apply(
|
||||||
|
tools: Option<Vec<Tool>>,
|
||||||
|
tool_choice: ToolChoice,
|
||||||
|
) -> Result<Option<Tools>, InferError> {
|
||||||
|
// if no tools are provided, we return None
|
||||||
|
let tools = match tools {
|
||||||
|
Some(tools) if !tools.is_empty() => tools,
|
||||||
|
_ => return Ok(None),
|
||||||
|
};
|
||||||
|
|
||||||
|
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
|
||||||
|
|
||||||
|
// if tools are provided and no tool_choice we default to the OneOf
|
||||||
|
let tools_to_use = match tool_choice {
|
||||||
|
ToolType::FunctionName(name) => {
|
||||||
|
vec![Self::find_tool_by_name(&tools, &name)?]
|
||||||
|
}
|
||||||
|
ToolType::Function { function } => {
|
||||||
|
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
||||||
|
}
|
||||||
|
ToolType::OneOf => tools,
|
||||||
|
ToolType::NoTool => return Ok(None),
|
||||||
|
};
|
||||||
|
|
||||||
|
// adds the error notification function for LLM feedback if required
|
||||||
|
let mut text_response_properties = Map::new();
|
||||||
|
text_response_properties.insert(
|
||||||
|
"error".to_string(),
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "string",
|
||||||
|
"description": "The error or issue to notify"
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
text_response_properties.insert(
|
||||||
|
"_name".to_string(),
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "string",
|
||||||
|
"const": "notify_error"
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
||||||
|
.iter()
|
||||||
|
.map(|tool| {
|
||||||
|
let func = tool.function.clone();
|
||||||
|
|
||||||
|
// Clone the existing parameters, which are expected to be a JSON object
|
||||||
|
let mut params = if let Value::Object(params) = &func.arguments {
|
||||||
|
params.clone()
|
||||||
|
} else {
|
||||||
|
Map::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Insert the function's description at the top level, outside of properties
|
||||||
|
params.insert(
|
||||||
|
"description".to_string(),
|
||||||
|
Value::String(func.description.clone().unwrap_or_default()),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Ensure 'properties' exists and is an object
|
||||||
|
let properties = params
|
||||||
|
.entry("properties".to_string())
|
||||||
|
.or_insert_with(|| json!({}))
|
||||||
|
.as_object_mut()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Insert the constant for the function name inside 'properties'
|
||||||
|
properties.insert(
|
||||||
|
"_name".to_string(),
|
||||||
|
json!({
|
||||||
|
"type": "string",
|
||||||
|
"const": func.name.clone(),
|
||||||
|
// "description": "The name of the function"
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Check if 'required' exists, and it is an array. If not, create an empty array.
|
||||||
|
let required = params
|
||||||
|
.entry("required".to_string())
|
||||||
|
.or_insert_with(|| json!([]))
|
||||||
|
.as_array_mut()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Add 'name' to the 'required' array if it is not already present
|
||||||
|
if !required.iter().any(|r| r == "_name") {
|
||||||
|
required.push(json!("_name"));
|
||||||
|
}
|
||||||
|
|
||||||
|
(func.name, Value::Object(params))
|
||||||
|
})
|
||||||
|
.chain([(
|
||||||
|
"notify_error".to_string(),
|
||||||
|
serde_json::json!({
|
||||||
|
"properties": text_response_properties,
|
||||||
|
"required": ["error", "_name"],
|
||||||
|
"type": "object"
|
||||||
|
}),
|
||||||
|
)])
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let tools = Tools {
|
||||||
|
functions_map: FunctionsMap { functions },
|
||||||
|
properties: Properties {
|
||||||
|
function: tools_to_use
|
||||||
|
.iter()
|
||||||
|
.map(|tool| FunctionRef {
|
||||||
|
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||||
|
})
|
||||||
|
.chain(std::iter::once(FunctionRef {
|
||||||
|
ref_path: "#/$functions/notify_error".to_string(),
|
||||||
|
}))
|
||||||
|
.collect(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Some(tools))
|
||||||
|
}
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
mod queue;
|
mod queue;
|
||||||
mod scheduler;
|
mod scheduler;
|
||||||
|
|
||||||
pub(crate) use scheduler::SchedulerV2;
|
pub(crate) use scheduler::BackendV2;
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
/// Batching and inference logic
|
/// Batching and inference logic
|
||||||
use crate::infer::v2::queue::{Entry, Queue};
|
use crate::infer::v2::queue::{Entry, Queue};
|
||||||
use crate::infer::{
|
use crate::infer::{
|
||||||
GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler,
|
Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse,
|
||||||
};
|
};
|
||||||
use crate::validation::ValidGenerateRequest;
|
use crate::validation::ValidGenerateRequest;
|
||||||
use crate::{FinishReason, PrefillToken, Token};
|
use crate::{FinishReason, PrefillToken, Token};
|
||||||
@ -18,14 +18,14 @@ use tokio::time::Instant;
|
|||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{info_span, instrument, Instrument, Span};
|
use tracing::{info_span, instrument, Instrument, Span};
|
||||||
|
|
||||||
pub(crate) struct SchedulerV2 {
|
pub(crate) struct BackendV2 {
|
||||||
/// Request queue
|
/// Request queue
|
||||||
queue: Queue,
|
queue: Queue,
|
||||||
/// Notify batcher on queue appends
|
/// Notify batcher on queue appends
|
||||||
batching_task_notifier: Arc<Notify>,
|
batching_task_notifier: Arc<Notify>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SchedulerV2 {
|
impl BackendV2 {
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
@ -69,7 +69,7 @@ impl SchedulerV2 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Scheduler for SchedulerV2 {
|
impl Backend for BackendV2 {
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
fn schedule(
|
fn schedule(
|
||||||
&self,
|
&self,
|
||||||
|
@ -1,5 +0,0 @@
|
|||||||
mod block_allocator;
|
|
||||||
mod queue;
|
|
||||||
mod scheduler;
|
|
||||||
|
|
||||||
pub(crate) use scheduler::SchedulerV3;
|
|
@ -1,11 +1,12 @@
|
|||||||
/// Text Generation Inference Webserver
|
/// Text Generation Inference Webserver
|
||||||
pub mod config;
|
pub mod config;
|
||||||
mod infer;
|
pub mod infer;
|
||||||
pub mod server;
|
pub mod server;
|
||||||
mod validation;
|
pub mod validation;
|
||||||
|
|
||||||
#[cfg(feature = "kserve")]
|
#[cfg(feature = "kserve")]
|
||||||
mod kserve;
|
mod kserve;
|
||||||
|
pub mod logging;
|
||||||
|
|
||||||
pub mod usage_stats;
|
pub mod usage_stats;
|
||||||
|
|
||||||
@ -148,12 +149,13 @@ pub struct Info {
|
|||||||
pub model_id: String,
|
pub model_id: String,
|
||||||
#[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
|
#[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
|
||||||
pub model_sha: Option<String>,
|
pub model_sha: Option<String>,
|
||||||
#[schema(example = "torch.float16")]
|
// #[schema(example = "torch.float16")]
|
||||||
pub model_dtype: String,
|
// pub model_dtype: String,
|
||||||
#[schema(example = "cuda")]
|
// #[schema(example = "cuda")]
|
||||||
pub model_device_type: String,
|
// pub model_device_type: String,
|
||||||
#[schema(nullable = true, example = "text-generation")]
|
#[schema(nullable = true, example = "text-generation")]
|
||||||
pub model_pipeline_tag: Option<String>,
|
pub model_pipeline_tag: Option<String>,
|
||||||
|
|
||||||
/// Router Parameters
|
/// Router Parameters
|
||||||
#[schema(example = "128")]
|
#[schema(example = "128")]
|
||||||
pub max_concurrent_requests: usize,
|
pub max_concurrent_requests: usize,
|
||||||
@ -165,18 +167,11 @@ pub struct Info {
|
|||||||
pub max_input_tokens: usize,
|
pub max_input_tokens: usize,
|
||||||
#[schema(example = "2048")]
|
#[schema(example = "2048")]
|
||||||
pub max_total_tokens: usize,
|
pub max_total_tokens: usize,
|
||||||
#[schema(example = "1.2")]
|
|
||||||
pub waiting_served_ratio: f32,
|
|
||||||
#[schema(example = "32000")]
|
|
||||||
pub max_batch_total_tokens: u32,
|
|
||||||
#[schema(example = "20")]
|
|
||||||
pub max_waiting_tokens: usize,
|
|
||||||
#[schema(nullable = true, example = "null")]
|
|
||||||
pub max_batch_size: Option<usize>,
|
|
||||||
#[schema(example = "2")]
|
#[schema(example = "2")]
|
||||||
pub validation_workers: usize,
|
pub validation_workers: usize,
|
||||||
#[schema(example = "32")]
|
#[schema(example = "32")]
|
||||||
pub max_client_batch_size: usize,
|
pub max_client_batch_size: usize,
|
||||||
|
|
||||||
/// Router Info
|
/// Router Info
|
||||||
#[schema(example = "text-generation-router")]
|
#[schema(example = "text-generation-router")]
|
||||||
pub router: &'static str,
|
pub router: &'static str,
|
||||||
@ -624,7 +619,7 @@ impl ChatCompletion {
|
|||||||
message,
|
message,
|
||||||
logprobs: return_logprobs
|
logprobs: return_logprobs
|
||||||
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
|
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
|
||||||
finish_reason: details.finish_reason.to_string(),
|
finish_reason: details.finish_reason.format(true),
|
||||||
}],
|
}],
|
||||||
usage: Usage {
|
usage: Usage {
|
||||||
prompt_tokens: details.prefill.len() as u32,
|
prompt_tokens: details.prefill.len() as u32,
|
||||||
@ -1068,23 +1063,23 @@ impl From<CompatGenerateRequest> for GenerateRequest {
|
|||||||
#[derive(Debug, Serialize, ToSchema)]
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
pub struct PrefillToken {
|
pub struct PrefillToken {
|
||||||
#[schema(example = 0)]
|
#[schema(example = 0)]
|
||||||
id: u32,
|
pub id: u32,
|
||||||
#[schema(example = "test")]
|
#[schema(example = "test")]
|
||||||
text: String,
|
pub text: String,
|
||||||
#[schema(nullable = true, example = - 0.34)]
|
#[schema(nullable = true, example = - 0.34)]
|
||||||
logprob: f32,
|
pub logprob: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, ToSchema, Clone)]
|
#[derive(Debug, Serialize, ToSchema, Clone)]
|
||||||
pub struct Token {
|
pub struct Token {
|
||||||
#[schema(example = 0)]
|
#[schema(example = 0)]
|
||||||
id: u32,
|
pub id: u32,
|
||||||
#[schema(example = "test")]
|
#[schema(example = "test")]
|
||||||
text: String,
|
pub text: String,
|
||||||
#[schema(nullable = true, example = - 0.34)]
|
#[schema(nullable = true, example = - 0.34)]
|
||||||
logprob: f32,
|
pub logprob: f32,
|
||||||
#[schema(example = "false")]
|
#[schema(example = "false")]
|
||||||
special: bool,
|
pub special: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, ToSchema)]
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
@ -1102,7 +1097,7 @@ pub struct SimpleToken {
|
|||||||
#[derive(Debug, Serialize, ToSchema)]
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
#[serde(rename_all(serialize = "snake_case"))]
|
#[serde(rename_all(serialize = "snake_case"))]
|
||||||
#[schema(example = "Length")]
|
#[schema(example = "Length")]
|
||||||
pub(crate) enum FinishReason {
|
pub enum FinishReason {
|
||||||
#[schema(rename = "length")]
|
#[schema(rename = "length")]
|
||||||
Length,
|
Length,
|
||||||
#[serde(rename = "eos_token")]
|
#[serde(rename = "eos_token")]
|
||||||
@ -1122,6 +1117,15 @@ impl std::fmt::Display for FinishReason {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl FinishReason {
|
||||||
|
pub fn format(&self, use_stop: bool) -> String {
|
||||||
|
match self {
|
||||||
|
FinishReason::EndOfSequenceToken if use_stop => "stop".to_string(),
|
||||||
|
_ => self.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct BestOfSequence {
|
pub(crate) struct BestOfSequence {
|
||||||
#[schema(example = "test")]
|
#[schema(example = "test")]
|
||||||
@ -1162,6 +1166,12 @@ pub(crate) struct GenerateResponse {
|
|||||||
pub details: Option<Details>,
|
pub details: Option<Details>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
|
pub(crate) struct ChatTokenizeResponse {
|
||||||
|
pub(crate) tokenize_response: TokenizeResponse,
|
||||||
|
pub(crate) templated_text: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
#[serde(transparent)]
|
#[serde(transparent)]
|
||||||
pub(crate) struct TokenizeResponse(Vec<SimpleToken>);
|
pub(crate) struct TokenizeResponse(Vec<SimpleToken>);
|
||||||
|
81
router/src/logging.rs
Normal file
81
router/src/logging.rs
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
||||||
|
use opentelemetry::sdk::trace;
|
||||||
|
use opentelemetry::sdk::trace::Sampler;
|
||||||
|
use opentelemetry::sdk::Resource;
|
||||||
|
use opentelemetry::{global, KeyValue};
|
||||||
|
use opentelemetry_otlp::WithExportConfig;
|
||||||
|
use tracing_subscriber::layer::SubscriberExt;
|
||||||
|
use tracing_subscriber::util::SubscriberInitExt;
|
||||||
|
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
|
||||||
|
|
||||||
|
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
|
||||||
|
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
|
||||||
|
/// - otlp_service_name service name to appear in APM
|
||||||
|
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
|
||||||
|
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
|
||||||
|
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
|
||||||
|
pub fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
|
||||||
|
let mut layers = Vec::new();
|
||||||
|
|
||||||
|
// STDOUT/STDERR layer
|
||||||
|
let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
|
||||||
|
let fmt_layer = tracing_subscriber::fmt::layer()
|
||||||
|
.with_file(true)
|
||||||
|
.with_ansi(ansi)
|
||||||
|
.with_line_number(true);
|
||||||
|
|
||||||
|
let fmt_layer = match json_output {
|
||||||
|
true => fmt_layer.json().flatten_event(true).boxed(),
|
||||||
|
false => fmt_layer.boxed(),
|
||||||
|
};
|
||||||
|
layers.push(fmt_layer);
|
||||||
|
|
||||||
|
// OpenTelemetry tracing layer
|
||||||
|
if let Some(otlp_endpoint) = otlp_endpoint {
|
||||||
|
global::set_text_map_propagator(TraceContextPropagator::new());
|
||||||
|
|
||||||
|
let tracer = opentelemetry_otlp::new_pipeline()
|
||||||
|
.tracing()
|
||||||
|
.with_exporter(
|
||||||
|
opentelemetry_otlp::new_exporter()
|
||||||
|
.tonic()
|
||||||
|
.with_endpoint(otlp_endpoint),
|
||||||
|
)
|
||||||
|
.with_trace_config(
|
||||||
|
trace::config()
|
||||||
|
.with_resource(Resource::new(vec![KeyValue::new(
|
||||||
|
"service.name",
|
||||||
|
otlp_service_name,
|
||||||
|
)]))
|
||||||
|
.with_sampler(Sampler::AlwaysOn),
|
||||||
|
)
|
||||||
|
.install_batch(opentelemetry::runtime::Tokio);
|
||||||
|
|
||||||
|
if let Ok(tracer) = tracer {
|
||||||
|
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
|
||||||
|
init_tracing_opentelemetry::init_propagator().unwrap();
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter events with LOG_LEVEL
|
||||||
|
let varname = "LOG_LEVEL";
|
||||||
|
let env_filter = if let Ok(log_level) = std::env::var(varname) {
|
||||||
|
// Override to avoid simple logs to be spammed with tokio level informations
|
||||||
|
let log_level = match &log_level[..] {
|
||||||
|
"warn" => "text_generation_launcher=warn,text_generation_router=warn",
|
||||||
|
"info" => "text_generation_launcher=info,text_generation_router=info",
|
||||||
|
"debug" => "text_generation_launcher=debug,text_generation_router=debug",
|
||||||
|
log_level => log_level,
|
||||||
|
};
|
||||||
|
EnvFilter::builder()
|
||||||
|
.with_default_directive(LevelFilter::INFO.into())
|
||||||
|
.parse_lossy(log_level)
|
||||||
|
} else {
|
||||||
|
EnvFilter::new("info")
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing_subscriber::registry()
|
||||||
|
.with(env_filter)
|
||||||
|
.with(layers)
|
||||||
|
.init();
|
||||||
|
}
|
1085
router/src/server.rs
1085
router/src/server.rs
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,5 @@
|
|||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
|
use clap::ValueEnum;
|
||||||
use csv::ReaderBuilder;
|
use csv::ReaderBuilder;
|
||||||
use reqwest::header::HeaderMap;
|
use reqwest::header::HeaderMap;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
@ -13,6 +14,13 @@ use uuid::Uuid;
|
|||||||
|
|
||||||
const TELEMETRY_URL: &str = "https://huggingface.co/api/telemetry/tgi";
|
const TELEMETRY_URL: &str = "https://huggingface.co/api/telemetry/tgi";
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Debug, Serialize, ValueEnum)]
|
||||||
|
pub enum UsageStatsLevel {
|
||||||
|
On,
|
||||||
|
NoStack,
|
||||||
|
Off,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize)]
|
#[derive(Debug, Clone, Serialize)]
|
||||||
pub struct UserAgent {
|
pub struct UserAgent {
|
||||||
pub uid: String,
|
pub uid: String,
|
||||||
@ -71,72 +79,69 @@ impl UsageStatsEvent {
|
|||||||
#[derive(Debug, Clone, Serialize)]
|
#[derive(Debug, Clone, Serialize)]
|
||||||
pub struct Args {
|
pub struct Args {
|
||||||
model_config: Option<Config>,
|
model_config: Option<Config>,
|
||||||
tokenizer_config: Option<String>,
|
tokenizer_class: Option<String>,
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
max_best_of: usize,
|
max_best_of: usize,
|
||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
max_top_n_tokens: u32,
|
max_top_n_tokens: u32,
|
||||||
max_input_tokens: usize,
|
max_input_tokens: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
waiting_served_ratio: f32,
|
// waiting_served_ratio: f32,
|
||||||
max_batch_prefill_tokens: u32,
|
// max_batch_prefill_tokens: u32,
|
||||||
max_batch_total_tokens: Option<u32>,
|
// max_batch_total_tokens: Option<u32>,
|
||||||
max_waiting_tokens: usize,
|
// max_waiting_tokens: usize,
|
||||||
max_batch_size: Option<usize>,
|
// max_batch_size: Option<usize>,
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
messages_api_enabled: bool,
|
messages_api_enabled: bool,
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
disable_usage_stats: bool,
|
usage_stats_level: UsageStatsLevel,
|
||||||
disable_crash_reports: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Args {
|
impl Args {
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn new(
|
pub fn new(
|
||||||
model_config: Option<Config>,
|
model_config: Option<Config>,
|
||||||
tokenizer_config: Option<String>,
|
tokenizer_class: Option<String>,
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
max_best_of: usize,
|
max_best_of: usize,
|
||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
max_top_n_tokens: u32,
|
max_top_n_tokens: u32,
|
||||||
max_input_tokens: usize,
|
max_input_tokens: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
waiting_served_ratio: f32,
|
// waiting_served_ratio: f32,
|
||||||
max_batch_prefill_tokens: u32,
|
// max_batch_prefill_tokens: u32,
|
||||||
max_batch_total_tokens: Option<u32>,
|
// max_batch_total_tokens: Option<u32>,
|
||||||
max_waiting_tokens: usize,
|
// max_waiting_tokens: usize,
|
||||||
max_batch_size: Option<usize>,
|
// max_batch_size: Option<usize>,
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
messages_api_enabled: bool,
|
messages_api_enabled: bool,
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
disable_usage_stats: bool,
|
usage_stats_level: UsageStatsLevel,
|
||||||
disable_crash_reports: bool,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
model_config,
|
model_config,
|
||||||
tokenizer_config,
|
tokenizer_class,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
max_best_of,
|
max_best_of,
|
||||||
max_stop_sequences,
|
max_stop_sequences,
|
||||||
max_top_n_tokens,
|
max_top_n_tokens,
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
waiting_served_ratio,
|
// waiting_served_ratio,
|
||||||
max_batch_prefill_tokens,
|
// max_batch_prefill_tokens,
|
||||||
max_batch_total_tokens,
|
// max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
// max_waiting_tokens,
|
||||||
max_batch_size,
|
// max_batch_size,
|
||||||
revision,
|
revision,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
messages_api_enabled,
|
messages_api_enabled,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
disable_usage_stats,
|
usage_stats_level,
|
||||||
disable_crash_reports,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,13 +5,12 @@ use crate::{
|
|||||||
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
|
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
|
||||||
};
|
};
|
||||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||||
use image::{io::Reader as ImageReader, ImageFormat};
|
use image::{ImageFormat, ImageReader};
|
||||||
use jsonschema::{Draft, JSONSchema};
|
use jsonschema::{Draft, JSONSchema};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
use std::iter;
|
use std::iter;
|
||||||
use text_generation_client::{Chunk, Image, InputChunk};
|
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
@ -96,7 +95,7 @@ impl Validation {
|
|||||||
&self,
|
&self,
|
||||||
inputs: String,
|
inputs: String,
|
||||||
truncate: Option<usize>,
|
truncate: Option<usize>,
|
||||||
) -> Result<Option<(tokenizers::Encoding, Vec<InputChunk>)>, ValidationError> {
|
) -> Result<Option<(tokenizers::Encoding, Vec<Chunk>)>, ValidationError> {
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
if let Some(sender) = &self.sender {
|
if let Some(sender) = &self.sender {
|
||||||
// Create response channel
|
// Create response channel
|
||||||
@ -122,7 +121,7 @@ impl Validation {
|
|||||||
inputs: String,
|
inputs: String,
|
||||||
truncate: Option<usize>,
|
truncate: Option<usize>,
|
||||||
max_new_tokens: Option<u32>,
|
max_new_tokens: Option<u32>,
|
||||||
) -> Result<(Vec<InputChunk>, usize, u32), ValidationError> {
|
) -> Result<(Vec<Chunk>, usize, u32), ValidationError> {
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
||||||
// Create response channel
|
// Create response channel
|
||||||
@ -181,11 +180,7 @@ impl Validation {
|
|||||||
input_length = input_length.saturating_sub(max_new_tokens as usize);
|
input_length = input_length.saturating_sub(max_new_tokens as usize);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok((
|
Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens))
|
||||||
vec![Chunk::Text(inputs).into()],
|
|
||||||
input_length,
|
|
||||||
max_new_tokens,
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -589,7 +584,7 @@ fn prepare_input(
|
|||||||
tokenizer: &Tokenizer,
|
tokenizer: &Tokenizer,
|
||||||
config: Option<&Config>,
|
config: Option<&Config>,
|
||||||
preprocessor_config: Option<&HubPreprocessorConfig>,
|
preprocessor_config: Option<&HubPreprocessorConfig>,
|
||||||
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> {
|
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
|
||||||
use Config::*;
|
use Config::*;
|
||||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||||
let (tokenizer_query, input_chunks) = match config {
|
let (tokenizer_query, input_chunks) = match config {
|
||||||
@ -601,16 +596,16 @@ fn prepare_input(
|
|||||||
let chunk_start = chunk.start();
|
let chunk_start = chunk.start();
|
||||||
let chunk_end = chunk.end();
|
let chunk_end = chunk.end();
|
||||||
if chunk_start != start {
|
if chunk_start != start {
|
||||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
|
||||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
}
|
}
|
||||||
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||||
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
input_chunks.push(Chunk::Image(Image { data, mimetype }));
|
||||||
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
|
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
|
||||||
start = chunk_end;
|
start = chunk_end;
|
||||||
}
|
}
|
||||||
if start != inputs.len() {
|
if start != inputs.len() {
|
||||||
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
input_chunks.push(Chunk::Text(inputs[start..].to_string()));
|
||||||
tokenizer_query.push_str(&inputs[start..]);
|
tokenizer_query.push_str(&inputs[start..]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -618,7 +613,7 @@ fn prepare_input(
|
|||||||
|
|
||||||
(tokenizer_query, input_chunks)
|
(tokenizer_query, input_chunks)
|
||||||
}
|
}
|
||||||
_ => (inputs.clone(), vec![Chunk::Text(inputs).into()]),
|
_ => (inputs.clone(), vec![Chunk::Text(inputs)]),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Get the number of tokens in the input
|
// Get the number of tokens in the input
|
||||||
@ -631,18 +626,51 @@ fn prepare_input(
|
|||||||
|
|
||||||
type TokenizerRequest = (
|
type TokenizerRequest = (
|
||||||
(String, Option<usize>),
|
(String, Option<usize>),
|
||||||
oneshot::Sender<Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError>>,
|
oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
|
||||||
Span,
|
Span,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||||
|
pub struct Image {
|
||||||
|
pub data: Vec<u8>,
|
||||||
|
pub mimetype: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||||
|
pub enum Chunk {
|
||||||
|
Text(String),
|
||||||
|
Image(Image),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert input chunks to a stringly-typed input for backwards
|
||||||
|
/// compat for backends that haven't implemented chunked inputs.
|
||||||
|
pub trait ChunksToString {
|
||||||
|
/// Convert chunks to string.
|
||||||
|
fn chunks_to_string(&self) -> String;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChunksToString for Vec<Chunk> {
|
||||||
|
fn chunks_to_string(&self) -> String {
|
||||||
|
let mut output = String::new();
|
||||||
|
self.iter().for_each(|c| match &c {
|
||||||
|
Chunk::Text(text) => output.push_str(text),
|
||||||
|
Chunk::Image(Image { data, mimetype }) => {
|
||||||
|
let encoded = STANDARD.encode(data);
|
||||||
|
output.push_str(&format!("", mimetype, encoded))
|
||||||
|
}
|
||||||
|
});
|
||||||
|
output
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) enum ValidGrammar {
|
pub enum ValidGrammar {
|
||||||
Json(String),
|
Json(String),
|
||||||
Regex(String),
|
Regex(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct ValidParameters {
|
pub struct ValidParameters {
|
||||||
/// / exponential scaling output probability distribution
|
/// / exponential scaling output probability distribution
|
||||||
pub temperature: f32,
|
pub temperature: f32,
|
||||||
/// / restricting to the k highest probability elements
|
/// / restricting to the k highest probability elements
|
||||||
@ -666,7 +694,7 @@ pub(crate) struct ValidParameters {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct ValidStoppingParameters {
|
pub struct ValidStoppingParameters {
|
||||||
/// / Maximum number of generated tokens
|
/// / Maximum number of generated tokens
|
||||||
pub max_new_tokens: u32,
|
pub max_new_tokens: u32,
|
||||||
/// / Optional stopping sequences
|
/// / Optional stopping sequences
|
||||||
@ -677,8 +705,8 @@ pub(crate) struct ValidStoppingParameters {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct ValidGenerateRequest {
|
pub struct ValidGenerateRequest {
|
||||||
pub inputs: Vec<InputChunk>,
|
pub inputs: Vec<Chunk>,
|
||||||
pub input_length: u32,
|
pub input_length: u32,
|
||||||
pub truncate: u32,
|
pub truncate: u32,
|
||||||
pub decoder_input_details: bool,
|
pub decoder_input_details: bool,
|
||||||
@ -750,6 +778,8 @@ pub enum ValidationError {
|
|||||||
InvalidImageContent(String),
|
InvalidImageContent(String),
|
||||||
#[error("Could not fetch image: {0}")]
|
#[error("Could not fetch image: {0}")]
|
||||||
FailedFetchImage(#[from] reqwest::Error),
|
FailedFetchImage(#[from] reqwest::Error),
|
||||||
|
#[error("{0} modality is not supported")]
|
||||||
|
UnsupportedModality(&'static str),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -34,7 +34,6 @@ def reshape_and_cache(
|
|||||||
|
|
||||||
|
|
||||||
def paged_attention(
|
def paged_attention(
|
||||||
out: torch.Tensor,
|
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key_cache: torch.Tensor,
|
key_cache: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
@ -85,7 +84,7 @@ def paged_attention(
|
|||||||
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
|
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
|
||||||
if softcap is None:
|
if softcap is None:
|
||||||
softcap = 0.0
|
softcap = 0.0
|
||||||
out2 = flash_attn_2_cuda.varlen_fwd(
|
out = flash_attn_2_cuda.varlen_fwd(
|
||||||
query,
|
query,
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
@ -108,13 +107,15 @@ def paged_attention(
|
|||||||
False, # return softmax
|
False, # return softmax
|
||||||
None, # generator
|
None, # generator
|
||||||
)
|
)
|
||||||
return out2[0]
|
return out[0]
|
||||||
else:
|
else:
|
||||||
if softcap is not None:
|
if softcap is not None:
|
||||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||||
input_lengths = seqlen.input_lengths
|
input_lengths = seqlen.input_lengths
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
|
|
||||||
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
use_v1 = max_s <= 8192 and (
|
use_v1 = max_s <= 8192 and (
|
||||||
max_num_partitions == 1 or num_seqs * num_heads > 512
|
max_num_partitions == 1 or num_seqs * num_heads > 512
|
||||||
)
|
)
|
||||||
@ -171,6 +172,10 @@ def paged_attention(
|
|||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
is_ampere_or_newer = major >= 8 and minor >= 0
|
||||||
|
if not is_ampere_or_newer:
|
||||||
|
raise ImportError("FlashAttention only supports Ampere GPUs or newer.")
|
||||||
|
|
||||||
import flash_attn_2_cuda
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
V2 = True
|
V2 = True
|
||||||
@ -200,13 +205,13 @@ except ImportError:
|
|||||||
|
|
||||||
|
|
||||||
SUPPORTS_WINDOWING = V2
|
SUPPORTS_WINDOWING = V2
|
||||||
|
|
||||||
if V2:
|
if V2:
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
out,
|
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
@ -214,6 +219,7 @@ if V2:
|
|||||||
causal=True,
|
causal=True,
|
||||||
softcap=0.0,
|
softcap=0.0,
|
||||||
):
|
):
|
||||||
|
out = torch.empty_like(q)
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
@ -238,7 +244,7 @@ if V2:
|
|||||||
softcap,
|
softcap,
|
||||||
False,
|
False,
|
||||||
None,
|
None,
|
||||||
)
|
)[0]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
@ -246,7 +252,6 @@ else:
|
|||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
out,
|
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
@ -286,7 +291,8 @@ else:
|
|||||||
.reshape(original_shape[0], -1, original_shape[2])
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
)
|
)
|
||||||
|
|
||||||
return flash_attn_cuda.fwd(
|
out = torch.empty_like(q)
|
||||||
|
flash_attn_cuda.fwd(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
@ -303,3 +309,4 @@ else:
|
|||||||
0,
|
0,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
return out
|
||||||
|
@ -11,7 +11,6 @@ def attention(
|
|||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
out,
|
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
@ -19,6 +18,8 @@ def attention(
|
|||||||
causal=True,
|
causal=True,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
|
out = torch.empty_like(q)
|
||||||
|
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
return ipex.llm.functional.varlen_attention(
|
return ipex.llm.functional.varlen_attention(
|
||||||
q,
|
q,
|
||||||
@ -51,7 +52,6 @@ def reshape_and_cache(
|
|||||||
|
|
||||||
|
|
||||||
def paged_attention(
|
def paged_attention(
|
||||||
out: torch.Tensor,
|
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key_cache: torch.Tensor,
|
key_cache: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
@ -62,6 +62,7 @@ def paged_attention(
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
|
out = torch.empty_like(query)
|
||||||
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||||
out,
|
out,
|
||||||
query,
|
query,
|
||||||
|
@ -39,7 +39,6 @@ def reshape_and_cache(
|
|||||||
|
|
||||||
|
|
||||||
def paged_attention(
|
def paged_attention(
|
||||||
out: torch.Tensor,
|
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key_cache: torch.Tensor,
|
key_cache: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
@ -72,6 +71,8 @@ def paged_attention(
|
|||||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||||
input_lengths = input_lengths.input_lengths
|
input_lengths = input_lengths.input_lengths
|
||||||
|
|
||||||
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
@ -174,7 +175,6 @@ if ENGINE == "ck":
|
|||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
out,
|
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
@ -184,6 +184,8 @@ if ENGINE == "ck":
|
|||||||
if window_size_left <= 0 and window_size_left != -1:
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
|
|
||||||
|
out = torch.empty_like(q)
|
||||||
|
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
q,
|
q,
|
||||||
@ -209,13 +211,14 @@ elif ENGINE == "triton":
|
|||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
out,
|
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
causal=True,
|
causal=True,
|
||||||
):
|
):
|
||||||
|
out = torch.empty_like(q)
|
||||||
|
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
output, _ = triton_attention(
|
output, _ = triton_attention(
|
||||||
q,
|
q,
|
||||||
|
@ -124,50 +124,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
self.sym = sym
|
self.sym = sym
|
||||||
|
|
||||||
def get_weights(self, weights: Weights, prefix: str):
|
def get_weights(self, weights: Weights, prefix: str):
|
||||||
from text_generation_server.layers.marlin import (
|
|
||||||
can_use_gptq_marlin,
|
|
||||||
repack_gptq_for_marlin,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._get_gptq_params(weights)
|
self._get_gptq_params(weights)
|
||||||
if can_use_gptq_marlin(
|
|
||||||
bits=self.bits,
|
|
||||||
groupsize=self.groupsize,
|
|
||||||
quant_method=self.quant_method,
|
|
||||||
quantize=self.quantize,
|
|
||||||
sym=self.sym,
|
|
||||||
):
|
|
||||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
|
||||||
try:
|
|
||||||
qweight = weights.get_tensor(f"{prefix}.qweight")
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.sym:
|
|
||||||
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
|
||||||
else:
|
|
||||||
qzeros = None
|
|
||||||
|
|
||||||
if self.quant_method == "awq":
|
|
||||||
g_idx = None
|
|
||||||
else:
|
|
||||||
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
|
||||||
scales = weights.get_tensor(f"{prefix}.scales")
|
|
||||||
|
|
||||||
return repack_gptq_for_marlin(
|
|
||||||
qweight=qweight,
|
|
||||||
scales=scales,
|
|
||||||
qzeros=qzeros,
|
|
||||||
g_idx=g_idx,
|
|
||||||
bits=self.bits,
|
|
||||||
desc_act=self.desc_act,
|
|
||||||
groupsize=self.groupsize,
|
|
||||||
quant_method=self.quant_method,
|
|
||||||
sym=self.sym,
|
|
||||||
sharded_infeatures=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
use_exllama = True
|
use_exllama = True
|
||||||
if self.bits != 4:
|
if self.bits != 4:
|
||||||
@ -248,11 +205,6 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
prefix: str,
|
prefix: str,
|
||||||
block_sizes: Union[int, List[int]],
|
block_sizes: Union[int, List[int]],
|
||||||
):
|
):
|
||||||
from text_generation_server.layers.marlin import (
|
|
||||||
can_use_gptq_marlin,
|
|
||||||
repack_gptq_for_marlin,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qweight = weights.get_packed_sharded(
|
qweight = weights.get_packed_sharded(
|
||||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||||
@ -267,36 +219,6 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
scales = scales.to(dtype=weights.dtype)
|
scales = scales.to(dtype=weights.dtype)
|
||||||
|
|
||||||
self._get_gptq_params(weights)
|
self._get_gptq_params(weights)
|
||||||
if can_use_gptq_marlin(
|
|
||||||
bits=self.bits,
|
|
||||||
groupsize=self.groupsize,
|
|
||||||
quant_method=self.quant_method,
|
|
||||||
quantize=self.quantize,
|
|
||||||
sym=self.sym,
|
|
||||||
):
|
|
||||||
if not self.sym:
|
|
||||||
qzeros = weights.get_packed_sharded(
|
|
||||||
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
qzeros = None
|
|
||||||
|
|
||||||
if self.quant_method == "awq":
|
|
||||||
g_idx = None
|
|
||||||
else:
|
|
||||||
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
|
||||||
return repack_gptq_for_marlin(
|
|
||||||
qweight=qweight,
|
|
||||||
scales=scales,
|
|
||||||
qzeros=qzeros,
|
|
||||||
g_idx=g_idx,
|
|
||||||
bits=self.bits,
|
|
||||||
desc_act=self.desc_act,
|
|
||||||
groupsize=self.groupsize,
|
|
||||||
quant_method=self.quant_method,
|
|
||||||
sym=self.sym,
|
|
||||||
sharded_infeatures=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
qzeros = weights.get_packed_sharded(
|
qzeros = weights.get_packed_sharded(
|
||||||
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
||||||
@ -334,11 +256,6 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||||
from text_generation_server.layers.marlin import (
|
|
||||||
can_use_gptq_marlin,
|
|
||||||
repack_gptq_for_marlin,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qweight = torch.cat(
|
qweight = torch.cat(
|
||||||
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||||
@ -353,41 +270,6 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._get_gptq_params(weights)
|
self._get_gptq_params(weights)
|
||||||
if can_use_gptq_marlin(
|
|
||||||
bits=self.bits,
|
|
||||||
groupsize=self.groupsize,
|
|
||||||
quant_method=self.quant_method,
|
|
||||||
quantize=self.quantize,
|
|
||||||
sym=self.sym,
|
|
||||||
):
|
|
||||||
|
|
||||||
if not self.sym:
|
|
||||||
qzeros = torch.cat(
|
|
||||||
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
qzeros = None
|
|
||||||
|
|
||||||
if self.quant_method == "awq":
|
|
||||||
g_idx = None
|
|
||||||
else:
|
|
||||||
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
|
||||||
for w2 in w[1:]:
|
|
||||||
torch.testing.assert_close(w2, w[0])
|
|
||||||
g_idx = w[0]
|
|
||||||
|
|
||||||
return repack_gptq_for_marlin(
|
|
||||||
qweight=qweight,
|
|
||||||
scales=scales,
|
|
||||||
qzeros=qzeros,
|
|
||||||
g_idx=g_idx,
|
|
||||||
bits=self.bits,
|
|
||||||
desc_act=self.desc_act,
|
|
||||||
groupsize=self.groupsize,
|
|
||||||
quant_method=self.quant_method,
|
|
||||||
sym=self.sym,
|
|
||||||
sharded_infeatures=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
qzeros = torch.cat(
|
qzeros = torch.cat(
|
||||||
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||||
@ -441,59 +323,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_weights_row(self, weights: Weights, prefix: str):
|
def get_weights_row(self, weights: Weights, prefix: str):
|
||||||
from text_generation_server.layers.marlin import (
|
|
||||||
can_use_gptq_marlin,
|
|
||||||
repack_gptq_for_marlin,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._get_gptq_params(weights)
|
self._get_gptq_params(weights)
|
||||||
if can_use_gptq_marlin(
|
|
||||||
bits=self.bits,
|
|
||||||
groupsize=self.groupsize,
|
|
||||||
quant_method=self.quant_method,
|
|
||||||
quantize=self.quantize,
|
|
||||||
sym=self.sym,
|
|
||||||
):
|
|
||||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
|
||||||
try:
|
|
||||||
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.sym:
|
|
||||||
if self.desc_act or self.groupsize == -1:
|
|
||||||
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
|
||||||
else:
|
|
||||||
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
|
|
||||||
else:
|
|
||||||
qzeros = None
|
|
||||||
|
|
||||||
if self.quant_method == "awq":
|
|
||||||
g_idx = None
|
|
||||||
else:
|
|
||||||
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
|
|
||||||
|
|
||||||
if self.desc_act or self.groupsize == -1:
|
|
||||||
scales = weights.get_tensor(f"{prefix}.scales")
|
|
||||||
else:
|
|
||||||
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
|
||||||
|
|
||||||
sharded_in_features = weights.process_group.size() > 1
|
|
||||||
|
|
||||||
return repack_gptq_for_marlin(
|
|
||||||
qweight=qweight,
|
|
||||||
scales=scales,
|
|
||||||
qzeros=qzeros,
|
|
||||||
g_idx=g_idx,
|
|
||||||
bits=self.bits,
|
|
||||||
desc_act=self.desc_act,
|
|
||||||
groupsize=self.groupsize,
|
|
||||||
quant_method=self.quant_method,
|
|
||||||
sym=self.sym,
|
|
||||||
sharded_infeatures=sharded_in_features,
|
|
||||||
)
|
|
||||||
|
|
||||||
use_exllama = True
|
use_exllama = True
|
||||||
if self.bits != 4:
|
if self.bits != 4:
|
||||||
|
@ -17,7 +17,7 @@ from loguru import logger
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from text_generation_server.layers.gptq.utils import torch_snr_error
|
from text_generation_server.layers.gptq.utils import torch_snr_error
|
||||||
|
|
||||||
from text_generation_server.utils.weights import DefaultWeightsLoader
|
from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
|
||||||
|
|
||||||
DEV = torch.device("cuda:0")
|
DEV = torch.device("cuda:0")
|
||||||
|
|
||||||
@ -897,7 +897,7 @@ def quantize(
|
|||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
aliases={"embed_tokens.weight": ["lm_head.weight"]},
|
aliases={"embed_tokens.weight": ["lm_head.weight"]},
|
||||||
weights_loader=DefaultWeightsLoader(),
|
weights_loader=DefaultWeightsLoader(UnquantizedWeight),
|
||||||
)
|
)
|
||||||
hooks = []
|
hooks = []
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
@ -960,9 +960,6 @@ def quantize(
|
|||||||
|
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
||||||
state_dict["gptq_bits"] = torch.LongTensor([bits])
|
|
||||||
state_dict["gptq_groupsize"] = torch.LongTensor([groupsize])
|
|
||||||
state_dict["gptq_sym"] = torch.BoolTensor([sym])
|
|
||||||
|
|
||||||
max_shard_size = "10GB"
|
max_shard_size = "10GB"
|
||||||
shards, index = shard_checkpoint(
|
shards, index = shard_checkpoint(
|
||||||
@ -994,6 +991,15 @@ def quantize(
|
|||||||
f"index located at {save_index_file}."
|
f"index located at {save_index_file}."
|
||||||
)
|
)
|
||||||
config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
||||||
|
config.quantization_config = {
|
||||||
|
"bits": bits,
|
||||||
|
"group_size": groupsize,
|
||||||
|
"damp_percent": percdamp,
|
||||||
|
"desc_act": act_order,
|
||||||
|
"static_groups": False,
|
||||||
|
"sym": sym,
|
||||||
|
"quant_method": "gptq",
|
||||||
|
}
|
||||||
config.save_pretrained(output_dir)
|
config.save_pretrained(output_dir)
|
||||||
logger.info("Saved config")
|
logger.info("Saved config")
|
||||||
logger.info("Saving tokenizer")
|
logger.info("Saving tokenizer")
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear
|
from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear
|
||||||
from text_generation_server.layers.marlin.gptq import (
|
from text_generation_server.layers.marlin.gptq import (
|
||||||
GPTQMarlinLinear,
|
GPTQMarlinWeightsLoader,
|
||||||
GPTQMarlinWeight,
|
|
||||||
can_use_gptq_marlin,
|
can_use_gptq_marlin,
|
||||||
repack_gptq_for_marlin,
|
repack_gptq_for_marlin,
|
||||||
)
|
)
|
||||||
@ -9,8 +8,7 @@ from text_generation_server.layers.marlin.marlin import MarlinWeightsLoader
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"GPTQMarlinFP8Linear",
|
"GPTQMarlinFP8Linear",
|
||||||
"GPTQMarlinLinear",
|
"GPTQMarlinWeightsLoader",
|
||||||
"GPTQMarlinWeight",
|
|
||||||
"MarlinWeightsLoader",
|
"MarlinWeightsLoader",
|
||||||
"can_use_gptq_marlin",
|
"can_use_gptq_marlin",
|
||||||
"repack_gptq_for_marlin",
|
"repack_gptq_for_marlin",
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import torch
|
import torch
|
||||||
@ -13,7 +13,7 @@ from text_generation_server.layers.marlin.util import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
from text_generation_server.utils.weights import Weight
|
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import marlin_kernels
|
import marlin_kernels
|
||||||
@ -48,6 +48,204 @@ def can_use_gptq_marlin(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GPTQMarlinWeightsLoader(WeightsLoader):
|
||||||
|
"""
|
||||||
|
Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
bits: int,
|
||||||
|
desc_act: bool,
|
||||||
|
groupsize: int,
|
||||||
|
quant_method: str,
|
||||||
|
quantize: str,
|
||||||
|
sym: bool,
|
||||||
|
):
|
||||||
|
self.bits = bits
|
||||||
|
self.desc_act = desc_act
|
||||||
|
self.groupsize = groupsize
|
||||||
|
self.quant_method = quant_method
|
||||||
|
self.quantize = quantize
|
||||||
|
self.sym = sym
|
||||||
|
|
||||||
|
def get_weights(self, weights: Weights, prefix: str):
|
||||||
|
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||||
|
try:
|
||||||
|
qweight = weights.get_tensor(f"{prefix}.qweight")
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.sym:
|
||||||
|
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||||
|
else:
|
||||||
|
qzeros = None
|
||||||
|
|
||||||
|
if self.quant_method == "awq":
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||||
|
scales = weights.get_tensor(f"{prefix}.scales")
|
||||||
|
|
||||||
|
return repack_gptq_for_marlin(
|
||||||
|
qweight=qweight,
|
||||||
|
scales=scales,
|
||||||
|
qzeros=qzeros,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
desc_act=self.desc_act,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
quant_method=self.quant_method,
|
||||||
|
sym=self.sym,
|
||||||
|
sharded_infeatures=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weights_col_packed(
|
||||||
|
self,
|
||||||
|
weights: Weights,
|
||||||
|
prefix: str,
|
||||||
|
block_sizes: Union[int, List[int]],
|
||||||
|
):
|
||||||
|
|
||||||
|
try:
|
||||||
|
qweight = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
|
||||||
|
)
|
||||||
|
scales = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
scales = scales.to(dtype=weights.dtype)
|
||||||
|
|
||||||
|
if not self.sym:
|
||||||
|
qzeros = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
qzeros = None
|
||||||
|
|
||||||
|
if self.quant_method == "awq":
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||||
|
return repack_gptq_for_marlin(
|
||||||
|
qweight=qweight,
|
||||||
|
scales=scales,
|
||||||
|
qzeros=qzeros,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
desc_act=self.desc_act,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
quant_method=self.quant_method,
|
||||||
|
sym=self.sym,
|
||||||
|
sharded_infeatures=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||||
|
try:
|
||||||
|
qweight = torch.cat(
|
||||||
|
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
|
||||||
|
)
|
||||||
|
|
||||||
|
scales = torch.cat(
|
||||||
|
[weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.sym:
|
||||||
|
qzeros = torch.cat(
|
||||||
|
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
qzeros = None
|
||||||
|
|
||||||
|
if self.quant_method == "awq":
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||||
|
for w2 in w[1:]:
|
||||||
|
torch.testing.assert_close(w2, w[0])
|
||||||
|
g_idx = w[0]
|
||||||
|
|
||||||
|
return repack_gptq_for_marlin(
|
||||||
|
qweight=qweight,
|
||||||
|
scales=scales,
|
||||||
|
qzeros=qzeros,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
desc_act=self.desc_act,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
quant_method=self.quant_method,
|
||||||
|
sym=self.sym,
|
||||||
|
sharded_infeatures=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weights_row(self, weights: Weights, prefix: str):
|
||||||
|
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||||
|
try:
|
||||||
|
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.sym:
|
||||||
|
if self.desc_act or self.groupsize == -1:
|
||||||
|
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||||
|
else:
|
||||||
|
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||||
|
else:
|
||||||
|
qzeros = None
|
||||||
|
|
||||||
|
if self.quant_method == "awq":
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||||
|
|
||||||
|
if self.desc_act or self.groupsize == -1:
|
||||||
|
scales = weights.get_tensor(f"{prefix}.scales")
|
||||||
|
else:
|
||||||
|
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
||||||
|
|
||||||
|
sharded_in_features = weights.process_group.size() > 1
|
||||||
|
|
||||||
|
return repack_gptq_for_marlin(
|
||||||
|
qweight=qweight,
|
||||||
|
scales=scales,
|
||||||
|
qzeros=qzeros,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
desc_act=self.desc_act,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
quant_method=self.quant_method,
|
||||||
|
sym=self.sym,
|
||||||
|
sharded_infeatures=sharded_in_features,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_gptq_params(self, weights: Weights):
|
||||||
|
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
|
||||||
|
self.bits = weights.get_tensor("gptq_bits").item()
|
||||||
|
self.groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||||
|
self.desc_act = False
|
||||||
|
# `server quantize` used asymmetric quantization unconditionally
|
||||||
|
# before the `gptq_sym` setting tensor was added.
|
||||||
|
self.sym = (
|
||||||
|
weights.get_tensor("gptq_sym").item()
|
||||||
|
if weights._has_tensor("gptq_sym")
|
||||||
|
else False
|
||||||
|
)
|
||||||
|
self.quant_method = "gptq"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GPTQMarlinWeight(Weight):
|
class GPTQMarlinWeight(Weight):
|
||||||
"""
|
"""
|
||||||
|
@ -484,6 +484,9 @@ def get_model(
|
|||||||
)
|
)
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
|
|
||||||
|
if max_input_tokens is not None and max_input_tokens <= sliding_window:
|
||||||
|
sliding_window = -1
|
||||||
|
|
||||||
if (
|
if (
|
||||||
(sliding_window is not None and sliding_window != -1)
|
(sliding_window is not None and sliding_window != -1)
|
||||||
and not SUPPORTS_WINDOWING
|
and not SUPPORTS_WINDOWING
|
||||||
|
@ -291,17 +291,13 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
|
|
||||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||||
|
|
||||||
# output tensor
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
attn_output,
|
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -309,7 +305,6 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
attn_output,
|
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
|
@ -330,17 +330,13 @@ class DbrxAttention(torch.nn.Module):
|
|||||||
|
|
||||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||||
|
|
||||||
# output tensor
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -348,7 +344,6 @@ class DbrxAttention(torch.nn.Module):
|
|||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
attn_output,
|
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
|
@ -358,25 +358,20 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||||||
|
|
||||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||||
|
|
||||||
# Output tensor
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
attn_output,
|
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention(
|
attn_output = paged_attention(
|
||||||
attn_output,
|
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
|
@ -231,17 +231,13 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||||||
|
|
||||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||||
|
|
||||||
# output tensor
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -252,7 +248,6 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
attn_output,
|
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
|
@ -225,17 +225,13 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
|
|
||||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||||
|
|
||||||
# output tensor
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -244,7 +240,6 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
attn_output,
|
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
|
@ -225,17 +225,13 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||||||
|
|
||||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||||
|
|
||||||
# output tensor
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
attn_output,
|
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -243,7 +239,6 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
attn_output,
|
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
|
@ -213,17 +213,13 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
|
|
||||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||||
|
|
||||||
# output tensor
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -231,7 +227,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
attn_output,
|
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
|
@ -212,17 +212,13 @@ class MistralAttention(torch.nn.Module):
|
|||||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
# output tensor
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -231,7 +227,6 @@ class MistralAttention(torch.nn.Module):
|
|||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
attn_output,
|
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
|
@ -269,17 +269,13 @@ class MixtralAttention(torch.nn.Module):
|
|||||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
# output tensor
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -288,7 +284,6 @@ class MixtralAttention(torch.nn.Module):
|
|||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
attn_output,
|
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
|
@ -158,17 +158,13 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
|
|
||||||
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots)
|
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots)
|
||||||
|
|
||||||
# output tensor
|
|
||||||
attn_output = torch.empty_like(qkv[:, 0])
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
attn_output = attention(
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
qkv[:, 1],
|
qkv[:, 1],
|
||||||
qkv[:, 2],
|
qkv[:, 2],
|
||||||
attn_output,
|
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -176,7 +172,6 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
attn_output,
|
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
|
@ -188,16 +188,12 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
# Reshape key and value and cache
|
# Reshape key and value and cache
|
||||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||||
|
|
||||||
# output tensor
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -205,7 +201,6 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
attn_output,
|
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
|
@ -130,17 +130,13 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
# output tensor
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -149,7 +145,6 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
attn_output,
|
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
|
@ -201,17 +201,13 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
|
|
||||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||||
|
|
||||||
# output
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -219,7 +215,6 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
attn_output,
|
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
@ -324,17 +319,13 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
slots,
|
slots,
|
||||||
)
|
)
|
||||||
|
|
||||||
# output
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=2, index=0),
|
torch.select(kv, dim=2, index=0),
|
||||||
torch.select(kv, dim=2, index=1),
|
torch.select(kv, dim=2, index=1),
|
||||||
attn_output,
|
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -342,7 +333,6 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
attn_output,
|
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
@ -392,8 +382,13 @@ class FlashRWLayer(nn.Module):
|
|||||||
|
|
||||||
prefix = f"{prefix}.h.{layer_id}"
|
prefix = f"{prefix}.h.{layer_id}"
|
||||||
|
|
||||||
|
# NOTE: Falcon 180B uses the ln_attn prefix
|
||||||
|
ln_prefix = "input_layernorm"
|
||||||
|
if config.num_hidden_layers == 80:
|
||||||
|
ln_prefix = "ln_attn"
|
||||||
|
|
||||||
self.input_layernorm = FastLayerNorm.load(
|
self.input_layernorm = FastLayerNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm",
|
prefix=f"{prefix}.{ln_prefix}",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_epsilon,
|
eps=config.layer_norm_epsilon,
|
||||||
)
|
)
|
||||||
@ -483,7 +478,13 @@ class FlashRWLayer(nn.Module):
|
|||||||
class FlashRWLayerNorm(nn.Module):
|
class FlashRWLayerNorm(nn.Module):
|
||||||
def __init__(self, config, prefix: str, weights):
|
def __init__(self, config, prefix: str, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_ln = config.num_ln_in_parallel_attn
|
# Falcon2 includes the number of layer norms in the config
|
||||||
|
# in the case no number of layer norms is provided, we default to 1
|
||||||
|
self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1)
|
||||||
|
|
||||||
|
# Falcon 180B uses the ln_attn prefix and has 2 layer norms
|
||||||
|
if config.num_hidden_layers == 80:
|
||||||
|
self.num_ln = 2
|
||||||
|
|
||||||
if self.num_ln == 1:
|
if self.num_ln == 1:
|
||||||
self.input_ln = FastLayerNorm.load(
|
self.input_ln = FastLayerNorm.load(
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user