Merge branch 'main' into auto_length

This commit is contained in:
Nicolas Patry 2024-10-25 10:20:00 +02:00 committed by GitHub
commit c3fb2ecdc0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
82 changed files with 2616 additions and 2395 deletions

68
Cargo.lock generated
View File

@ -2706,9 +2706,9 @@ dependencies = [
[[package]] [[package]]
name = "opentelemetry" name = "opentelemetry"
version = "0.23.0" version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b69a91d4893e713e06f724597ad630f1fa76057a5e1026c0ca67054a9032a76" checksum = "4c365a63eec4f55b7efeceb724f1336f26a9cf3427b70e59e2cd2a5b947fba96"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"futures-sink", "futures-sink",
@ -2819,19 +2819,17 @@ dependencies = [
[[package]] [[package]]
name = "opentelemetry_sdk" name = "opentelemetry_sdk"
version = "0.23.0" version = "0.24.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae312d58eaa90a82d2e627fd86e075cf5230b3f11794e2ed74199ebbe572d4fd" checksum = "692eac490ec80f24a17828d49b40b60f5aeaccdfe6a503f939713afd22bc28df"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"futures-channel", "futures-channel",
"futures-executor", "futures-executor",
"futures-util", "futures-util",
"glob", "glob",
"lazy_static",
"once_cell", "once_cell",
"opentelemetry 0.23.0", "opentelemetry 0.24.0",
"ordered-float 4.3.0",
"percent-encoding", "percent-encoding",
"rand", "rand",
"thiserror", "thiserror",
@ -4185,16 +4183,17 @@ dependencies = [
"cmake", "cmake",
"cxx", "cxx",
"cxx-build", "cxx-build",
"hashbrown 0.14.5",
"hf-hub",
"log", "log",
"parking_lot",
"pkg-config", "pkg-config",
"text-generation-router", "text-generation-router",
"thiserror", "thiserror",
"tokenizers 0.19.1", "tokenizers",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tracing", "tracing",
"tracing-opentelemetry 0.24.0", "tracing-opentelemetry 0.25.0",
"tracing-subscriber", "tracing-subscriber",
] ]
@ -4212,7 +4211,7 @@ dependencies = [
"tabled", "tabled",
"text-generation-client", "text-generation-client",
"thiserror", "thiserror",
"tokenizers 0.20.0", "tokenizers",
"tokio", "tokio",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
@ -4292,7 +4291,7 @@ dependencies = [
"serde_json", "serde_json",
"sysinfo", "sysinfo",
"thiserror", "thiserror",
"tokenizers 0.20.0", "tokenizers",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tower-http", "tower-http",
@ -4341,7 +4340,7 @@ dependencies = [
"slotmap", "slotmap",
"text-generation-router", "text-generation-router",
"thiserror", "thiserror",
"tokenizers 0.20.0", "tokenizers",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tonic 0.10.2", "tonic 0.10.2",
@ -4392,7 +4391,7 @@ dependencies = [
"slotmap", "slotmap",
"text-generation-router", "text-generation-router",
"thiserror", "thiserror",
"tokenizers 0.20.0", "tokenizers",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tonic 0.10.2", "tonic 0.10.2",
@ -4514,39 +4513,6 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokenizers"
version = "0.19.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e500fad1dd3af3d626327e6a3fe5050e664a6eaa4708b8ca92f1794aaf73e6fd"
dependencies = [
"aho-corasick",
"derive_builder",
"esaxx-rs",
"getrandom",
"hf-hub",
"indicatif",
"itertools 0.12.1",
"lazy_static",
"log",
"macro_rules_attribute",
"monostate",
"onig",
"paste",
"rand",
"rayon",
"rayon-cond",
"regex",
"regex-syntax 0.8.5",
"serde",
"serde_json",
"spm_precompiled",
"thiserror",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",
]
[[package]] [[package]]
name = "tokenizers" name = "tokenizers"
version = "0.20.0" version = "0.20.0"
@ -4933,14 +4899,14 @@ dependencies = [
[[package]] [[package]]
name = "tracing-opentelemetry" name = "tracing-opentelemetry"
version = "0.24.0" version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f68803492bf28ab40aeccaecc7021096bd256baf7ca77c3d425d89b35a7be4e4" checksum = "a9784ed4da7d921bc8df6963f8c80a0e4ce34ba6ba76668acadd3edbd985ff3b"
dependencies = [ dependencies = [
"js-sys", "js-sys",
"once_cell", "once_cell",
"opentelemetry 0.23.0", "opentelemetry 0.24.0",
"opentelemetry_sdk 0.23.0", "opentelemetry_sdk 0.24.1",
"smallvec", "smallvec",
"tracing", "tracing",
"tracing-core", "tracing-core",

View File

@ -1,23 +0,0 @@
# All the tooling for CUDA
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 AS cuda-builder
WORKDIR /usr/src/tgi/backends/trtllm
RUN apt update && apt install -y cmake git git-lfs gcc g++ ninja-build libopenmpi-dev python3-dev python3-pip wget
COPY . /usr/src/tgi
RUN chmod +x scripts/install_tensorrt.sh && scripts/install_tensorrt.sh
RUN cmake -G Ninja -B build -DTRT_LIB_DIR=/usr/local/tensorrt/lib -DTRT_INCLUDE_DIR=/usr/local/tensorrt/include .
RUN cmake --build build --parallel -t tgi_trtllm_backend_impl
# All the tooling for Rust
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
WORKDIR /usr/src
# Include CUDA related libraries and tools to the Rust based image
COPY --from=cuda-builder /usr/local/cuda /usr/local/cuda
COPY --from=cuda-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=cuda-builder /usr/src/tgi/backends/trtllm/build /usr/local/tgi/trtllm/build
ENV PATH=/usr/local/cuda/bin:$PATH
ENV LD_LIBRARY_PATH=/usr/local/tensorrt/lib:$LD_LIBRARY_PATH
RUN apt update && apt install -y cmake git gcc g++ ninja-build libopenmpi3

View File

@ -10,7 +10,7 @@ COPY . .
RUN cargo chef prepare --recipe-path recipe.json RUN cargo chef prepare --recipe-path recipe.json
# CUDA dependent dependencies resolver stage # CUDA dependent dependencies resolver stage
FROM nvidia/cuda:12.5.1-cudnn-devel-ubuntu22.04 AS cuda-builder FROM nvidia/cuda:12.6.1-cudnn-devel-ubuntu22.04 AS cuda-builder
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \ --mount=type=cache,target=/var/lib/apt,sharing=locked \
@ -26,6 +26,7 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
ninja-build \ ninja-build \
pkg-config \ pkg-config \
python3 \ python3 \
python3-dev \
python3-setuptools \ python3-setuptools \
tar \ tar \
wget wget
@ -82,10 +83,15 @@ RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$
cd backends/trtllm && \ cd backends/trtllm && \
CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release
FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu22.04 AS runtime FROM nvidia/cuda:12.6.1-cudnn-runtime-ubuntu22.04 AS runtime
RUN apt update && apt install -y python3 && \
rm -rf /var/lib/{apt,dpkg,cache,log}/
WORKDIR /usr/local/tgi/bin WORKDIR /usr/local/tgi/bin
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH" ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
ENV TOKENIZERS_PARALLELISM=false
ENV OMPI_MCA_plm_rsh_agent=""
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt

View File

@ -98,7 +98,7 @@ curl 127.0.0.1:8080/generate_stream \
You can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses. You can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses.
```bash ```bash
curl localhost:3000/v1/chat/completions \ curl localhost:8080/v1/chat/completions \
-X POST \ -X POST \
-d '{ -d '{
"model": "tgi", "model": "tgi",

View File

@ -1,5 +1,17 @@
cmake_minimum_required(VERSION 3.20) cmake_minimum_required(VERSION 3.20)
if (NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER AND CMAKE_BUILD_TYPE STREQUAL "Debug")
find_program(CCACHE_EXECUTABLE "ccache")
if (CCACHE_EXECUTABLE)
message(STATUS "Using ccache")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_EXECUTABLE}" CACHE PATH "Path to ccache" FORCE)
endif ()
endif ()
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
cmake_policy(SET CMP0135 NEW)
endif ()
project(tgi-trtllm-backend VERSION 1.0.0) project(tgi-trtllm-backend VERSION 1.0.0)
set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD 20)
@ -14,7 +26,7 @@ set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include"
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located") set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features # We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
find_package(CUDAToolkit 12.5 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml) find_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
#### External dependencies #### #### External dependencies ####
include(cmake/fmt.cmake) include(cmake/fmt.cmake)

View File

@ -10,16 +10,17 @@ async-trait = "0.1"
async-stream = "0.3" async-stream = "0.3"
clap = { version = "4.5", features = ["derive"] } clap = { version = "4.5", features = ["derive"] }
cxx = "1.0" cxx = "1.0"
hashbrown = "0.14"
hf-hub = { workspace = true }
log = { version = "0.4", features = [] } log = { version = "0.4", features = [] }
text-generation-router = { path = "../../router" } text-generation-router = { path = "../../router" }
tokenizers = { version = "0.19", features = ["hf-hub"] } tokenizers = { workspace = true }
tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.15" tokio-stream = "0.1.15"
thiserror = "1.0.62" thiserror = "1.0.63"
tracing = "0.1" tracing = "0.1"
tracing-opentelemetry = "0.24" tracing-opentelemetry = "0.25"
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
parking_lot = "0.12"
[build-dependencies] [build-dependencies]
cmake = "0.1" cmake = "0.1"

View File

@ -6,7 +6,7 @@ use std::path::{absolute, PathBuf};
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"]; const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST"); const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST");
const CUDA_REQUIRED_VERSION: &str = "12.5"; const CUDA_REQUIRED_VERSION: &str = "12.6";
const MPI_REQUIRED_VERSION: &str = "4.1"; const MPI_REQUIRED_VERSION: &str = "4.1";
const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX"); const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX");
const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR"); const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR");
@ -36,7 +36,7 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
// Build the backend implementation through CMake // Build the backend implementation through CMake
let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi"); let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi");
let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt"); let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt");
let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("90-real"); // Hopper by default let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("75-real;80-real;86-real;89-real;90-real");
let mut install_path = PathBuf::from(install_path); let mut install_path = PathBuf::from(install_path);
if !install_path.is_absolute() { if !install_path.is_absolute() {
@ -81,7 +81,12 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
(PathBuf::from(install_path), deps_folder) (PathBuf::from(install_path), deps_folder)
} }
fn build_ffi_layer(deps_folder: &PathBuf) { fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) {
let ndebug = match is_debug {
true => "1",
false => "0",
};
CFG.include_prefix = "backends/trtllm"; CFG.include_prefix = "backends/trtllm";
cxx_build::bridge("src/lib.rs") cxx_build::bridge("src/lib.rs")
.static_flag(true) .static_flag(true)
@ -93,9 +98,14 @@ fn build_ffi_layer(deps_folder: &PathBuf) {
.include("/usr/local/tensorrt/include") .include("/usr/local/tensorrt/include")
.file("src/ffi.cpp") .file("src/ffi.cpp")
.std("c++20") .std("c++20")
.define("NDEBUG", ndebug)
.compile("tgi_trtllm_backend"); .compile("tgi_trtllm_backend");
println!("cargo:rerun-if-changed=CMakeLists.txt"); println!("cargo:rerun-if-changed=CMakeLists.txt");
println!("cargo:rerun-if-changed=cmake/trtllm.cmake");
println!("cargo:rerun-if-changed=cmake/json.cmake");
println!("cargo:rerun-if-changed=cmake/fmt.cmake");
println!("cargo:rerun-if-changed=cmake/spdlog.cmake");
println!("cargo:rerun-if-changed=include/backend.h"); println!("cargo:rerun-if-changed=include/backend.h");
println!("cargo:rerun-if-changed=lib/backend.cpp"); println!("cargo:rerun-if-changed=lib/backend.cpp");
println!("cargo:rerun-if-changed=include/ffi.h"); println!("cargo:rerun-if-changed=include/ffi.h");
@ -115,7 +125,7 @@ fn main() {
let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir); let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir);
// Build the FFI layer calling the backend above // Build the FFI layer calling the backend above
build_ffi_layer(&deps_folder); build_ffi_layer(&deps_folder, is_debug);
// Emit linkage search path // Emit linkage search path
probe!("ompi", MPI_REQUIRED_VERSION); probe!("ompi", MPI_REQUIRED_VERSION);

View File

@ -1,6 +1,6 @@
FetchContent_Declare( FetchContent_Declare(
fmt fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt DOWNLOAD_EXTRACT_TIMESTAMP
GIT_TAG 11.0.1 URL https://github.com/fmtlib/fmt/archive/refs/tags/11.0.2.tar.gz
) )
FetchContent_MakeAvailable(fmt) FetchContent_MakeAvailable(fmt)

View File

@ -1,5 +1,6 @@
fetchcontent_declare( fetchcontent_declare(
json json
DOWNLOAD_EXTRACT_TIMESTAMP
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz
) )
fetchcontent_makeavailable(json) fetchcontent_makeavailable(json)

View File

@ -11,7 +11,7 @@ endif ()
fetchcontent_declare( fetchcontent_declare(
spdlog spdlog
GIT_REPOSITORY https://github.com/gabime/spdlog.git DOWNLOAD_EXTRACT_TIMESTAMP
GIT_TAG v1.14.1 URL https://github.com/gabime/spdlog/archive/refs/tags/v1.14.1.tar.gz
) )
fetchcontent_makeavailable(spdlog) fetchcontent_makeavailable(spdlog)

View File

@ -23,8 +23,9 @@ endif ()
fetchcontent_declare( fetchcontent_declare(
trtllm trtllm
GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git
GIT_TAG a681853d3803ee5893307e812530b5e7004bb6e1 GIT_TAG 201135e58aa525af7e523d091d4c9584229524bc
GIT_SHALLOW FALSE GIT_SHALLOW FALSE
DOWNLOAD_EXTRACT_TIMESTAMP
) )
fetchcontent_makeavailable(trtllm) fetchcontent_makeavailable(trtllm)

View File

@ -23,6 +23,12 @@ namespace huggingface::tgi::backends {
using RequestId = tle::IdType; using RequestId = tle::IdType;
using TokenId = tle::TokenIdType; using TokenId = tle::TokenIdType;
const static auto OUTPUT_CONFIG = tle::OutputConfig(true, false, false, true, false);
constexpr auto FMT_EXECUTOR_STATS = FMT_STRING(
"Submitting inference [{}] to the executor ({:d} already in-flight)");
constexpr auto FMT_SAMPLING_CONFIG = FMT_STRING(
"Sampling: topK={:d}, topP={:.1f}, temperature={:.1f}, repetition_penalty={:.1f}, frequency_penalty={:.1f}, seed={:d}");
/** /**
* Initialize all the components required by TRTLLM. * Initialize all the components required by TRTLLM.
* It is required to call this function before attempting to load any engine * It is required to call this function before attempting to load any engine
@ -54,7 +60,7 @@ namespace huggingface::tgi::backends {
float_t repetition_penalty, float_t repetition_penalty,
float_t frequency_penalty, float_t frequency_penalty,
uint64_t seed uint64_t seed
); ) noexcept;
/** /**
* *
@ -64,18 +70,15 @@ namespace huggingface::tgi::backends {
const json config; const json config;
tle::Executor executor; tle::Executor executor;
/** Frequently accessed variables cached here **/
uint32_t maxNumTokens;
public: public:
explicit TensorRtLlmBackend( explicit TensorRtLlmBackend(
const std::filesystem::path &engineFolder, const std::filesystem::path &engineFolder,
const std::filesystem::path &executorWorker const std::filesystem::path &executorWorker
); );
/**
* Indicate if the backend is ready to accept incoming request
* @return true if ready, false otherwise
*/
[[nodiscard]] bool IsReady() const;
/** /**
* Query the executor for the number of token available for pulling * Query the executor for the number of token available for pulling
* @return * @return
@ -95,25 +98,16 @@ namespace huggingface::tgi::backends {
*/ */
[[nodiscard]] RequestId Submit( [[nodiscard]] RequestId Submit(
const std::vector<TokenId> &tokens, const std::vector<TokenId> &tokens,
int32_t topK, const uint32_t maxNewTokens,
float_t topP, const int32_t topK,
float_t temperature, const float_t topP,
float_t repetition_penalty, const float_t temperature,
float_t frequency_penalty, const float_t repetition_penalty,
uint64_t seed const float_t frequency_penalty,
const uint64_t seed
); );
/** [[nodiscard]] std::vector<tle::Response> PullNewTokens();
*
* @param requestId The request id to poll the generation results
* @return
*/
std::vector<tle::Response> Poll(RequestId requestId);
/**
* Stop the underlying executor
*/
void Shutdown();
}; };
} }

View File

