This commit is contained in:
Tzu-Yu Lee 2025-06-14 22:30:03 +08:00 committed by GitHub
commit dd2dcb4d66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 216 additions and 38 deletions

View File

@ -119,7 +119,7 @@ struct Args {
#[clap(default_value = "3000", long, short, env)] #[clap(default_value = "3000", long, short, env)]
port: u16, port: u16,
#[clap(default_value = "9000", long, short, env)] #[clap(default_value = "9000", long, env)]
prometheus_port: u16, prometheus_port: u16,
/// Enable JSON output format. /// Enable JSON output format.

View File

@ -59,7 +59,14 @@ namespace huggingface::tgi::backends::trtllm {
static_cast<tle::SizeType32>(g_params.max_new_tokens), static_cast<tle::SizeType32>(g_params.max_new_tokens),
true, true,
(tle::SamplingConfig) s_params, (tle::SamplingConfig) s_params,
tle::OutputConfig{ /* returnLogProbs= */ true}, tle::OutputConfig{
/* returnLogProbs= */ true,
false,
false,
false,
false,
/* returnPerfMetrics=*/ true,
},
std::nullopt, std::nullopt,
std::nullopt, std::nullopt,
std::nullopt, std::nullopt,

View File

@ -1,6 +1,8 @@
#ifndef TGI_BACKEND_TRTLLM_FFI #ifndef TGI_BACKEND_TRTLLM_FFI
#define TGI_BACKEND_TRTLLM_FFI #define TGI_BACKEND_TRTLLM_FFI
#include <chrono>
#include <exception>
#include <memory> #include <memory>
#include <thread> #include <thread>
@ -17,7 +19,7 @@ namespace rust::behavior {
template<typename Try, typename Fail> template<typename Try, typename Fail>
static void trycatch(Try &&func, Fail &&fail) noexcept try { static void trycatch(Try &&func, Fail &&fail) noexcept try {
func(); func();
} catch (tensorrt_llm::common::TllmException &e) { } catch (const std::exception &e) {
fail(e.what()); fail(e.what());
} }
} }
@ -42,22 +44,46 @@ namespace huggingface::tgi::backends::trtllm {
return finish_reason_t::kEND_ID; return finish_reason_t::kEND_ID;
case tle::FinishReason::kLENGTH: case tle::FinishReason::kLENGTH:
return finish_reason_t::kLENGTH; return finish_reason_t::kLENGTH;
case tle::FinishReason::kTIMED_OUT:
return finish_reason_t::kTIMED_OUT;
case tle::FinishReason::kCANCELLED:
return finish_reason_t::kCANCELLED;
default: default:
std::unreachable(); std::unreachable();
} }
} }
static auto as_generation_step = [](const tle::Response &r) { static auto as_generation_step = [](const tle::Response &r, const std::chrono::time_point<std::chrono::steady_clock> created) {
const auto reqId = r.getRequestId(); const auto reqId = r.getRequestId();
if (!r.hasError()) [[likely]] { if (!r.hasError()) [[likely]] {
const auto result = r.getResult(); const auto result = r.getResult();
const auto logits = result.logProbs.value()[0]; std::optional<uint32_t> token_id = std::nullopt;
if (!result.outputTokenIds.empty() && !result.outputTokenIds[0].empty()) {
token_id = static_cast<uint32_t>(result.outputTokenIds[0][0]);
}
std::optional<float> log_prob = std::nullopt;
if (result.logProbs && !result.logProbs->empty() && !result.logProbs.value()[0].empty()) {
log_prob = result.logProbs.value()[0].back();
}
std::optional<int64_t> first_scheduled_time_ns = std::nullopt;
if (result.requestPerfMetrics) {
const auto &t = result.requestPerfMetrics->timingMetrics;
const auto ns = std::chrono::duration_cast<std::chrono::nanoseconds>(t.firstScheduledTime - created).count();
first_scheduled_time_ns = static_cast<int64_t>(ns);
}
return generation_step_t{ return generation_step_t{
reqId, reqId,
static_cast<uint32_t>(result.outputTokenIds[0][0]), token_id.value_or(0),
logits.back(), log_prob.value_or(0.0),
first_scheduled_time_ns.value_or(0),
result.isFinal, result.isFinal,
as_finish_reason_t(result.finishReasons[0]), as_finish_reason_t(result.finishReasons[0]),
token_id.has_value(),
log_prob.has_value(),
first_scheduled_time_ns.has_value(),
false, false,
std::string() std::string()
}; };
@ -66,8 +92,12 @@ namespace huggingface::tgi::backends::trtllm {
reqId, reqId,
0, 0,
0.0, 0.0,
0,
true, true,
finish_reason_t::kNOT_FINISHED, finish_reason_t::kNOT_FINISHED,
false,
false,
false,
true, true,
std::move(r.getErrorMsg()) std::move(r.getErrorMsg())
}; };
@ -79,9 +109,16 @@ namespace huggingface::tgi::backends::trtllm {
private: private:
backend_t inner_; backend_t inner_;
// m_created_time is a reference point to convert time from c++ time_point
// to rust Instant.
std::chrono::time_point<std::chrono::steady_clock> m_created_time;
public: public:
tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path) tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path, const std::chrono::time_point<std::chrono::steady_clock>& created_time)
: inner_(engine_folder, executor_worker_path) {} : inner_(engine_folder, executor_worker_path),
m_created_time {created_time}
{}
size_t num_tokens_ready() const noexcept { return inner_.num_tokens_ready(); } size_t num_tokens_ready() const noexcept { return inner_.num_tokens_ready(); }
@ -121,13 +158,16 @@ namespace huggingface::tgi::backends::trtllm {
SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size()); SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size());
auto f = [this](const tle::Response &r){
return as_generation_step(r, m_created_time);
};
// Transform tle::Response to generation_step_t // Transform tle::Response to generation_step_t
#ifdef __cpp_lib_ranges_to_container #ifdef __cpp_lib_ranges_to_container
auto steps = responses | std::views::transform(as_generation_step) | std::ranges::to<std::vector>(); auto steps = responses | std::views::transform(f) | std::ranges::to<std::vector>();
#else #else
auto steps = std::vector<generation_step_t>(); auto steps = std::vector<generation_step_t>();
steps.reserve(responses.size()); steps.reserve(responses.size());
std::transform(responses.begin(), responses.end(), std::back_inserter(steps), as_generation_step); std::transform(responses.begin(), responses.end(), std::back_inserter(steps), f);
#endif #endif
return std::make_unique<std::vector<generation_step_t>>(steps); return std::make_unique<std::vector<generation_step_t>>(steps);
@ -179,12 +219,14 @@ namespace huggingface::tgi::backends::trtllm {
std::unique_ptr<tensorrt_llm_backend_t> std::unique_ptr<tensorrt_llm_backend_t>
create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) { create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) {
const auto created_time = std::chrono::steady_clock::now();
std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend); std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend);
return std::make_unique<tensorrt_llm_backend_t>( return std::make_unique<tensorrt_llm_backend_t>(
std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()), std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()),
std::filesystem::path::format::auto_format), std::filesystem::path::format::auto_format),
std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()), std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()),
std::filesystem::path::format::auto_format) std::filesystem::path::format::auto_format),
created_time
); );
} }
} }

