Merge branch 'main' into auto_length

This commit is contained in:
Nicolas Patry 2024-10-25 10:20:00 +02:00 committed by GitHub
commit c3fb2ecdc0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
82 changed files with 2616 additions and 2395 deletions

68
Cargo.lock generated
View File

@ -2706,9 +2706,9 @@ dependencies = [
[[package]]
name = "opentelemetry"
version = "0.23.0"
version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b69a91d4893e713e06f724597ad630f1fa76057a5e1026c0ca67054a9032a76"
checksum = "4c365a63eec4f55b7efeceb724f1336f26a9cf3427b70e59e2cd2a5b947fba96"
dependencies = [
"futures-core",
"futures-sink",
@ -2819,19 +2819,17 @@ dependencies = [
[[package]]
name = "opentelemetry_sdk"
version = "0.23.0"
version = "0.24.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae312d58eaa90a82d2e627fd86e075cf5230b3f11794e2ed74199ebbe572d4fd"
checksum = "692eac490ec80f24a17828d49b40b60f5aeaccdfe6a503f939713afd22bc28df"
dependencies = [
"async-trait",
"futures-channel",
"futures-executor",
"futures-util",
"glob",
"lazy_static",
"once_cell",
"opentelemetry 0.23.0",
"ordered-float 4.3.0",
"opentelemetry 0.24.0",
"percent-encoding",
"rand",
"thiserror",
@ -4185,16 +4183,17 @@ dependencies = [
"cmake",
"cxx",
"cxx-build",
"hashbrown 0.14.5",
"hf-hub",
"log",
"parking_lot",
"pkg-config",
"text-generation-router",
"thiserror",
"tokenizers 0.19.1",
"tokenizers",
"tokio",
"tokio-stream",
"tracing",
"tracing-opentelemetry 0.24.0",
"tracing-opentelemetry 0.25.0",
"tracing-subscriber",
]
@ -4212,7 +4211,7 @@ dependencies = [
"tabled",
"text-generation-client",
"thiserror",
"tokenizers 0.20.0",
"tokenizers",
"tokio",
"tracing",
"tracing-subscriber",
@ -4292,7 +4291,7 @@ dependencies = [
"serde_json",
"sysinfo",
"thiserror",
"tokenizers 0.20.0",
"tokenizers",
"tokio",
"tokio-stream",
"tower-http",
@ -4341,7 +4340,7 @@ dependencies = [
"slotmap",
"text-generation-router",
"thiserror",
"tokenizers 0.20.0",
"tokenizers",
"tokio",
"tokio-stream",
"tonic 0.10.2",
@ -4392,7 +4391,7 @@ dependencies = [
"slotmap",
"text-generation-router",
"thiserror",
"tokenizers 0.20.0",
"tokenizers",
"tokio",
"tokio-stream",
"tonic 0.10.2",
@ -4514,39 +4513,6 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokenizers"
version = "0.19.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e500fad1dd3af3d626327e6a3fe5050e664a6eaa4708b8ca92f1794aaf73e6fd"
dependencies = [
"aho-corasick",
"derive_builder",
"esaxx-rs",
"getrandom",
"hf-hub",
"indicatif",
"itertools 0.12.1",
"lazy_static",
"log",
"macro_rules_attribute",
"monostate",
"onig",
"paste",
"rand",
"rayon",
"rayon-cond",
"regex",
"regex-syntax 0.8.5",
"serde",
"serde_json",
"spm_precompiled",
"thiserror",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",
]
[[package]]
name = "tokenizers"
version = "0.20.0"
@ -4933,14 +4899,14 @@ dependencies = [
[[package]]
name = "tracing-opentelemetry"
version = "0.24.0"
version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f68803492bf28ab40aeccaecc7021096bd256baf7ca77c3d425d89b35a7be4e4"
checksum = "a9784ed4da7d921bc8df6963f8c80a0e4ce34ba6ba76668acadd3edbd985ff3b"
dependencies = [
"js-sys",
"once_cell",
"opentelemetry 0.23.0",
"opentelemetry_sdk 0.23.0",
"opentelemetry 0.24.0",
"opentelemetry_sdk 0.24.1",
"smallvec",
"tracing",
"tracing-core",

View File

@ -1,23 +0,0 @@
# All the tooling for CUDA
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 AS cuda-builder
WORKDIR /usr/src/tgi/backends/trtllm
RUN apt update && apt install -y cmake git git-lfs gcc g++ ninja-build libopenmpi-dev python3-dev python3-pip wget
COPY . /usr/src/tgi
RUN chmod +x scripts/install_tensorrt.sh && scripts/install_tensorrt.sh
RUN cmake -G Ninja -B build -DTRT_LIB_DIR=/usr/local/tensorrt/lib -DTRT_INCLUDE_DIR=/usr/local/tensorrt/include .
RUN cmake --build build --parallel -t tgi_trtllm_backend_impl
# All the tooling for Rust
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
WORKDIR /usr/src
# Include CUDA related libraries and tools to the Rust based image
COPY --from=cuda-builder /usr/local/cuda /usr/local/cuda
COPY --from=cuda-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=cuda-builder /usr/src/tgi/backends/trtllm/build /usr/local/tgi/trtllm/build
ENV PATH=/usr/local/cuda/bin:$PATH
ENV LD_LIBRARY_PATH=/usr/local/tensorrt/lib:$LD_LIBRARY_PATH
RUN apt update && apt install -y cmake git gcc g++ ninja-build libopenmpi3

View File

@ -10,7 +10,7 @@ COPY . .
RUN cargo chef prepare --recipe-path recipe.json
# CUDA dependent dependencies resolver stage
FROM nvidia/cuda:12.5.1-cudnn-devel-ubuntu22.04 AS cuda-builder
FROM nvidia/cuda:12.6.1-cudnn-devel-ubuntu22.04 AS cuda-builder
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
@ -26,6 +26,7 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
ninja-build \
pkg-config \
python3 \
python3-dev \
python3-setuptools \
tar \
wget
@ -82,10 +83,15 @@ RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$
cd backends/trtllm && \
CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release
FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu22.04 AS runtime
FROM nvidia/cuda:12.6.1-cudnn-runtime-ubuntu22.04 AS runtime
RUN apt update && apt install -y python3 && \
rm -rf /var/lib/{apt,dpkg,cache,log}/
WORKDIR /usr/local/tgi/bin
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
ENV TOKENIZERS_PARALLELISM=false
ENV OMPI_MCA_plm_rsh_agent=""
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt

View File

@ -98,7 +98,7 @@ curl 127.0.0.1:8080/generate_stream \
You can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses.
```bash
curl localhost:3000/v1/chat/completions \
curl localhost:8080/v1/chat/completions \
-X POST \
-d '{
"model": "tgi",

View File

@ -1,5 +1,17 @@
cmake_minimum_required(VERSION 3.20)
if (NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER AND CMAKE_BUILD_TYPE STREQUAL "Debug")
find_program(CCACHE_EXECUTABLE "ccache")
if (CCACHE_EXECUTABLE)
message(STATUS "Using ccache")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_EXECUTABLE}" CACHE PATH "Path to ccache" FORCE)
endif ()
endif ()
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
cmake_policy(SET CMP0135 NEW)
endif ()
project(tgi-trtllm-backend VERSION 1.0.0)
set(CMAKE_CXX_STANDARD 20)
@ -14,7 +26,7 @@ set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include"
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
find_package(CUDAToolkit 12.5 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
find_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
#### External dependencies ####
include(cmake/fmt.cmake)

View File

@ -10,16 +10,17 @@ async-trait = "0.1"
async-stream = "0.3"
clap = { version = "4.5", features = ["derive"] }
cxx = "1.0"
hashbrown = "0.14"
hf-hub = { workspace = true }
log = { version = "0.4", features = [] }
text-generation-router = { path = "../../router" }
tokenizers = { version = "0.19", features = ["hf-hub"] }
tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokenizers = { workspace = true }
tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.15"
thiserror = "1.0.62"
thiserror = "1.0.63"
tracing = "0.1"
tracing-opentelemetry = "0.24"
tracing-opentelemetry = "0.25"
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
parking_lot = "0.12"
[build-dependencies]
cmake = "0.1"

View File

@ -6,7 +6,7 @@ use std::path::{absolute, PathBuf};
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST");
const CUDA_REQUIRED_VERSION: &str = "12.5";
const CUDA_REQUIRED_VERSION: &str = "12.6";
const MPI_REQUIRED_VERSION: &str = "4.1";
const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX");
const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR");
@ -36,7 +36,7 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
// Build the backend implementation through CMake
let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi");
let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt");
let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("90-real"); // Hopper by default
let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("75-real;80-real;86-real;89-real;90-real");
let mut install_path = PathBuf::from(install_path);
if !install_path.is_absolute() {
@ -81,7 +81,12 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
(PathBuf::from(install_path), deps_folder)
}
fn build_ffi_layer(deps_folder: &PathBuf) {
fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) {
let ndebug = match is_debug {
true => "1",
false => "0",
};
CFG.include_prefix = "backends/trtllm";
cxx_build::bridge("src/lib.rs")
.static_flag(true)
@ -93,9 +98,14 @@ fn build_ffi_layer(deps_folder: &PathBuf) {
.include("/usr/local/tensorrt/include")
.file("src/ffi.cpp")
.std("c++20")
.define("NDEBUG", ndebug)
.compile("tgi_trtllm_backend");
println!("cargo:rerun-if-changed=CMakeLists.txt");
println!("cargo:rerun-if-changed=cmake/trtllm.cmake");
println!("cargo:rerun-if-changed=cmake/json.cmake");
println!("cargo:rerun-if-changed=cmake/fmt.cmake");
println!("cargo:rerun-if-changed=cmake/spdlog.cmake");
println!("cargo:rerun-if-changed=include/backend.h");
println!("cargo:rerun-if-changed=lib/backend.cpp");
println!("cargo:rerun-if-changed=include/ffi.h");
@ -115,7 +125,7 @@ fn main() {
let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir);
// Build the FFI layer calling the backend above
build_ffi_layer(&deps_folder);
build_ffi_layer(&deps_folder, is_debug);
// Emit linkage search path
probe!("ompi", MPI_REQUIRED_VERSION);

View File

@ -1,6 +1,6 @@
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt
GIT_TAG 11.0.1
DOWNLOAD_EXTRACT_TIMESTAMP
URL https://github.com/fmtlib/fmt/archive/refs/tags/11.0.2.tar.gz
)
FetchContent_MakeAvailable(fmt)

View File

@ -1,5 +1,6 @@
fetchcontent_declare(
json
DOWNLOAD_EXTRACT_TIMESTAMP
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz
)
fetchcontent_makeavailable(json)

View File

@ -11,7 +11,7 @@ endif ()
fetchcontent_declare(
spdlog
GIT_REPOSITORY https://github.com/gabime/spdlog.git
GIT_TAG v1.14.1
DOWNLOAD_EXTRACT_TIMESTAMP
URL https://github.com/gabime/spdlog/archive/refs/tags/v1.14.1.tar.gz
)
fetchcontent_makeavailable(spdlog)

View File

@ -23,8 +23,9 @@ endif ()
fetchcontent_declare(
trtllm
GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git
GIT_TAG a681853d3803ee5893307e812530b5e7004bb6e1
GIT_TAG 201135e58aa525af7e523d091d4c9584229524bc
GIT_SHALLOW FALSE
DOWNLOAD_EXTRACT_TIMESTAMP
)
fetchcontent_makeavailable(trtllm)

View File

@ -23,6 +23,12 @@ namespace huggingface::tgi::backends {
using RequestId = tle::IdType;
using TokenId = tle::TokenIdType;
const static auto OUTPUT_CONFIG = tle::OutputConfig(true, false, false, true, false);
constexpr auto FMT_EXECUTOR_STATS = FMT_STRING(
"Submitting inference [{}] to the executor ({:d} already in-flight)");
constexpr auto FMT_SAMPLING_CONFIG = FMT_STRING(
"Sampling: topK={:d}, topP={:.1f}, temperature={:.1f}, repetition_penalty={:.1f}, frequency_penalty={:.1f}, seed={:d}");
/**
* Initialize all the components required by TRTLLM.
* It is required to call this function before attempting to load any engine
@ -54,7 +60,7 @@ namespace huggingface::tgi::backends {
float_t repetition_penalty,
float_t frequency_penalty,
uint64_t seed
);
) noexcept;
/**
*
@ -64,18 +70,15 @@ namespace huggingface::tgi::backends {
const json config;
tle::Executor executor;
/** Frequently accessed variables cached here **/
uint32_t maxNumTokens;
public:
explicit TensorRtLlmBackend(
const std::filesystem::path &engineFolder,
const std::filesystem::path &executorWorker
);
/**
* Indicate if the backend is ready to accept incoming request
* @return true if ready, false otherwise
*/
[[nodiscard]] bool IsReady() const;
/**
* Query the executor for the number of token available for pulling
* @return
@ -95,25 +98,16 @@ namespace huggingface::tgi::backends {
*/
[[nodiscard]] RequestId Submit(
const std::vector<TokenId> &tokens,
int32_t topK,
float_t topP,
float_t temperature,
float_t repetition_penalty,
float_t frequency_penalty,
uint64_t seed
const uint32_t maxNewTokens,
const int32_t topK,
const float_t topP,
const float_t temperature,
const float_t repetition_penalty,
const float_t frequency_penalty,
const uint64_t seed
);
/**
*
* @param requestId The request id to poll the generation results
* @return
*/
std::vector<tle::Response> Poll(RequestId requestId);
/**
* Stop the underlying executor
*/
void Shutdown();
[[nodiscard]] std::vector<tle::Response> PullNewTokens();
};
}

View File

@ -5,20 +5,31 @@
#ifndef TGI_TRTLLM_BACKEND_FFI_H
#define TGI_TRTLLM_BACKEND_FFI_H
#include <cmath>
#include <cstddef>
#include <memory>
#include "backend.h"
namespace huggingface::tgi::backends {
class TensorRtLlmBackendImpl;
}
// Template to support returning error from TllmException back to Rust in a Result<>
#include <tensorrt_llm/common/tllmException.h>
namespace rust::behavior {
template<typename Try, typename Fail>
static void trycatch(Try &&func, Fail &&fail) noexcept try {
func();
} catch (tensorrt_llm::common::TllmException &e) {
fail(e.what());
}
}
#include "backends/trtllm/src/lib.rs.h"
namespace huggingface::tgi::backends {
// struct GenerationContext;
class TensorRtLlmBackendImpl : public TensorRtLlmBackend {
public:
/***
@ -28,15 +39,10 @@ namespace huggingface::tgi::backends {
*/
TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker);
/***
*
* @return
*/
bool IsReady() const;
/***
*
* @param tokens
* @param maxNewTokens
* @param topK
* @param topP
* @param temperature
@ -47,21 +53,15 @@ namespace huggingface::tgi::backends {
*/
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
uint64_t
Submit(rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature,
Submit(rust::Slice<const uint32_t> tokens, uint32_t maxNewTokens,
int32_t topK, float_t topP, float_t temperature,
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed);
/***
*
* @param requestId
* @param ctx
* @param callback
* @return
*/
size_t StreamTokens(
const RequestId requestId,
huggingface::tgi::backends::GenerationContext *ctx,
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
huggingface::tgi::backends::GenerationStep)> callback);
std::unique_ptr<std::vector<GenerationStep>> PullTokens();
};
/***

View File

@ -14,7 +14,7 @@
namespace huggingface::hardware::cuda {
#define AMPERE_SM_MAJOR 8
#define HOPPER_SM_MAJOR 8
#define HOPPER_SM_MAJOR 9
/**
* Store information about the version of the CUDA Compute Capabilities detected on the device

View File

@ -1,3 +1,4 @@
#include <cstdlib>
#include <fstream>
#include <fmt/ranges.h>
@ -8,10 +9,23 @@
#include "hardware.h"
void huggingface::tgi::backends::InitializeBackend() {
if (const auto TRTLLM_LOG_LEVEL_CSTR = std::getenv("TRTLLM_LOG_LEVEL")) {
std::string log_level(TRTLLM_LOG_LEVEL_CSTR);
std::transform(log_level.begin(), log_level.end(), log_level.begin(), [](unsigned char c) {
return std::tolower(c);
});
if (log_level == "debug")
spdlog::set_level(spdlog::level::debug);
else
spdlog::set_level(spdlog::level::info);
}
SPDLOG_INFO("Initializing Backend...");
nvmlInit_v2();
initTrtLlmPlugins();
SPDLOG_INFO("Backend Executor Version: {}", tle::version());
const auto numGpus = huggingface::hardware::cuda::GetNumDevices();
if (numGpus.has_value()) {
SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value());
@ -22,7 +36,7 @@ void huggingface::tgi::backends::InitializeBackend() {
[[nodiscard]]
tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
tle::ExecutorConfig execConfig(1);
tle::ExecutorConfig execConfig(/* maxBeamWidth = */ 1);
// Retrieve the compute capabilities to enable some options at runtime
const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities();
@ -55,12 +69,13 @@ tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &co
}
tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
uint32_t topK,
float_t topP,
float_t temperature,
float_t repetition_penalty,
float_t frequency_penalty,
uint64_t seed) {
const uint32_t topK,
const float_t topP,
const float_t temperature,
const float_t repetition_penalty,
const float_t frequency_penalty,
const uint64_t seed) noexcept {
return tle::SamplingConfig(
1, // TGI only use a single beam
topK,
@ -83,26 +98,29 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
const std::filesystem::path &executorWorker
) :
config(json::parse(std::ifstream(enginesFolder / "config.json"))),
executor(
enginesFolder,
tensorrt_llm::executor::ModelType::kDECODER_ONLY,
GetExecutorConfig(config, executorWorker.string()
)) {
executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY,
GetExecutorConfig(config, executorWorker.string())) {
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
}
bool huggingface::tgi::backends::TensorRtLlmBackend::IsReady() const {
return executor.canEnqueueRequests();
// Cache variables
maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();
}
[[nodiscard("Returned number of requests needs to be consumed")]]
size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const {
return executor.getNumResponsesReady();
const auto numResponses = executor.getNumResponsesReady();
#ifndef NDEBUG
if(numResponses > 0) SPDLOG_INFO(FMT_STRING("Num responses ready: {:d}"), numResponses);
#endif
return numResponses;
}
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
const std::vector<tle::TokenIdType> &tokens,
const uint32_t maxNewTokens,
const int32_t topK,
const float_t topP,
const float_t temperature,
@ -110,37 +128,23 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
const float_t frequency_penalty,
const uint64_t seed
) {
#ifdef NDEBUG
SPDLOG_DEBUG(
FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"),
tokens.size(),
executor.getLatestIterationStats().back().numActiveRequests
);
#else
SPDLOG_DEBUG(
FMT_STRING("Submitting inference [{}] to the executor ({:d} already in-flight)"),
fmt::join(tokens, ", "),
executor.getLatestIterationStats().front().numActiveRequests
);
const auto maxNewTokensChecked = std::min(maxNewTokens, static_cast<uint32_t>(maxNumTokens - tokens.size()));
#ifndef NDEBUG
{
const auto &iterations = executor.getLatestIterationStats();
const auto &lastIteration = iterations.front();
SPDLOG_DEBUG(FMT_EXECUTOR_STATS, fmt::join(tokens, ", "), lastIteration.numActiveRequests);
SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
SPDLOG_DEBUG(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked);
}
#endif
const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<size_t>();
const auto maxNewTokens = static_cast<int32_t>(std::max(1ul, maxNumTokens - tokens.size()));
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
const auto output = tle::OutputConfig(true, false, false, true, false);
return executor.enqueueRequest(
tle::Request{tokens, maxNewTokens, true, sampling, output});
const auto maxNewTokensChecked_ = static_cast<tle::SizeType32>(maxNewTokensChecked);
return executor.enqueueRequest(tle::Request{tokens, maxNewTokensChecked_, true, sampling, OUTPUT_CONFIG});
}
[[nodiscard("Generated tokens result must be used")]]
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType requestId) {
SPDLOG_DEBUG(FMT_STRING("Polling status for request {:d}"), requestId);
return executor.awaitResponses(requestId);
}
void huggingface::tgi::backends::TensorRtLlmBackend::Shutdown() {
SPDLOG_INFO("Shutting down executor");
executor.shutdown();
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() {
return executor.awaitResponses();
}

View File

@ -2,12 +2,13 @@
set -ex
TRT_VER="10.2.0.19"
CUDA_VER="12.5"
CUDNN_VER="9.2.1.18-1"
NCCL_VER="2.22.3-1+cuda12.5"
CUBLAS_VER="12.5.3.2-1"
NVRTC_VER="12.5.82-1"
TRT_VER_BASE="10.4.0"
TRT_VER_FULL="${TRT_VER_BASE}.26"
CUDA_VER="12.6"
CUDNN_VER="9.5.0.50-1"
NCCL_VER="2.22.3-1+cuda12.6"
CUBLAS_VER="12.6.3.3-1"
NVRTC_VER="12.6.77-1"
for i in "$@"; do
case $i in
@ -32,8 +33,9 @@ install_ubuntu_requirements() {
ARCH=$(uname -m)
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi
curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/${ARCH}/cuda-keyring_1.0-1_all.deb
dpkg -i cuda-keyring_1.0-1_all.deb
curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/${ARCH}/cuda-keyring_1.1-1_all.deb
dpkg -i cuda-keyring_1.1-1_all.deb
rm /etc/apt/sources.list.d/cuda-ubuntu2404-x86_64.list
apt-get update
if [[ $(apt list --installed | grep libcudnn9) ]]; then
@ -71,7 +73,7 @@ install_centos_requirements() {
install_tensorrt() {
#PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')
#PARSED_PY_VERSION=$(echo "${PY_VERSION//./}")
TRT_CUDA_VERSION="12.5"
TRT_CUDA_VERSION="12.6"
if [ -z "$RELEASE_URL_TRT" ];then
ARCH=${TRT_TARGETARCH}
@ -79,12 +81,12 @@ install_tensorrt() {
if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi
if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-22.04" && OS="ubuntu-22.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi
RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.2.0/tars/TensorRT-${TRT_VER}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz
if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-24.04" && OS="ubuntu-24.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi
RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/${TRT_VER_BASE}/tars/TensorRT-${TRT_VER_FULL}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz
fi
wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar
tar -xf /tmp/TensorRT.tar -C /usr/local/
mv /usr/local/TensorRT-${TRT_VER} /usr/local/tensorrt
mv /usr/local/TensorRT-${TRT_VER_FULL} /usr/local/tensorrt
# pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl
rm -rf /tmp/TensorRT.tar
}

View File

@ -1,330 +0,0 @@
use std::future::Future;
use std::path::Path;
use std::pin::{pin, Pin};
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, OnceLock};
use std::task::{Context, Poll};
use std::time::Duration;
use async_trait::async_trait;
use cxx::UniquePtr;
use log::{error, warn};
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::time::{sleep, Instant};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::{Stream, StreamExt};
use tracing::{instrument, span, Level};
// use tokio::sync::RwLock;
use parking_lot::RwLock;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidationError::UnsupportedModality;
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError};
use text_generation_router::{FinishReason, Token};
use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
// Value used to poll the state of the generation stream
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
type InferResult<T> = Result<T, InferError>;
pub(crate) struct Generation {
executor: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
done: Arc<AtomicBool>,
}
/// Holds the user provided input to be executed along with a channel allowing
/// to bubble up all the generated tokens for that tokens the to end stream.
pub struct GenerationContext {
sender: UnboundedSender<InferResult<InferStreamResponse>>,
tokenizer: Arc<Tokenizer>,
tokens: Vec<u32>,
done: Arc<AtomicBool>,
queued: Instant,
start: Option<Instant>,
}
impl Stream for Generation {
type Item = usize;
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let interval = POLLING_INTERVAL_US.get_or_init(|| {
u64::from_str(option_env!("TRTLLM_BACKEND_POLLING_INTERVAL_US").unwrap_or("100"))
.expect("Invalid value provided for envvar POLLING_INTERVAL_US")
});
if !self.done.load(Ordering::Relaxed) {
let backend = pin!(self.executor.read());
let status = match backend.poll(ctx) {
Poll::Ready(executor_r) => {
let ready = executor_r.num_responses_ready();
if ready == 0 {
Poll::Pending
} else {
Poll::Ready(Some(ready))
}
}
Poll::Pending => Poll::Pending,
};
let waker = ctx.waker().clone();
tokio::spawn(async {
sleep(Duration::from_micros(*interval)).await;
waker.wake();
});
status
} else {
Poll::Ready(None) // end of stream
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(1, None)
}
}
unsafe impl Send for TensorRtLlmBackendImpl {}
unsafe impl Sync for TensorRtLlmBackendImpl {}
/// Implements the logic to execute generation with TensorRT-LLM executor API in background
pub struct TensorRtLlmBackend {
tokenizer: Arc<Tokenizer>,
// Backing the backend behind a RwLock to allow concurrent read access to retrieve
// the number of available tokens (read only) in the Generation stream
backend: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
}
impl TensorRtLlmBackend {
pub fn new<P: AsRef<Path> + Send + 'static, PP: AsRef<Path> + Send + 'static>(
tokenizer: Tokenizer,
engine_folder: P,
executor_worker_path: PP,
) -> Result<Self, TensorRtLlmBackendError> {
Ok(TensorRtLlmBackend {
tokenizer: Arc::new(tokenizer),
backend: Arc::new(RwLock::new(create_tensorrt_llm_backend(
engine_folder.as_ref().to_str().unwrap(),
executor_worker_path.as_ref().to_str().unwrap(),
))),
})
}
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
if request.top_n_tokens > 1 {
return Err(InferError::ValidationError(
ValidationError::TopNTokensDisabled,
));
}
// TODO: Is it really needed? How can it be validated before?
if request.parameters.grammar.is_some() {
return Err(InferError::ValidationError(ValidationError::Grammar));
}
match request.inputs.len() {
0 => Err(InferError::ValidationError(ValidationError::EmptyInput)),
2.. => Err(InferError::GenerationError(
"TensorRT-LLM backend don't support multi-chunk".into(),
)),
1 => match request.inputs.first().expect("Single item-chunk") {
Chunk::Text(text) => Ok(text),
Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))),
},
}
}
fn generate(
&self,
sender: UnboundedSender<InferResult<InferStreamResponse>>,
tokens: Vec<u32>,
top_k: u32,
top_p: f32,
temperature: f32,
repetition_penalty: f32,
frequency_penalty: f32,
seed: u64,
) {
let tokenizer = Arc::clone(&self.tokenizer);
let executor = Arc::clone(&self.backend);
// Let's push this in async context
tokio::spawn(async move {
// Define the generation state
let mut generation = Generation {
executor: executor.clone(),
done: Arc::new(AtomicBool::new(false)),
};
// Define the context over the generation
// TODO(asap): Do we really need so many shared-ownership?
let ctx = Box::new(GenerationContext {
sender: sender.clone(),
tokenizer,
tokens: vec![],
done: Arc::clone(&generation.done),
start: None,
queued: Instant::now(),
});
// We are leaking the context on-purpose to avoid the box being dropped while there are
// still computation ongoing
// TODO(asap): Can we achieve the same with an Arc<Box<T>> without the need to go unsafe?
let ctx_ = Box::leak(ctx);
// Submit the request to the batcher
let request_id = span!(Level::DEBUG, "submit")
.in_scope(|| async {
let mut handle = executor.write().await;
let request_id = handle.pin_mut().submit(
&tokens,
top_k as i32,
top_p,
temperature,
repetition_penalty,
frequency_penalty,
seed,
);
request_id
})
.await;
while let Some(_) = generation.next().await {
let mut executor_w = executor.write().await;
let executor = executor_w.pin_mut();
span!(Level::DEBUG, "decode")
.in_scope(|| async {
unsafe {
executor.stream_tokens(
request_id,
ctx_,
|ctx: *mut GenerationContext, step: GenerationStep| {
let inner_ctx = &mut *ctx;
// Update the timestamp at which the request started effectively
// Can be a bit off, would need to be before the callback, let's see
inner_ctx.start.get_or_insert(Instant::now());
inner_ctx.done.store(step.is_final, Ordering::Relaxed);
// Ensure we are not running into errors
let parcel = if !step.has_error {
// Insert the latest generated token to the tracker
inner_ctx.tokens.push(step.token_id);
// Decode the token
let text = inner_ctx
.tokenizer
.decode(&[step.token_id], true)
.expect("Failed to decode token");
let special = inner_ctx
.tokenizer
.get_added_vocabulary()
.is_special_token(&text);
// Create the structure holding the token
let token = Token {
id: step.token_id,
text,
logprob: step.log_prob,
special,
};
if step.is_final {
let generated_text = inner_ctx
.tokenizer
.decode(&inner_ctx.tokens, true)
.expect("Failed to decode generated_tokens");
Ok(InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text: GeneratedText {
text: generated_text,
generated_tokens: inner_ctx.tokens.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken,
seed: None,
},
start: inner_ctx.start.unwrap_or(Instant::now()),
queued: inner_ctx.queued,
})
} else {
Ok(InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
})
}
} else {
error!("Error caught while decoding: {}", &step.error_msg);
Err(InferError::GenerationError(step.error_msg))
};
// Send the parcel to the client
inner_ctx
.sender
.send(parcel)
.expect("Failed to sent msg through the channel");
},
);
}
})
.await;
}
// "Properly" free the shared context...
// TODO: clean that piece of sh** asap
unsafe {
let _ = Box::from_raw(ctx_);
}
});
}
}
#[async_trait]
impl Backend for TensorRtLlmBackend {
#[instrument(skip_all)]
fn schedule(
&self,
request: ValidGenerateRequest,
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
// Let's add a few more validation
let input = TensorRtLlmBackend::validate(&request)?;
// Channel to stream the generated token as they come from the worker thread back to the transport layer
let (sender, receiver) = unbounded_channel();
// Unpack parameters
let params = &request.parameters;
// Preprocess the inputs to send to TRTLLM backend
let encoding = self
.tokenizer
.encode(input.as_str(), true)
.map_err(|e| InferError::GenerationError(e.to_string()))?;
// Generate the response
self.generate(
sender,
Vec::from(encoding.get_ids()),
params.top_k,
params.top_p,
params.temperature,
params.repetition_penalty,
params.frequency_penalty,
params.seed,
);
Ok(UnboundedReceiverStream::new(receiver))
}
async fn health(&self, _current_health: bool) -> bool {
true
}
}

View File

@ -1,9 +1,16 @@
use std::path::PathBuf;
use thiserror::Error;
use text_generation_router::server;
#[derive(Debug, Error)]
pub enum TensorRtLlmBackendError {
#[error("Provided engine folder {0} doesn't exist")]
EngineFolderDoesntExists(PathBuf),
#[error("Provided executorWorker binary path {0} doesn't exist")]
ExecutorWorkerNotFound(PathBuf),
#[error("TensorRT-LLM Runtime error: {0}")]
Runtime(String),
#[error("Tokenizer error: {0}")]
Tokenizer(String),
#[error("Argument validation error: {0}")]

View File

@ -3,11 +3,13 @@
//
#pragma once
#include <cmath>
#include <algorithm>
#include <exception>
#include <filesystem>
#include <functional>
#include <limits>
#include <iterator>
#include <ranges>
#include <vector>
#include <spdlog/spdlog.h>
@ -20,61 +22,59 @@ huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(
) : TensorRtLlmBackend(engineFolder, executorWorker) {}
bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const {
return TensorRtLlmBackend::IsReady();
}
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, float_t repetition_penalty,
float_t frequency_penalty, uint64_t seed) {
rust::Slice<const uint32_t> tokens, uint32_t maxNewTokens,
int32_t topK, float_t topP, float_t temperature,
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed) {
// This will copy all the items from the initial slice
std::vector<int32_t> tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end()));
std::vector<int32_t> tokens_(tokens.begin(), tokens.end());
return TensorRtLlmBackend::Submit(
std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
std::move(tokens_), maxNewTokens, topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
}
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
const uint64_t requestId,
huggingface::tgi::backends::GenerationContext *ctx,
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
huggingface::tgi::backends::GenerationStep)> callback) {
std::unique_ptr<std::vector<huggingface::tgi::backends::GenerationStep>>
huggingface::tgi::backends::TensorRtLlmBackendImpl::PullTokens() {
const auto responses = TensorRtLlmBackend::PullNewTokens();
size_t numTokens = 0;
for (const auto &item: Poll(requestId)) {
GenerationStep step;
if (!item.hasError()) {
SPDLOG_DEBUG("\tStreamTokens -> Decoding token...");
const auto decoded = item.getResult();
auto steps = std::make_unique<std::vector<GenerationStep>>();
steps->reserve(responses.size());
const auto token = decoded.outputTokenIds[0][0];
const auto isFinal = decoded.isFinal;
const auto logProb = decoded.logProbs.value()[0][0];
#ifndef NDEBUG
SPDLOG_DEBUG(FMT_STRING("Pulled out {:d} new tokens"), responses->size());
#endif
++numTokens;
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
step = huggingface::tgi::backends::GenerationStep{
static_cast<uint32_t>(token), logProb, isFinal, false, std::move(std::string())
// Transform tle::Response to GenerationStep
std::ranges::transform(responses.begin(), responses.end(), std::back_inserter(*steps), [](const tle::Response &r) {
const auto reqId = r.getRequestId();
if (!r.hasError()) {
const auto result = r.getResult();
return GenerationStep{
reqId,
static_cast<uint32_t>(result.outputTokenIds[0][0]),
result.logProbs.value()[0][0],
result.isFinal,
false,
std::string()
};
SPDLOG_DEBUG("\tStreamTokens -> Post callback");
} else {
// TODO : Return rest::Result with error
const auto what = item.getErrorMsg();
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what);
step = huggingface::tgi::backends::GenerationStep{
std::numeric_limits<uint32_t>::max(), 0.0, true, true, std::move(what)
return GenerationStep{
reqId,
0,
0.0,
true,
true,
std::move(r.getErrorMsg())
};
}
});
callback(std::move(ctx), std::move(step));
}
return numTokens;
return steps;
}
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
SPDLOG_INFO("Creating TensorRT-LLM Backend");
// Unconditionally call this to initialize and discover TRTLLM plugins
InitializeBackend();

View File

@ -1,14 +1,16 @@
pub use backend::{GenerationContext, TensorRtLlmBackend};
pub use looper::TensorRtLlmBackendV2;
mod backend;
pub mod errors;
mod looper;
mod utils;
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
mod ffi {
/// Struct used as shared type between rust and C++ to represent the result
/// of a single decoding iteration
#[derive(Debug, Clone)]
pub struct GenerationStep {
request_id: u64,
token_id: u32,
log_prob: f32,
is_final: bool,
@ -16,10 +18,6 @@ mod ffi {
error_msg: String,
}
extern "Rust" {
type GenerationContext;
}
unsafe extern "C++" {
include!("backends/trtllm/src/ffi.cpp");
@ -44,10 +42,7 @@ mod ffi {
fn CreateTensorRtLlmBackend(
engine_folder: &str,
executor_worker: &str,
) -> UniquePtr<TensorRtLlmBackendImpl>;
// #[rust_name = "is_ready"]
// fn IsReady(self: &TensorRtLlmBackendImpl) -> bool;
) -> Result<UniquePtr<TensorRtLlmBackendImpl>>;
#[rust_name = "num_responses_ready"]
fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
@ -56,23 +51,18 @@ mod ffi {
fn Submit(
self: Pin<&mut TensorRtLlmBackendImpl>,
tokens: &[u32],
max_new_tokens: u32,
top_k: i32,
top_p: f32,
temperature: f32,
repetition_penalty: f32,
frequency_penalty: f32,
seed: u64,
) -> u64;
) -> Result<u64>;
#[rust_name = "stream_tokens"]
unsafe fn StreamTokens(
#[rust_name = "pull_tokens"]
fn PullTokens(
self: Pin<&mut TensorRtLlmBackendImpl>,
request_id: u64,
ctx: *mut GenerationContext,
cb: unsafe fn(*mut GenerationContext, GenerationStep),
) -> usize;
// #[rust_name = "shutdown"]
// fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>);
) -> Result<UniquePtr<CxxVector<GenerationStep>>>;
}
}

View File

@ -0,0 +1,395 @@
use std::hint;
use std::ops::Deref;
use std::path::Path;
use async_trait::async_trait;
use cxx::UniquePtr;
use hashbrown::HashMap;
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::TryAcquireError;
use tokio::task::{spawn_blocking, JoinHandle};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, warn};
use text_generation_router::infer::InferError::{GenerationError, ValidationError};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidationError::{
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
};
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
use text_generation_router::{FinishReason, Token};
use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
use crate::utils::first_line;
type InferResult<T> = Result<T, InferError>;
struct IdentifiableRequest<T> {
request_id: u64,
inner: T,
}
/// Wrap the requests along with the channel used to stream back to the client the decoded tokens
struct GenerationContext {
request: ValidGenerateRequest,
start: Option<Instant>,
queued: Instant,
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
}
#[derive(Debug, Copy, Clone)]
struct DecodedToken {
id: u32,
log_prob: f32,
is_final: bool,
}
impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
type Error = InferError;
fn try_from(step: &'step GenerationStep) -> Result<Self, Self::Error> {
if !step.has_error {
Ok(Self {
id: step.token_id,
log_prob: step.log_prob,
is_final: step.is_final,
})
} else {
Err(GenerationError(step.error_msg.clone()))
}
}
}
/// Wraps the decoded token with the channel used to stream back to the client the decoded tokens
struct DecodedTokenContext {
token: DecodedToken,
start: Option<Instant>,
queued: Instant,
channel: UnboundedSender<InferResult<InferStreamResponse>>,
}
fn executor_status_looper(
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
max_inflight_requests: usize,
mut waiting_requests: UnboundedReceiver<GenerationContext>,
post_processor_sender: UnboundedSender<(u64, InferResult<DecodedTokenContext>)>,
) {
// Track the tuple (request_id, stream) for each request
let mut in_flights =
HashMap::<u64, GenerationContext>::with_capacity(max_inflight_requests * 2);
// TODO: Does it need a spin-loop?
'scheduler: loop {
// Is there any request pending to be scheduled?
let awaiting_requests = waiting_requests.len();
for _ in 0..awaiting_requests {
// Retrieve all the requests
if let Some(mut ctx) = waiting_requests.blocking_recv() {
// Submit all the request to the executor and move the context to the in-flight tracker
let request = &ctx.request;
let generation_params = &request.parameters;
let stopping_params = &request.stopping_parameters;
let input_ids = request.input_ids.as_deref();
// Submit to the TensorRT-LLM executor for scheduling
match backend.pin_mut().submit(
&input_ids.unwrap(), // This is checked beforehand in validate()
stopping_params.max_new_tokens,
generation_params.top_k as i32,
generation_params.top_p,
generation_params.temperature,
generation_params.repetition_penalty,
generation_params.frequency_penalty,
generation_params.seed,
) {
Ok(request_id) => {
// Insert the context linked to the generated request id in the tracker
debug!("[in-flight] Added {}", request_id);
ctx.start = Some(Instant::now());
in_flights.insert(request_id, ctx);
}
Err(e) => {
// Return to the caller
let what = e.to_string();
error!(error = what.as_str(), "Failed to schedule request");
let err = Err(InferError::Overloaded(TryAcquireError::NoPermits));
if let Err(_) = ctx.streamer.send(err) {
error!("Failed to send back error to the client");
}
}
};
}
}
if backend.num_responses_ready() > 0 {
match backend.pin_mut().pull_tokens() {
Ok(responses) => {
// Iterate through all the decoded token
for step in responses.deref() {
if let Some(ctx) = in_flights.get(&step.request_id) {
// Remove from tracked requests
let parcel =
DecodedToken::try_from(step).map(|dt| DecodedTokenContext {
token: dt,
start: ctx.start,
queued: ctx.queued,
channel: ctx.streamer.clone(),
});
// Submit the work to p:the post_processor
let posted = post_processor_sender.send((step.request_id, parcel));
if posted.is_err() || step.is_final {
debug!("Removing {}", step.request_id);
let _ = in_flights.remove(&step.request_id);
}
} else {
warn!("Untracked request {}", step.request_id,);
}
}
}
Err(ref err) => {
error!("Failed to get responses from the executor: {}.", err.what());
break 'scheduler;
}
}
}
// Hint the CPU we are spin-locking
hint::spin_loop();
}
}
fn post_processor_looper(
tokenizer: Tokenizer,
max_num_tokens: usize,
max_inflight_requests: usize,
mut decoded_tokens: UnboundedReceiver<(u64, InferResult<DecodedTokenContext>)>,
) {
let mut states: HashMap<u64, Vec<u32>> = HashMap::with_capacity(max_inflight_requests * 2);
'post_processor: loop {
if decoded_tokens.is_closed() {
warn!("Post processor IPC is closed, loop will exit now.");
break 'post_processor;
}
if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() {
match decoded {
Ok(ctx) => {
states
.entry(request_id)
.and_modify(|s| s.push(*&ctx.token.id))
.or_insert_with(|| {
let mut state = Vec::with_capacity(max_num_tokens);
state.push(*&ctx.token.id);
state
});
let out = match tokenizer.decode(&[ctx.token.id], false) {
Ok(text) => {
let is_special =
tokenizer.get_added_vocabulary().is_special_token(&text);
let token = Token {
id: ctx.token.id,
text,
logprob: ctx.token.log_prob,
special: is_special,
};
let out = if !ctx.token.is_final {
InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
}
} else {
let tokens = states.remove(&request_id).unwrap();
let text = tokenizer.decode(&tokens, true);
let generated_text = GeneratedText {
text: text.unwrap(),
generated_tokens: tokens.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken,
seed: None,
};
InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text,
start: ctx.start.unwrap(),
queued: ctx.queued,
}
};
Ok(out)
}
Err(err) => Err(GenerationError(err.to_string())),
};
if let Err(_) = ctx.channel.send(out) {
warn!("Failed to send decoded token back to the user")
}
}
Err(_err) => {
todo!("what do we do?")
}
}
}
}
}
fn ensure_paths_exist<P: AsRef<Path>, PP: AsRef<Path>>(
engine_folder: P,
executor_worker_path: PP,
) -> Result<(String, String), TensorRtLlmBackendError> {
// Retrieve paths as &str for the backend creation
let engine_folder = engine_folder.as_ref();
let executor_worker_path = executor_worker_path.as_ref();
// Ensure the engine folder exists
if !engine_folder.exists() {
let err = TensorRtLlmBackendError::EngineFolderDoesntExists(engine_folder.to_path_buf());
error!("Path validation failed: {}", err,);
return Err(err);
}
// Ensure executor worker binary exists
if !executor_worker_path.exists() {
let err = TensorRtLlmBackendError::ExecutorWorkerNotFound(engine_folder.to_path_buf());
error!("Path validation failed: {}", err,);
return Err(err);
}
let engine_folder = String::from(
engine_folder
.to_str()
.expect("Failed to convert engine_folder to valid UTF-8"),
);
let executor_worker_path = String::from(
executor_worker_path
.to_str()
.expect("Failed to convert executor_worker_path to valid UTF-8"),
);
Ok((engine_folder, executor_worker_path))
}
unsafe impl Send for TensorRtLlmBackendImpl {}
pub struct TensorRtLlmBackendV2 {
executor_looper: JoinHandle<()>,
post_processor_looper: JoinHandle<()>,
executor: UnboundedSender<GenerationContext>,
}
impl TensorRtLlmBackendV2 {
pub fn new<P: AsRef<Path> + Send, PP: AsRef<Path> + Send>(
tokenizer: Tokenizer,
engine_folder: P,
executor_worker_path: PP,
max_inflight_requests: usize,
) -> Result<Self, TensorRtLlmBackendError> {
let (engine_folder, executor_worker_path) =
ensure_paths_exist(engine_folder, executor_worker_path)?;
// Allocate the IPC layer to communicate with the backend
let (executor_sender, executor_receiver) = unbounded_channel();
let (post_processor_sender, post_processor_receiver) = unbounded_channel();
// Create the FFI backend
let backend = create_tensorrt_llm_backend(&engine_folder, &executor_worker_path)
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
// Executor looper is responsible for scheduling and pulling requests state at regular interval
let executor_looper = spawn_blocking(move || {
executor_status_looper(
backend,
max_inflight_requests,
executor_receiver,
post_processor_sender,
)
});
// Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user
let post_processor_looper = spawn_blocking(move || {
post_processor_looper(
tokenizer,
512,
max_inflight_requests,
post_processor_receiver,
)
});
Ok(TensorRtLlmBackendV2 {
executor_looper,
post_processor_looper,
executor: executor_sender,
})
}
fn validate(request: &ValidGenerateRequest) -> InferResult<()> {
if request.input_ids.is_none() {
return Err(ValidationError(UnsupportedModality("No token provided")));
}
if request.top_n_tokens > 1 {
return Err(ValidationError(TopNTokensDisabled));
}
// TODO: Is it really needed? How can it be validated before?
if request.parameters.grammar.is_some() {
return Err(ValidationError(Grammar));
}
match request.inputs.len() {
0 => Err(ValidationError(EmptyInput)),
2.. => Err(GenerationError(
"TensorRT-LLM backend don't support multi-chunk".into(),
)),
1 => match request.inputs.first().expect("Single item-chunk") {
Chunk::Text(_) => Ok(()),
Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
},
}
}
}
#[async_trait]
impl Backend for TensorRtLlmBackendV2 {
fn schedule(
&self,
inner: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
Self::validate(&inner)?;
// Open-up the stream to send tokens
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
// Send the context to the executor for scheduling
let queued = Instant::now();
match self.executor.send(GenerationContext {
request: inner,
start: None,
queued,
streamer,
}) {
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
Err(_) => Err(GenerationError(
"Failed to submit request to the backend".into(),
)),
}
}
async fn health(&self, current_health: bool) -> bool {
current_health
& !self.executor_looper.is_finished()
& !self.post_processor_looper.is_finished()
}
}

View File

@ -1,10 +1,16 @@
use std::path::{Path, PathBuf};
use clap::Parser;
use std::collections::HashMap;
use std::path::PathBuf;
use hf_hub::api::tokio::{Api, ApiBuilder};
use hf_hub::{Cache, Repo, RepoType};
use tokenizers::Tokenizer;
use tracing::info;
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
use text_generation_backends_trtllm::TensorRtLlmBackend;
use text_generation_router::server;
use tokenizers::{FromPretrainedParameters, Tokenizer};
use text_generation_backends_trtllm::TensorRtLlmBackendV2;
use text_generation_router::server::get_base_tokenizer;
use text_generation_router::usage_stats::UsageStatsLevel;
use text_generation_router::{server, HubTokenizerConfig};
/// App Configuration
#[derive(Parser, Debug)]
@ -48,14 +54,138 @@ struct Args {
otlp_service_name: String,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
#[clap(long, env)]
auth_token: Option<String>,
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
executor_worker: PathBuf,
#[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel,
}
async fn get_tokenizer(
tokenizer_name: &str,
tokenizer_config_path: Option<&str>,
revision: Option<&str>,
) -> Option<Tokenizer> {
// Parse Huggingface hub token
let authorization_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();
// Tokenizer instance
let local_path = Path::new(tokenizer_name);
// Shared API builder initialization
let api_builder = || {
let mut builder = ApiBuilder::new()
.with_progress(false)
.with_token(authorization_token);
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
builder = builder.with_cache_dir(cache_dir.into());
}
builder
};
// Decide if we need to use the API based on the revision and local path
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
// Initialize API if needed
#[derive(Clone)]
enum Type {
Api(Api),
Cache(Cache),
None,
}
let api = if use_api {
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
.map_err(|_| ())
.map(|cache_dir| Cache::new(cache_dir.into()))
.unwrap_or_else(|_| Cache::default());
tracing::warn!("Offline mode active using cache defaults");
Type::Cache(cache)
} else {
tracing::info!("Using the Hugging Face API");
match api_builder().build() {
Ok(api) => Type::Api(api),
Err(_) => {
tracing::warn!("Unable to build the Hugging Face API");
Type::None
}
}
}
} else {
Type::None
};
// Load tokenizer and model info
let (
tokenizer_filename,
_config_filename,
tokenizer_config_filename,
_preprocessor_config_filename,
_processor_config_filename,
) = match api {
Type::None => (
Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")),
),
Type::Api(api) => {
let api_repo = api.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.unwrap_or_else(|| "main").to_string(),
));
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
Ok(tokenizer_filename) => Some(tokenizer_filename),
Err(_) => get_base_tokenizer(&api, &api_repo).await,
};
let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
(
tokenizer_filename,
config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
)
}
Type::Cache(cache) => {
let repo = cache.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main").to_string(),
));
(
repo.get("tokenizer.json"),
repo.get("config.json"),
repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"),
repo.get("processor_config.json"),
)
}
};
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{
HubTokenizerConfig::from_file(filename)
} else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
};
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok())
}
#[tokio::main]
@ -83,10 +213,10 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
otlp_endpoint,
otlp_service_name,
cors_allow_origin,
messages_api_enabled,
max_client_batch_size,
auth_token,
executor_worker,
usage_stats,
} = args;
// Launch Tokio runtime
@ -124,18 +254,26 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
)));
}
// Run server
let tokenizer = Tokenizer::from_pretrained(
tokenizer_name.clone(),
Some(FromPretrainedParameters {
revision: revision.clone().unwrap_or(String::from("main")),
user_agent: HashMap::new(),
auth_token,
}),
// Create the backend
let tokenizer = get_tokenizer(
&tokenizer_name,
tokenizer_config_path.as_deref(),
revision.as_deref(),
)
.map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?;
.await
.expect("Failed to retrieve tokenizer implementation");
let backend = TensorRtLlmBackend::new(tokenizer, model_id, executor_worker)?;
info!("Successfully retrieved tokenizer {}", &tokenizer_name);
let backend = TensorRtLlmBackendV2::new(
tokenizer,
model_id,
executor_worker,
max_concurrent_requests,
)?;
info!("Successfully created backend");
// Run server
server::run(
backend,
max_concurrent_requests,
@ -145,7 +283,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
max_input_tokens,
max_total_tokens,
validation_workers,
None,
auth_token,
tokenizer_name,
tokenizer_config_path,
revision,
@ -155,11 +293,9 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
false,
None,
None,
messages_api_enabled,
true,
max_client_batch_size,
false,
false,
usage_stats,
)
.await?;
Ok(())

View File

@ -0,0 +1,22 @@
///
/// Extract the first line of the provided string reference.
/// If there is no lines in the buffer, it returns a string
/// which content is defined by the content of `fail`
/// # Arguments
///
/// * `s`: The string buffer to extract the first-line from
/// * `fail`: A string content which is returned if no lines are
/// present in `s`
///
/// returns: String
///
/// # Examples
///
/// ```
/// let s = "My name is Morgan.\n I'm working at Hugging Face.";
/// first_line(s, "No line in string");
/// ```
#[inline]
pub(crate) fn first_line(s: &str, fail: &str) -> String {
s.lines().next().unwrap_or(fail).to_string()
}

View File

@ -44,6 +44,8 @@ struct Args {
tokenizer_config_path: Option<String>,
#[clap(long, env)]
revision: Option<String>,
#[clap(long, env, value_enum)]
trust_remote_code: bool,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
@ -63,8 +65,6 @@ struct Args {
#[clap(long, env)]
ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
@ -101,6 +101,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_name,
tokenizer_config_path,
revision,
trust_remote_code,
validation_workers,
api_key,
json_output,
@ -110,7 +111,6 @@ async fn main() -> Result<(), RouterError> {
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats,
@ -184,13 +184,13 @@ async fn main() -> Result<(), RouterError> {
tokenizer_name,
tokenizer_config_path,
revision,
trust_remote_code,
hostname,
port,
cors_allow_origin,
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats,

View File

@ -44,6 +44,8 @@ struct Args {
tokenizer_config_path: Option<String>,
#[clap(long, env)]
revision: Option<String>,
#[clap(long, env, value_enum)]
trust_remote_code: bool,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
@ -63,8 +65,6 @@ struct Args {
#[clap(long, env)]
ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
@ -101,6 +101,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_name,
tokenizer_config_path,
revision,
trust_remote_code,
validation_workers,
api_key,
json_output,
@ -110,7 +111,6 @@ async fn main() -> Result<(), RouterError> {
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats,
@ -200,13 +200,13 @@ async fn main() -> Result<(), RouterError> {
tokenizer_name,
tokenizer_config_path,
revision,
trust_remote_code,
hostname,
port,
cors_allow_origin,
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats,

View File

@ -316,6 +316,98 @@
}
}
},
"/invocations": {
"post": {
"tags": [
"Text Generation Inference"
],
"summary": "Generate tokens from Sagemaker request",
"operationId": "sagemaker_compatibility",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SagemakerRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Generated Chat Completion",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SagemakerResponse"
}
},
"text/event-stream": {
"schema": {
"$ref": "#/components/schemas/SagemakerStreamResponse"
}
}
}
},
"422": {
"description": "Input validation error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Input validation error",
"error_type": "validation"
}
}
}
},
"424": {
"description": "Generation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Request failed during generation",
"error_type": "generation"
}
}
}
},
"429": {
"description": "Model is overloaded",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Model is overloaded",
"error_type": "overloaded"
}
}
}
},
"500": {
"description": "Incomplete generation",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Incomplete generation",
"error_type": "incomplete_generation"
}
}
}
}
}
}
},
"/metrics": {
"get": {
"tags": [
@ -1865,6 +1957,45 @@
"type": "string"
}
},
"SagemakerRequest": {
"oneOf": [
{
"$ref": "#/components/schemas/CompatGenerateRequest"
},
{
"$ref": "#/components/schemas/ChatRequest"
},
{
"$ref": "#/components/schemas/CompletionRequest"
}
]
},
"SagemakerResponse": {
"oneOf": [
{
"$ref": "#/components/schemas/GenerateResponse"
},
{
"$ref": "#/components/schemas/ChatCompletion"
},
{
"$ref": "#/components/schemas/CompletionFinal"
}
]
},
"SagemakerStreamResponse": {
"oneOf": [
{
"$ref": "#/components/schemas/StreamResponse"
},
{
"$ref": "#/components/schemas/ChatCompletionChunk"
},
{
"$ref": "#/components/schemas/Chunk"
}
]
},
"SimpleToken": {
"type": "object",
"required": [

View File

@ -141,9 +141,7 @@ TGI can be deployed on various cloud providers for scalable and robust text gene
## Amazon SageMaker
To enable the Messages API in Amazon SageMaker you need to set the environment variable `MESSAGES_API_ENABLED=true`.
This will modify the `/invocations` route to accept Messages dictonaries consisting out of role and content. See the example below on how to deploy Llama with the new Messages API.
Amazon Sagemaker natively supports the message API:
```python
import json
@ -161,12 +159,11 @@ except ValueError:
hub = {
'HF_MODEL_ID':'HuggingFaceH4/zephyr-7b-beta',
'SM_NUM_GPUS': json.dumps(1),
'MESSAGES_API_ENABLED': True
}
# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.0"),
image_uri=get_huggingface_llm_image_uri("huggingface",version="2.3.2"),
env=hub,
role=role,
)

View File

@ -8,6 +8,7 @@ Text Generation Inference enables serving optimized models. The following sectio
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
- [Granite](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct)
- [Gemma](https://huggingface.co/google/gemma-7b)
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)

View File

@ -26,7 +26,6 @@ As of release 2.1.2 this is an example of the data collected:
"max_top_n_tokens": 5,
"max_total_tokens": 2048,
"max_waiting_tokens": 20,
"messages_api_enabled": false,
"model_config": {
"model_type": "Bloom"
},

View File

@ -978,15 +978,16 @@
"nixpkgs": "nixpkgs_6"
},
"locked": {
"lastModified": 1728381423,
"narHash": "sha256-gpHy1WtlA8ZTd8XmxsdCoDd4Z7DE7co37lH7P+nsADA=",
"lastModified": 1729531056,
"narHash": "sha256-dW9IOA31+j3VS19WAWAmkJW2YCzeVZGqd6HpIJfODtI=",
"owner": "huggingface",
"repo": "text-generation-inference-nix",
"rev": "93123736c97e9f7bfe825bfaf3d7de0fc9a21a1e",
"rev": "a84a90281a17b15762873845c947e5c78f5a8dd1",
"type": "github"
},
"original": {
"owner": "huggingface",
"ref": "marlin-kernels-0.3.0",
"repo": "text-generation-inference-nix",
"type": "github"
}

View File

@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
};
nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:huggingface/text-generation-inference-nix";
tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.0";
nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
rust-overlay = {
@ -137,6 +137,11 @@
impure = callPackage ./nix/impure-shell.nix { inherit server; };
impureWithCuda = callPackage ./nix/impure-shell.nix {
inherit server;
withCuda = true;
};
impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix {
server = server.override { flash-attn = python3.pkgs.flash-attn-v1; };
};

View File

@ -11,27 +11,27 @@
},
{
"id": 3923,
"logprob": -5.6328125,
"logprob": -6.1875,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"logprob": -0.93359375,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"logprob": -9.875,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"logprob": -1.1796875,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"logprob": -1.75,
"text": "?"
}
],
@ -39,66 +39,66 @@
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"logprob": -1.109375,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.079956055,
"logprob": -0.005432129,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.2763672,
"logprob": -0.028808594,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37548828,
"logprob": -0.013671875,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4628906,
"logprob": -0.69921875,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02885437,
"logprob": -0.0005874634,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.2565918,
"logprob": -0.026855469,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0063438416,
"logprob": -0.00020885468,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3056641,
"logprob": -0.17773438,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.6035156,
"id": 18065,
"logprob": -0.703125,
"special": false,
"text": " is"
"text": " involves"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
"generated_text": " Deep learning is a subset of machine learning that involves"
}

View File

@ -1,8 +1,8 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 3,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
@ -11,22 +11,22 @@
},
{
"id": 374,
"logprob": -22.96875,
"logprob": -18.0,
"text": " is"
},
{
"id": 5655,
"logprob": -10.71875,
"logprob": -11.75,
"text": " deep"
},
{
"id": 6975,
"logprob": -2.6992188,
"logprob": -2.0625,
"text": " learning"
},
{
"id": 30,
"logprob": -4.8398438,
"logprob": -6.0,
"text": "?"
}
],
@ -34,24 +34,66 @@
"tokens": [
{
"id": 720,
"logprob": -0.4411621,
"logprob": 0.0,
"special": false,
"text": " \n"
},
{
"id": 220,
"logprob": -0.35864258,
"id": 34564,
"logprob": -0.11279297,
"special": false,
"text": " "
"text": "Deep"
},
{
"id": 128001,
"id": 6975,
"logprob": -0.16015625,
"special": false,
"text": " learning"
},
{
"id": 320,
"logprob": -0.25195312,
"special": false,
"text": " ("
},
{
"id": 16931,
"logprob": -1.703125,
"special": false,
"text": "DL"
},
{
"id": 8,
"logprob": 0.0,
"special": true,
"text": "<|end_of_text|>"
"special": false,
"text": ")"
},
{
"id": 374,
"logprob": -1.140625,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 1207,
"logprob": -1.3125,
"special": false,
"text": " sub"
},
{
"id": 2630,
"logprob": 0.0,
"special": false,
"text": "field"
}
],
"top_tokens": null
},
"generated_text": "What is deep learning? \n "
"generated_text": "What is deep learning? \nDeep learning (DL) is a subfield"
}

View File

@ -12,27 +12,27 @@
},
{
"id": 3923,
"logprob": -5.6328125,
"logprob": -6.1875,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"logprob": -0.93359375,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"logprob": -9.875,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"logprob": -1.1796875,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"logprob": -1.75,
"text": "?"
}
],
@ -40,68 +40,68 @@
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"logprob": -1.109375,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"logprob": -0.0047912598,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"logprob": -0.025512695,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"logprob": -0.012145996,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"logprob": -0.72265625,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"logprob": -0.0005760193,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"logprob": -0.02722168,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"logprob": -0.00023651123,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"logprob": -0.17285156,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"id": 18065,
"logprob": -0.703125,
"special": false,
"text": " is"
"text": " involves"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
"generated_text": " Deep learning is a subset of machine learning that involves"
},
{
"details": {
@ -116,27 +116,27 @@
},
{
"id": 3923,
"logprob": -5.6328125,
"logprob": -6.21875,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"logprob": -0.95703125,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"logprob": -9.9375,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"logprob": -1.1328125,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"logprob": -1.75,
"text": "?"
}
],
@ -144,68 +144,68 @@
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"logprob": -1.1796875,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"logprob": -0.005432129,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"logprob": -0.02758789,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"logprob": -0.013366699,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"logprob": -0.6953125,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"logprob": -0.0004863739,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"logprob": -0.02709961,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"logprob": -0.00022506714,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"logprob": -0.19726562,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"id": 18065,
"logprob": -0.77734375,
"special": false,
"text": " is"
"text": " involves"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
"generated_text": " Deep learning is a subset of machine learning that involves"
},
{
"details": {
@ -220,27 +220,27 @@
},
{
"id": 3923,
"logprob": -5.6328125,
"logprob": -6.21875,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"logprob": -0.95703125,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"logprob": -9.9375,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"logprob": -1.1328125,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"logprob": -1.75,
"text": "?"
}
],
@ -248,68 +248,68 @@
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"logprob": -1.1796875,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"logprob": -0.005432129,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"logprob": -0.02758789,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"logprob": -0.013366699,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"logprob": -0.6953125,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"logprob": -0.0004863739,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"logprob": -0.02709961,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"logprob": -0.00022506714,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"logprob": -0.19726562,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"id": 18065,
"logprob": -0.77734375,
"special": false,
"text": " is"
"text": " involves"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
"generated_text": " Deep learning is a subset of machine learning that involves"
},
{
"details": {
@ -324,27 +324,27 @@
},
{
"id": 3923,
"logprob": -5.6328125,
"logprob": -6.21875,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"logprob": -0.95703125,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"logprob": -9.9375,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"logprob": -1.1328125,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"logprob": -1.75,
"text": "?"
}
],
@ -352,67 +352,67 @@
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"logprob": -1.1796875,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"logprob": -0.005432129,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"logprob": -0.02758789,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"logprob": -0.013366699,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"logprob": -0.6953125,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"logprob": -0.0004863739,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"logprob": -0.02709961,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"logprob": -0.00022506714,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"logprob": -0.19726562,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"id": 18065,
"logprob": -0.77734375,
"special": false,
"text": " is"
"text": " involves"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
"generated_text": " Deep learning is a subset of machine learning that involves"
}
]

View File

@ -11,32 +11,32 @@
},
{
"id": 338,
"logprob": -0.7133789,
"logprob": -0.6201172,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9296875,
"logprob": -13.6484375,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.048919678,
"logprob": -0.003894806,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"logprob": -2.6386719,
"text": "?"
},
{
"id": 13,
"logprob": -2.8105469,
"logprob": -6.46875,
"text": "\n"
},
{
"id": 13,
"logprob": -0.84521484,
"logprob": -6.6875,
"text": "\n"
}
],
@ -44,66 +44,66 @@
"tokens": [
{
"id": 25584,
"logprob": -0.017028809,
"logprob": -0.008979797,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0027313232,
"logprob": -8.34465e-07,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"logprob": -0.0009407997,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0623207e-05,
"logprob": -0.0003838539,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5361328,
"id": 385,
"logprob": -0.24499512,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17578125,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.00011539459,
"special": false,
"text": "order"
"text": " an"
},
{
"id": 13883,
"logprob": -0.47436523,
"logprob": -0.010406494,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027680397,
"logprob": -0.00024354458,
"special": false,
"text": " algorithm"
},
{
"id": 15574,
"logprob": -0.6582031,
"special": false,
"text": " commonly"
},
{
"id": 1304,
"logprob": -0.00092840195,
"special": false,
"text": " used"
},
{
"id": 297,
"logprob": -0.19470215,
"special": false,
"text": " in"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
}

View File

@ -5,95 +5,95 @@
"generated_tokens": 10,
"prefill": [
{
"id": 16030,
"id": 338,
"logprob": null,
"text": "is"
},
{
"id": 16030,
"logprob": -13.328125,
"text": "gradient"
},
{
"id": 26815,
"logprob": -6.4960938,
"logprob": -0.24023438,
"text": "descent"
},
{
"id": 29973,
"logprob": -5.1484375,
"logprob": -3.1386719,
"text": "?"
},
{
"id": 13,
"logprob": -4.0351562,
"text": "\n"
},
{
"id": 13,
"logprob": -5.2265625,
"logprob": -3.0878906,
"text": "\n"
}
],
"seed": 0,
"tokens": [
{
"id": 10994,
"logprob": -1.1542969,
"special": false,
"text": "Hello"
},
{
"id": 29991,
"id": 25584,
"logprob": 0.0,
"special": false,
"text": "!"
"text": "Grad"
},
{
"id": 739,
"id": 993,
"logprob": 0.0,
"special": false,
"text": " It"
"text": "ient"
},
{
"id": 2444,
"logprob": -0.42260742,
"special": false,
"text": " seems"
},
{
"id": 366,
"id": 2726,
"logprob": 0.0,
"special": false,
"text": " you"
"text": " Des"
},
{
"id": 29915,
"id": 1760,
"logprob": 0.0,
"special": false,
"text": "'"
"text": "cent"
},
{
"id": 276,
"logprob": -0.9838867,
"id": 313,
"logprob": -0.12322998,
"special": false,
"text": "re"
"text": " ("
},
{
"id": 3211,
"id": 29954,
"logprob": 0.0,
"special": false,
"text": " address"
"text": "G"
},
{
"id": 292,
"id": 29928,
"logprob": 0.0,
"special": false,
"text": "ing"
"text": "D"
},
{
"id": 263,
"logprob": -0.15124512,
"id": 29897,
"logprob": 0.0,
"special": false,
"text": " a"
"text": ")"
},
{
"id": 338,
"logprob": -0.6040039,
"special": false,
"text": " is"
},
{
"id": 385,
"logprob": -0.1796875,
"special": false,
"text": " an"
}
],
"top_tokens": null
},
"generated_text": "What is gradient descent?\n\nHello! It seems you're addressing a"
"generated_text": "What is gradient descent?\nGradient Descent (GD) is an"
}

View File

@ -12,32 +12,32 @@
},
{
"id": 338,
"logprob": -0.7133789,
"logprob": -0.6201172,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9296875,
"logprob": -13.6484375,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.048919678,
"logprob": -0.003894806,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"logprob": -2.6386719,
"text": "?"
},
{
"id": 13,
"logprob": -2.8105469,
"logprob": -6.46875,
"text": "\n"
},
{
"id": 13,
"logprob": -0.84521484,
"logprob": -6.6875,
"text": "\n"
}
],
@ -45,68 +45,68 @@
"tokens": [
{
"id": 25584,
"logprob": -0.017028809,
"logprob": -0.008979797,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0028476715,
"logprob": -8.34465e-07,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023971558,
"logprob": -0.00097084045,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0384789e-05,
"logprob": -0.0003838539,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"id": 385,
"logprob": -0.23840332,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17602539,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.000116467476,
"special": false,
"text": "order"
"text": " an"
},
{
"id": 13883,
"logprob": -0.47436523,
"logprob": -0.010406494,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027871132,
"logprob": -0.0002501011,
"special": false,
"text": " algorithm"
},
{
"id": 15574,
"logprob": -0.6582031,
"special": false,
"text": " commonly"
},
{
"id": 1304,
"logprob": -0.00092840195,
"special": false,
"text": " used"
},
{
"id": 297,
"logprob": -0.18933105,
"special": false,
"text": " in"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
},
{
"details": {
@ -121,32 +121,32 @@
},
{
"id": 338,
"logprob": -0.7128906,
"logprob": -0.6113281,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9375,
"logprob": -13.6640625,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.05053711,
"logprob": -0.003929138,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0058594,
"logprob": -2.625,
"text": "?"
},
{
"id": 13,
"logprob": -2.8242188,
"logprob": -6.484375,
"text": "\n"
},
{
"id": 13,
"logprob": -0.84521484,
"logprob": -6.6875,
"text": "\n"
}
],
@ -154,68 +154,68 @@
"tokens": [
{
"id": 25584,
"logprob": -0.018859863,
"logprob": -0.009017944,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.002822876,
"logprob": -9.536743e-07,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"logprob": -0.00097084045,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0384789e-05,
"logprob": -0.0003838539,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"id": 385,
"logprob": -0.24499512,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17126465,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.0001155138,
"special": false,
"text": "order"
"text": " an"
},
{
"id": 13883,
"logprob": -0.47436523,
"logprob": -0.010406494,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027036667,
"logprob": -0.0002501011,
"special": false,
"text": " algorithm"
},
{
"id": 15574,
"logprob": -0.6435547,
"special": false,
"text": " commonly"
},
{
"id": 1304,
"logprob": -0.0009279251,
"special": false,
"text": " used"
},
{
"id": 297,
"logprob": -0.18933105,
"special": false,
"text": " in"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
},
{
"details": {
@ -230,32 +230,32 @@
},
{
"id": 338,
"logprob": -0.71484375,
"logprob": -0.609375,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9375,
"logprob": -13.671875,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.049346924,
"logprob": -0.0040016174,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"logprob": -2.6230469,
"text": "?"
},
{
"id": 13,
"logprob": -2.8242188,
"logprob": -6.453125,
"text": "\n"
},
{
"id": 13,
"logprob": -0.86328125,
"logprob": -6.6875,
"text": "\n"
}
],
@ -263,68 +263,68 @@
"tokens": [
{
"id": 25584,
"logprob": -0.017196655,
"logprob": -0.008956909,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0028438568,
"logprob": -8.34465e-07,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"logprob": -0.0009407997,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.026558e-05,
"logprob": -0.0003721714,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"id": 385,
"logprob": -0.24499512,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17602539,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.00011622906,
"special": false,
"text": "order"
"text": " an"
},
{
"id": 13883,
"logprob": -0.48608398,
"logprob": -0.010406494,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027894974,
"logprob": -0.0002501011,
"special": false,
"text": " algorithm"
},
{
"id": 15574,
"logprob": -0.6435547,
"special": false,
"text": " commonly"
},
{
"id": 1304,
"logprob": -0.00092601776,
"special": false,
"text": " used"
},
{
"id": 297,
"logprob": -0.19177246,
"special": false,
"text": " in"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
},
{
"details": {
@ -339,32 +339,32 @@
},
{
"id": 338,
"logprob": -0.7192383,
"logprob": -0.609375,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9375,
"logprob": -13.6640625,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.050445557,
"logprob": -0.0038967133,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"logprob": -2.6347656,
"text": "?"
},
{
"id": 13,
"logprob": -2.8242188,
"logprob": -6.453125,
"text": "\n"
},
{
"id": 13,
"logprob": -0.8276367,
"logprob": -6.6875,
"text": "\n"
}
],
@ -372,67 +372,67 @@
"tokens": [
{
"id": 25584,
"logprob": -0.01727295,
"logprob": -0.008979797,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0027542114,
"logprob": -9.536743e-07,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"logprob": -0.0009407997,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0384789e-05,
"logprob": -0.00038409233,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"id": 385,
"logprob": -0.24499512,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17126465,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.00011301041,
"special": false,
"text": "order"
"text": " an"
},
{
"id": 13883,
"logprob": -0.48608398,
"logprob": -0.010414124,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027894974,
"logprob": -0.00024354458,
"special": false,
"text": " algorithm"
},
{
"id": 15574,
"logprob": -0.6435547,
"special": false,
"text": " commonly"
},
{
"id": 1304,
"logprob": -0.0009279251,
"special": false,
"text": " used"
},
{
"id": 297,
"logprob": -0.19470215,
"special": false,
"text": " in"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
}
]

View File

@ -4,7 +4,9 @@ import pytest
@pytest.fixture(scope="module")
def flash_llama_fp8_kv_cache_handle(launcher):
with launcher(
"meta-llama/Meta-Llama-3-8B", num_shard=2, kv_cache_dtype="fp8_e5m2"
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
num_shard=2,
kv_cache_dtype="fp8_e4m3fn",
) as handle:
yield handle
@ -25,7 +27,7 @@ async def test_flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache, response_snaps
assert (
response.generated_text
== " Deep learning is a subset of machine learning that is"
== " Deep learning is a subset of machine learning that involves"
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@ -69,7 +71,7 @@ async def test_flash_llama_fp8_kv_cache_load(
assert len(responses) == 4
assert (
responses[0].generated_text
== " Deep learning is a subset of machine learning that is"
== " Deep learning is a subset of machine learning that involves"
)
assert all(
[r.generated_text == responses[0].generated_text for r in responses]

View File

@ -25,7 +25,7 @@ async def test_flash_phi35_moe(flash_phi35_moe, response_snapshot):
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "Gradient descent is a first-order optimization algorithm"
== "Gradient descent is an optimization algorithm commonly used in"
)
assert response == response_snapshot
@ -33,7 +33,7 @@ async def test_flash_phi35_moe(flash_phi35_moe, response_snapshot):
@pytest.mark.asyncio
async def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot):
response = await flash_phi35_moe.generate(
"What is gradient descent?\n\n",
"What is gradient descent?\n",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
@ -51,7 +51,7 @@ async def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot):
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "What is gradient descent?\n\nHello! It seems you're addressing a"
== "What is gradient descent?\nGradient Descent (GD) is an"
)
assert response == response_snapshot
@ -66,7 +66,7 @@ async def test_flash_phi35_moe_load(flash_phi35_moe, generate_load, response_sna
assert responses[0].details.generated_tokens == 10
assert (
responses[0].generated_text
== "Gradient descent is a first-order optimization algorithm"
== "Gradient descent is an optimization algorithm commonly used in"
)
assert all(
[r.generated_text == responses[0].generated_text for r in responses]

View File

@ -1108,6 +1108,8 @@ fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
}
}
}
} else {
break;
}
}
}
@ -1519,6 +1521,10 @@ fn spawn_webserver(
router_args.push(revision.to_string())
}
if args.trust_remote_code {
router_args.push("--trust-remote-code".to_string());
}
if args.json_output {
router_args.push("--json-output".to_string());
}

View File

@ -1,7 +1,12 @@
{
lib,
mkShell,
black,
cmake,
isort,
ninja,
which,
cudaPackages,
openssl,
pkg-config,
protobuf,
@ -11,14 +16,17 @@
ruff,
rust-bin,
server,
# Enable dependencies for building CUDA packages. Useful for e.g.
# developing marlin/moe-kernels in-place.
withCuda ? false,
}:
mkShell {
buildInputs =
nativeBuildInputs =
[
black
isort
openssl.dev
pkg-config
(rust-bin.stable.latest.default.override {
extensions = [
@ -31,6 +39,19 @@ mkShell {
redocly
ruff
]
++ (lib.optionals withCuda [
cmake
ninja
which
# For most Torch-based extensions, setting CUDA_HOME is enough, but
# some custom CMake builds (e.g. vLLM) also need to have nvcc in PATH.
cudaPackages.cuda_nvcc
]);
buildInputs =
[
openssl.dev
]
++ (with python3.pkgs; [
venvShellHook
docker
@ -40,10 +61,29 @@ mkShell {
pytest
pytest-asyncio
syrupy
]);
])
++ (lib.optionals withCuda (
with cudaPackages;
[
cuda_cccl
cuda_cudart
cuda_nvrtc
cuda_nvtx
cuda_profiler_api
cudnn
libcublas
libcusolver
libcusparse
]
));
inputsFrom = [ server ];
env = lib.optionalAttrs withCuda {
CUDA_HOME = "${lib.getDev cudaPackages.cuda_nvcc}";
TORCH_CUDA_ARCH_LIST = lib.concatStringsSep ";" python3.pkgs.torch.cudaCapabilities;
};
venvDir = "./.venv";
postVenvCreation = ''
@ -51,6 +91,7 @@ mkShell {
( cd server ; python -m pip install --no-dependencies -e . )
( cd clients/python ; python -m pip install --no-dependencies -e . )
'';
postShellHook = ''
unset SOURCE_DATE_EPOCH
export PATH=$PATH:~/.cargo/bin

View File

@ -150,6 +150,7 @@ pub enum Config {
Idefics2(Idefics2),
Ssm,
GptBigcode,
Granite,
Santacoder,
Bloom,
Mpt,

View File

@ -8,6 +8,7 @@ pub mod validation;
mod kserve;
pub mod logging;
mod sagemaker;
pub mod usage_stats;
mod vertex;

View File

@ -1,748 +0,0 @@
use axum::http::HeaderValue;
use clap::Parser;
use clap::Subcommand;
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Cache, Repo, RepoType};
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
use std::fs::File;
use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf};
use text_generation_router::config::Config;
use text_generation_router::usage_stats;
use text_generation_router::{
server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
};
use thiserror::Error;
use tokenizers::{processors::template::TemplateProcessing, Tokenizer};
use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[command(subcommand)]
command: Option<Commands>,
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)]
max_input_tokens: usize,
#[clap(default_value = "2048", long, env)]
max_total_tokens: usize,
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
#[clap(long, env)]
max_batch_total_tokens: Option<u32>,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(long, env)]
max_batch_size: Option<usize>,
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
master_shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String,
#[clap(long, env)]
tokenizer_config_path: Option<String>,
#[clap(long, env)]
revision: Option<String>,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: String,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env)]
api_key: Option<String>,
#[clap(long, env)]
ngrok: bool,
#[clap(long, env)]
ngrok_authtoken: Option<String>,
#[clap(long, env)]
ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
#[clap(long, env, default_value_t)]
disable_usage_stats: bool,
#[clap(long, env, default_value_t)]
disable_crash_reports: bool,
}
#[derive(Debug, Subcommand)]
enum Commands {
PrintSchema,
}
#[tokio::main]
async fn main() -> Result<(), RouterError> {
let args = Args::parse();
// Pattern match configuration
let Args {
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
hostname,
port,
master_shard_uds_path,
tokenizer_name,
tokenizer_config_path,
revision,
validation_workers,
json_output,
otlp_endpoint,
otlp_service_name,
cors_allow_origin,
api_key,
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
disable_usage_stats,
disable_crash_reports,
command,
} = args;
let print_schema_command = match command {
Some(Commands::PrintSchema) => true,
None => {
// only init logging if we are not running the print schema command
init_logging(otlp_endpoint, otlp_service_name, json_output);
false
}
};
// Validate args
if max_input_tokens >= max_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
if validation_workers == 0 {
return Err(RouterError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
));
}
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
}
// CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowOrigin
let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
AllowOrigin::list(
cors_allow_origin
.iter()
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
)
});
// Parse Huggingface hub token
let authorization_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();
// Tokenizer instance
// This will only be used to validate payloads
let local_path = Path::new(&tokenizer_name);
// Shared API builder initialization
let api_builder = || {
let mut builder = ApiBuilder::new()
.with_progress(false)
.with_token(authorization_token);
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
builder = builder.with_cache_dir(cache_dir.into());
}
builder
};
// Decide if we need to use the API based on the revision and local path
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
// Initialize API if needed
#[derive(Clone)]
enum Type {
Api(Api),
Cache(Cache),
None,
}
let api = if use_api {
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
.map_err(|_| ())
.map(|cache_dir| Cache::new(cache_dir.into()))
.unwrap_or_else(|_| Cache::default());
tracing::warn!("Offline mode active using cache defaults");
Type::Cache(cache)
} else {
tracing::info!("Using the Hugging Face API");
match api_builder().build() {
Ok(api) => Type::Api(api),
Err(_) => {
tracing::warn!("Unable to build the Hugging Face API");
Type::None
}
}
}
} else {
Type::None
};
// Load tokenizer and model info
let (
tokenizer_filename,
config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
model_info,
) = match api {
Type::None => (
Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")),
None,
),
Type::Api(api) => {
let api_repo = api.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
));
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
Ok(tokenizer_filename) => Some(tokenizer_filename),
Err(_) => get_base_tokenizer(&api, &api_repo).await,
};
let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
Some(model_info)
} else {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
None
};
(
tokenizer_filename,
config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
model_info,
)
}
Type::Cache(cache) => {
let repo = cache.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
));
(
repo.get("tokenizer.json"),
repo.get("config.json"),
repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"),
repo.get("processor_config.json"),
None,
)
}
};
let config: Option<Config> = config_filename.and_then(|filename| {
std::fs::read_to_string(filename)
.ok()
.as_ref()
.and_then(|c| {
let config: Result<Config, _> = serde_json::from_str(c);
if let Err(err) = &config {
tracing::warn!("Could not parse config {err:?}");
}
config.ok()
})
});
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
model_id: tokenizer_name.to_string(),
sha: None,
pipeline_tag: None,
});
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{
HubTokenizerConfig::from_file(filename)
} else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
};
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
});
let tokenizer_class = tokenizer_config.tokenizer_class.clone();
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
let mut tokenizer = Tokenizer::from_file(filename).ok();
if let Some(tokenizer) = &mut tokenizer {
if let Some(class) = &tokenizer_config.tokenizer_class {
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
tokenizer.with_post_processor(post_processor);
}
}
}
}
tokenizer
});
let preprocessor_config =
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
let processor_config = processor_config_filename
.and_then(HubProcessorConfig::from_file)
.unwrap_or_default();
tracing::info!("Using config {config:?}");
if tokenizer.is_none() {
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
tracing::warn!("Rust input length validation and truncation is disabled");
}
// if pipeline-tag == text-generation we default to return_full_text = true
let compat_return_full_text = match &model_info.pipeline_tag {
None => {
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
true
}
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
};
// Determine the server port based on the feature and environment variable.
let port = if cfg!(feature = "google") {
std::env::var("AIP_HTTP_PORT")
.map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
.unwrap_or(port)
} else {
port
};
let addr = match hostname.parse() {
Ok(ip) => SocketAddr::new(ip, port),
Err(_) => {
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
}
};
// Only send usage stats when TGI is run in container and the function returns Some
let is_container = matches!(usage_stats::is_container(), Ok(true));
let user_agent = if !disable_usage_stats && is_container {
let reduced_args = usage_stats::Args::new(
config.clone(),
tokenizer_class,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
revision,
validation_workers,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
disable_usage_stats,
disable_crash_reports,
);
Some(usage_stats::UserAgent::new(reduced_args))
} else {
None
};
if let Some(ref ua) = user_agent {
let start_event =
usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None);
tokio::spawn(async move {
start_event.send().await;
});
};
// Run server
let result = server::run(
master_shard_uds_path,
model_info,
compat_return_full_text,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
tokenizer,
config,
validation_workers,
addr,
cors_allow_origin,
api_key,
ngrok,
ngrok_authtoken,
ngrok_edge,
tokenizer_config,
preprocessor_config,
processor_config,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
print_schema_command,
)
.await;
match result {
Ok(_) => {
if let Some(ref ua) = user_agent {
let stop_event = usage_stats::UsageStatsEvent::new(
ua.clone(),
usage_stats::EventType::Stop,
None,
);
stop_event.send().await;
};
Ok(())
}
Err(e) => {
if let Some(ref ua) = user_agent {
if !disable_crash_reports {
let error_event = usage_stats::UsageStatsEvent::new(
ua.clone(),
usage_stats::EventType::Error,
Some(e.to_string()),
);
error_event.send().await;
} else {
let unknow_error_event = usage_stats::UsageStatsEvent::new(
ua.clone(),
usage_stats::EventType::Error,
Some("unknow_error".to_string()),
);
unknow_error_event.send().await;
}
};
Err(RouterError::WebServer(e))
}
}
}
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
/// - otlp_service_name service name to appear in APM
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
let mut layers = Vec::new();
// STDOUT/STDERR layer
let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
let fmt_layer = tracing_subscriber::fmt::layer()
.with_file(true)
.with_ansi(ansi)
.with_line_number(true);
let fmt_layer = match json_output {
true => fmt_layer.json().flatten_event(true).boxed(),
false => fmt_layer.boxed(),
};
layers.push(fmt_layer);
// OpenTelemetry tracing layer
if let Some(otlp_endpoint) = otlp_endpoint {
global::set_text_map_propagator(TraceContextPropagator::new());
let tracer = opentelemetry_otlp::new_pipeline()
.tracing()
.with_exporter(
opentelemetry_otlp::new_exporter()
.tonic()
.with_endpoint(otlp_endpoint),
)
.with_trace_config(
trace::config()
.with_resource(Resource::new(vec![KeyValue::new(
"service.name",
otlp_service_name,
)]))
.with_sampler(Sampler::AlwaysOn),
)
.install_batch(opentelemetry::runtime::Tokio);
if let Ok(tracer) = tracer {
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
init_tracing_opentelemetry::init_propagator().unwrap();
};
}
// Filter events with LOG_LEVEL
let varname = "LOG_LEVEL";
let env_filter = if let Ok(log_level) = std::env::var(varname) {
// Override to avoid simple logs to be spammed with tokio level informations
let log_level = match &log_level[..] {
"warn" => "text_generation_launcher=warn,text_generation_router=warn",
"info" => "text_generation_launcher=info,text_generation_router=info",
"debug" => "text_generation_launcher=debug,text_generation_router=debug",
log_level => log_level,
};
EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.parse_lossy(log_level)
} else {
EnvFilter::new("info")
};
tracing_subscriber::registry()
.with(env_filter)
.with(layers)
.init();
}
/// get model info from the Huggingface Hub
pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
let response = api.info_request().send().await.ok()?;
if response.status().is_success() {
let hub_model_info: HubModelInfo =
serde_json::from_str(&response.text().await.ok()?).ok()?;
if let Some(sha) = &hub_model_info.sha {
tracing::info!(
"Serving revision {sha} of model {}",
hub_model_info.model_id
);
}
Some(hub_model_info)
} else {
None
}
}
/// get base tokenizer
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
let config_filename = api_repo.get("config.json").await.ok()?;
// Open the file in read-only mode with buffer.
let file = File::open(config_filename).ok()?;
let reader = BufReader::new(file);
// Read the JSON contents of the file as an instance of `User`.
let config: serde_json::Value = serde_json::from_reader(reader).ok()?;
if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
let api_base_repo = api.repo(Repo::with_revision(
base_model_id.to_string(),
RepoType::Model,
"main".to_string(),
));
api_base_repo.get("tokenizer.json").await.ok()
} else {
None
}
}
/// get tokenizer_config from the Huggingface Hub
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;
// Open the file in read-only mode with buffer.
let file = File::open(tokenizer_config_filename).ok()?;
let reader = BufReader::new(file);
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
.map_err(|e| {
tracing::warn!("Unable to parse tokenizer config: {}", e);
e
})
.ok()?;
Some(tokenizer_config)
}
/// Create a post_processor for the LlamaTokenizer
pub fn create_post_processor(
tokenizer: &Tokenizer,
tokenizer_config: &HubTokenizerConfig,
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);
let bos_token = tokenizer_config.bos_token.as_ref();
let eos_token = tokenizer_config.eos_token.as_ref();
if add_bos_token && bos_token.is_none() {
panic!("add_bos_token = true but bos_token is None");
}
if add_eos_token && eos_token.is_none() {
panic!("add_eos_token = true but eos_token is None");
}
let mut single = Vec::new();
let mut pair = Vec::new();
let mut special_tokens = Vec::new();
if add_bos_token {
if let Some(bos) = bos_token {
let bos_token_id = tokenizer
.token_to_id(bos.as_str())
.expect("Should have found the bos token id");
special_tokens.push((bos.as_str(), bos_token_id));
single.push(format!("{}:0", bos.as_str()));
pair.push(format!("{}:0", bos.as_str()));
}
}
single.push("$A:0".to_string());
pair.push("$A:0".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
let eos_token_id = tokenizer
.token_to_id(eos.as_str())
.expect("Should have found the eos token id");
special_tokens.push((eos.as_str(), eos_token_id));
single.push(format!("{}:0", eos.as_str()));
pair.push(format!("{}:0", eos.as_str()));
}
}
if add_bos_token {
if let Some(bos) = bos_token {
pair.push(format!("{}:1", bos.as_str()));
}
}
pair.push("$B:1".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
pair.push(format!("{}:1", eos.as_str()));
}
}
let post_processor = TemplateProcessing::builder()
.try_single(single)?
.try_pair(pair)?
.special_tokens(special_tokens)
.build()?;
Ok(post_processor)
}
#[derive(Debug, Error)]
enum RouterError {
#[error("Argument validation error: {0}")]
ArgumentValidation(String),
#[error("WebServer error: {0}")]
WebServer(#[from] server::WebServerError),
#[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error),
}
#[cfg(test)]
mod tests {
use super::*;
use text_generation_router::TokenizerConfigToken;
#[test]
fn test_create_post_processor() {
let tokenizer_config = HubTokenizerConfig {
add_bos_token: None,
add_eos_token: None,
bos_token: Some(TokenizerConfigToken::String("<s>".to_string())),
eos_token: Some(TokenizerConfigToken::String("</s>".to_string())),
chat_template: None,
tokenizer_class: None,
completion_template: None,
};
let tokenizer =
Tokenizer::from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", None).unwrap();
let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap();
let expected = TemplateProcessing::builder()
.try_single("<s>:0 $A:0")
.unwrap()
.try_pair("<s>:0 $A:0 <s>:1 $B:1")
.unwrap()
.special_tokens(vec![("<s>".to_string(), 1)])
.build()
.unwrap();
assert_eq!(post_processor, expected);
}
}

82
router/src/sagemaker.rs Normal file
View File

@ -0,0 +1,82 @@
use crate::infer::Infer;
use crate::server::{chat_completions, compat_generate, completions, ComputeType};
use crate::{
ChatCompletion, ChatCompletionChunk, ChatRequest, Chunk, CompatGenerateRequest,
CompletionFinal, CompletionRequest, ErrorResponse, GenerateResponse, Info, StreamResponse,
};
use axum::extract::Extension;
use axum::http::StatusCode;
use axum::response::Response;
use axum::Json;
use serde::{Deserialize, Serialize};
use tracing::instrument;
use utoipa::ToSchema;
#[derive(Clone, Deserialize, ToSchema)]
#[serde(untagged)]
pub(crate) enum SagemakerRequest {
Generate(CompatGenerateRequest),
Chat(ChatRequest),
Completion(CompletionRequest),
}
// Used for OpenAPI specs
#[allow(dead_code)]
#[derive(Serialize, ToSchema)]
#[serde(untagged)]
pub(crate) enum SagemakerResponse {
Generate(GenerateResponse),
Chat(ChatCompletion),
Completion(CompletionFinal),
}
// Used for OpenAPI specs
#[allow(dead_code)]
#[derive(Serialize, ToSchema)]
#[serde(untagged)]
pub(crate) enum SagemakerStreamResponse {
Generate(StreamResponse),
Chat(ChatCompletionChunk),
Completion(Chunk),
}
/// Generate tokens from Sagemaker request
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/invocations",
request_body = SagemakerRequest,
responses(
(status = 200, description = "Generated Chat Completion",
content(
("application/json" = SagemakerResponse),
("text/event-stream" = SagemakerStreamResponse),
)),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation", "error_type": "generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded", "error_type": "overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error", "error_type": "validation"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation", "error_type": "incomplete_generation"})),
)
)]
#[instrument(skip_all)]
pub(crate) async fn sagemaker_compatibility(
default_return_full_text: Extension<bool>,
infer: Extension<Infer>,
compute_type: Extension<ComputeType>,
info: Extension<Info>,
Json(req): Json<SagemakerRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
match req {
SagemakerRequest::Generate(req) => {
compat_generate(default_return_full_text, infer, compute_type, Json(req)).await
}
SagemakerRequest::Chat(req) => chat_completions(infer, compute_type, info, Json(req)).await,
SagemakerRequest::Completion(req) => {
completions(infer, compute_type, info, Json(req)).await
}
}
}

View File

@ -7,6 +7,10 @@ use crate::kserve::{
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
kserve_model_metadata, kserve_model_metadata_ready,
};
use crate::sagemaker::{
sagemaker_compatibility, SagemakerRequest, SagemakerResponse, SagemakerStreamResponse,
__path_sagemaker_compatibility,
};
use crate::validation::ValidationError;
use crate::vertex::vertex_compatibility;
use crate::ChatTokenizeResponse;
@ -83,7 +87,7 @@ example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(skip(infer, req))]
async fn compat_generate(
pub(crate) async fn compat_generate(
Extension(default_return_full_text): Extension<bool>,
infer: Extension<Infer>,
compute_type: Extension<ComputeType>,
@ -678,7 +682,7 @@ time_per_token,
seed,
)
)]
async fn completions(
pub(crate) async fn completions(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>,
@ -1202,7 +1206,7 @@ time_per_token,
seed,
)
)]
async fn chat_completions(
pub(crate) async fn chat_completions(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>,
@ -1513,11 +1517,13 @@ completions,
tokenize,
metrics,
openai_get_model_info,
sagemaker_compatibility,
),
components(
schemas(
Info,
CompatGenerateRequest,
SagemakerRequest,
GenerateRequest,
GrammarType,
ChatRequest,
@ -1540,6 +1546,8 @@ ChatCompletionTopLogprob,
ChatCompletion,
CompletionRequest,
CompletionComplete,
SagemakerResponse,
SagemakerStreamResponse,
Chunk,
Completion,
CompletionFinal,
@ -1601,13 +1609,13 @@ pub async fn run(
tokenizer_name: String,
tokenizer_config_path: Option<String>,
revision: Option<String>,
trust_remote_code: bool,
hostname: String,
port: u16,
cors_allow_origin: Option<Vec<String>>,
ngrok: bool,
_ngrok_authtoken: Option<String>,
_ngrok_edge: Option<String>,
messages_api_enabled: bool,
disable_grammar_support: bool,
max_client_batch_size: usize,
usage_stats_level: usage_stats::UsageStatsLevel,
@ -1761,10 +1769,13 @@ pub async fn run(
let auto = transformers.getattr("AutoTokenizer")?;
let from_pretrained = auto.getattr("from_pretrained")?;
let args = (tokenizer_name.to_string(),);
let kwargs = [(
"revision",
revision.clone().unwrap_or_else(|| "main".to_string()),
)]
let kwargs = [
(
"revision",
(revision.clone().unwrap_or_else(|| "main".to_string())).into_py(py),
),
("trust_remote_code", trust_remote_code.into_py(py)),
]
.into_py_dict_bound(py);
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
let save = tokenizer.getattr("save_pretrained")?;
@ -1836,7 +1847,6 @@ pub async fn run(
// max_batch_size,
revision.clone(),
validation_workers,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats_level,
@ -1878,7 +1888,6 @@ pub async fn run(
ngrok,
_ngrok_authtoken,
_ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
model_info,
@ -1938,7 +1947,6 @@ async fn start(
ngrok: bool,
_ngrok_authtoken: Option<String>,
_ngrok_edge: Option<String>,
messages_api_enabled: bool,
disable_grammar_support: bool,
max_client_batch_size: usize,
model_info: HubModelInfo,
@ -2253,6 +2261,7 @@ async fn start(
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/completions", post(completions))
.route("/vertex", post(vertex_compatibility))
.route("/invocations", post(sagemaker_compatibility))
.route("/tokenize", post(tokenize));
if let Some(api_key) = api_key {
@ -2288,13 +2297,6 @@ async fn start(
.route("/metrics", get(metrics))
.route("/v1/models", get(openai_get_model_info));
// Conditional AWS Sagemaker route
let aws_sagemaker_route = if messages_api_enabled {
Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED
} else {
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
};
let compute_type =
ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
@ -2302,8 +2304,7 @@ async fn start(
let mut app = Router::new()
.merge(swagger_ui)
.merge(base_routes)
.merge(info_routes)
.merge(aws_sagemaker_route);
.merge(info_routes);
#[cfg(feature = "google")]
{

View File

@ -93,7 +93,6 @@ pub struct Args {
// max_batch_size: Option<usize>,
revision: Option<String>,
validation_workers: usize,
messages_api_enabled: bool,
disable_grammar_support: bool,
max_client_batch_size: usize,
usage_stats_level: UsageStatsLevel,
@ -117,7 +116,6 @@ impl Args {
// max_batch_size: Option<usize>,
revision: Option<String>,
validation_workers: usize,
messages_api_enabled: bool,
disable_grammar_support: bool,
max_client_batch_size: usize,
usage_stats_level: UsageStatsLevel,
@ -138,7 +136,6 @@ impl Args {
// max_batch_size,
revision,
validation_workers,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats_level,

View File

@ -31,7 +31,7 @@ install: install-cuda
echo "Installed server"
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm
pip install -e ".[bnb]"
pip install -e ".[bnb,marlin,moe]"
pip install nvidia-nccl-cu12==2.22.3
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm

1379
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -41,10 +41,10 @@ py-cpuinfo = "^9.0.0"
numpy = "^1.26"
marlin-kernels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
]
moe-kernels = [
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },

View File

@ -1,5 +1,5 @@
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -1,5 +1,5 @@
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -1,5 +1,5 @@
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -28,10 +28,11 @@ else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
from .kv_cache import KVCache
from .kv_cache import KVCache, get_kv_scales
__all__ = [
"attention",
"get_kv_scales",
"paged_attention",
"SUPPORTS_WINDOWING",
"KVCache",

View File

@ -1,5 +1,5 @@
import torch
from text_generation_server.layers.attention.kv_cache import KVCache
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import (
ATTENTION,
@ -8,6 +8,7 @@ from text_generation_server.models.globals import (
from text_generation_server.layers.attention import Seqlen
from typing import Optional
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512
@ -21,6 +22,8 @@ def paged_attention(
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
@ -46,6 +49,8 @@ def paged_attention(
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
can_scale = kv_cache.can_scale(kv_scales)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
@ -55,10 +60,13 @@ def paged_attention(
from text_generation_server.layers.attention.flashinfer import decode_state
return decode_state.get().forward(
# TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged.
query.contiguous(),
paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
)
elif ATTENTION == "flashdecoding":
max_q = 1
@ -204,6 +212,7 @@ def attention(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
@ -211,6 +220,8 @@ def attention(
causal: bool = True,
softcap: Optional[float] = None,
):
can_scale = kv_cache.can_scale(kv_scales)
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import (
prefill_with_paged_kv_state,
@ -220,12 +231,15 @@ def attention(
softcap = 0.0
return prefill_with_paged_kv_state.get().forward(
# TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged.
query.contiguous(),
causal=causal,
paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
window_left=window_size_left,
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
)
# If we are using flashdecoding or paged, we always use flash-attn for

View File

@ -204,6 +204,7 @@ def use_decode_state(
num_kv_heads: int,
head_size: int,
page_size: int,
kv_cache_dtype: torch.dtype,
dtype: torch.dtype,
window_left: int,
):
@ -240,7 +241,7 @@ def use_decode_state(
num_kv_heads=num_kv_heads,
head_dim=head_size,
page_size=page_size,
data_type=dtype,
data_type=kv_cache_dtype,
q_data_type=dtype,
window_left=window_left,
)

View File

@ -1,6 +1,6 @@
import intel_extension_for_pytorch as ipex
import torch
from text_generation_server.layers.attention.kv_cache import KVCache
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen
from typing import Optional
@ -14,6 +14,7 @@ def attention(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
@ -55,6 +56,8 @@ def paged_attention(
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
):
if softcap is not None:

View File

@ -1,8 +1,38 @@
from typing import Tuple
from dataclasses import dataclass, field
from loguru import logger
import torch
from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weights
@dataclass
class KVScales:
"""
Key-value scales for FP8 KV cache.
This data class stores key and value scales both as a GPU tensor and
as a GPU float. This inconvenience is necessary because some functions
(e.g. scaling kernels) take scales as a GPU tensor, whereas others
(e.g. flashinfer) take scales as a CPU scalar.
"""
key_scale: torch.Tensor
value_scale: torch.Tensor
key_scale_cpu: float = field(init=False)
value_scale_cpu: float = field(init=False)
def __post_init__(self):
if self.key_scale.numel() != 1 or self.value_scale.numel() != 1:
raise ValueError("Key and value scales must be scalar tensors.")
self.key_scale_cpu = self.key_scale.item()
self.value_scale_cpu = self.value_scale.item()
class KVCache:
@ -76,6 +106,33 @@ class KVCache:
),
)
def can_scale(self, kv_scales: KVScales) -> bool:
"""Check if the cache can be scaled by the given scales."""
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
return False
elif (
self.dtype == torch.float8_e4m3fn
and ATTENTION == "flashinfer"
and SYSTEM == "cuda"
):
log_once(
logger.info,
"Using FP8 KV cache scales",
)
return True
else:
# We have scales, but not the correct FP8 cache type, so warn once.
log_once(
logger.info,
"Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported",
)
return False
@property
def dtype(self):
"""Get the data type of the cache."""
return self.kv_cache[0].dtype
@property
def key(self):
"""Get the key cache."""
@ -94,17 +151,33 @@ class KVCache:
key: torch.Tensor,
value: torch.Tensor,
slots: torch.Tensor,
kv_scales: KVScales,
):
"""Store the key and value at the given slots."""
key_cache = self.kv_cache[0]
value_cache = self.kv_cache[1]
if self.can_scale(kv_scales):
if kv_scales.key_scale_cpu != 1.0:
key = fp8_quantize(
key.float(),
scale=kv_scales.key_scale,
qdtype=self.dtype,
scalar=True,
)[0]
if kv_scales.value_scale_cpu != 1.0:
value = fp8_quantize(
value.float(),
scale=kv_scales.value_scale,
qdtype=self.dtype,
scalar=True,
)[0]
if ATTENTION in {"flashdecoding", "flashinfer"}:
# TODO: add scale
key = key.to(key_cache.dtype)
value = value.to(value_cache.dtype)
if key_cache.dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
# Torch index_put does not support float8_{e5m2,e4m3fn} yet, so
# put as raw data instead.
key_cache = key_cache.view(torch.uint8)
@ -151,5 +224,23 @@ def paged_reshape_and_cache(
)
else:
raise NotImplementedError(
f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supportedattention"
f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supported"
)
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
"""Load KV cache scales."""
key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device)
value_scale = key_scale
if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor(
f"{prefix}.v_scale"
):
key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float()
value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float()
elif weights.has_tensor(f"{prefix}.kv_scale"):
# Fall back to older more coarse-grained scale when available.
key_scale = weights.get_tensor(f"{prefix}.kv_scale").float()
value_scale = key_scale
return KVScales(key_scale=key_scale, value_scale=value_scale)

View File

@ -1,7 +1,7 @@
import os
from typing import Optional
import torch
from text_generation_server.layers.attention.kv_cache import KVCache
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils.log import log_master
@ -36,6 +36,8 @@ def paged_attention(
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
@ -210,6 +212,7 @@ def attention(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,

View File

@ -26,6 +26,12 @@ def is_fbgemm_gpu_available():
return False
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
if is_fbgemm_gpu_available():
if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability()
@ -94,6 +100,17 @@ def fp8_quantize(
)
return qweight, scale
if marlin_kernels is not None:
shape = weight.shape
qweight, scale = marlin_kernels.scaled_fp8_quant(
weight.reshape(-1, shape[-1]),
dtype=qdtype,
scale=scale,
scale_ub=scale_upper_bound,
)
return qweight.reshape(shape), scale
# weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch.finfo(qdtype)

View File

@ -11,7 +11,7 @@ from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
if SYSTEM == "ipex":
from .ipex import QuantLinear
elif SYSTEM in {"cuda", "rocm"}:
from .cuda import QuantLinear
from .triton import QuantLinear
@dataclass

View File

@ -195,6 +195,11 @@ class ModelType(enum.Enum):
"name": "Phi 3",
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
}
GRANITE = {
"type": "granite",
"name": "Granite",
"url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct",
}
GEMMA = {
"type": "gemma",
"name": "Gemma",
@ -862,7 +867,12 @@ def get_model(
trust_remote_code=trust_remote_code,
)
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
elif (
model_type == LLAMA
or model_type == BAICHUAN
or model_type == PHI3
or model_type == GRANITE
):
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
@ -876,7 +886,9 @@ def get_model(
lora_adapter_ids=lora_adapter_ids,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
)
else:
return CausalLM.fallback(
model_id,

View File

@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
attention,
Seqlen,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -227,6 +228,7 @@ class FlashCohereAttention(torch.nn.Module):
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.use_qk_norm = config.use_qk_norm
if self.use_qk_norm:
@ -289,7 +291,12 @@ class FlashCohereAttention(torch.nn.Module):
self.rotary_emb(query, key, cos, sin)
kv_cache.store(key=key, value=value, slots=slots)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -299,6 +306,7 @@ class FlashCohereAttention(torch.nn.Module):
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -313,6 +321,7 @@ class FlashCohereAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.o_proj(

View File

@ -20,6 +20,7 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM != "ipex":
@ -288,6 +289,7 @@ class DbrxAttention(torch.nn.Module):
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
@ -328,7 +330,12 @@ class DbrxAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -338,6 +345,7 @@ class DbrxAttention(torch.nn.Module):
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -352,6 +360,7 @@ class DbrxAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -34,6 +34,7 @@ from text_generation_server.layers.attention import (
attention,
paged_attention,
)
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
@ -230,6 +231,8 @@ class DeepseekV2Attention(torch.nn.Module):
),
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.kv_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
)
@ -258,7 +261,7 @@ class DeepseekV2Attention(torch.nn.Module):
cos: torch.Tensor,
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
kv_cache: KVCache,
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
@ -319,7 +322,12 @@ class DeepseekV2Attention(torch.nn.Module):
value, (0, self.head_pad_size - self.value_head_size), value=0
)
kv_cache.store(key=key, value=value, slots=slots)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -329,6 +337,7 @@ class DeepseekV2Attention(torch.nn.Module):
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -343,6 +352,7 @@ class DeepseekV2Attention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
# Remove padding.

View File

@ -39,6 +39,7 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
@ -206,6 +207,7 @@ class FlashGemma2Attention(torch.nn.Module):
],
process_group=weights.process_group,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load(
config,
@ -251,7 +253,12 @@ class FlashGemma2Attention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -261,6 +268,7 @@ class FlashGemma2Attention(torch.nn.Module):
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -278,6 +286,7 @@ class FlashGemma2Attention(torch.nn.Module):
seqlen,
max_s,
softcap=self.softcap,
kv_scales=self.kv_scales,
)
return self.o_proj(

View File

@ -37,6 +37,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
@ -185,6 +186,7 @@ class FlashGemmaAttention(torch.nn.Module):
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
@ -222,7 +224,12 @@ class FlashGemmaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -232,6 +239,7 @@ class FlashGemmaAttention(torch.nn.Module):
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -247,6 +255,7 @@ class FlashGemmaAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -36,6 +36,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
def load_qkv(config, prefix: str, weights, head_size, num_heads):
@ -193,6 +194,7 @@ class FlashGPT2Attention(torch.nn.Module):
head_size=self.head_size,
num_heads=self.num_heads,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = load_row(
config,
@ -222,7 +224,12 @@ class FlashGPT2Attention(torch.nn.Module):
key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size)
kv_cache.store(key=key, value=value, slots=slots)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -232,6 +239,7 @@ class FlashGPT2Attention(torch.nn.Module):
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -246,6 +254,7 @@ class FlashGPT2Attention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -24,6 +24,7 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
@ -138,6 +139,7 @@ class FlashGPTJAttention(torch.nn.Module):
prefix=prefix,
weights=weights,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = load_row(
config,
@ -184,7 +186,12 @@ class FlashGPTJAttention(torch.nn.Module):
else:
self.rotary_emb(query, key, cos, sin)
kv_cache.store(key=key, value=value, slots=slots)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -194,6 +201,7 @@ class FlashGPTJAttention(torch.nn.Module):
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -208,6 +216,7 @@ class FlashGPTJAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -27,7 +27,10 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from text_generation_server.layers.attention import KVCache
from text_generation_server.layers.attention import (
KVCache,
get_kv_scales,
)
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
@ -156,7 +159,10 @@ class FlashLlamaAttention(torch.nn.Module):
device=weights.device,
)
self.softmax_scale = self.head_size**-0.5
# `config.attention_multiplier` is used in Granite
self.softmax_scale = getattr(
config, "attention_multiplier", self.head_size**-0.5
)
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
@ -176,11 +182,13 @@ class FlashLlamaAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights, index)
self.index = index
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
bias=getattr(config, "attention_bias", False),
)
self.o_proj = TensorParallelAdapterRowLinear.load(
@ -221,7 +229,12 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -230,6 +243,7 @@ class FlashLlamaAttention(torch.nn.Module):
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_scales=self.kv_scales,
kv_cache=kv_cache,
seqlen=seqlen,
block_tables=block_tables,
@ -245,6 +259,7 @@ class FlashLlamaAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.o_proj(
@ -436,6 +451,11 @@ class FlashLlamaLayer(nn.Module):
eps=config.rms_norm_eps,
)
# Used in Granite
# This could eventually be baked into the weights like we do for the embeddings/lm_head
# but this would mean modifying the lora code
self.residual_multiplier = getattr(config, "residual_multiplier", None)
def forward(
self,
hidden_states,
@ -466,13 +486,16 @@ class FlashLlamaLayer(nn.Module):
max_s,
adapter_data,
)
if self.residual_multiplier is not None:
attn_output *= self.residual_multiplier
# faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
)
mlp_output = self.dense(normed_attn_res_output, adapter_data)
if self.residual_multiplier is not None:
mlp_output *= self.residual_multiplier
return mlp_output, attn_res
@ -624,6 +647,11 @@ class FlashLlamaForCausalLM(torch.nn.Module):
else:
suffix = "lm_head"
# Used in Granite
embedding_multiplier = getattr(config, "embedding_multiplier", None)
if embedding_multiplier is not None:
self.embed_tokens.weight.data *= embedding_multiplier
with no_fp8(weights):
self.lm_head = SpeculativeHead.load(
config,
@ -631,6 +659,16 @@ class FlashLlamaForCausalLM(torch.nn.Module):
weights=weights,
)
# Used in Granite
self.logits_scaling = getattr(config, "logits_scaling", None)
if self.logits_scaling is not None and self.lm_head.head is not None:
try:
# Scale the weights directly
self.lm_head.head.linear.weight.data /= self.logits_scaling
self.logits_scaled = True
except Exception:
self.logits_scaled = False
def forward(
self,
input_ids: torch.Tensor,
@ -664,4 +702,11 @@ class FlashLlamaForCausalLM(torch.nn.Module):
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
# Used in Granite
if self.logits_scaling is not None and not self.logits_scaled:
logits /= self.logits_scaling
if speculative_logits is not None:
speculative_logits /= self.logits_scaling
return logits, speculative_logits

View File

@ -26,6 +26,7 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
@ -158,6 +159,7 @@ class MistralAttention(torch.nn.Module):
],
process_group=weights.process_group,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load(
config,
@ -208,7 +210,12 @@ class MistralAttention(torch.nn.Module):
else:
kv_to_cache = kv
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -218,6 +225,7 @@ class MistralAttention(torch.nn.Module):
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -233,6 +241,7 @@ class MistralAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.o_proj(

View File

@ -38,6 +38,7 @@ from text_generation_server.layers.attention import (
attention,
paged_attention,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding
@ -213,6 +214,7 @@ class MixtralAttention(torch.nn.Module):
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
@ -256,7 +258,12 @@ class MixtralAttention(torch.nn.Module):
else:
kv_to_cache = kv
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -266,6 +273,7 @@ class MixtralAttention(torch.nn.Module):
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -281,6 +289,7 @@ class MixtralAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -38,6 +38,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
@ -130,6 +131,7 @@ class FlashNeoxAttention(torch.nn.Module):
head_size=self.head_size,
hidden_size=self.hidden_size,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=True
)
@ -163,7 +165,12 @@ class FlashNeoxAttention(torch.nn.Module):
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
kv_cache.store(key=qkv[:, 1], value=qkv[:, 2], slots=slots)
kv_cache.store(
key=qkv[:, 1],
value=qkv[:, 2],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -173,6 +180,7 @@ class FlashNeoxAttention(torch.nn.Module):
key=qkv[:, 1],
value=qkv[:, 2],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -187,6 +195,7 @@ class FlashNeoxAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -18,6 +18,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
@ -137,6 +138,7 @@ class FlashPhiAttention(torch.nn.Module):
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
# in llama the dense layer is called "o_proj" and has bias=False
self.dense = TensorParallelRowLinear.load(
@ -186,7 +188,12 @@ class FlashPhiAttention(torch.nn.Module):
)
# Reshape key and value and cache
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -194,6 +201,7 @@ class FlashPhiAttention(torch.nn.Module):
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_scales=self.kv_scales,
kv_cache=kv_cache,
seqlen=seqlen,
block_tables=block_tables,
@ -209,6 +217,7 @@ class FlashPhiAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -16,6 +16,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding,
SpeculativeHead,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
@ -84,6 +85,8 @@ class Qwen2Attention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
@ -126,7 +129,12 @@ class Qwen2Attention(torch.nn.Module):
else:
kv_to_cache = kv
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -136,6 +144,7 @@ class Qwen2Attention(torch.nn.Module):
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -151,6 +160,7 @@ class Qwen2Attention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -12,6 +12,7 @@ from text_generation_server.layers import (
TensorParallelRowLinear,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import FastLayerNorm
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.attention import (
@ -158,6 +159,7 @@ class FlashRWAttention(torch.nn.Module):
weights=weights,
bias=config.bias,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
)
@ -198,7 +200,12 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -208,6 +215,7 @@ class FlashRWAttention(torch.nn.Module):
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -222,6 +230,7 @@ class FlashRWAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -276,6 +285,7 @@ class FlashRWLargeAttention(torch.nn.Module):
weights=weights,
bias=config.bias,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
)
@ -311,7 +321,10 @@ class FlashRWLargeAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
kv_cache.store(
key=kv[:, :, 0].contiguous(), value=kv[:, :, 1].contiguous(), slots=slots
key=kv[:, :, 0].contiguous(),
value=kv[:, :, 1].contiguous(),
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
@ -322,6 +335,7 @@ class FlashRWLargeAttention(torch.nn.Module):
key=kv[:, :, 0],
value=kv[:, :, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -336,6 +350,7 @@ class FlashRWLargeAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.dense(

View File

@ -17,6 +17,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import (
FastLayerNorm,
@ -257,6 +258,7 @@ class FlashMQAttention(torch.nn.Module):
self.c_proj = load_row(
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.kv_head_mapping = torch.zeros(
self.num_heads, dtype=torch.int32, device=weights.device
)
@ -282,7 +284,12 @@ class FlashMQAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size)
kv_cache.store(key=key_value[:, 0], value=key_value[:, 1], slots=slots)
kv_cache.store(
key=key_value[:, 0],
value=key_value[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -292,6 +299,7 @@ class FlashMQAttention(torch.nn.Module):
key=key_value[:, 0],
value=key_value[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -306,6 +314,7 @@ class FlashMQAttention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -38,6 +38,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import (
FastLayerNorm,
FastRMSNorm,
@ -188,6 +189,7 @@ class Starcoder2Attention(torch.nn.Module):
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
@ -231,7 +233,12 @@ class Starcoder2Attention(torch.nn.Module):
else:
kv_to_cache = kv
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
@ -241,6 +248,7 @@ class Starcoder2Attention(torch.nn.Module):
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
@ -256,6 +264,7 @@ class Starcoder2Attention(torch.nn.Module):
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -2313,6 +2313,7 @@ class FlashCausalLM(Model):
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
page_size=BLOCK_SIZE,
kv_cache_dtype=self.kv_cache_dtype,
dtype=self.dtype,
window_left=self.sliding_window,
)

View File

@ -207,7 +207,9 @@ class Weights:
def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True):
def get_tensor(
self, tensor_name: str, to_device: bool = True, to_dtype: bool = True
) -> torch.Tensor:
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name)

View File

@ -172,6 +172,8 @@ def check_openapi(check: bool):
# allow for trailing whitespace since it's not significant
# and the precommit hook will remove it
"lint",
"--skip-rule",
"security-defined",
filename,
],
capture_output=True,