@ -5,20 +5,31 @@
#ifndef TGI_TRTLLM_BACKEND_FFI_H #ifndef TGI_TRTLLM_BACKEND_FFI_H
#define TGI_TRTLLM_BACKEND_FFI_H #define TGI_TRTLLM_BACKEND_FFI_H
#include <cmath>
#include <cstddef> #include <cstddef>
#include <memory>
#include "backend.h" #include "backend.h"
namespace huggingface::tgi::backends { namespace huggingface::tgi::backends {
class TensorRtLlmBackendImpl; class TensorRtLlmBackendImpl;
} }
// Template to support returning error from TllmException back to Rust in a Result<>
#include <tensorrt_llm/common/tllmException.h>
namespace rust::behavior {
template<typename Try, typename Fail>
static void trycatch(Try &&func, Fail &&fail) noexcept try {
func();
} catch (tensorrt_llm::common::TllmException &e) {
fail(e.what());
}
}
#include "backends/trtllm/src/lib.rs.h" #include "backends/trtllm/src/lib.rs.h"
namespace huggingface::tgi::backends { namespace huggingface::tgi::backends {
// struct GenerationContext;
class TensorRtLlmBackendImpl : public TensorRtLlmBackend { class TensorRtLlmBackendImpl : public TensorRtLlmBackend {
public: public:
/*** /***
@ -28,15 +39,10 @@ namespace huggingface::tgi::backends {
*/ */
TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker); TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker);
/***
*
* @return
*/
bool IsReady() const;
/*** /***
* *
* @param tokens * @param tokens
* @param maxNewTokens
* @param topK * @param topK
* @param topP * @param topP
* @param temperature * @param temperature
@ -47,21 +53,15 @@ namespace huggingface::tgi::backends {
*/ */
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]] [[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
uint64_t uint64_t
Submit(rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, Submit(rust::Slice<const uint32_t> tokens, uint32_t maxNewTokens,
int32_t topK, float_t topP, float_t temperature,
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed); float_t repetition_penalty, float_t frequency_penalty, uint64_t seed);
/*** /***
* *
* @param requestId
* @param ctx
* @param callback
* @return * @return
*/ */
size_t StreamTokens( std::unique_ptr<std::vector<GenerationStep>> PullTokens();
const RequestId requestId,
huggingface::tgi::backends::GenerationContext *ctx,
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
huggingface::tgi::backends::GenerationStep)> callback);
}; };
/*** /***

View File

@ -14,7 +14,7 @@
namespace huggingface::hardware::cuda { namespace huggingface::hardware::cuda {
#define AMPERE_SM_MAJOR 8 #define AMPERE_SM_MAJOR 8
#define HOPPER_SM_MAJOR 8 #define HOPPER_SM_MAJOR 9
/** /**
* Store information about the version of the CUDA Compute Capabilities detected on the device * Store information about the version of the CUDA Compute Capabilities detected on the device

View File

@ -1,3 +1,4 @@
#include <cstdlib>
#include <fstream> #include <fstream>
#include <fmt/ranges.h> #include <fmt/ranges.h>
@ -8,10 +9,23 @@
#include "hardware.h" #include "hardware.h"
void huggingface::tgi::backends::InitializeBackend() { void huggingface::tgi::backends::InitializeBackend() {
if (const auto TRTLLM_LOG_LEVEL_CSTR = std::getenv("TRTLLM_LOG_LEVEL")) {
std::string log_level(TRTLLM_LOG_LEVEL_CSTR);
std::transform(log_level.begin(), log_level.end(), log_level.begin(), [](unsigned char c) {
return std::tolower(c);
});
if (log_level == "debug")
spdlog::set_level(spdlog::level::debug);
else
spdlog::set_level(spdlog::level::info);
}
SPDLOG_INFO("Initializing Backend..."); SPDLOG_INFO("Initializing Backend...");
nvmlInit_v2(); nvmlInit_v2();
initTrtLlmPlugins(); initTrtLlmPlugins();
SPDLOG_INFO("Backend Executor Version: {}", tle::version());
const auto numGpus = huggingface::hardware::cuda::GetNumDevices(); const auto numGpus = huggingface::hardware::cuda::GetNumDevices();
if (numGpus.has_value()) { if (numGpus.has_value()) {
SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value()); SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value());
@ -22,7 +36,7 @@ void huggingface::tgi::backends::InitializeBackend() {
[[nodiscard]] [[nodiscard]]
tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) { tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
tle::ExecutorConfig execConfig(1); tle::ExecutorConfig execConfig(/* maxBeamWidth = */ 1);
// Retrieve the compute capabilities to enable some options at runtime // Retrieve the compute capabilities to enable some options at runtime
const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities(); const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities();
@ -55,12 +69,13 @@ tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &co
} }
tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig( tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
uint32_t topK, const uint32_t topK,
float_t topP, const float_t topP,
float_t temperature, const float_t temperature,
float_t repetition_penalty, const float_t repetition_penalty,
float_t frequency_penalty, const float_t frequency_penalty,
uint64_t seed) { const uint64_t seed) noexcept {
return tle::SamplingConfig( return tle::SamplingConfig(
1, // TGI only use a single beam 1, // TGI only use a single beam
topK, topK,
@ -83,26 +98,29 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
const std::filesystem::path &executorWorker const std::filesystem::path &executorWorker
) : ) :
config(json::parse(std::ifstream(enginesFolder / "config.json"))), config(json::parse(std::ifstream(enginesFolder / "config.json"))),
executor( executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY,
enginesFolder, GetExecutorConfig(config, executorWorker.string())) {
tensorrt_llm::executor::ModelType::kDECODER_ONLY,
GetExecutorConfig(config, executorWorker.string()
)) {
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>()); SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
}
bool huggingface::tgi::backends::TensorRtLlmBackend::IsReady() const { // Cache variables
return executor.canEnqueueRequests(); maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();
} }
[[nodiscard("Returned number of requests needs to be consumed")]] [[nodiscard("Returned number of requests needs to be consumed")]]
size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const { size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const {
return executor.getNumResponsesReady(); const auto numResponses = executor.getNumResponsesReady();
#ifndef NDEBUG
if(numResponses > 0) SPDLOG_INFO(FMT_STRING("Num responses ready: {:d}"), numResponses);
#endif
return numResponses;
} }
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]] [[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit( tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
const std::vector<tle::TokenIdType> &tokens, const std::vector<tle::TokenIdType> &tokens,
const uint32_t maxNewTokens,
const int32_t topK, const int32_t topK,
const float_t topP, const float_t topP,
const float_t temperature, const float_t temperature,
@ -110,37 +128,23 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
const float_t frequency_penalty, const float_t frequency_penalty,
const uint64_t seed const uint64_t seed
) { ) {
#ifdef NDEBUG const auto maxNewTokensChecked = std::min(maxNewTokens, static_cast<uint32_t>(maxNumTokens - tokens.size()));
SPDLOG_DEBUG( #ifndef NDEBUG
FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"), {
tokens.size(), const auto &iterations = executor.getLatestIterationStats();
executor.getLatestIterationStats().back().numActiveRequests const auto &lastIteration = iterations.front();
);
#else SPDLOG_DEBUG(FMT_EXECUTOR_STATS, fmt::join(tokens, ", "), lastIteration.numActiveRequests);
SPDLOG_DEBUG( SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
FMT_STRING("Submitting inference [{}] to the executor ({:d} already in-flight)"), SPDLOG_DEBUG(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked);
fmt::join(tokens, ", "), }
executor.getLatestIterationStats().front().numActiveRequests
);
#endif #endif
const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<size_t>();
const auto maxNewTokens = static_cast<int32_t>(std::max(1ul, maxNumTokens - tokens.size()));
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed); const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
const auto output = tle::OutputConfig(true, false, false, true, false); const auto maxNewTokensChecked_ = static_cast<tle::SizeType32>(maxNewTokensChecked);
return executor.enqueueRequest( return executor.enqueueRequest(tle::Request{tokens, maxNewTokensChecked_, true, sampling, OUTPUT_CONFIG});
tle::Request{tokens, maxNewTokens, true, sampling, output});
} }
[[nodiscard("Generated tokens result must be used")]] std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() {
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType requestId) { return executor.awaitResponses();
SPDLOG_DEBUG(FMT_STRING("Polling status for request {:d}"), requestId);
return executor.awaitResponses(requestId);
}
void huggingface::tgi::backends::TensorRtLlmBackend::Shutdown() {
SPDLOG_INFO("Shutting down executor");
executor.shutdown();
} }

View File

@ -2,12 +2,13 @@
set -ex set -ex
TRT_VER="10.2.0.19" TRT_VER_BASE="10.4.0"
CUDA_VER="12.5" TRT_VER_FULL="${TRT_VER_BASE}.26"
CUDNN_VER="9.2.1.18-1" CUDA_VER="12.6"
NCCL_VER="2.22.3-1+cuda12.5" CUDNN_VER="9.5.0.50-1"
CUBLAS_VER="12.5.3.2-1" NCCL_VER="2.22.3-1+cuda12.6"
NVRTC_VER="12.5.82-1" CUBLAS_VER="12.6.3.3-1"
NVRTC_VER="12.6.77-1"
for i in "$@"; do for i in "$@"; do
case $i in case $i in
@ -32,8 +33,9 @@ install_ubuntu_requirements() {
ARCH=$(uname -m) ARCH=$(uname -m)
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi
curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/${ARCH}/cuda-keyring_1.0-1_all.deb curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/${ARCH}/cuda-keyring_1.1-1_all.deb
dpkg -i cuda-keyring_1.0-1_all.deb dpkg -i cuda-keyring_1.1-1_all.deb
rm /etc/apt/sources.list.d/cuda-ubuntu2404-x86_64.list
apt-get update apt-get update
if [[ $(apt list --installed | grep libcudnn9) ]]; then if [[ $(apt list --installed | grep libcudnn9) ]]; then
@ -71,7 +73,7 @@ install_centos_requirements() {
install_tensorrt() { install_tensorrt() {
#PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))') #PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')
#PARSED_PY_VERSION=$(echo "${PY_VERSION//./}") #PARSED_PY_VERSION=$(echo "${PY_VERSION//./}")
TRT_CUDA_VERSION="12.5" TRT_CUDA_VERSION="12.6"
if [ -z "$RELEASE_URL_TRT" ];then if [ -z "$RELEASE_URL_TRT" ];then
ARCH=${TRT_TARGETARCH} ARCH=${TRT_TARGETARCH}
@ -79,12 +81,12 @@ install_tensorrt() {
if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi
if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-22.04" && OS="ubuntu-22.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-24.04" && OS="ubuntu-24.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi
RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.2.0/tars/TensorRT-${TRT_VER}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/${TRT_VER_BASE}/tars/TensorRT-${TRT_VER_FULL}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz
fi fi
wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar
tar -xf /tmp/TensorRT.tar -C /usr/local/ tar -xf /tmp/TensorRT.tar -C /usr/local/
mv /usr/local/TensorRT-${TRT_VER} /usr/local/tensorrt mv /usr/local/TensorRT-${TRT_VER_FULL} /usr/local/tensorrt
# pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl # pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl
rm -rf /tmp/TensorRT.tar rm -rf /tmp/TensorRT.tar
} }

View File

@ -1,330 +0,0 @@
use std::future::Future;
use std::path::Path;
use std::pin::{pin, Pin};
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, OnceLock};
use std::task::{Context, Poll};
use std::time::Duration;
use async_trait::async_trait;
use cxx::UniquePtr;
use log::{error, warn};
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::time::{sleep, Instant};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::{Stream, StreamExt};
use tracing::{instrument, span, Level};
// use tokio::sync::RwLock;
use parking_lot::RwLock;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidationError::UnsupportedModality;
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError};
use text_generation_router::{FinishReason, Token};
use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
// Value used to poll the state of the generation stream
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
type InferResult<T> = Result<T, InferError>;
pub(crate) struct Generation {
executor: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
done: Arc<AtomicBool>,
}
/// Holds the user provided input to be executed along with a channel allowing
/// to bubble up all the generated tokens for that tokens the to end stream.
pub struct GenerationContext {
sender: UnboundedSender<InferResult<InferStreamResponse>>,
tokenizer: Arc<Tokenizer>,
tokens: Vec<u32>,
done: Arc<AtomicBool>,
queued: Instant,
start: Option<Instant>,
}
impl Stream for Generation {
type Item = usize;
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let interval = POLLING_INTERVAL_US.get_or_init(|| {
u64::from_str(option_env!("TRTLLM_BACKEND_POLLING_INTERVAL_US").unwrap_or("100"))
.expect("Invalid value provided for envvar POLLING_INTERVAL_US")
});
if !self.done.load(Ordering::Relaxed) {
let backend = pin!(self.executor.read());
let status = match backend.poll(ctx) {
Poll::Ready(executor_r) => {
let ready = executor_r.num_responses_ready();
if ready == 0 {
Poll::Pending
} else {
Poll::Ready(Some(ready))
}
}
Poll::Pending => Poll::Pending,
};
let waker = ctx.waker().clone();
tokio::spawn(async {
sleep(Duration::from_micros(*interval)).await;
waker.wake();
});
status
} else {
Poll::Ready(None) // end of stream
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(1, None)
}
}
unsafe impl Send for TensorRtLlmBackendImpl {}
unsafe impl Sync for TensorRtLlmBackendImpl {}
/// Implements the logic to execute generation with TensorRT-LLM executor API in background
pub struct TensorRtLlmBackend {
tokenizer: Arc<Tokenizer>,
// Backing the backend behind a RwLock to allow concurrent read access to retrieve
// the number of available tokens (read only) in the Generation stream
backend: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
}
impl TensorRtLlmBackend {
pub fn new<P: AsRef<Path> + Send + 'static, PP: AsRef<Path> + Send + 'static>(
tokenizer: Tokenizer,
engine_folder: P,
executor_worker_path: PP,
) -> Result<Self, TensorRtLlmBackendError> {
Ok(TensorRtLlmBackend {
tokenizer: Arc::new(tokenizer),
backend: Arc::new(RwLock::new(create_tensorrt_llm_backend(
engine_folder.as_ref().to_str().unwrap(),
executor_worker_path.as_ref().to_str().unwrap(),
))),
})
}
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
if request.top_n_tokens > 1 {
return Err(InferError::ValidationError(
ValidationError::TopNTokensDisabled,
));
}
// TODO: Is it really needed? How can it be validated before?
if request.parameters.grammar.is_some() {
return Err(InferError::ValidationError(ValidationError::Grammar));
}
match request.inputs.len() {
0 => Err(InferError::ValidationError(ValidationError::EmptyInput)),
2.. => Err(InferError::GenerationError(
"TensorRT-LLM backend don't support multi-chunk".into(),
)),
1 => match request.inputs.first().expect("Single item-chunk") {
Chunk::Text(text) => Ok(text),
Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))),
},
}
}
fn generate(
&self,
sender: UnboundedSender<InferResult<InferStreamResponse>>,
tokens: Vec<u32>,
top_k: u32,
top_p: f32,
temperature: f32,
repetition_penalty: f32,
frequency_penalty: f32,
seed: u64,
) {
let tokenizer = Arc::clone(&self.tokenizer);
let executor = Arc::clone(&self.backend);
// Let's push this in async context
tokio::spawn(async move {
// Define the generation state
let mut generation = Generation {
executor: executor.clone(),
done: Arc::new(AtomicBool::new(false)),
};
// Define the context over the generation
// TODO(asap): Do we really need so many shared-ownership?
let ctx = Box::new(GenerationContext {
sender: sender.clone(),
tokenizer,
tokens: vec![],
done: Arc::clone(&generation.done),
start: None,
queued: Instant::now(),
});
// We are leaking the context on-purpose to avoid the box being dropped while there are
// still computation ongoing
// TODO(asap): Can we achieve the same with an Arc<Box<T>> without the need to go unsafe?
let ctx_ = Box::leak(ctx);
// Submit the request to the batcher
let request_id = span!(Level::DEBUG, "submit")
.in_scope(|| async {
let mut handle = executor.write().await;
let request_id = handle.pin_mut().submit(
&tokens,
top_k as i32,
top_p,
temperature,
repetition_penalty,
frequency_penalty,
seed,
);
request_id
})
.await;
while let Some(_) = generation.next().await {
let mut executor_w = executor.write().await;
let executor = executor_w.pin_mut();
span!(Level::DEBUG, "decode")
.in_scope(|| async {
unsafe {
executor.stream_tokens(
request_id,
ctx_,
|ctx: *mut GenerationContext, step: GenerationStep| {
let inner_ctx = &mut *ctx;
// Update the timestamp at which the request started effectively
// Can be a bit off, would need to be before the callback, let's see
inner_ctx.start.get_or_insert(Instant::now());
inner_ctx.done.store(step.is_final, Ordering::Relaxed);
// Ensure we are not running into errors
let parcel = if !step.has_error {
// Insert the latest generated token to the tracker
inner_ctx.tokens.push(step.token_id);
// Decode the token
let text = inner_ctx
.tokenizer
.decode(&[step.token_id], true)
.expect("Failed to decode token");
let special = inner_ctx
.tokenizer
.get_added_vocabulary()
.is_special_token(&text);
// Create the structure holding the token
let token = Token {
id: step.token_id,
text,
logprob: step.log_prob,
special,
};
if step.is_final {
let generated_text = inner_ctx
.tokenizer
.decode(&inner_ctx.tokens, true)
.expect("Failed to decode generated_tokens");
Ok(InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text: GeneratedText {
text: generated_text,
generated_tokens: inner_ctx.tokens.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken,
seed: None,
},
start: inner_ctx.start.unwrap_or(Instant::now()),
queued: inner_ctx.queued,
})
} else {
Ok(InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
})
}
} else {
error!("Error caught while decoding: {}", &step.error_msg);
Err(InferError::GenerationError(step.error_msg))
};
// Send the parcel to the client
inner_ctx
.sender
.send(parcel)
.expect("Failed to sent msg through the channel");
},
);
}
})
.await;
}
// "Properly" free the shared context...
// TODO: clean that piece of sh** asap
unsafe {
let _ = Box::from_raw(ctx_);
}
});
}
}
#[async_trait]
impl Backend for TensorRtLlmBackend {
#[instrument(skip_all)]
fn schedule(
&self,
request: ValidGenerateRequest,
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
// Let's add a few more validation
let input = TensorRtLlmBackend::validate(&request)?;
// Channel to stream the generated token as they come from the worker thread back to the transport layer
let (sender, receiver) = unbounded_channel();
// Unpack parameters
let params = &request.parameters;
// Preprocess the inputs to send to TRTLLM backend
let encoding = self
.tokenizer
.encode(input.as_str(), true)
.map_err(|e| InferError::GenerationError(e.to_string()))?;
// Generate the response
self.generate(
sender,
Vec::from(encoding.get_ids()),
params.top_k,
params.top_p,
params.temperature,
params.repetition_penalty,
params.frequency_penalty,
params.seed,
);
Ok(UnboundedReceiverStream::new(receiver))
}
async fn health(&self, _current_health: bool) -> bool {
true
}
}

View File