View File

@ -19,4 +19,8 @@ pub enum TensorRtLlmBackendError {
WebServer(#[from] server::WebServerError), WebServer(#[from] server::WebServerError),
#[error("Tokio runtime failed to start: {0}")] #[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error), Tokio(#[from] std::io::Error),
#[error("config.json doesn't exist in engine folder {0}")]
ConfigNotFound(PathBuf),
#[error("generation_config.json doesn't exist in engine folder {0}")]
GenerationConfigNotFound(PathBuf),
} }

View File

@ -24,6 +24,14 @@ mod ffi {
/// The request finished because the maximum number of tokens was reached. /// The request finished because the maximum number of tokens was reached.
#[cxx_name = "kLENGTH"] #[cxx_name = "kLENGTH"]
MaxLength = 3u8, MaxLength = 3u8,
#[cxx_name = "kTIMED_OUT"]
/// The request finished because it got timed out (via the mAllotedTime parameter)
TimedOut = 4u8,
#[cxx_name = "kCANCELLED"]
/// The request was cancelled by calling cancelRequest.
Cancelled = 5u8,
} }
/// Struct used as shared type between rust and C++ to represent the result /// Struct used as shared type between rust and C++ to represent the result
@ -34,8 +42,14 @@ mod ffi {
request_id: u64, request_id: u64,
token_id: u32, token_id: u32,
log_prob: f32, log_prob: f32,
/// The time of first schedule since the creation of the backend
first_scheduled_time_ns: i64,
is_final: bool, is_final: bool,
finish_reason: FinishReason, finish_reason: FinishReason,
token_id_valid: bool,
log_prob_valid: bool,
first_scheduled_time_ns_valid: bool,
has_error: bool, has_error: bool,
error_msg: String, error_msg: String,
} }