@ -1,9 +1,16 @@
use std::path::PathBuf;
use thiserror::Error; use thiserror::Error;
use text_generation_router::server; use text_generation_router::server;
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum TensorRtLlmBackendError { pub enum TensorRtLlmBackendError {
#[error("Provided engine folder {0} doesn't exist")]
EngineFolderDoesntExists(PathBuf),
#[error("Provided executorWorker binary path {0} doesn't exist")]
ExecutorWorkerNotFound(PathBuf),
#[error("TensorRT-LLM Runtime error: {0}")]
Runtime(String),
#[error("Tokenizer error: {0}")] #[error("Tokenizer error: {0}")]
Tokenizer(String), Tokenizer(String),
#[error("Argument validation error: {0}")] #[error("Argument validation error: {0}")]

View File

@ -3,11 +3,13 @@
// //
#pragma once #pragma once
#include <cmath> #include <algorithm>
#include <exception> #include <exception>
#include <filesystem> #include <filesystem>
#include <functional>
#include <limits> #include <limits>
#include <iterator> #include <iterator>
#include <ranges>
#include <vector> #include <vector>
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
@ -20,61 +22,59 @@ huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(
) : TensorRtLlmBackend(engineFolder, executorWorker) {} ) : TensorRtLlmBackend(engineFolder, executorWorker) {}
bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const {
return TensorRtLlmBackend::IsReady();
}
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit( uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, float_t repetition_penalty, rust::Slice<const uint32_t> tokens, uint32_t maxNewTokens,
float_t frequency_penalty, uint64_t seed) { int32_t topK, float_t topP, float_t temperature,
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed) {
// This will copy all the items from the initial slice // This will copy all the items from the initial slice
std::vector<int32_t> tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end())); std::vector<int32_t> tokens_(tokens.begin(), tokens.end());
return TensorRtLlmBackend::Submit( return TensorRtLlmBackend::Submit(
std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed); std::move(tokens_), maxNewTokens, topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
} }
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens( std::unique_ptr<std::vector<huggingface::tgi::backends::GenerationStep>>
const uint64_t requestId, huggingface::tgi::backends::TensorRtLlmBackendImpl::PullTokens() {
huggingface::tgi::backends::GenerationContext *ctx, const auto responses = TensorRtLlmBackend::PullNewTokens();
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
huggingface::tgi::backends::GenerationStep)> callback) {
size_t numTokens = 0; auto steps = std::make_unique<std::vector<GenerationStep>>();
for (const auto &item: Poll(requestId)) { steps->reserve(responses.size());
GenerationStep step;
if (!item.hasError()) {
SPDLOG_DEBUG("\tStreamTokens -> Decoding token...");
const auto decoded = item.getResult();
const auto token = decoded.outputTokenIds[0][0]; #ifndef NDEBUG
const auto isFinal = decoded.isFinal; SPDLOG_DEBUG(FMT_STRING("Pulled out {:d} new tokens"), responses->size());
const auto logProb = decoded.logProbs.value()[0][0]; #endif
++numTokens; // Transform tle::Response to GenerationStep
std::ranges::transform(responses.begin(), responses.end(), std::back_inserter(*steps), [](const tle::Response &r) {
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal); const auto reqId = r.getRequestId();
step = huggingface::tgi::backends::GenerationStep{ if (!r.hasError()) {
static_cast<uint32_t>(token), logProb, isFinal, false, std::move(std::string()) const auto result = r.getResult();
return GenerationStep{
reqId,
static_cast<uint32_t>(result.outputTokenIds[0][0]),
result.logProbs.value()[0][0],
result.isFinal,
false,
std::string()
}; };
SPDLOG_DEBUG("\tStreamTokens -> Post callback");
} else { } else {
// TODO : Return rest::Result with error return GenerationStep{
const auto what = item.getErrorMsg(); reqId,
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what); 0,
step = huggingface::tgi::backends::GenerationStep{ 0.0,
std::numeric_limits<uint32_t>::max(), 0.0, true, true, std::move(what) true,
true,
std::move(r.getErrorMsg())
}; };
} }
});
callback(std::move(ctx), std::move(step)); return steps;
}
return numTokens;
} }
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl> std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) { huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
SPDLOG_INFO("Creating TensorRT-LLM Backend");
// Unconditionally call this to initialize and discover TRTLLM plugins // Unconditionally call this to initialize and discover TRTLLM plugins
InitializeBackend(); InitializeBackend();

View File

@ -1,14 +1,16 @@
pub use backend::{GenerationContext, TensorRtLlmBackend}; pub use looper::TensorRtLlmBackendV2;
mod backend;
pub mod errors; pub mod errors;
mod looper;
mod utils;
#[cxx::bridge(namespace = "huggingface::tgi::backends")] #[cxx::bridge(namespace = "huggingface::tgi::backends")]
mod ffi { mod ffi {
/// Struct used as shared type between rust and C++ to represent the result /// Struct used as shared type between rust and C++ to represent the result
/// of a single decoding iteration /// of a single decoding iteration
#[derive(Debug, Clone)]
pub struct GenerationStep { pub struct GenerationStep {
request_id: u64,
token_id: u32, token_id: u32,
log_prob: f32, log_prob: f32,
is_final: bool, is_final: bool,
@ -16,10 +18,6 @@ mod ffi {
error_msg: String, error_msg: String,
} }
extern "Rust" {
type GenerationContext;
}
unsafe extern "C++" { unsafe extern "C++" {
include!("backends/trtllm/src/ffi.cpp"); include!("backends/trtllm/src/ffi.cpp");
@ -44,10 +42,7 @@ mod ffi {
fn CreateTensorRtLlmBackend( fn CreateTensorRtLlmBackend(
engine_folder: &str, engine_folder: &str,
executor_worker: &str, executor_worker: &str,
) -> UniquePtr<TensorRtLlmBackendImpl>; ) -> Result<UniquePtr<TensorRtLlmBackendImpl>>;
// #[rust_name = "is_ready"]
// fn IsReady(self: &TensorRtLlmBackendImpl) -> bool;
#[rust_name = "num_responses_ready"] #[rust_name = "num_responses_ready"]
fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize; fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
@ -56,23 +51,18 @@ mod ffi {
fn Submit( fn Submit(
self: Pin<&mut TensorRtLlmBackendImpl>, self: Pin<&mut TensorRtLlmBackendImpl>,
tokens: &[u32], tokens: &[u32],
max_new_tokens: u32,
top_k: i32, top_k: i32,
top_p: f32, top_p: f32,
temperature: f32, temperature: f32,
repetition_penalty: f32, repetition_penalty: f32,
frequency_penalty: f32, frequency_penalty: f32,
seed: u64, seed: u64,
) -> u64; ) -> Result<u64>;
#[rust_name = "stream_tokens"] #[rust_name = "pull_tokens"]
unsafe fn StreamTokens( fn PullTokens(
self: Pin<&mut TensorRtLlmBackendImpl>, self: Pin<&mut TensorRtLlmBackendImpl>,
request_id: u64, ) -> Result<UniquePtr<CxxVector<GenerationStep>>>;
ctx: *mut GenerationContext,
cb: unsafe fn(*mut GenerationContext, GenerationStep),
) -> usize;
// #[rust_name = "shutdown"]
// fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>);
} }
} }

View File

@ -0,0 +1,395 @@
use std::hint;
use std::ops::Deref;
use std::path::Path;
use async_trait::async_trait;
use cxx::UniquePtr;
use hashbrown::HashMap;
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::TryAcquireError;
use tokio::task::{spawn_blocking, JoinHandle};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, warn};
use text_generation_router::infer::InferError::{GenerationError, ValidationError};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidationError::{
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
};
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
use text_generation_router::{FinishReason, Token};
use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
use crate::utils::first_line;
type InferResult<T> = Result<T, InferError>;
struct IdentifiableRequest<T> {
request_id: u64,
inner: T,
}
/// Wrap the requests along with the channel used to stream back to the client the decoded tokens
struct GenerationContext {
request: ValidGenerateRequest,
start: Option<Instant>,
queued: Instant,
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
}
#[derive(Debug, Copy, Clone)]
struct DecodedToken {
id: u32,
log_prob: f32,
is_final: bool,
}
impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
type Error = InferError;
fn try_from(step: &'step GenerationStep) -> Result<Self, Self::Error> {
if !step.has_error {
Ok(Self {
id: step.token_id,
log_prob: step.log_prob,
is_final: step.is_final,
})
} else {
Err(GenerationError(step.error_msg.clone()))
}
}
}
/// Wraps the decoded token with the channel used to stream back to the client the decoded tokens
struct DecodedTokenContext {
token: DecodedToken,
start: Option<Instant>,
queued: Instant,
channel: UnboundedSender<InferResult<InferStreamResponse>>,
}
fn executor_status_looper(
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
max_inflight_requests: usize,
mut waiting_requests: UnboundedReceiver<GenerationContext>,
post_processor_sender: UnboundedSender<(u64, InferResult<DecodedTokenContext>)>,
) {
// Track the tuple (request_id, stream) for each request
let mut in_flights =
HashMap::<u64, GenerationContext>::with_capacity(max_inflight_requests * 2);
// TODO: Does it need a spin-loop?
'scheduler: loop {
// Is there any request pending to be scheduled?
let awaiting_requests = waiting_requests.len();
for _ in 0..awaiting_requests {
// Retrieve all the requests
if let Some(mut ctx) = waiting_requests.blocking_recv() {
// Submit all the request to the executor and move the context to the in-flight tracker
let request = &ctx.request;
let generation_params = &request.parameters;
let stopping_params = &request.stopping_parameters;
let input_ids = request.input_ids.as_deref();
// Submit to the TensorRT-LLM executor for scheduling
match backend.pin_mut().submit(
&input_ids.unwrap(), // This is checked beforehand in validate()
stopping_params.max_new_tokens,
generation_params.top_k as i32,
generation_params.top_p,
generation_params.temperature,
generation_params.repetition_penalty,
generation_params.frequency_penalty,
generation_params.seed,
) {
Ok(request_id) => {
// Insert the context linked to the generated request id in the tracker
debug!("[in-flight] Added {}", request_id);
ctx.start = Some(Instant::now());
in_flights.insert(request_id, ctx);
}
Err(e) => {
// Return to the caller
let what = e.to_string();
error!(error = what.as_str(), "Failed to schedule request");
let err = Err(InferError::Overloaded(TryAcquireError::NoPermits));
if let Err(_) = ctx.streamer.send(err) {
error!("Failed to send back error to the client");
}
}
};
}
}
if backend.num_responses_ready() > 0 {
match backend.pin_mut().pull_tokens() {
Ok(responses) => {
// Iterate through all the decoded token
for step in responses.deref() {
if let Some(ctx) = in_flights.get(&step.request_id) {
// Remove from tracked requests
let parcel =
DecodedToken::try_from(step).map(|dt| DecodedTokenContext {
token: dt,
start: ctx.start,
queued: ctx.queued,
channel: ctx.streamer.clone(),
});
// Submit the work to p:the post_processor
let posted = post_processor_sender.send((step.request_id, parcel));
if posted.is_err() || step.is_final {
debug!("Removing {}", step.request_id);
let _ = in_flights.remove(&step.request_id);
}
} else {
warn!("Untracked request {}", step.request_id,);
}
}
}
Err(ref err) => {
error!("Failed to get responses from the executor: {}.", err.what());
break 'scheduler;
}
}
}
// Hint the CPU we are spin-locking
hint::spin_loop();
}
}
fn post_processor_looper(
tokenizer: Tokenizer,
max_num_tokens: usize,
max_inflight_requests: usize,
mut decoded_tokens: UnboundedReceiver<(u64, InferResult<DecodedTokenContext>)>,
) {
let mut states: HashMap<u64, Vec<u32>> = HashMap::with_capacity(max_inflight_requests * 2);
'post_processor: loop {
if decoded_tokens.is_closed() {
warn!("Post processor IPC is closed, loop will exit now.");
break 'post_processor;
}
if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() {
match decoded {
Ok(ctx) => {
states
.entry(request_id)
.and_modify(|s| s.push(*&ctx.token.id))
.or_insert_with(|| {
let mut state = Vec::with_capacity(max_num_tokens);
state.push(*&ctx.token.id);
state
});
let out = match tokenizer.decode(&[ctx.token.id], false) {
Ok(text) => {
let is_special =
tokenizer.get_added_vocabulary().is_special_token(&text);
let token = Token {
id: ctx.token.id,
text,
logprob: ctx.token.log_prob,
special: is_special,
};
let out = if !ctx.token.is_final {
InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
}
} else {
let tokens = states.remove(&request_id).unwrap();
let text = tokenizer.decode(&tokens, true);
let generated_text = GeneratedText {
text: text.unwrap(),
generated_tokens: tokens.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken,
seed: None,
};
InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text,
start: ctx.start.unwrap(),
queued: ctx.queued,
}
};
Ok(out)
}
Err(err) => Err(GenerationError(err.to_string())),
};
if let Err(_) = ctx.channel.send(out) {
warn!("Failed to send decoded token back to the user")
}
}
Err(_err) => {
todo!("what do we do?")
}
}
}
}
}
fn ensure_paths_exist<P: AsRef<Path>, PP: AsRef<Path>>(
engine_folder: P,
executor_worker_path: PP,
) -> Result<(String, String), TensorRtLlmBackendError> {
// Retrieve paths as &str for the backend creation
let engine_folder = engine_folder.as_ref();
let executor_worker_path = executor_worker_path.as_ref();
// Ensure the engine folder exists
if !engine_folder.exists() {
let err = TensorRtLlmBackendError::EngineFolderDoesntExists(engine_folder.to_path_buf());
error!("Path validation failed: {}", err,);
return Err(err);
}
// Ensure executor worker binary exists
if !executor_worker_path.exists() {
let err = TensorRtLlmBackendError::ExecutorWorkerNotFound(engine_folder.to_path_buf());
error!("Path validation failed: {}", err,);
return Err(err);
}
let engine_folder = String::from(
engine_folder
.to_str()
.expect("Failed to convert engine_folder to valid UTF-8"),
);
let executor_worker_path = String::from(
executor_worker_path
.to_str()
.expect("Failed to convert executor_worker_path to valid UTF-8"),
);
Ok((engine_folder, executor_worker_path))
}
unsafe impl Send for TensorRtLlmBackendImpl {}
pub struct TensorRtLlmBackendV2 {
executor_looper: JoinHandle<()>,
post_processor_looper: JoinHandle<()>,
executor: UnboundedSender<GenerationContext>,
}
impl TensorRtLlmBackendV2 {
pub fn new<P: AsRef<Path> + Send, PP: AsRef<Path> + Send>(
tokenizer: Tokenizer,
engine_folder: P,
executor_worker_path: PP,
max_inflight_requests: usize,
) -> Result<Self, TensorRtLlmBackendError> {
let (engine_folder, executor_worker_path) =
ensure_paths_exist(engine_folder, executor_worker_path)?;
// Allocate the IPC layer to communicate with the backend
let (executor_sender, executor_receiver) = unbounded_channel();
let (post_processor_sender, post_processor_receiver) = unbounded_channel();
// Create the FFI backend
let backend = create_tensorrt_llm_backend(&engine_folder, &executor_worker_path)
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
// Executor looper is responsible for scheduling and pulling requests state at regular interval
let executor_looper = spawn_blocking(move || {
executor_status_looper(
backend,
max_inflight_requests,
executor_receiver,
post_processor_sender,
)
});
// Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user
let post_processor_looper = spawn_blocking(move || {
post_processor_looper(
tokenizer,
512,
max_inflight_requests,
post_processor_receiver,
)
});
Ok(TensorRtLlmBackendV2 {
executor_looper,
post_processor_looper,
executor: executor_sender,
})
}
fn validate(request: &ValidGenerateRequest) -> InferResult<()> {
if request.input_ids.is_none() {
return Err(ValidationError(UnsupportedModality("No token provided")));
}
if request.top_n_tokens > 1 {
return Err(ValidationError(TopNTokensDisabled));
}
// TODO: Is it really needed? How can it be validated before?
if request.parameters.grammar.is_some() {
return Err(ValidationError(Grammar));
}
match request.inputs.len() {
0 => Err(ValidationError(EmptyInput)),
2.. => Err(GenerationError(
"TensorRT-LLM backend don't support multi-chunk".into(),
)),
1 => match request.inputs.first().expect("Single item-chunk") {
Chunk::Text(_) => Ok(()),
Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
},
}
}
}
#[async_trait]
impl Backend for TensorRtLlmBackendV2 {
fn schedule(
&self,
inner: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
Self::validate(&inner)?;
// Open-up the stream to send tokens
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
// Send the context to the executor for scheduling
let queued = Instant::now();
match self.executor.send(GenerationContext {
request: inner,
start: None,
queued,
streamer,
}) {
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
Err(_) => Err(GenerationError(
"Failed to submit request to the backend".into(),
)),
}
}
async fn health(&self, current_health: bool) -> bool {
current_health
& !self.executor_looper.is_finished()
& !self.post_processor_looper.is_finished()
}
}

View File

@ -1,10 +1,16 @@
use std::path::{Path, PathBuf};
use clap::Parser; use clap::Parser;
use std::collections::HashMap; use hf_hub::api::tokio::{Api, ApiBuilder};
use std::path::PathBuf; use hf_hub::{Cache, Repo, RepoType};
use tokenizers::Tokenizer;
use tracing::info;
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
use text_generation_backends_trtllm::TensorRtLlmBackend; use text_generation_backends_trtllm::TensorRtLlmBackendV2;
use text_generation_router::server; use text_generation_router::server::get_base_tokenizer;
use tokenizers::{FromPretrainedParameters, Tokenizer}; use text_generation_router::usage_stats::UsageStatsLevel;
use text_generation_router::{server, HubTokenizerConfig};
/// App Configuration /// App Configuration
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -48,14 +54,138 @@ struct Args {
otlp_service_name: String, otlp_service_name: String,
#[clap(long, env)] #[clap(long, env)]
cors_allow_origin: Option<Vec<String>>, cors_allow_origin: Option<Vec<String>>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(default_value = "4", long, env)] #[clap(default_value = "4", long, env)]
max_client_batch_size: usize, max_client_batch_size: usize,
#[clap(long, env)] #[clap(long, env)]
auth_token: Option<String>, auth_token: Option<String>,
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")] #[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
executor_worker: PathBuf, executor_worker: PathBuf,
#[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel,
}
async fn get_tokenizer(
tokenizer_name: &str,
tokenizer_config_path: Option<&str>,
revision: Option<&str>,
) -> Option<Tokenizer> {
// Parse Huggingface hub token
let authorization_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();
// Tokenizer instance
let local_path = Path::new(tokenizer_name);
// Shared API builder initialization
let api_builder = || {
let mut builder = ApiBuilder::new()
.with_progress(false)
.with_token(authorization_token);
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
builder = builder.with_cache_dir(cache_dir.into());
}
builder
};
// Decide if we need to use the API based on the revision and local path
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
// Initialize API if needed
#[derive(Clone)]
enum Type {
Api(Api),
Cache(Cache),
None,
}
let api = if use_api {
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
.map_err(|_| ())
.map(|cache_dir| Cache::new(cache_dir.into()))
.unwrap_or_else(|_| Cache::default());
tracing::warn!("Offline mode active using cache defaults");
Type::Cache(cache)
} else {
tracing::info!("Using the Hugging Face API");
match api_builder().build() {
Ok(api) => Type::Api(api),
Err(_) => {
tracing::warn!("Unable to build the Hugging Face API");
Type::None
}
}
}
} else {
Type::None
};
// Load tokenizer and model info
let (
tokenizer_filename,
_config_filename,
tokenizer_config_filename,
_preprocessor_config_filename,
_processor_config_filename,
) = match api {
Type::None => (
Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")),
),
Type::Api(api) => {
let api_repo = api.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.unwrap_or_else(|| "main").to_string(),
));
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
Ok(tokenizer_filename) => Some(tokenizer_filename),
Err(_) => get_base_tokenizer(&api, &api_repo).await,
};
let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
(
tokenizer_filename,
config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
)
}
Type::Cache(cache) => {
let repo = cache.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main").to_string(),
));
(
repo.get("tokenizer.json"),
repo.get("config.json"),
repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"),
repo.get("processor_config.json"),
)
}
};
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{
HubTokenizerConfig::from_file(filename)
} else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
};
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok())
} }
#[tokio::main] #[tokio::main]
@ -83,10 +213,10 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
otlp_endpoint, otlp_endpoint,
otlp_service_name, otlp_service_name,
cors_allow_origin, cors_allow_origin,
messages_api_enabled,
max_client_batch_size, max_client_batch_size,
auth_token, auth_token,
executor_worker, executor_worker,
usage_stats,
} = args; } = args;
// Launch Tokio runtime // Launch Tokio runtime
@ -124,18 +254,26 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
))); )));
} }
// Run server // Create the backend
let tokenizer = Tokenizer::from_pretrained( let tokenizer = get_tokenizer(
tokenizer_name.clone(), &tokenizer_name,
Some(FromPretrainedParameters { tokenizer_config_path.as_deref(),
revision: revision.clone().unwrap_or(String::from("main")), revision.as_deref(),
user_agent: HashMap::new(),
auth_token,
}),
) )
.map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?; .await
.expect("Failed to retrieve tokenizer implementation");
let backend = TensorRtLlmBackend::new(tokenizer, model_id, executor_worker)?; info!("Successfully retrieved tokenizer {}", &tokenizer_name);
let backend = TensorRtLlmBackendV2::new(
tokenizer,
model_id,
executor_worker,
max_concurrent_requests,
)?;
info!("Successfully created backend");
// Run server
server::run( server::run(
backend, backend,
max_concurrent_requests, max_concurrent_requests,
@ -145,7 +283,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
max_input_tokens, max_input_tokens,
max_total_tokens, max_total_tokens,
validation_workers, validation_workers,
None, auth_token,
tokenizer_name, tokenizer_name,
tokenizer_config_path, tokenizer_config_path,
revision, revision,
@ -155,11 +293,9 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
false, false,
None, None,
None, None,
messages_api_enabled,
true, true,
max_client_batch_size, max_client_batch_size,
false, usage_stats,
false,
) )
.await?; .await?;
Ok(()) Ok(())

View File

@ -0,0 +1,22 @@
///
/// Extract the first line of the provided string reference.
/// If there is no lines in the buffer, it returns a string
/// which content is defined by the content of `fail`
/// # Arguments
///
/// * `s`: The string buffer to extract the first-line from
/// * `fail`: A string content which is returned if no lines are
/// present in `s`
///
/// returns: String
///
/// # Examples
///
/// ```
/// let s = "My name is Morgan.\n I'm working at Hugging Face.";
/// first_line(s, "No line in string");
/// ```
#[inline]
pub(crate) fn first_line(s: &str, fail: &str) -> String {
s.lines().next().unwrap_or(fail).to_string()
}

View File

@ -44,6 +44,8 @@ struct Args {
tokenizer_config_path: Option<String>, tokenizer_config_path: Option<String>,
#[clap(long, env)] #[clap(long, env)]
revision: Option<String>, revision: Option<String>,
#[clap(long, env, value_enum)]
trust_remote_code: bool,
#[clap(default_value = "2", long, env)] #[clap(default_value = "2", long, env)]
validation_workers: usize, validation_workers: usize,
#[clap(long, env)] #[clap(long, env)]
@ -63,8 +65,6 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
ngrok_edge: Option<String>, ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)] #[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool, disable_grammar_support: bool,
#[clap(default_value = "4", long, env)] #[clap(default_value = "4", long, env)]
max_client_batch_size: usize, max_client_batch_size: usize,
@ -101,6 +101,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_name, tokenizer_name,
tokenizer_config_path, tokenizer_config_path,
revision, revision,
trust_remote_code,
validation_workers, validation_workers,
api_key, api_key,
json_output, json_output,
@ -110,7 +111,6 @@ async fn main() -> Result<(), RouterError> {
ngrok, ngrok,
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
usage_stats, usage_stats,
@ -184,13 +184,13 @@ async fn main() -> Result<(), RouterError> {
tokenizer_name, tokenizer_name,
tokenizer_config_path, tokenizer_config_path,
revision, revision,
trust_remote_code,
hostname, hostname,
port, port,
cors_allow_origin, cors_allow_origin,
ngrok, ngrok,
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
usage_stats, usage_stats,

View File

@ -44,6 +44,8 @@ struct Args {
tokenizer_config_path: Option<String>, tokenizer_config_path: Option<String>,
#[clap(long, env)] #[clap(long, env)]
revision: Option<String>, revision: Option<String>,
#[clap(long, env, value_enum)]
trust_remote_code: bool,
#[clap(default_value = "2", long, env)] #[clap(default_value = "2", long, env)]
validation_workers: usize, validation_workers: usize,
#[clap(long, env)] #[clap(long, env)]
@ -63,8 +65,6 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
ngrok_edge: Option<String>, ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)] #[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool, disable_grammar_support: bool,
#[clap(default_value = "4", long, env)] #[clap(default_value = "4", long, env)]
max_client_batch_size: usize, max_client_batch_size: usize,
@ -101,6 +101,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_name, tokenizer_name,
tokenizer_config_path, tokenizer_config_path,
revision, revision,
trust_remote_code,
validation_workers, validation_workers,
api_key, api_key,
json_output, json_output,
@ -110,7 +111,6 @@ async fn main() -> Result<(), RouterError> {
ngrok, ngrok,
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
usage_stats, usage_stats,
@ -200,13 +200,13 @@ async fn main() -> Result<(), RouterError> {
tokenizer_name, tokenizer_name,
tokenizer_config_path, tokenizer_config_path,
revision, revision,
trust_remote_code,
hostname, hostname,
port, port,
cors_allow_origin, cors_allow_origin,
ngrok, ngrok,
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
usage_stats, usage_stats,

View File

@ -316,6 +316,98 @@
} }
} }
}, },
"/invocations": {
"post": {
"tags": [
"Text Generation Inference"
],
"summary": "Generate tokens from Sagemaker request",
"operationId": "sagemaker_compatibility",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SagemakerRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Generated Chat Completion",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SagemakerResponse"
}
},
"text/event-stream": {
"schema": {
"$ref": "#/components/schemas/SagemakerStreamResponse"
}
}
}
},
"422": {
"description": "Input validation error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Input validation error",
"error_type": "validation"
}
}
}
},
"424": {
"description": "Generation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Request failed during generation",
"error_type": "generation"
}
}
}
},
"429": {
"description": "Model is overloaded",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Model is overloaded",
"error_type": "overloaded"
}
}
}
},
"500": {
"description": "Incomplete generation",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Incomplete generation",
"error_type": "incomplete_generation"
}
}
}
}
}
}
},
"/metrics": { "/metrics": {
"get": { "get": {
"tags": [ "tags": [
@ -1865,6 +1957,45 @@
"type": "string" "type": "string"
} }
}, },
"SagemakerRequest": {
"oneOf": [
{
"$ref": "#/components/schemas/CompatGenerateRequest"
},
{
"$ref": "#/components/schemas/ChatRequest"
},
{
"$ref": "#/components/schemas/CompletionRequest"
}
]
},
"SagemakerResponse": {
"oneOf": [
{
"$ref": "#/components/schemas/GenerateResponse"
},
{
"$ref": "#/components/schemas/ChatCompletion"
},
{
"$ref": "#/components/schemas/CompletionFinal"
}
]
},
"SagemakerStreamResponse": {
"oneOf": [
{
"$ref": "#/components/schemas/StreamResponse"
},
{
"$ref": "#/components/schemas/ChatCompletionChunk"
},
{
"$ref": "#/components/schemas/Chunk"
}
]
},
"SimpleToken": { "SimpleToken": {
"type": "object", "type": "object",
"required": [ "required": [

View File

@ -141,9 +141,7 @@ TGI can be deployed on various cloud providers for scalable and robust text gene
## Amazon SageMaker ## Amazon SageMaker
To enable the Messages API in Amazon SageMaker you need to set the environment variable `MESSAGES_API_ENABLED=true`. Amazon Sagemaker natively supports the message API:
This will modify the `/invocations` route to accept Messages dictonaries consisting out of role and content. See the example below on how to deploy Llama with the new Messages API.
```python ```python
import json import json
@ -161,12 +159,11 @@ except ValueError:
hub = { hub = {
'HF_MODEL_ID':'HuggingFaceH4/zephyr-7b-beta', 'HF_MODEL_ID':'HuggingFaceH4/zephyr-7b-beta',
'SM_NUM_GPUS': json.dumps(1), 'SM_NUM_GPUS': json.dumps(1),
'MESSAGES_API_ENABLED': True
} }
# create Hugging Face Model Class # create Hugging Face Model Class
huggingface_model = HuggingFaceModel( huggingface_model = HuggingFaceModel(
image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.0"), image_uri=get_huggingface_llm_image_uri("huggingface",version="2.3.2"),
env=hub, env=hub,
role=role, role=role,
) )

View File

@ -8,6 +8,7 @@ Text Generation Inference enables serving optimized models. The following sectio
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal) - [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f) - [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) - [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
- [Granite](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct)
- [Gemma](https://huggingface.co/google/gemma-7b) - [Gemma](https://huggingface.co/google/gemma-7b)
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224) - [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315) - [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)

View File

@ -26,7 +26,6 @@ As of release 2.1.2 this is an example of the data collected:
"max_top_n_tokens": 5, "max_top_n_tokens": 5,
"max_total_tokens": 2048, "max_total_tokens": 2048,
"max_waiting_tokens": 20, "max_waiting_tokens": 20,
"messages_api_enabled": false,
"model_config": { "model_config": {
"model_type": "Bloom" "model_type": "Bloom"
}, },

View File

@ -978,15 +978,16 @@
"nixpkgs": "nixpkgs_6" "nixpkgs": "nixpkgs_6"
}, },
"locked": { "locked": {
"lastModified": 1728381423, "lastModified": 1729531056,
"narHash": "sha256-gpHy1WtlA8ZTd8XmxsdCoDd4Z7DE7co37lH7P+nsADA=", "narHash": "sha256-dW9IOA31+j3VS19WAWAmkJW2YCzeVZGqd6HpIJfODtI=",
"owner": "huggingface", "owner": "huggingface",
"repo": "text-generation-inference-nix", "repo": "text-generation-inference-nix",
"rev": "93123736c97e9f7bfe825bfaf3d7de0fc9a21a1e", "rev": "a84a90281a17b15762873845c947e5c78f5a8dd1",
"type": "github" "type": "github"
}, },
"original": { "original": {
"owner": "huggingface", "owner": "huggingface",
"ref": "marlin-kernels-0.3.0",
"repo": "text-generation-inference-nix", "repo": "text-generation-inference-nix",
"type": "github" "type": "github"
} }

View File

@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
}; };
nix-filter.url = "github:numtide/nix-filter"; nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:huggingface/text-generation-inference-nix"; tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.0";
nixpkgs.follows = "tgi-nix/nixpkgs"; nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils"; flake-utils.url = "github:numtide/flake-utils";
rust-overlay = { rust-overlay = {
@ -137,6 +137,11 @@
impure = callPackage ./nix/impure-shell.nix { inherit server; }; impure = callPackage ./nix/impure-shell.nix { inherit server; };
impureWithCuda = callPackage ./nix/impure-shell.nix {
inherit server;
withCuda = true;
};
impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix { impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix {
server = server.override { flash-attn = python3.pkgs.flash-attn-v1; }; server = server.override { flash-attn = python3.pkgs.flash-attn-v1; };
}; };

View File

@ -11,27 +11,27 @@
}, },
{ {
"id": 3923, "id": 3923,
"logprob": -5.6328125, "logprob": -6.1875,
"text": "What" "text": "What"
}, },
{ {
"id": 374, "id": 374,
"logprob": -1.2265625, "logprob": -0.93359375,
"text": " is" "text": " is"
}, },
{ {
"id": 5655, "id": 5655,
"logprob": -9.1015625, "logprob": -9.875,
"text": " deep" "text": " deep"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -1.8085938, "logprob": -1.1796875,
"text": " learning" "text": " learning"
}, },
{ {
"id": 30, "id": 30,
"logprob": -1.0439453, "logprob": -1.75,
"text": "?" "text": "?"
} }
], ],
@ -39,66 +39,66 @@
"tokens": [ "tokens": [
{ {
"id": 18682, "id": 18682,
"logprob": -2.1992188, "logprob": -1.109375,
"special": false, "special": false,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -0.079956055, "logprob": -0.005432129,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 374, "id": 374,
"logprob": -0.2763672, "logprob": -0.028808594,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.37548828, "logprob": -0.013671875,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 27084, "id": 27084,
"logprob": -1.4628906, "logprob": -0.69921875,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },
{ {
"id": 315, "id": 315,
"logprob": -0.02885437, "logprob": -0.0005874634,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5780, "id": 5780,
"logprob": -0.2565918, "logprob": -0.026855469,
"special": false, "special": false,
"text": " machine" "text": " machine"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -0.0063438416, "logprob": -0.00020885468,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 430, "id": 430,
"logprob": -1.3056641, "logprob": -0.17773438,
"special": false, "special": false,
"text": " that" "text": " that"
}, },
{ {
"id": 374, "id": 18065,
"logprob": -1.6035156, "logprob": -0.703125,
"special": false, "special": false,
"text": " is" "text": " involves"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": " Deep learning is a subset of machine learning that is" "generated_text": " Deep learning is a subset of machine learning that involves"
} }

View File

@ -1,8 +1,8 @@
{ {
"details": { "details": {
"best_of_sequences": null, "best_of_sequences": null,
"finish_reason": "eos_token", "finish_reason": "length",
"generated_tokens": 3, "generated_tokens": 10,
"prefill": [ "prefill": [
{ {
"id": 128000, "id": 128000,
@ -11,22 +11,22 @@
}, },
{ {
"id": 374, "id": 374,
"logprob": -22.96875, "logprob": -18.0,
"text": " is" "text": " is"
}, },
{ {
"id": 5655, "id": 5655,
"logprob": -10.71875, "logprob": -11.75,
"text": " deep" "text": " deep"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -2.6992188, "logprob": -2.0625,
"text": " learning" "text": " learning"
}, },
{ {
"id": 30, "id": 30,
"logprob": -4.8398438, "logprob": -6.0,
"text": "?" "text": "?"
} }
], ],
@ -34,24 +34,66 @@
"tokens": [ "tokens": [
{ {
"id": 720, "id": 720,
"logprob": -0.4411621, "logprob": 0.0,
"special": false, "special": false,
"text": " \n" "text": " \n"
}, },
{ {
"id": 220, "id": 34564,
"logprob": -0.35864258, "logprob": -0.11279297,
"special": false, "special": false,
"text": " " "text": "Deep"
}, },
{ {
"id": 128001, "id": 6975,
"logprob": -0.16015625,
"special": false,
"text": " learning"
},
{
"id": 320,
"logprob": -0.25195312,
"special": false,
"text": " ("
},
{
"id": 16931,
"logprob": -1.703125,
"special": false,
"text": "DL"
},
{
"id": 8,
"logprob": 0.0, "logprob": 0.0,
"special": true, "special": false,
"text": "<|end_of_text|>" "text": ")"
},
{
"id": 374,
"logprob": -1.140625,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 1207,
"logprob": -1.3125,
"special": false,
"text": " sub"
},
{
"id": 2630,
"logprob": 0.0,
"special": false,
"text": "field"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "What is deep learning? \n " "generated_text": "What is deep learning? \nDeep learning (DL) is a subfield"
} }

View File

@ -12,27 +12,27 @@
}, },
{ {
"id": 3923, "id": 3923,
"logprob": -5.6328125, "logprob": -6.1875,
"text": "What" "text": "What"
}, },
{ {
"id": 374, "id": 374,
"logprob": -1.2265625, "logprob": -0.93359375,
"text": " is" "text": " is"
}, },
{ {
"id": 5655, "id": 5655,
"logprob": -9.1015625, "logprob": -9.875,
"text": " deep" "text": " deep"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -1.8085938, "logprob": -1.1796875,
"text": " learning" "text": " learning"
}, },
{ {
"id": 30, "id": 30,
"logprob": -1.0439453, "logprob": -1.75,
"text": "?" "text": "?"
} }
], ],
@ -40,68 +40,68 @@
"tokens": [ "tokens": [
{ {
"id": 18682, "id": 18682,
"logprob": -2.1992188, "logprob": -1.109375,
"special": false, "special": false,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -0.07897949, "logprob": -0.0047912598,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 374, "id": 374,
"logprob": -0.27734375, "logprob": -0.025512695,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.37402344, "logprob": -0.012145996,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 27084, "id": 27084,
"logprob": -1.4511719, "logprob": -0.72265625,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },
{ {
"id": 315, "id": 315,
"logprob": -0.02909851, "logprob": -0.0005760193,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5780, "id": 5780,
"logprob": -0.25854492, "logprob": -0.02722168,
"special": false, "special": false,
"text": " machine" "text": " machine"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -0.0061798096, "logprob": -0.00023651123,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 430, "id": 430,
"logprob": -1.3046875, "logprob": -0.17285156,
"special": false, "special": false,
"text": " that" "text": " that"
}, },
{ {
"id": 374, "id": 18065,
"logprob": -1.5537109, "logprob": -0.703125,
"special": false, "special": false,
"text": " is" "text": " involves"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": " Deep learning is a subset of machine learning that is" "generated_text": " Deep learning is a subset of machine learning that involves"
}, },
{ {
"details": { "details": {
@ -116,27 +116,27 @@
}, },
{ {
"id": 3923, "id": 3923,
"logprob": -5.6328125, "logprob": -6.21875,
"text": "What" "text": "What"
}, },
{ {
"id": 374, "id": 374,
"logprob": -1.2265625, "logprob": -0.95703125,
"text": " is" "text": " is"
}, },
{ {
"id": 5655, "id": 5655,
"logprob": -9.1015625, "logprob": -9.9375,
"text": " deep" "text": " deep"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -1.8085938, "logprob": -1.1328125,
"text": " learning" "text": " learning"
}, },
{ {
"id": 30, "id": 30,
"logprob": -1.0439453, "logprob": -1.75,
"text": "?" "text": "?"
} }
], ],
@ -144,68 +144,68 @@
"tokens": [ "tokens": [
{ {
"id": 18682, "id": 18682,
"logprob": -2.1992188, "logprob": -1.1796875,
"special": false, "special": false,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -0.07897949, "logprob": -0.005432129,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 374, "id": 374,
"logprob": -0.27734375, "logprob": -0.02758789,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.37402344, "logprob": -0.013366699,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 27084, "id": 27084,
"logprob": -1.4511719, "logprob": -0.6953125,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },
{ {
"id": 315, "id": 315,
"logprob": -0.02909851, "logprob": -0.0004863739,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5780, "id": 5780,
"logprob": -0.25854492, "logprob": -0.02709961,
"special": false, "special": false,
"text": " machine" "text": " machine"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -0.0061798096, "logprob": -0.00022506714,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 430, "id": 430,
"logprob": -1.3046875, "logprob": -0.19726562,
"special": false, "special": false,
"text": " that" "text": " that"
}, },
{ {
"id": 374, "id": 18065,
"logprob": -1.5537109, "logprob": -0.77734375,
"special": false, "special": false,
"text": " is" "text": " involves"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": " Deep learning is a subset of machine learning that is" "generated_text": " Deep learning is a subset of machine learning that involves"
}, },
{ {
"details": { "details": {
@ -220,27 +220,27 @@
}, },
{ {
"id": 3923, "id": 3923,
"logprob": -5.6328125, "logprob": -6.21875,
"text": "What" "text": "What"
}, },
{ {
"id": 374, "id": 374,
"logprob": -1.2265625, "logprob": -0.95703125,
"text": " is" "text": " is"
}, },
{ {
"id": 5655, "id": 5655,
"logprob": -9.1015625, "logprob": -9.9375,
"text": " deep" "text": " deep"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -1.8085938, "logprob": -1.1328125,
"text": " learning" "text": " learning"
}, },
{ {
"id": 30, "id": 30,
"logprob": -1.0439453, "logprob": -1.75,
"text": "?" "text": "?"
} }
], ],
@ -248,68 +248,68 @@
"tokens": [ "tokens": [
{ {
"id": 18682, "id": 18682,
"logprob": -2.1992188, "logprob": -1.1796875,
"special": false, "special": false,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -0.07897949, "logprob": -0.005432129,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 374, "id": 374,
"logprob": -0.27734375, "logprob": -0.02758789,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.37402344, "logprob": -0.013366699,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 27084, "id": 27084,
"logprob": -1.4511719, "logprob": -0.6953125,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },
{ {
"id": 315, "id": 315,
"logprob": -0.02909851, "logprob": -0.0004863739,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5780, "id": 5780,
"logprob": -0.25854492, "logprob": -0.02709961,
"special": false, "special": false,
"text": " machine" "text": " machine"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -0.0061798096, "logprob": -0.00022506714,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 430, "id": 430,
"logprob": -1.3046875, "logprob": -0.19726562,
"special": false, "special": false,
"text": " that" "text": " that"
}, },
{ {
"id": 374, "id": 18065,
"logprob": -1.5537109, "logprob": -0.77734375,
"special": false, "special": false,
"text": " is" "text": " involves"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": " Deep learning is a subset of machine learning that is" "generated_text": " Deep learning is a subset of machine learning that involves"
}, },
{ {
"details": { "details": {
@ -324,27 +324,27 @@
}, },
{ {
"id": 3923, "id": 3923,
"logprob": -5.6328125, "logprob": -6.21875,
"text": "What" "text": "What"
}, },
{ {
"id": 374, "id": 374,
"logprob": -1.2265625, "logprob": -0.95703125,
"text": " is" "text": " is"
}, },
{ {
"id": 5655, "id": 5655,
"logprob": -9.1015625, "logprob": -9.9375,
"text": " deep" "text": " deep"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -1.8085938, "logprob": -1.1328125,
"text": " learning" "text": " learning"
}, },
{ {
"id": 30, "id": 30,
"logprob": -1.0439453, "logprob": -1.75,
"text": "?" "text": "?"
} }
], ],
@ -352,67 +352,67 @@
"tokens": [ "tokens": [
{ {
"id": 18682, "id": 18682,
"logprob": -2.1992188, "logprob": -1.1796875,
"special": false, "special": false,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -0.07897949, "logprob": -0.005432129,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 374, "id": 374,
"logprob": -0.27734375, "logprob": -0.02758789,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.37402344, "logprob": -0.013366699,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 27084, "id": 27084,
"logprob": -1.4511719, "logprob": -0.6953125,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },
{ {
"id": 315, "id": 315,
"logprob": -0.02909851, "logprob": -0.0004863739,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5780, "id": 5780,
"logprob": -0.25854492, "logprob": -0.02709961,
"special": false, "special": false,
"text": " machine" "text": " machine"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -0.0061798096, "logprob": -0.00022506714,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 430, "id": 430,
"logprob": -1.3046875, "logprob": -0.19726562,
"special": false, "special": false,
"text": " that" "text": " that"
}, },
{ {
"id": 374, "id": 18065,
"logprob": -1.5537109, "logprob": -0.77734375,
"special": false, "special": false,
"text": " is" "text": " involves"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": " Deep learning is a subset of machine learning that is" "generated_text": " Deep learning is a subset of machine learning that involves"
} }
] ]

View File

@ -11,32 +11,32 @@
}, },
{ {
"id": 338, "id": 338,
"logprob": -0.7133789, "logprob": -0.6201172,
"text": "is" "text": "is"
}, },
{ {
"id": 16030, "id": 16030,
"logprob": -13.9296875, "logprob": -13.6484375,
"text": "gradient" "text": "gradient"
}, },
{ {
"id": 26815, "id": 26815,
"logprob": -0.048919678, "logprob": -0.003894806,
"text": "descent" "text": "descent"
}, },
{ {
"id": 29973, "id": 29973,
"logprob": -3.0078125, "logprob": -2.6386719,
"text": "?" "text": "?"
}, },
{ {
"id": 13, "id": 13,
"logprob": -2.8105469, "logprob": -6.46875,
"text": "\n" "text": "\n"
}, },
{ {
"id": 13, "id": 13,
"logprob": -0.84521484, "logprob": -6.6875,
"text": "\n" "text": "\n"
} }
], ],
@ -44,66 +44,66 @@
"tokens": [ "tokens": [
{ {
"id": 25584, "id": 25584,
"logprob": -0.017028809, "logprob": -0.008979797,
"special": false, "special": false,
"text": "Grad" "text": "Grad"
}, },
{ {
"id": 993, "id": 993,
"logprob": -0.0027313232, "logprob": -8.34465e-07,
"special": false, "special": false,
"text": "ient" "text": "ient"
}, },
{ {
"id": 26815, "id": 26815,
"logprob": -0.023254395, "logprob": -0.0009407997,
"special": false, "special": false,
"text": " descent" "text": " descent"
}, },
{ {
"id": 338, "id": 338,
"logprob": -2.0623207e-05, "logprob": -0.0003838539,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 263, "id": 385,
"logprob": -0.5361328, "logprob": -0.24499512,
"special": false, "special": false,
"text": " a" "text": " an"
},
{
"id": 937,
"logprob": -0.17578125,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.00011539459,
"special": false,
"text": "order"
}, },
{ {
"id": 13883, "id": 13883,
"logprob": -0.47436523, "logprob": -0.010406494,
"special": false, "special": false,
"text": " optimization" "text": " optimization"
}, },
{ {
"id": 5687, "id": 5687,
"logprob": -0.00027680397, "logprob": -0.00024354458,
"special": false, "special": false,
"text": " algorithm" "text": " algorithm"
},
{
"id": 15574,
"logprob": -0.6582031,
"special": false,
"text": " commonly"
},
{
"id": 1304,
"logprob": -0.00092840195,
"special": false,
"text": " used"
},
{
"id": 297,
"logprob": -0.19470215,
"special": false,
"text": " in"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "Gradient descent is a first-order optimization algorithm" "generated_text": "Gradient descent is an optimization algorithm commonly used in"
} }

View File

@ -5,95 +5,95 @@
"generated_tokens": 10, "generated_tokens": 10,
"prefill": [ "prefill": [
{ {
"id": 16030, "id": 338,
"logprob": null, "logprob": null,
"text": "is"
},
{
"id": 16030,
"logprob": -13.328125,
"text": "gradient" "text": "gradient"
}, },
{ {
"id": 26815, "id": 26815,
"logprob": -6.4960938, "logprob": -0.24023438,
"text": "descent" "text": "descent"
}, },
{ {
"id": 29973, "id": 29973,
"logprob": -5.1484375, "logprob": -3.1386719,
"text": "?" "text": "?"
}, },
{ {
"id": 13, "id": 13,
"logprob": -4.0351562, "logprob": -3.0878906,
"text": "\n"
},
{
"id": 13,
"logprob": -5.2265625,
"text": "\n" "text": "\n"
} }
], ],
"seed": 0, "seed": 0,
"tokens": [ "tokens": [
{ {
"id": 10994, "id": 25584,
"logprob": -1.1542969,
"special": false,
"text": "Hello"
},
{
"id": 29991,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": "!" "text": "Grad"
}, },
{ {
"id": 739, "id": 993,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": " It" "text": "ient"
}, },
{ {
"id": 2444, "id": 2726,
"logprob": -0.42260742,
"special": false,
"text": " seems"
},
{
"id": 366,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": " you" "text": " Des"
}, },
{ {
"id": 29915, "id": 1760,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": "'" "text": "cent"
}, },
{ {
"id": 276, "id": 313,
"logprob": -0.9838867, "logprob": -0.12322998,
"special": false, "special": false,
"text": "re" "text": " ("
}, },
{ {
"id": 3211, "id": 29954,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": " address" "text": "G"
}, },
{ {
"id": 292, "id": 29928,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": "ing" "text": "D"
}, },
{ {
"id": 263, "id": 29897,
"logprob": -0.15124512, "logprob": 0.0,
"special": false, "special": false,
"text": " a" "text": ")"
},
{
"id": 338,
"logprob": -0.6040039,
"special": false,
"text": " is"
},
{
"id": 385,
"logprob": -0.1796875,
"special": false,
"text": " an"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "What is gradient descent?\n\nHello! It seems you're addressing a" "generated_text": "What is gradient descent?\nGradient Descent (GD) is an"
} }

View File

@ -12,32 +12,32 @@
}, },
{ {
"id": 338, "id": 338,
"logprob": -0.7133789, "logprob": -0.6201172,
"text": "is" "text": "is"
}, },
{ {
"id": 16030, "id": 16030,
"logprob": -13.9296875, "logprob": -13.6484375,
"text": "gradient" "text": "gradient"
}, },
{ {
"id": 26815, "id": 26815,
"logprob": -0.048919678, "logprob": -0.003894806,
"text": "descent" "text": "descent"
}, },
{ {
"id": 29973, "id": 29973,
"logprob": -3.0078125, "logprob": -2.6386719,
"text": "?" "text": "?"
}, },
{ {
"id": 13, "id": 13,
"logprob": -2.8105469, "logprob": -6.46875,
"text": "\n" "text": "\n"
}, },
{ {
"id": 13, "id": 13,
"logprob": -0.84521484, "logprob": -6.6875,
"text": "\n" "text": "\n"
} }
], ],
@ -45,68 +45,68 @@
"tokens": [ "tokens": [
{ {
"id": 25584, "id": 25584,
"logprob": -0.017028809, "logprob": -0.008979797,
"special": false, "special": false,
"text": "Grad" "text": "Grad"
}, },
{ {
"id": 993, "id": 993,
"logprob": -0.0028476715, "logprob": -8.34465e-07,
"special": false, "special": false,
"text": "ient" "text": "ient"
}, },
{ {
"id": 26815, "id": 26815,
"logprob": -0.023971558, "logprob": -0.00097084045,
"special": false, "special": false,
"text": " descent" "text": " descent"
}, },
{ {
"id": 338, "id": 338,
"logprob": -2.0384789e-05, "logprob": -0.0003838539,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 263, "id": 385,
"logprob": -0.5229492, "logprob": -0.23840332,
"special": false, "special": false,
"text": " a" "text": " an"
},
{
"id": 937,
"logprob": -0.17602539,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.000116467476,
"special": false,
"text": "order"
}, },
{ {
"id": 13883, "id": 13883,
"logprob": -0.47436523, "logprob": -0.010406494,
"special": false, "special": false,
"text": " optimization" "text": " optimization"
}, },
{ {
"id": 5687, "id": 5687,
"logprob": -0.00027871132, "logprob": -0.0002501011,
"special": false, "special": false,
"text": " algorithm" "text": " algorithm"
},
{
"id": 15574,
"logprob": -0.6582031,
"special": false,
"text": " commonly"
},
{
"id": 1304,
"logprob": -0.00092840195,
"special": false,
"text": " used"
},
{
"id": 297,
"logprob": -0.18933105,
"special": false,
"text": " in"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "Gradient descent is a first-order optimization algorithm" "generated_text": "Gradient descent is an optimization algorithm commonly used in"
}, },
{ {
"details": { "details": {
@ -121,32 +121,32 @@
}, },
{ {
"id": 338, "id": 338,
"logprob": -0.7128906, "logprob": -0.6113281,
"text": "is" "text": "is"
}, },
{ {
"id": 16030, "id": 16030,
"logprob": -13.9375, "logprob": -13.6640625,
"text": "gradient" "text": "gradient"
}, },
{ {
"id": 26815, "id": 26815,
"logprob": -0.05053711, "logprob": -0.003929138,
"text": "descent" "text": "descent"
}, },
{ {
"id": 29973, "id": 29973,
"logprob": -3.0058594, "logprob": -2.625,
"text": "?" "text": "?"
}, },
{ {
"id": 13, "id": 13,
"logprob": -2.8242188, "logprob": -6.484375,
"text": "\n" "text": "\n"
}, },
{ {
"id": 13, "id": 13,
"logprob": -0.84521484, "logprob": -6.6875,
"text": "\n" "text": "\n"
} }
], ],
@ -154,68 +154,68 @@
"tokens": [ "tokens": [
{ {
"id": 25584, "id": 25584,
"logprob": -0.018859863, "logprob": -0.009017944,
"special": false, "special": false,
"text": "Grad" "text": "Grad"
}, },
{ {
"id": 993, "id": 993,
"logprob": -0.002822876, "logprob": -9.536743e-07,
"special": false, "special": false,
"text": "ient" "text": "ient"
}, },
{ {
"id": 26815, "id": 26815,
"logprob": -0.023254395, "logprob": -0.00097084045,
"special": false, "special": false,
"text": " descent" "text": " descent"
}, },
{ {
"id": 338, "id": 338,
"logprob": -2.0384789e-05, "logprob": -0.0003838539,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 263, "id": 385,
"logprob": -0.5229492, "logprob": -0.24499512,
"special": false, "special": false,
"text": " a" "text": " an"
},
{
"id": 937,
"logprob": -0.17126465,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.0001155138,
"special": false,
"text": "order"
}, },
{ {
"id": 13883, "id": 13883,
"logprob": -0.47436523, "logprob": -0.010406494,
"special": false, "special": false,
"text": " optimization" "text": " optimization"
}, },
{ {
"id": 5687, "id": 5687,
"logprob": -0.00027036667, "logprob": -0.0002501011,
"special": false, "special": false,
"text": " algorithm" "text": " algorithm"
},
{
"id": 15574,
"logprob": -0.6435547,
"special": false,
"text": " commonly"
},
{
"id": 1304,
"logprob": -0.0009279251,
"special": false,
"text": " used"
},
{
"id": 297,
"logprob": -0.18933105,
"special": false,
"text": " in"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "Gradient descent is a first-order optimization algorithm" "generated_text": "Gradient descent is an optimization algorithm commonly used in"
}, },
{ {
"details": { "details": {
@ -230,32 +230,32 @@
}, },
{ {
"id": 338, "id": 338,
"logprob": -0.71484375, "logprob": -0.609375,
"text": "is" "text": "is"
}, },
{ {
"id": 16030, "id": 16030,
"logprob": -13.9375, "logprob": -13.671875,
"text": "gradient" "text": "gradient"
}, },
{ {
"id": 26815, "id": 26815,
"logprob": -0.049346924, "logprob": -0.0040016174,
"text": "descent" "text": "descent"
}, },
{ {
"id": 29973, "id": 29973,
"logprob": -3.0078125, "logprob": -2.6230469,
"text": "?" "text": "?"
}, },
{ {
"id": 13, "id": 13,
"logprob": -2.8242188, "logprob": -6.453125,
"text": "\n" "text": "\n"
}, },
{ {
"id": 13, "id": 13,
"logprob": -0.86328125, "logprob": -6.6875,
"text": "\n" "text": "\n"
} }
], ],
@ -263,68 +263,68 @@
"tokens": [ "tokens": [
{ {
"id": 25584, "id": 25584,
"logprob": -0.017196655, "logprob": -0.008956909,
"special": false, "special": false,
"text": "Grad" "text": "Grad"
}, },
{ {
"id": 993, "id": 993,
"logprob": -0.0028438568, "logprob": -8.34465e-07,
"special": false, "special": false,
"text": "ient" "text": "ient"
}, },
{ {
"id": 26815, "id": 26815,
"logprob": -0.023254395, "logprob": -0.0009407997,
"special": false, "special": false,
"text": " descent" "text": " descent"
}, },
{ {
"id": 338, "id": 338,
"logprob": -2.026558e-05, "logprob": -0.0003721714,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 263, "id": 385,
"logprob": -0.5229492, "logprob": -0.24499512,
"special": false, "special": false,
"text": " a" "text": " an"
},
{
"id": 937,
"logprob": -0.17602539,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.00011622906,
"special": false,
"text": "order"
}, },
{ {
"id": 13883, "id": 13883,
"logprob": -0.48608398, "logprob": -0.010406494,
"special": false, "special": false,
"text": " optimization" "text": " optimization"
}, },
{ {
"id": 5687, "id": 5687,
"logprob": -0.00027894974, "logprob": -0.0002501011,
"special": false, "special": false,
"text": " algorithm" "text": " algorithm"
},
{
"id": 15574,
"logprob": -0.6435547,
"special": false,
"text": " commonly"
},
{
"id": 1304,
"logprob": -0.00092601776,
"special": false,
"text": " used"
},
{
"id": 297,
"logprob": -0.19177246,
"special": false,
"text": " in"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "Gradient descent is a first-order optimization algorithm" "generated_text": "Gradient descent is an optimization algorithm commonly used in"
}, },
{ {
"details": { "details": {
@ -339,32 +339,32 @@
}, },
{ {
"id": 338, "id": 338,
"logprob": -0.7192383, "logprob": -0.609375,
"text": "is" "text": "is"
}, },
{ {
"id": 16030, "id": 16030,
"logprob": -13.9375, "logprob": -13.6640625,
"text": "gradient" "text": "gradient"
}, },
{ {
"id": 26815, "id": 26815,
"logprob": -0.050445557, "logprob": -0.0038967133,
"text": "descent" "text": "descent"
}, },
{ {
"id": 29973, "id": 29973,
"logprob": -3.0078125, "logprob": -2.6347656,
"text": "?" "text": "?"
}, },
{ {
"id": 13, "id": 13,
"logprob": -2.8242188, "logprob": -6.453125,
"text": "\n" "text": "\n"
}, },
{ {
"id": 13, "id": 13,
"logprob": -0.8276367, "logprob": -6.6875,
"text": "\n" "text": "\n"
} }
], ],
@ -372,67 +372,67 @@
"tokens": [ "tokens": [
{ {
"id": 25584, "id": 25584,
"logprob": -0.01727295, "logprob": -0.008979797,
"special": false, "special": false,
"text": "Grad" "text": "Grad"
}, },
{ {
"id": 993, "id": 993,
"logprob": -0.0027542114, "logprob": -9.536743e-07,
"special": false, "special": false,
"text": "ient" "text": "ient"
}, },
{ {
"id": 26815, "id": 26815,
"logprob": -0.023254395, "logprob": -0.0009407997,
"special": false, "special": false,
"text": " descent" "text": " descent"
}, },
{ {
"id": 338, "id": 338,
"logprob": -2.0384789e-05, "logprob": -0.00038409233,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 263, "id": 385,
"logprob": -0.5229492, "logprob": -0.24499512,
"special": false, "special": false,
"text": " a" "text": " an"
},
{
"id": 937,
"logprob": -0.17126465,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.00011301041,
"special": false,
"text": "order"
}, },
{ {
"id": 13883, "id": 13883,
"logprob": -0.48608398, "logprob": -0.010414124,
"special": false, "special": false,
"text": " optimization" "text": " optimization"
}, },
{ {
"id": 5687, "id": 5687,
"logprob": -0.00027894974, "logprob": -0.00024354458,
"special": false, "special": false,
"text": " algorithm" "text": " algorithm"
},
{
"id": 15574,
"logprob": -0.6435547,
"special": false,
"text": " commonly"
},
{
"id": 1304,
"logprob": -0.0009279251,
"special": false,
"text": " used"
},
{
"id": 297,
"logprob": -0.19470215,
"special": false,
"text": " in"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "Gradient descent is a first-order optimization algorithm" "generated_text": "Gradient descent is an optimization algorithm commonly used in"
} }
] ]

View File

@ -4,7 +4,9 @@ import pytest
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_llama_fp8_kv_cache_handle(launcher): def flash_llama_fp8_kv_cache_handle(launcher):
with launcher( with launcher(
"meta-llama/Meta-Llama-3-8B", num_shard=2, kv_cache_dtype="fp8_e5m2" "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
num_shard=2,
kv_cache_dtype="fp8_e4m3fn",
) as handle: ) as handle:
yield handle yield handle
@ -25,7 +27,7 @@ async def test_flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache, response_snaps
assert ( assert (
response.generated_text response.generated_text
== " Deep learning is a subset of machine learning that is" == " Deep learning is a subset of machine learning that involves"
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert response == response_snapshot assert response == response_snapshot
@ -69,7 +71,7 @@ async def test_flash_llama_fp8_kv_cache_load(
assert len(responses) == 4 assert len(responses) == 4
assert ( assert (
responses[0].generated_text responses[0].generated_text
== " Deep learning is a subset of machine learning that is" == " Deep learning is a subset of machine learning that involves"
) )
assert all( assert all(
[r.generated_text == responses[0].generated_text for r in responses] [r.generated_text == responses[0].generated_text for r in responses]

View File

@ -25,7 +25,7 @@ async def test_flash_phi35_moe(flash_phi35_moe, response_snapshot):
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert ( assert (
response.generated_text response.generated_text
== "Gradient descent is a first-order optimization algorithm" == "Gradient descent is an optimization algorithm commonly used in"
) )
assert response == response_snapshot assert response == response_snapshot
@ -33,7 +33,7 @@ async def test_flash_phi35_moe(flash_phi35_moe, response_snapshot):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot): async def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot):
response = await flash_phi35_moe.generate( response = await flash_phi35_moe.generate(
"What is gradient descent?\n\n", "What is gradient descent?\n",
max_new_tokens=10, max_new_tokens=10,
repetition_penalty=1.2, repetition_penalty=1.2,
return_full_text=True, return_full_text=True,
@ -51,7 +51,7 @@ async def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot):
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert ( assert (
response.generated_text response.generated_text
== "What is gradient descent?\n\nHello! It seems you're addressing a" == "What is gradient descent?\nGradient Descent (GD) is an"
) )
assert response == response_snapshot assert response == response_snapshot
@ -66,7 +66,7 @@ async def test_flash_phi35_moe_load(flash_phi35_moe, generate_load, response_sna
assert responses[0].details.generated_tokens == 10 assert responses[0].details.generated_tokens == 10
assert ( assert (
responses[0].generated_text responses[0].generated_text
== "Gradient descent is a first-order optimization algorithm" == "Gradient descent is an optimization algorithm commonly used in"
) )
assert all( assert all(
[r.generated_text == responses[0].generated_text for r in responses] [r.generated_text == responses[0].generated_text for r in responses]

View File

@ -1108,6 +1108,8 @@ fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
} }
} }
} }
} else {
break;
} }
} }
} }
@ -1519,6 +1521,10 @@ fn spawn_webserver(
router_args.push(revision.to_string()) router_args.push(revision.to_string())
} }
if args.trust_remote_code {
router_args.push("--trust-remote-code".to_string());
}
if args.json_output { if args.json_output {
router_args.push("--json-output".to_string()); router_args.push("--json-output".to_string());
} }

View File

@ -1,7 +1,12 @@
{ {
lib,
mkShell, mkShell,
black, black,
cmake,
isort, isort,
ninja,
which,
cudaPackages,
openssl, openssl,
pkg-config, pkg-config,
protobuf, protobuf,
@ -11,14 +16,17 @@
ruff, ruff,
rust-bin, rust-bin,
server, server,
# Enable dependencies for building CUDA packages. Useful for e.g.
# developing marlin/moe-kernels in-place.
withCuda ? false,
}: }:
mkShell { mkShell {
buildInputs = nativeBuildInputs =
[ [
black black
isort isort
openssl.dev
pkg-config pkg-config
(rust-bin.stable.latest.default.override { (rust-bin.stable.latest.default.override {
extensions = [ extensions = [
@ -31,6 +39,19 @@ mkShell {
redocly redocly
ruff ruff
] ]
++ (lib.optionals withCuda [
cmake
ninja
which
# For most Torch-based extensions, setting CUDA_HOME is enough, but
# some custom CMake builds (e.g. vLLM) also need to have nvcc in PATH.
cudaPackages.cuda_nvcc
]);
buildInputs =
[
openssl.dev
]
++ (with python3.pkgs; [ ++ (with python3.pkgs; [
venvShellHook venvShellHook
docker docker
@ -40,10 +61,29 @@ mkShell {
pytest pytest
pytest-asyncio pytest-asyncio
syrupy syrupy
]); ])
++ (lib.optionals withCuda (
with cudaPackages;
[
cuda_cccl
cuda_cudart
cuda_nvrtc
cuda_nvtx
cuda_profiler_api
cudnn
libcublas
libcusolver
libcusparse
]
));
inputsFrom = [ server ]; inputsFrom = [ server ];
env = lib.optionalAttrs withCuda {
CUDA_HOME = "${lib.getDev cudaPackages.cuda_nvcc}";
TORCH_CUDA_ARCH_LIST = lib.concatStringsSep ";" python3.pkgs.torch.cudaCapabilities;
};
venvDir = "./.venv"; venvDir = "./.venv";
postVenvCreation = '' postVenvCreation = ''
@ -51,6 +91,7 @@ mkShell {
( cd server ; python -m pip install --no-dependencies -e . ) ( cd server ; python -m pip install --no-dependencies -e . )
( cd clients/python ; python -m pip install --no-dependencies -e . ) ( cd clients/python ; python -m pip install --no-dependencies -e . )
''; '';
postShellHook = '' postShellHook = ''
unset SOURCE_DATE_EPOCH unset SOURCE_DATE_EPOCH
export PATH=$PATH:~/.cargo/bin export PATH=$PATH:~/.cargo/bin

View File

@ -150,6 +150,7 @@ pub enum Config {
Idefics2(Idefics2), Idefics2(Idefics2),
Ssm, Ssm,
GptBigcode, GptBigcode,
Granite,
Santacoder, Santacoder,
Bloom, Bloom,
Mpt, Mpt,

View File

@ -8,6 +8,7 @@ pub mod validation;
mod kserve; mod kserve;
pub mod logging; pub mod logging;
mod sagemaker;
pub mod usage_stats; pub mod usage_stats;
mod vertex; mod vertex;

View File

@ -1,748 +0,0 @@
use axum::http::HeaderValue;
use clap::Parser;
use clap::Subcommand;
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Cache, Repo, RepoType};
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
use std::fs::File;
use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf};
use text_generation_router::config::Config;
use text_generation_router::usage_stats;
use text_generation_router::{
server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
};
use thiserror::Error;
use tokenizers::{processors::template::TemplateProcessing, Tokenizer};
use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[command(subcommand)]
command: Option<Commands>,
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)]
max_input_tokens: usize,
#[clap(default_value = "2048", long, env)]
max_total_tokens: usize,
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
#[clap(long, env)]
max_batch_total_tokens: Option<u32>,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(long, env)]
max_batch_size: Option<usize>,
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
master_shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String,
#[clap(long, env)]
tokenizer_config_path: Option<String>,
#[clap(long, env)]
revision: Option<String>,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: String,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env)]
api_key: Option<String>,
#[clap(long, env)]
ngrok: bool,
#[clap(long, env)]
ngrok_authtoken: Option<String>,
#[clap(long, env)]
ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
#[clap(long, env, default_value_t)]
disable_usage_stats: bool,
#[clap(long, env, default_value_t)]
disable_crash_reports: bool,
}
#[derive(Debug, Subcommand)]
enum Commands {
PrintSchema,
}
#[tokio::main]
async fn main() -> Result<(), RouterError> {
let args = Args::parse();
// Pattern match configuration
let Args {
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
hostname,
port,
master_shard_uds_path,
tokenizer_name,
tokenizer_config_path,
revision,
validation_workers,
json_output,
otlp_endpoint,
otlp_service_name,
cors_allow_origin,
api_key,
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
disable_usage_stats,
disable_crash_reports,
command,
} = args;
let print_schema_command = match command {
Some(Commands::PrintSchema) => true,
None => {
// only init logging if we are not running the print schema command
init_logging(otlp_endpoint, otlp_service_name, json_output);
false
}
};
// Validate args
if max_input_tokens >= max_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
if validation_workers == 0 {
return Err(RouterError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
));
}
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
}
// CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowOrigin
let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
AllowOrigin::list(
cors_allow_origin
.iter()
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
)
});
// Parse Huggingface hub token
let authorization_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();
// Tokenizer instance
// This will only be used to validate payloads
let local_path = Path::new(&tokenizer_name);
// Shared API builder initialization
let api_builder = || {
let mut builder = ApiBuilder::new()
.with_progress(false)
.with_token(authorization_token);
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
builder = builder.with_cache_dir(cache_dir.into());
}
builder
};
// Decide if we need to use the API based on the revision and local path
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
// Initialize API if needed
#[derive(Clone)]
enum Type {
Api(Api),
Cache(Cache),
None,
}
let api = if use_api {
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
.map_err(|_| ())
.map(|cache_dir| Cache::new(cache_dir.into()))
.unwrap_or_else(|_| Cache::default());
tracing::warn!("Offline mode active using cache defaults");
Type::Cache(cache)
} else {
tracing::info!("Using the Hugging Face API");
match api_builder().build() {
Ok(api) => Type::Api(api),
Err(_) => {
tracing::warn!("Unable to build the Hugging Face API");
Type::None
}
}
}
} else {
Type::None
};
// Load tokenizer and model info
let (
tokenizer_filename,
config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
model_info,
) = match api {
Type::None => (
Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")),
None,
),
Type::Api(api) => {
let api_repo = api.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
));
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
Ok(tokenizer_filename) => Some(tokenizer_filename),
Err(_) => get_base_tokenizer(&api, &api_repo).await,
};
let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
Some(model_info)
} else {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
None
};
(
tokenizer_filename,
config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
model_info,
)
}
Type::Cache(cache) => {
let repo = cache.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
));
(
repo.get("tokenizer.json"),
repo.get("config.json"),
repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"),
repo.get("processor_config.json"),
None,
)
}
};
let config: Option<Config> = config_filename.and_then(|filename| {
std::fs::read_to_string(filename)
.ok()
.as_ref()
.and_then(|c| {
let config: Result<Config, _> = serde_json::from_str(c);
if let Err(err) = &config {
tracing::warn!("Could not parse config {err:?}");
}
config.ok()
})
});
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
model_id: tokenizer_name.to_string(),
sha: None,
pipeline_tag: None,
});
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{
HubTokenizerConfig::from_file(filename)
} else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
};
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
});
let tokenizer_class = tokenizer_config.tokenizer_class.clone();
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
let mut tokenizer = Tokenizer::from_file(filename).ok();
if let Some(tokenizer) = &mut tokenizer {
if let Some(class) = &tokenizer_config.tokenizer_class {
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
tokenizer.with_post_processor(post_processor);
}
}
}
}
tokenizer
});
let preprocessor_config =
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
let processor_config = processor_config_filename
.and_then(HubProcessorConfig::from_file)
.unwrap_or_default();
tracing::info!("Using config {config:?}");
if tokenizer.is_none() {
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
tracing::warn!("Rust input length validation and truncation is disabled");
}
// if pipeline-tag == text-generation we default to return_full_text = true
let compat_return_full_text = match &model_info.pipeline_tag {
None => {
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
true
}
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
};
// Determine the server port based on the feature and environment variable.
let port = if cfg!(feature = "google") {
std::env::var("AIP_HTTP_PORT")
.map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
.unwrap_or(port)
} else {
port
};
let addr = match hostname.parse() {
Ok(ip) => SocketAddr::new(ip, port),
Err(_) => {
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
}
};
// Only send usage stats when TGI is run in container and the function returns Some
let is_container = matches!(usage_stats::is_container(), Ok(true));
let user_agent = if !disable_usage_stats && is_container {
let reduced_args = usage_stats::Args::new(
config.clone(),
tokenizer_class,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
revision,
validation_workers,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
disable_usage_stats,
disable_crash_reports,
);
Some(usage_stats::UserAgent::new(reduced_args))
} else {
None
};
if let Some(ref ua) = user_agent {
let start_event =
usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None);
tokio::spawn(async move {
start_event.send().await;
});
};
// Run server
let result = server::run(
master_shard_uds_path,
model_info,
compat_return_full_text,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
tokenizer,
config,
validation_workers,
addr,
cors_allow_origin,
api_key,
ngrok,
ngrok_authtoken,
ngrok_edge,
tokenizer_config,
preprocessor_config,
processor_config,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
print_schema_command,
)
.await;
match result {
Ok(_) => {
if let Some(ref ua) = user_agent {
let stop_event = usage_stats::UsageStatsEvent::new(
ua.clone(),
usage_stats::EventType::Stop,
None,
);
stop_event.send().await;
};
Ok(())
}
Err(e) => {
if let Some(ref ua) = user_agent {
if !disable_crash_reports {
let error_event = usage_stats::UsageStatsEvent::new(
ua.clone(),
usage_stats::EventType::Error,
Some(e.to_string()),
);
error_event.send().await;
} else {
let unknow_error_event = usage_stats::UsageStatsEvent::new(
ua.clone(),
usage_stats::EventType::Error,
Some("unknow_error".to_string()),
);
unknow_error_event.send().await;
}
};
Err(RouterError::WebServer(e))
}
}
}
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
/// - otlp_service_name service name to appear in APM
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
let mut layers = Vec::new();
// STDOUT/STDERR layer
let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
let fmt_layer = tracing_subscriber::fmt::layer()
.with_file(true)
.with_ansi(ansi)
.with_line_number(true);
let fmt_layer = match json_output {
true => fmt_layer.json().flatten_event(true).boxed(),
false => fmt_layer.boxed(),
};
layers.push(fmt_layer);
// OpenTelemetry tracing layer
if let Some(otlp_endpoint) = otlp_endpoint {
global::set_text_map_propagator(TraceContextPropagator::new());
let tracer = opentelemetry_otlp::new_pipeline()
.tracing()
.with_exporter(
opentelemetry_otlp::new_exporter()
.tonic()
.with_endpoint(otlp_endpoint),
)
.with_trace_config(
trace::config()
.with_resource(Resource::new(vec![KeyValue::new(
"service.name",
otlp_service_name,
)]))
.with_sampler(Sampler::AlwaysOn),
)
.install_batch(opentelemetry::runtime::Tokio);
if let Ok(tracer) = tracer {
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
init_tracing_opentelemetry::init_propagator().unwrap();
};
}
// Filter events with LOG_LEVEL
let varname = "LOG_LEVEL";
let env_filter = if let Ok(log_level) = std::env::var(varname) {
// Override to avoid simple logs to be spammed with tokio level informations
let log_level = match &log_level[..] {
"warn" => "text_generation_launcher=warn,text_generation_router=warn",
"info" => "text_generation_launcher=info,text_generation_router=info",
"debug" => "text_generation_launcher=debug,text_generation_router=debug",
log_level => log_level,
};
EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.parse_lossy(log_level)
} else {
EnvFilter::new("info")
};
tracing_subscriber::registry()
.with(env_filter)
.with(layers)
.init();
}
/// get model info from the Huggingface Hub
pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
let response = api.info_request().send().await.ok()?;
if response.status().is_success() {
let hub_model_info: HubModelInfo =
serde_json::from_str(&response.text().await.ok()?).ok()?;
if let Some(sha) = &hub_model_info.sha {
tracing::info!(
"Serving revision {sha} of model {}",
hub_model_info.model_id
);
}
Some(hub_model_info)
} else {
None
}
}
/// get base tokenizer
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
let config_filename = api_repo.get("config.json").await.ok()?;
// Open the file in read-only mode with buffer.
let file = File::open(config_filename).ok()?;
let reader = BufReader::new(file);
// Read the JSON contents of the file as an instance of `User`.
let config: serde_json::Value = serde_json::from_reader(reader).ok()?;
if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
let api_base_repo = api.repo(Repo::with_revision(
base_model_id.to_string(),
RepoType::Model,
"main".to_string(),
));
api_base_repo.get("tokenizer.json").await.ok()
} else {
None
}
}
/// get tokenizer_config from the Huggingface Hub
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;
// Open the file in read-only mode with buffer.
let file = File::open(tokenizer_config_filename).ok()?;
let reader = BufReader::new(file);
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
.map_err(|e| {
tracing::warn!("Unable to parse tokenizer config: {}", e);
e
})
.ok()?;
Some(tokenizer_config)
}
/// Create a post_processor for the LlamaTokenizer
pub fn create_post_processor(
tokenizer: &Tokenizer,
tokenizer_config: &HubTokenizerConfig,
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);
let bos_token = tokenizer_config.bos_token.as_ref();
let eos_token = tokenizer_config.eos_token.as_ref();
if add_bos_token && bos_token.is_none() {
panic!("add_bos_token = true but bos_token is None");
}
if add_eos_token && eos_token.is_none() {
panic!("add_eos_token = true but eos_token is None");
}
let mut single = Vec::new();
let mut pair = Vec::new();
let mut special_tokens = Vec::new();
if add_bos_token {
if let Some(bos) = bos_token {
let bos_token_id = tokenizer
.token_to_id(bos.as_str())
.expect("Should have found the bos token id");
special_tokens.push((bos.as_str(), bos_token_id));
single.push(format!("{}:0", bos.as_str()));
pair.push(format!("{}:0", bos.as_str()));
}
}
single.push("$A:0".to_string());
pair.push("$A:0".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
let eos_token_id = tokenizer
.token_to_id(eos.as_str())
.expect("Should have found the eos token id");
special_tokens.push((eos.as_str(), eos_token_id));
single.push(format!("{}:0", eos.as_str()));
pair.push(format!("{}:0", eos.as_str()));
}
}
if add_bos_token {
if let Some(bos) = bos_token {
pair.push(format!("{}:1", bos.as_str()));
}
}
pair.push("$B:1".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
pair.push(format!("{}:1", eos.as_str()));
}
}
let post_processor = TemplateProcessing::builder()
.try_single(single)?
.try_pair(pair)?
.special_tokens(special_tokens)
.build()?;
Ok(post_processor)
}
#[derive(Debug, Error)]
enum RouterError {
#[error("Argument validation error: {0}")]
ArgumentValidation(String),
#[error("WebServer error: {0}")]
WebServer(#[from] server::WebServerError),
#[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error),
}
#[cfg(test)]
mod tests {
use super::*;
use text_generation_router::TokenizerConfigToken;
#[test]
fn test_create_post_processor() {
let tokenizer_config = HubTokenizerConfig {
add_bos_token: None,
add_eos_token: None,
bos_token: Some(TokenizerConfigToken::String("<s>".to_string())),
eos_token: Some(TokenizerConfigToken::String("</s>".to_string())),
chat_template: None,
tokenizer_class: None,
completion_template: None,
};
let tokenizer =
Tokenizer::from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", None).unwrap();
let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap();
let expected = TemplateProcessing::builder()
.try_single("<s>:0 $A:0")
.unwrap()
.try_pair("<s>:0 $A:0 <s>:1 $B:1")
.unwrap()
.special_tokens(vec![("<s>".to_string(), 1)])
.build()
.unwrap();
assert_eq!(post_processor, expected);
}
}

82
router/src/sagemaker.rs Normal file
View File

@ -0,0 +1,82 @@
use crate::infer::Infer;
use crate::server::{chat_completions, compat_generate, completions, ComputeType};
use crate::{
ChatCompletion, ChatCompletionChunk, ChatRequest, Chunk, CompatGenerateRequest,
CompletionFinal, CompletionRequest, ErrorResponse, GenerateResponse, Info, StreamResponse,
};
use axum::extract::Extension;
use axum::http::StatusCode;
use axum::response::Response;
use axum::Json;
use serde::{Deserialize, Serialize};
use tracing::instrument;
use utoipa::ToSchema;
#[derive(Clone, Deserialize, ToSchema)]
#[serde(untagged)]
pub(crate) enum SagemakerRequest {
Generate(CompatGenerateRequest),
Chat(ChatRequest),
Completion(CompletionRequest),
}
// Used for OpenAPI specs
#[allow(dead_code)]
#[derive(Serialize, ToSchema)]
#[serde(untagged)]
pub(crate) enum SagemakerResponse {
Generate(GenerateResponse),
Chat(ChatCompletion),
Completion(CompletionFinal),
}
// Used for OpenAPI specs
#[allow(dead_code)]
#[derive(Serialize, ToSchema)]
#[serde(untagged)]
pub(crate) enum SagemakerStreamResponse {
Generate(StreamResponse),
Chat(ChatCompletionChunk),
Completion(Chunk),
}
/// Generate tokens from Sagemaker request
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/invocations",
request_body = SagemakerRequest,
responses(
(status = 200, description = "Generated Chat Completion",
content(
("application/json" = SagemakerResponse),
("text/event-stream" = SagemakerStreamResponse),
)),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation", "error_type": "generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded", "error_type": "overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error", "error_type": "validation"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation", "error_type": "incomplete_generation"})),
)
)]
#[instrument(skip_all)]
pub(crate) async fn sagemaker_compatibility(
default_return_full_text: Extension<bool>,
infer: Extension<Infer>,
compute_type: Extension<ComputeType>,
info: Extension<Info>,
Json(req): Json<SagemakerRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
match req {
SagemakerRequest::Generate(req) => {
compat_generate(default_return_full_text, infer, compute_type, Json(req)).await
}
SagemakerRequest::Chat(req) => chat_completions(infer, compute_type, info, Json(req)).await,
SagemakerRequest::Completion(req) => {
completions(infer, compute_type, info, Json(req)).await
}
}
}