View File

@ -3,12 +3,12 @@ use cxx::UniquePtr;
use hashbrown::HashMap; use hashbrown::HashMap;
use std::hint; use std::hint;
use std::ops::Deref; use std::ops::Deref;
use std::path::Path; use std::path::{Path, PathBuf};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::TryAcquireError; use tokio::sync::TryAcquireError;
use tokio::task::spawn_blocking; use tokio::task::spawn_blocking;
use tokio::time::Instant; use tokio::time::{Duration, Instant};
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, warn}; use tracing::{debug, error, warn};
@ -35,6 +35,9 @@ struct GenerationContext {
tokens: Vec<u32>, tokens: Vec<u32>,
start: Option<Instant>, start: Option<Instant>,
queued: Instant, queued: Instant,
/// output_buffer stores the output for detecting stop sequences
output_buffer: Option<String>,
} }
#[derive(Debug, Copy, Clone)] #[derive(Debug, Copy, Clone)]
@ -49,16 +52,28 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
type Error = InferError; type Error = InferError;
fn try_from(step: &'step GenerationStep) -> Result<Self, Self::Error> { fn try_from(step: &'step GenerationStep) -> Result<Self, Self::Error> {
if !step.has_error { if step.has_error {
Ok(Self { return Err(GenerationError(step.error_msg.clone()));
id: step.token_id,
log_prob: step.log_prob,
is_final: step.is_final,
finish_reason: step.finish_reason,
})
} else {
Err(GenerationError(step.error_msg.clone()))
} }
if !step.token_id_valid {
return Err(GenerationError(
"GenerationStep contains no token_id".to_string(),
));
}
if !step.log_prob_valid {
return Err(GenerationError(
"GenerationStep contains no log_prob".to_string(),
));
}
Ok(Self {
id: step.token_id,
log_prob: step.log_prob,
is_final: step.is_final,
finish_reason: step.finish_reason,
})
} }
} }
@ -67,6 +82,7 @@ fn executor_status_looper(
tokenizer: Tokenizer, tokenizer: Tokenizer,
mut backend: UniquePtr<TensorRtLlmBackendImpl>, mut backend: UniquePtr<TensorRtLlmBackendImpl>,
mut backlog: UnboundedReceiver<GenerationContext>, mut backlog: UnboundedReceiver<GenerationContext>,
created_time: Instant,
) { ) {
// Track the tuple (request_id, stream) for each request // Track the tuple (request_id, stream) for each request
let mut in_flights = let mut in_flights =
@ -74,7 +90,12 @@ fn executor_status_looper(
'scheduler: loop { 'scheduler: loop {
// Is there any request pending to be scheduled? // Is there any request pending to be scheduled?
let awaiting_requests = backlog.len(); let mut awaiting_requests = backlog.len();
if awaiting_requests == 0 && in_flights.is_empty() {
// Wait for 1 request if we are not waiting for any response,
// so that the loop blocks at receive from backlog.
awaiting_requests += 1;
}
for _ in 0..awaiting_requests { for _ in 0..awaiting_requests {
// Retrieve all the requests // Retrieve all the requests
if let Some(ctx) = backlog.blocking_recv() { if let Some(ctx) = backlog.blocking_recv() {
@ -83,12 +104,17 @@ fn executor_status_looper(
let generation_params = &request.parameters; let generation_params = &request.parameters;
let stopping_params = &request.stopping_parameters; let stopping_params = &request.stopping_parameters;
let input_ids = request.input_ids.as_deref(); let input_ids = request.input_ids.as_deref();
let top_k = if generation_params.do_sample {
generation_params.top_k
} else {
1
};
// Submit to the TensorRT-LLM executor for scheduling // Submit to the TensorRT-LLM executor for scheduling
match backend.pin_mut().submit( match backend.pin_mut().submit(
&input_ids.unwrap(), // This is checked beforehand in validate() &input_ids.unwrap(), // This is checked beforehand in validate()
stopping_params.max_new_tokens, stopping_params.max_new_tokens,
generation_params.top_k, top_k,
generation_params.top_p, generation_params.top_p,
generation_params.temperature, generation_params.temperature,
generation_params.repetition_penalty, generation_params.repetition_penalty,
@ -124,12 +150,22 @@ fn executor_status_looper(
for step in responses.deref() { for step in responses.deref() {
if let Some(ctx) = in_flights.get_mut(&step.request_id) { if let Some(ctx) = in_flights.get_mut(&step.request_id) {
// Update the starting timestamp if not set // Update the starting timestamp if not set
// This value might not be the actual real starting time of the request
// on the executor side - Need to expose more info from the executor to
// retrieve this value
// TODO : Expose actual real starting time for a request on FFI layer
if ctx.start.is_none() { if ctx.start.is_none() {
ctx.start = Some(Instant::now()); if step.first_scheduled_time_ns_valid {
if step.first_scheduled_time_ns >= 0 {
ctx.start = created_time.checked_add(Duration::from_nanos(
step.first_scheduled_time_ns as u64,
));
} else {
ctx.start = created_time.checked_sub(Duration::from_nanos(
-step.first_scheduled_time_ns as u64,
));
}
}
if ctx.start.is_none() {
ctx.start = Some(Instant::now());
}
} }
// Try to map the generation step to a DecodedToken // Try to map the generation step to a DecodedToken
@ -151,7 +187,16 @@ fn executor_status_looper(
let _ = in_flights.remove(&step.request_id); let _ = in_flights.remove(&step.request_id);
} }
} else { } else {
warn!("Untracked request {}", step.request_id,); match step.finish_reason {
FinishReason::Cancelled => {
// The client has canceled the request, so this should not generate a
// warning.
debug!("Cancelled request {}", step.request_id);
}
_ => {
warn!("Untracked request {}", step.request_id);
}
}
} }
} }
} }
@ -170,11 +215,39 @@ fn executor_status_looper(
fn post_process_decoded_token( fn post_process_decoded_token(
tokenizer: &Tokenizer, tokenizer: &Tokenizer,
ctx: &mut GenerationContext, ctx: &mut GenerationContext,
decoded_token: DecodedToken, mut decoded_token: DecodedToken,
) -> InferResult<InferStreamResponse> { ) -> InferResult<InferStreamResponse> {
match tokenizer.decode(&[decoded_token.id], false) { match tokenizer.decode(&[decoded_token.id], false) {
Ok(text) => { Ok(text) => {
let is_special = tokenizer.get_added_vocabulary().is_special_token(&text); let is_special = tokenizer.get_added_vocabulary().is_special_token(&text);
if let Some(buf) = ctx.output_buffer.as_mut() {
if buf.len() + text.len() > buf.capacity() {
let mut start = buf.len() + text.len() - buf.capacity();
while start <= buf.len() && !buf.is_char_boundary(start) {
start += 1;
}
buf.drain(..start);
}
buf.push_str(&text);
for stop_seq in &ctx.request.stopping_parameters.stop_sequences {
let start = if 1 + buf.len() > text.len() + stop_seq.len() {
let mut start = 1 + buf.len() - text.len() - stop_seq.len();
while start > 0 && !buf.is_char_boundary(start) {
start -= 1;
}
start
} else {
0
};
if buf[start..].contains(stop_seq) {
decoded_token.is_final = true;
decoded_token.finish_reason = FinishReason::StopWords;
}
}
}
let token = Token { let token = Token {
id: decoded_token.id, id: decoded_token.id,
text, text,
@ -231,6 +304,26 @@ fn ensure_paths_exist<P: AsRef<Path>, PP: AsRef<Path>>(
return Err(err); return Err(err);
} }
let mut config_path = PathBuf::from(engine_folder);
config_path.push("config.json");
if !config_path.exists() {
let err = TensorRtLlmBackendError::ConfigNotFound(engine_folder.to_path_buf());
error!("Path validation failed: {}", err,);
return Err(err);
}
let mut generation_config_path = PathBuf::from(engine_folder);
generation_config_path.push("generation_config.json");
if !generation_config_path.exists() {
let err = TensorRtLlmBackendError::GenerationConfigNotFound(engine_folder.to_path_buf());
error!("Path validation failed: {}", err,);
return Err(err);
}
// Ensure executor worker binary exists // Ensure executor worker binary exists
if !executor_worker_path.exists() { if !executor_worker_path.exists() {
let err = TensorRtLlmBackendError::ExecutorWorkerNotFound(engine_folder.to_path_buf()); let err = TensorRtLlmBackendError::ExecutorWorkerNotFound(engine_folder.to_path_buf());
@ -271,13 +364,23 @@ impl TensorRtLlmBackendV2 {
// Allocate the IPC layer to communicate with the backend // Allocate the IPC layer to communicate with the backend
let (executor_sender, executor_receiver) = unbounded_channel(); let (executor_sender, executor_receiver) = unbounded_channel();
// This is a reference point to convert time from c++ time_point
// to rust Instant.
let created_time = Instant::now();
// Create the FFI backend // Create the FFI backend
let backend = create_backend_from_engine_folder(&engine_folder, &executor_worker_path) let backend = create_backend_from_engine_folder(&engine_folder, &executor_worker_path)
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?; .map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
// Executor looper is responsible for scheduling and pulling requests state at regular interval // Executor looper is responsible for scheduling and pulling requests state at regular interval
spawn_blocking(move || { spawn_blocking(move || {
executor_status_looper(max_inflight_requests, tokenizer, backend, executor_receiver) executor_status_looper(
max_inflight_requests,
tokenizer,
backend,
executor_receiver,
created_time,
)
}); });
Ok(TensorRtLlmBackendV2(executor_sender)) Ok(TensorRtLlmBackendV2(executor_sender))
@ -323,12 +426,20 @@ impl Backend for TensorRtLlmBackendV2 {
// Send the context to the executor for scheduling // Send the context to the executor for scheduling
let queued = Instant::now(); let queued = Instant::now();
let output_buffer = request
.stopping_parameters
.stop_sequences
.iter()
.map(|x| x.len())
.max()
.map(|m| String::with_capacity(m + 32)); // TODO: is this number enough?
match self.0.send(GenerationContext { match self.0.send(GenerationContext {
request, request,
streamer, streamer,
tokens: Vec::with_capacity(256), tokens: Vec::with_capacity(256),
start: None, start: None,
queued, queued,
output_buffer,
}) { }) {
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)), Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
Err(_) => Err(GenerationError( Err(_) => Err(GenerationError(

View File

@ -37,7 +37,7 @@ struct Args {
hostname: String, hostname: String,
#[clap(default_value = "3000", long, short, env)] #[clap(default_value = "3000", long, short, env)]
port: u16, port: u16,
#[clap(default_value = "9000", long, short, env)] #[clap(default_value = "9000", long, env)]
prometheus_port: u16, prometheus_port: u16,
#[clap(long, env, required = true)] #[clap(long, env, required = true)]
tokenizer_name: String, tokenizer_name: String,

View File

@ -36,7 +36,7 @@ struct Args {
hostname: String, hostname: String,
#[clap(default_value = "3000", long, short, env)] #[clap(default_value = "3000", long, short, env)]
port: u16, port: u16,
#[clap(default_value = "9000", long, short, env)] #[clap(default_value = "9000", long, env)]
prometheus_port: u16, prometheus_port: u16,
#[clap(default_value = "/tmp/text-generation-server-0", long, env)] #[clap(default_value = "/tmp/text-generation-server-0", long, env)]
master_shard_uds_path: String, master_shard_uds_path: String,

View File

@ -36,7 +36,7 @@ struct Args {
hostname: String, hostname: String,
#[clap(default_value = "3000", long, short, env)] #[clap(default_value = "3000", long, short, env)]
port: u16, port: u16,
#[clap(default_value = "9000", long, short, env)] #[clap(default_value = "9000", long, env)]
prometheus_port: u16, prometheus_port: u16,
#[clap(default_value = "/tmp/text-generation-server-0", long, env)] #[clap(default_value = "/tmp/text-generation-server-0", long, env)]
master_shard_uds_path: String, master_shard_uds_path: String,

View File

@ -254,7 +254,7 @@ Options:
``` ```
## PROMETHEUS_PORT ## PROMETHEUS_PORT
```shell ```shell
-p, --prometheus-port <PROMETHEUS_PORT> --prometheus-port <PROMETHEUS_PORT>
The Prometheus port to listen on The Prometheus port to listen on
[env: PROMETHEUS_PORT=] [env: PROMETHEUS_PORT=]

View File

@ -774,7 +774,7 @@ struct Args {
port: u16, port: u16,
/// The Prometheus port to listen on. /// The Prometheus port to listen on.
#[clap(default_value = "9000", long, short, env)] #[clap(default_value = "9000", long, env)]
prometheus_port: u16, prometheus_port: u16,
/// The name of the socket for gRPC communication between the webserver /// The name of the socket for gRPC communication between the webserver