View File

@ -7,6 +7,10 @@ use crate::kserve::{
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer, kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
kserve_model_metadata, kserve_model_metadata_ready, kserve_model_metadata, kserve_model_metadata_ready,
}; };
use crate::sagemaker::{
sagemaker_compatibility, SagemakerRequest, SagemakerResponse, SagemakerStreamResponse,
__path_sagemaker_compatibility,
};
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::vertex::vertex_compatibility; use crate::vertex::vertex_compatibility;
use crate::ChatTokenizeResponse; use crate::ChatTokenizeResponse;
@ -83,7 +87,7 @@ example = json ! ({"error": "Incomplete generation"})),
) )
)] )]
#[instrument(skip(infer, req))] #[instrument(skip(infer, req))]
async fn compat_generate( pub(crate) async fn compat_generate(
Extension(default_return_full_text): Extension<bool>, Extension(default_return_full_text): Extension<bool>,
infer: Extension<Infer>, infer: Extension<Infer>,
compute_type: Extension<ComputeType>, compute_type: Extension<ComputeType>,
@ -678,7 +682,7 @@ time_per_token,
seed, seed,
) )
)] )]
async fn completions( pub(crate) async fn completions(
Extension(infer): Extension<Infer>, Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>, Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>, Extension(info): Extension<Info>,
@ -1202,7 +1206,7 @@ time_per_token,
seed, seed,
) )
)] )]
async fn chat_completions( pub(crate) async fn chat_completions(
Extension(infer): Extension<Infer>, Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>, Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>, Extension(info): Extension<Info>,
@ -1513,11 +1517,13 @@ completions,
tokenize, tokenize,
metrics, metrics,
openai_get_model_info, openai_get_model_info,
sagemaker_compatibility,
), ),
components( components(
schemas( schemas(
Info, Info,
CompatGenerateRequest, CompatGenerateRequest,
SagemakerRequest,
GenerateRequest, GenerateRequest,
GrammarType, GrammarType,
ChatRequest, ChatRequest,
@ -1540,6 +1546,8 @@ ChatCompletionTopLogprob,
ChatCompletion, ChatCompletion,
CompletionRequest, CompletionRequest,
CompletionComplete, CompletionComplete,
SagemakerResponse,
SagemakerStreamResponse,
Chunk, Chunk,
Completion, Completion,
CompletionFinal, CompletionFinal,
@ -1601,13 +1609,13 @@ pub async fn run(
tokenizer_name: String, tokenizer_name: String,
tokenizer_config_path: Option<String>, tokenizer_config_path: Option<String>,
revision: Option<String>, revision: Option<String>,
trust_remote_code: bool,
hostname: String, hostname: String,
port: u16, port: u16,
cors_allow_origin: Option<Vec<String>>, cors_allow_origin: Option<Vec<String>>,
ngrok: bool, ngrok: bool,
_ngrok_authtoken: Option<String>, _ngrok_authtoken: Option<String>,
_ngrok_edge: Option<String>, _ngrok_edge: Option<String>,
messages_api_enabled: bool,
disable_grammar_support: bool, disable_grammar_support: bool,
max_client_batch_size: usize, max_client_batch_size: usize,
usage_stats_level: usage_stats::UsageStatsLevel, usage_stats_level: usage_stats::UsageStatsLevel,
@ -1761,10 +1769,13 @@ pub async fn run(
let auto = transformers.getattr("AutoTokenizer")?; let auto = transformers.getattr("AutoTokenizer")?;
let from_pretrained = auto.getattr("from_pretrained")?; let from_pretrained = auto.getattr("from_pretrained")?;
let args = (tokenizer_name.to_string(),); let args = (tokenizer_name.to_string(),);
let kwargs = [( let kwargs = [
"revision", (
revision.clone().unwrap_or_else(|| "main".to_string()), "revision",
)] (revision.clone().unwrap_or_else(|| "main".to_string())).into_py(py),
),
("trust_remote_code", trust_remote_code.into_py(py)),
]
.into_py_dict_bound(py); .into_py_dict_bound(py);
let tokenizer = from_pretrained.call(args, Some(&kwargs))?; let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
let save = tokenizer.getattr("save_pretrained")?; let save = tokenizer.getattr("save_pretrained")?;
@ -1836,7 +1847,6 @@ pub async fn run(
// max_batch_size, // max_batch_size,
revision.clone(), revision.clone(),
validation_workers, validation_workers,
messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
usage_stats_level, usage_stats_level,
@ -1878,7 +1888,6 @@ pub async fn run(
ngrok, ngrok,
_ngrok_authtoken, _ngrok_authtoken,
_ngrok_edge, _ngrok_edge,
messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
model_info, model_info,
@ -1938,7 +1947,6 @@ async fn start(
ngrok: bool, ngrok: bool,
_ngrok_authtoken: Option<String>, _ngrok_authtoken: Option<String>,
_ngrok_edge: Option<String>, _ngrok_edge: Option<String>,
messages_api_enabled: bool,
disable_grammar_support: bool, disable_grammar_support: bool,
max_client_batch_size: usize, max_client_batch_size: usize,
model_info: HubModelInfo, model_info: HubModelInfo,
@ -2253,6 +2261,7 @@ async fn start(
.route("/v1/chat/completions", post(chat_completions)) .route("/v1/chat/completions", post(chat_completions))
.route("/v1/completions", post(completions)) .route("/v1/completions", post(completions))
.route("/vertex", post(vertex_compatibility)) .route("/vertex", post(vertex_compatibility))
.route("/invocations", post(sagemaker_compatibility))
.route("/tokenize", post(tokenize)); .route("/tokenize", post(tokenize));
if let Some(api_key) = api_key { if let Some(api_key) = api_key {
@ -2288,13 +2297,6 @@ async fn start(
.route("/metrics", get(metrics)) .route("/metrics", get(metrics))
.route("/v1/models", get(openai_get_model_info)); .route("/v1/models", get(openai_get_model_info));
// Conditional AWS Sagemaker route
let aws_sagemaker_route = if messages_api_enabled {
Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED
} else {
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
};
let compute_type = let compute_type =
ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string())); ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
@ -2302,8 +2304,7 @@ async fn start(
let mut app = Router::new() let mut app = Router::new()
.merge(swagger_ui) .merge(swagger_ui)
.merge(base_routes) .merge(base_routes)
.merge(info_routes) .merge(info_routes);
.merge(aws_sagemaker_route);
#[cfg(feature = "google")] #[cfg(feature = "google")]
{ {

View File

@ -93,7 +93,6 @@ pub struct Args {
// max_batch_size: Option<usize>, // max_batch_size: Option<usize>,
revision: Option<String>, revision: Option<String>,
validation_workers: usize, validation_workers: usize,
messages_api_enabled: bool,
disable_grammar_support: bool, disable_grammar_support: bool,
max_client_batch_size: usize, max_client_batch_size: usize,
usage_stats_level: UsageStatsLevel, usage_stats_level: UsageStatsLevel,
@ -117,7 +116,6 @@ impl Args {
// max_batch_size: Option<usize>, // max_batch_size: Option<usize>,
revision: Option<String>, revision: Option<String>,
validation_workers: usize, validation_workers: usize,
messages_api_enabled: bool,
disable_grammar_support: bool, disable_grammar_support: bool,
max_client_batch_size: usize, max_client_batch_size: usize,
usage_stats_level: UsageStatsLevel, usage_stats_level: UsageStatsLevel,
@ -138,7 +136,6 @@ impl Args {
// max_batch_size, // max_batch_size,
revision, revision,
validation_workers, validation_workers,
messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
usage_stats_level, usage_stats_level,

View File

@ -31,7 +31,7 @@ install: install-cuda
echo "Installed server" echo "Installed server"
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm
pip install -e ".[bnb]" pip install -e ".[bnb,marlin,moe]"
pip install nvidia-nccl-cu12==2.22.3 pip install nvidia-nccl-cu12==2.22.3
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm

1379
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -41,10 +41,10 @@ py-cpuinfo = "^9.0.0"
numpy = "^1.26" numpy = "^1.26"
marlin-kernels = [ marlin-kernels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
] ]
moe-kernels = [ moe-kernels = [
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },

View File

@ -1,5 +1,5 @@
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13" certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13" idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13" regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13" rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13" setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13" transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -1,5 +1,5 @@
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13" certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13" idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13" regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13" rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13" setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13" transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -1,5 +1,5 @@
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13" certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13" idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13" regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13" rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13" setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13" transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -28,10 +28,11 @@ else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
# KVCache needs `reshape_and_cache`, so ensure that it is defined already. # KVCache needs `reshape_and_cache`, so ensure that it is defined already.
from .kv_cache import KVCache from .kv_cache import KVCache, get_kv_scales
__all__ = [ __all__ = [
"attention", "attention",
"get_kv_scales",
"paged_attention", "paged_attention",
"SUPPORTS_WINDOWING", "SUPPORTS_WINDOWING",
"KVCache", "KVCache",

View File

@ -1,5 +1,5 @@
import torch import torch
from text_generation_server.layers.attention.kv_cache import KVCache from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ( from text_generation_server.models.globals import (
ATTENTION, ATTENTION,
@ -8,6 +8,7 @@ from text_generation_server.models.globals import (
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from typing import Optional from typing import Optional
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5 is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
@ -21,6 +22,8 @@ def paged_attention(
block_tables: torch.Tensor, block_tables: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, max_s: int,
*,
kv_scales: KVScales,
softcap: Optional[float] = None, softcap: Optional[float] = None,
): ):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
@ -46,6 +49,8 @@ def paged_attention(
num_seqs, num_heads, head_size = query.shape num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
can_scale = kv_cache.can_scale(kv_scales)
# NOTE(woosuk): We use a simple heuristic to decide whether to use # NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use # PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of # V1 to avoid the overhead of reduction. Also, if the number of
@ -55,10 +60,13 @@ def paged_attention(
from text_generation_server.layers.attention.flashinfer import decode_state from text_generation_server.layers.attention.flashinfer import decode_state
return decode_state.get().forward( return decode_state.get().forward(
# TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged.
query.contiguous(), query.contiguous(),
paged_kv_cache=(kv_cache.key, kv_cache.value), paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap, logits_soft_cap=softcap,
sm_scale=softmax_scale, sm_scale=softmax_scale,
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
) )
elif ATTENTION == "flashdecoding": elif ATTENTION == "flashdecoding":
max_q = 1 max_q = 1
@ -204,6 +212,7 @@ def attention(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen, seqlen: Seqlen,
block_tables: torch.Tensor, block_tables: torch.Tensor,
softmax_scale: float, softmax_scale: float,
@ -211,6 +220,8 @@ def attention(
causal: bool = True, causal: bool = True,
softcap: Optional[float] = None, softcap: Optional[float] = None,
): ):
can_scale = kv_cache.can_scale(kv_scales)
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import ( from text_generation_server.layers.attention.flashinfer import (
prefill_with_paged_kv_state, prefill_with_paged_kv_state,
@ -220,12 +231,15 @@ def attention(
softcap = 0.0 softcap = 0.0
return prefill_with_paged_kv_state.get().forward( return prefill_with_paged_kv_state.get().forward(
# TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged.
query.contiguous(), query.contiguous(),
causal=causal, causal=causal,
paged_kv_cache=(kv_cache.key, kv_cache.value), paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap, logits_soft_cap=softcap,
sm_scale=softmax_scale, sm_scale=softmax_scale,
window_left=window_size_left, window_left=window_size_left,
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
) )
# If we are using flashdecoding or paged, we always use flash-attn for # If we are using flashdecoding or paged, we always use flash-attn for

View File

@ -204,6 +204,7 @@ def use_decode_state(
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
page_size: int, page_size: int,
kv_cache_dtype: torch.dtype,
dtype: torch.dtype, dtype: torch.dtype,
window_left: int, window_left: int,
): ):
@ -240,7 +241,7 @@ def use_decode_state(
num_kv_heads=num_kv_heads, num_kv_heads=num_kv_heads,
head_dim=head_size, head_dim=head_size,
page_size=page_size, page_size=page_size,
data_type=dtype, data_type=kv_cache_dtype,
q_data_type=dtype, q_data_type=dtype,
window_left=window_left, window_left=window_left,
) )

View File

@ -1,6 +1,6 @@
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
import torch import torch
from text_generation_server.layers.attention.kv_cache import KVCache from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from typing import Optional from typing import Optional
@ -14,6 +14,7 @@ def attention(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen, seqlen: Seqlen,
block_tables: torch.Tensor, block_tables: torch.Tensor,
softmax_scale: float, softmax_scale: float,
@ -55,6 +56,8 @@ def paged_attention(
block_tables: torch.Tensor, block_tables: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, max_s: int,
*,
kv_scales: KVScales,
softcap: Optional[float] = None, softcap: Optional[float] = None,
): ):
if softcap is not None: if softcap is not None:

View File

@ -1,8 +1,38 @@
from typing import Tuple from typing import Tuple
from dataclasses import dataclass, field
from loguru import logger
import torch import torch
from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weights
@dataclass
class KVScales:
"""
Key-value scales for FP8 KV cache.
This data class stores key and value scales both as a GPU tensor and
as a GPU float. This inconvenience is necessary because some functions
(e.g. scaling kernels) take scales as a GPU tensor, whereas others
(e.g. flashinfer) take scales as a CPU scalar.
"""
key_scale: torch.Tensor
value_scale: torch.Tensor
key_scale_cpu: float = field(init=False)
value_scale_cpu: float = field(init=False)
def __post_init__(self):
if self.key_scale.numel() != 1 or self.value_scale.numel() != 1:
raise ValueError("Key and value scales must be scalar tensors.")
self.key_scale_cpu = self.key_scale.item()
self.value_scale_cpu = self.value_scale.item()
class KVCache: class KVCache:
@ -76,6 +106,33 @@ class KVCache:
), ),
) )
def can_scale(self, kv_scales: KVScales) -> bool:
"""Check if the cache can be scaled by the given scales."""
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
return False
elif (
self.dtype == torch.float8_e4m3fn
and ATTENTION == "flashinfer"
and SYSTEM == "cuda"
):
log_once(
logger.info,
"Using FP8 KV cache scales",
)
return True
else:
# We have scales, but not the correct FP8 cache type, so warn once.
log_once(
logger.info,
"Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported",
)
return False
@property
def dtype(self):
"""Get the data type of the cache."""
return self.kv_cache[0].dtype
@property @property
def key(self): def key(self):
"""Get the key cache.""" """Get the key cache."""
@ -94,17 +151,33 @@ class KVCache:
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
kv_scales: KVScales,
): ):
"""Store the key and value at the given slots.""" """Store the key and value at the given slots."""
key_cache = self.kv_cache[0] key_cache = self.kv_cache[0]
value_cache = self.kv_cache[1] value_cache = self.kv_cache[1]
if self.can_scale(kv_scales):
if kv_scales.key_scale_cpu != 1.0:
key = fp8_quantize(
key.float(),
scale=kv_scales.key_scale,
qdtype=self.dtype,
scalar=True,
)[0]
if kv_scales.value_scale_cpu != 1.0:
value = fp8_quantize(
value.float(),
scale=kv_scales.value_scale,
qdtype=self.dtype,
scalar=True,
)[0]
if ATTENTION in {"flashdecoding", "flashinfer"}: if ATTENTION in {"flashdecoding", "flashinfer"}:
# TODO: add scale
key = key.to(key_cache.dtype) key = key.to(key_cache.dtype)
value = value.to(value_cache.dtype) value = value.to(value_cache.dtype)
if key_cache.dtype in {torch.float8_e5m2, torch.float8_e4m3fn}: if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
# Torch index_put does not support float8_{e5m2,e4m3fn} yet, so # Torch index_put does not support float8_{e5m2,e4m3fn} yet, so
# put as raw data instead. # put as raw data instead.
key_cache = key_cache.view(torch.uint8) key_cache = key_cache.view(torch.uint8)
@ -151,5 +224,23 @@ def paged_reshape_and_cache(
) )
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supportedattention" f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supported"
) )
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
"""Load KV cache scales."""
key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device)
value_scale = key_scale
if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor(
f"{prefix}.v_scale"
):
key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float()
value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float()
elif weights.has_tensor(f"{prefix}.kv_scale"):
# Fall back to older more coarse-grained scale when available.
key_scale = weights.get_tensor(f"{prefix}.kv_scale").float()
value_scale = key_scale
return KVScales(key_scale=key_scale, value_scale=value_scale)

View File

@ -1,7 +1,7 @@
import os import os
from typing import Optional from typing import Optional
import torch import torch
from text_generation_server.layers.attention.kv_cache import KVCache from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
@ -36,6 +36,8 @@ def paged_attention(
block_tables: torch.Tensor, block_tables: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, max_s: int,
*,
kv_scales: KVScales,
softcap: Optional[float] = None, softcap: Optional[float] = None,
): ):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
@ -210,6 +212,7 @@ def attention(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen, seqlen: Seqlen,
block_tables: torch.Tensor, block_tables: torch.Tensor,
softmax_scale: float, softmax_scale: float,

View File

@ -26,6 +26,12 @@ def is_fbgemm_gpu_available():
return False return False
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
if is_fbgemm_gpu_available(): if is_fbgemm_gpu_available():
if SYSTEM == "cuda": if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability() major, _ = torch.cuda.get_device_capability()
@ -94,6 +100,17 @@ def fp8_quantize(
) )
return qweight, scale return qweight, scale
if marlin_kernels is not None:
shape = weight.shape
qweight, scale = marlin_kernels.scaled_fp8_quant(
weight.reshape(-1, shape[-1]),
dtype=qdtype,
scale=scale,
scale_ub=scale_upper_bound,
)
return qweight.reshape(shape), scale
# weight, scale = quant_weights(weight, torch.int8, False) # weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch.finfo(qdtype) finfo = torch.finfo(qdtype)

View File

@ -11,7 +11,7 @@ from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
if SYSTEM == "ipex": if SYSTEM == "ipex":
from .ipex import QuantLinear from .ipex import QuantLinear
elif SYSTEM in {"cuda", "rocm"}: elif SYSTEM in {"cuda", "rocm"}:
from .cuda import QuantLinear from .triton import QuantLinear
@dataclass @dataclass

View File

@ -195,6 +195,11 @@ class ModelType(enum.Enum):
"name": "Phi 3", "name": "Phi 3",
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
} }
GRANITE = {
"type": "granite",
"name": "Granite",
"url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct",
}
GEMMA = { GEMMA = {
"type": "gemma", "type": "gemma",
"name": "Gemma", "name": "Gemma",
@ -862,7 +867,12 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: elif (
model_type == LLAMA
or model_type == BAICHUAN
or model_type == PHI3
or model_type == GRANITE
):
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
@ -876,7 +886,9 @@ def get_model(
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
)
else: else:
return CausalLM.fallback( return CausalLM.fallback(
model_id, model_id,

View File

@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
attention, attention,
Seqlen, Seqlen,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -227,6 +228,7 @@ class FlashCohereAttention(torch.nn.Module):
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.use_qk_norm = config.use_qk_norm self.use_qk_norm = config.use_qk_norm
if self.use_qk_norm: if self.use_qk_norm:
@ -289,7 +291,12 @@ class FlashCohereAttention(torch.nn.Module):
self.rotary_emb(query, key, cos, sin) self.rotary_emb(query, key, cos, sin)
kv_cache.store(key=key, value=value, slots=slots) kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -299,6 +306,7 @@ class FlashCohereAttention(torch.nn.Module):
key=key, key=key,
value=value, value=value,
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -313,6 +321,7 @@ class FlashCohereAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj( return self.o_proj(

View File

@ -20,6 +20,7 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any from typing import Optional, List, Tuple, Any
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM != "ipex": if SYSTEM != "ipex":
@ -288,6 +289,7 @@ class DbrxAttention(torch.nn.Module):
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
@ -328,7 +330,12 @@ class DbrxAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -338,6 +345,7 @@ class DbrxAttention(torch.nn.Module):
key=kv[:, 0], key=kv[:, 0],
value=kv[:, 1], value=kv[:, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -352,6 +360,7 @@ class DbrxAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -34,6 +34,7 @@ from text_generation_server.layers.attention import (
attention, attention,
paged_attention, paged_attention,
) )
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
@ -230,6 +231,8 @@ class DeepseekV2Attention(torch.nn.Module):
), ),
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.kv_a_layernorm = FastRMSNorm.load( self.kv_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
) )
@ -258,7 +261,7 @@ class DeepseekV2Attention(torch.nn.Module):
cos: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor, sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor, cu_seqlen_prefill: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor], kv_cache: KVCache,
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
@ -319,7 +322,12 @@ class DeepseekV2Attention(torch.nn.Module):
value, (0, self.head_pad_size - self.value_head_size), value=0 value, (0, self.head_pad_size - self.value_head_size), value=0
) )
kv_cache.store(key=key, value=value, slots=slots) kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -329,6 +337,7 @@ class DeepseekV2Attention(torch.nn.Module):
key=key, key=key,
value=value, value=value,
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -343,6 +352,7 @@ class DeepseekV2Attention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
# Remove padding. # Remove padding.

View File

@ -39,6 +39,7 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear, TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear, TensorParallelAdapterRowLinear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
@ -206,6 +207,7 @@ class FlashGemma2Attention(torch.nn.Module):
], ],
process_group=weights.process_group, process_group=weights.process_group,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load( o_proj = TensorParallelRowLinear.load(
config, config,
@ -251,7 +253,12 @@ class FlashGemma2Attention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -261,6 +268,7 @@ class FlashGemma2Attention(torch.nn.Module):
key=kv[:, 0], key=kv[:, 0],
value=kv[:, 1], value=kv[:, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -278,6 +286,7 @@ class FlashGemma2Attention(torch.nn.Module):
seqlen, seqlen,
max_s, max_s,
softcap=self.softcap, softcap=self.softcap,
kv_scales=self.kv_scales,
) )
return self.o_proj( return self.o_proj(

View File

@ -37,6 +37,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
@ -185,6 +186,7 @@ class FlashGemmaAttention(torch.nn.Module):
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
@ -222,7 +224,12 @@ class FlashGemmaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -232,6 +239,7 @@ class FlashGemmaAttention(torch.nn.Module):
key=kv[:, 0], key=kv[:, 0],
value=kv[:, 1], value=kv[:, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -247,6 +255,7 @@ class FlashGemmaAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -36,6 +36,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
def load_qkv(config, prefix: str, weights, head_size, num_heads): def load_qkv(config, prefix: str, weights, head_size, num_heads):
@ -193,6 +194,7 @@ class FlashGPT2Attention(torch.nn.Module):
head_size=self.head_size, head_size=self.head_size,
num_heads=self.num_heads, num_heads=self.num_heads,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = load_row( self.o_proj = load_row(
config, config,
@ -222,7 +224,12 @@ class FlashGPT2Attention(torch.nn.Module):
key = key.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_heads, self.head_size)
kv_cache.store(key=key, value=value, slots=slots) kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -232,6 +239,7 @@ class FlashGPT2Attention(torch.nn.Module):
key=key, key=key,
value=value, value=value,
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -246,6 +254,7 @@ class FlashGPT2Attention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -24,6 +24,7 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
@ -138,6 +139,7 @@ class FlashGPTJAttention(torch.nn.Module):
prefix=prefix, prefix=prefix,
weights=weights, weights=weights,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = load_row( self.o_proj = load_row(
config, config,
@ -184,7 +186,12 @@ class FlashGPTJAttention(torch.nn.Module):
else: else:
self.rotary_emb(query, key, cos, sin) self.rotary_emb(query, key, cos, sin)
kv_cache.store(key=key, value=value, slots=slots) kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -194,6 +201,7 @@ class FlashGPTJAttention(torch.nn.Module):
key=key, key=key,
value=value, value=value,
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -208,6 +216,7 @@ class FlashGPTJAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -27,7 +27,10 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from text_generation_server.layers.attention import KVCache from text_generation_server.layers.attention import (
KVCache,
get_kv_scales,
)
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
@ -156,7 +159,10 @@ class FlashLlamaAttention(torch.nn.Module):
device=weights.device, device=weights.device,
) )
self.softmax_scale = self.head_size**-0.5 # `config.attention_multiplier` is used in Granite
self.softmax_scale = getattr(
config, "attention_multiplier", self.head_size**-0.5
)
if self.num_heads % weights.process_group.size() != 0: if self.num_heads % weights.process_group.size() != 0:
raise ValueError( raise ValueError(
@ -176,11 +182,13 @@ class FlashLlamaAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights, index) self.query_key_value = load_attention(config, prefix, weights, index)
self.index = index self.index = index
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load( o_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
weights=weights, weights=weights,
bias=False, bias=getattr(config, "attention_bias", False),
) )
self.o_proj = TensorParallelAdapterRowLinear.load( self.o_proj = TensorParallelAdapterRowLinear.load(
@ -221,7 +229,12 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -230,6 +243,7 @@ class FlashLlamaAttention(torch.nn.Module):
query=query, query=query,
key=kv[:, 0], key=kv[:, 0],
value=kv[:, 1], value=kv[:, 1],
kv_scales=self.kv_scales,
kv_cache=kv_cache, kv_cache=kv_cache,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
@ -245,6 +259,7 @@ class FlashLlamaAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj( return self.o_proj(
@ -436,6 +451,11 @@ class FlashLlamaLayer(nn.Module):
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
) )
# Used in Granite
# This could eventually be baked into the weights like we do for the embeddings/lm_head
# but this would mean modifying the lora code
self.residual_multiplier = getattr(config, "residual_multiplier", None)
def forward( def forward(
self, self,
hidden_states, hidden_states,
@ -466,13 +486,16 @@ class FlashLlamaLayer(nn.Module):
max_s, max_s,
adapter_data, adapter_data,
) )
if self.residual_multiplier is not None:
attn_output *= self.residual_multiplier
# faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm( normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res attn_output, res
) )
mlp_output = self.dense(normed_attn_res_output, adapter_data) mlp_output = self.dense(normed_attn_res_output, adapter_data)
if self.residual_multiplier is not None:
mlp_output *= self.residual_multiplier
return mlp_output, attn_res return mlp_output, attn_res
@ -624,6 +647,11 @@ class FlashLlamaForCausalLM(torch.nn.Module):
else: else:
suffix = "lm_head" suffix = "lm_head"
# Used in Granite
embedding_multiplier = getattr(config, "embedding_multiplier", None)
if embedding_multiplier is not None:
self.embed_tokens.weight.data *= embedding_multiplier
with no_fp8(weights): with no_fp8(weights):
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
@ -631,6 +659,16 @@ class FlashLlamaForCausalLM(torch.nn.Module):
weights=weights, weights=weights,
) )
# Used in Granite
self.logits_scaling = getattr(config, "logits_scaling", None)
if self.logits_scaling is not None and self.lm_head.head is not None:
try:
# Scale the weights directly
self.lm_head.head.linear.weight.data /= self.logits_scaling
self.logits_scaled = True
except Exception:
self.logits_scaled = False
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -664,4 +702,11 @@ class FlashLlamaForCausalLM(torch.nn.Module):
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states) logits, speculative_logits = self.lm_head(hidden_states)
# Used in Granite
if self.logits_scaling is not None and not self.logits_scaled:
logits /= self.logits_scaling
if speculative_logits is not None:
speculative_logits /= self.logits_scaling
return logits, speculative_logits return logits, speculative_logits

View File

@ -26,6 +26,7 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
@ -158,6 +159,7 @@ class MistralAttention(torch.nn.Module):
], ],
process_group=weights.process_group, process_group=weights.process_group,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load( o_proj = TensorParallelRowLinear.load(
config, config,
@ -208,7 +210,12 @@ class MistralAttention(torch.nn.Module):
else: else:
kv_to_cache = kv kv_to_cache = kv
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots) kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -218,6 +225,7 @@ class MistralAttention(torch.nn.Module):
key=kv_to_cache[:, 0], key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1], value=kv_to_cache[:, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -233,6 +241,7 @@ class MistralAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj( return self.o_proj(

View File

@ -38,6 +38,7 @@ from text_generation_server.layers.attention import (
attention, attention,
paged_attention, paged_attention,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
@ -213,6 +214,7 @@ class MixtralAttention(torch.nn.Module):
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
@ -256,7 +258,12 @@ class MixtralAttention(torch.nn.Module):
else: else:
kv_to_cache = kv kv_to_cache = kv
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots) kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -266,6 +273,7 @@ class MixtralAttention(torch.nn.Module):
key=kv_to_cache[:, 0], key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1], value=kv_to_cache[:, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -281,6 +289,7 @@ class MixtralAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -38,6 +38,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
@ -130,6 +131,7 @@ class FlashNeoxAttention(torch.nn.Module):
head_size=self.head_size, head_size=self.head_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row( self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=True config, prefix=f"{prefix}.dense", weights=weights, bias=True
) )
@ -163,7 +165,12 @@ class FlashNeoxAttention(torch.nn.Module):
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1) qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1) qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
kv_cache.store(key=qkv[:, 1], value=qkv[:, 2], slots=slots) kv_cache.store(
key=qkv[:, 1],
value=qkv[:, 2],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -173,6 +180,7 @@ class FlashNeoxAttention(torch.nn.Module):
key=qkv[:, 1], key=qkv[:, 1],
value=qkv[:, 2], value=qkv[:, 2],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -187,6 +195,7 @@ class FlashNeoxAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) return self.dense(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -18,6 +18,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
@ -137,6 +138,7 @@ class FlashPhiAttention(torch.nn.Module):
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
# in llama the dense layer is called "o_proj" and has bias=False # in llama the dense layer is called "o_proj" and has bias=False
self.dense = TensorParallelRowLinear.load( self.dense = TensorParallelRowLinear.load(
@ -186,7 +188,12 @@ class FlashPhiAttention(torch.nn.Module):
) )
# Reshape key and value and cache # Reshape key and value and cache
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -194,6 +201,7 @@ class FlashPhiAttention(torch.nn.Module):
query=query, query=query,
key=kv[:, 0], key=kv[:, 0],
value=kv[:, 1], value=kv[:, 1],
kv_scales=self.kv_scales,
kv_cache=kv_cache, kv_cache=kv_cache,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
@ -209,6 +217,7 @@ class FlashPhiAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) return self.dense(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -16,6 +16,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
SpeculativeHead, SpeculativeHead,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
@ -84,6 +85,8 @@ class Qwen2Attention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
@ -126,7 +129,12 @@ class Qwen2Attention(torch.nn.Module):
else: else:
kv_to_cache = kv kv_to_cache = kv
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots) kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -136,6 +144,7 @@ class Qwen2Attention(torch.nn.Module):
key=kv_to_cache[:, 0], key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1], value=kv_to_cache[:, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -151,6 +160,7 @@ class Qwen2Attention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -12,6 +12,7 @@ from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import FastLayerNorm from text_generation_server.layers.layernorm import FastLayerNorm
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
@ -158,6 +159,7 @@ class FlashRWAttention(torch.nn.Module):
weights=weights, weights=weights,
bias=config.bias, bias=config.bias,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row( self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
) )
@ -198,7 +200,12 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary # Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -208,6 +215,7 @@ class FlashRWAttention(torch.nn.Module):
key=kv[:, 0], key=kv[:, 0],
value=kv[:, 1], value=kv[:, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -222,6 +230,7 @@ class FlashRWAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -276,6 +285,7 @@ class FlashRWLargeAttention(torch.nn.Module):
weights=weights, weights=weights,
bias=config.bias, bias=config.bias,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row( self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
) )
@ -311,7 +321,10 @@ class FlashRWLargeAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
kv_cache.store( kv_cache.store(
key=kv[:, :, 0].contiguous(), value=kv[:, :, 1].contiguous(), slots=slots key=kv[:, :, 0].contiguous(),
value=kv[:, :, 1].contiguous(),
slots=slots,
kv_scales=self.kv_scales,
) )
# Prefill # Prefill
@ -322,6 +335,7 @@ class FlashRWLargeAttention(torch.nn.Module):
key=kv[:, :, 0], key=kv[:, :, 0],
value=kv[:, :, 1], value=kv[:, :, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -336,6 +350,7 @@ class FlashRWLargeAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.dense( return self.dense(

View File

@ -17,6 +17,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
@ -257,6 +258,7 @@ class FlashMQAttention(torch.nn.Module):
self.c_proj = load_row( self.c_proj = load_row(
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.kv_head_mapping = torch.zeros( self.kv_head_mapping = torch.zeros(
self.num_heads, dtype=torch.int32, device=weights.device self.num_heads, dtype=torch.int32, device=weights.device
) )
@ -282,7 +284,12 @@ class FlashMQAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size)
kv_cache.store(key=key_value[:, 0], value=key_value[:, 1], slots=slots) kv_cache.store(
key=key_value[:, 0],
value=key_value[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -292,6 +299,7 @@ class FlashMQAttention(torch.nn.Module):
key=key_value[:, 0], key=key_value[:, 0],
value=key_value[:, 1], value=key_value[:, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -306,6 +314,7 @@ class FlashMQAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -38,6 +38,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
FastRMSNorm, FastRMSNorm,
@ -188,6 +189,7 @@ class Starcoder2Attention(torch.nn.Module):
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
@ -231,7 +233,12 @@ class Starcoder2Attention(torch.nn.Module):
else: else:
kv_to_cache = kv kv_to_cache = kv
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots) kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -241,6 +248,7 @@ class Starcoder2Attention(torch.nn.Module):
key=kv_to_cache[:, 0], key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1], value=kv_to_cache[:, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -256,6 +264,7 @@ class Starcoder2Attention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -2313,6 +2313,7 @@ class FlashCausalLM(Model):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
head_size=self.head_size, head_size=self.head_size,
page_size=BLOCK_SIZE, page_size=BLOCK_SIZE,
kv_cache_dtype=self.kv_cache_dtype,
dtype=self.dtype, dtype=self.dtype,
window_left=self.sliding_window, window_left=self.sliding_window,
) )

View File

@ -207,7 +207,9 @@ class Weights:
def get_shape(self, tensor_name: str): def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape() return self._get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True): def get_tensor(
self, tensor_name: str, to_device: bool = True, to_dtype: bool = True
) -> torch.Tensor:
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name) tensor = f.get_tensor(tensor_name)

View File

@ -172,6 +172,8 @@ def check_openapi(check: bool):
# allow for trailing whitespace since it's not significant # allow for trailing whitespace since it's not significant
# and the precommit hook will remove it # and the precommit hook will remove it
"lint", "lint",
"--skip-rule",
"security-defined",
filename, filename,
], ],
capture_output=True, capture_output=True,