mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Merge fab395b41f
into 3752143b39
This commit is contained in:
commit
dd2dcb4d66
@ -119,7 +119,7 @@ struct Args {
|
|||||||
#[clap(default_value = "3000", long, short, env)]
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
port: u16,
|
port: u16,
|
||||||
|
|
||||||
#[clap(default_value = "9000", long, short, env)]
|
#[clap(default_value = "9000", long, env)]
|
||||||
prometheus_port: u16,
|
prometheus_port: u16,
|
||||||
|
|
||||||
/// Enable JSON output format.
|
/// Enable JSON output format.
|
||||||
|
@ -59,7 +59,14 @@ namespace huggingface::tgi::backends::trtllm {
|
|||||||
static_cast<tle::SizeType32>(g_params.max_new_tokens),
|
static_cast<tle::SizeType32>(g_params.max_new_tokens),
|
||||||
true,
|
true,
|
||||||
(tle::SamplingConfig) s_params,
|
(tle::SamplingConfig) s_params,
|
||||||
tle::OutputConfig{ /* returnLogProbs= */ true},
|
tle::OutputConfig{
|
||||||
|
/* returnLogProbs= */ true,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
/* returnPerfMetrics=*/ true,
|
||||||
|
},
|
||||||
std::nullopt,
|
std::nullopt,
|
||||||
std::nullopt,
|
std::nullopt,
|
||||||
std::nullopt,
|
std::nullopt,
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
#ifndef TGI_BACKEND_TRTLLM_FFI
|
#ifndef TGI_BACKEND_TRTLLM_FFI
|
||||||
#define TGI_BACKEND_TRTLLM_FFI
|
#define TGI_BACKEND_TRTLLM_FFI
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
#include <exception>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
|
||||||
@ -17,7 +19,7 @@ namespace rust::behavior {
|
|||||||
template<typename Try, typename Fail>
|
template<typename Try, typename Fail>
|
||||||
static void trycatch(Try &&func, Fail &&fail) noexcept try {
|
static void trycatch(Try &&func, Fail &&fail) noexcept try {
|
||||||
func();
|
func();
|
||||||
} catch (tensorrt_llm::common::TllmException &e) {
|
} catch (const std::exception &e) {
|
||||||
fail(e.what());
|
fail(e.what());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -42,22 +44,46 @@ namespace huggingface::tgi::backends::trtllm {
|
|||||||
return finish_reason_t::kEND_ID;
|
return finish_reason_t::kEND_ID;
|
||||||
case tle::FinishReason::kLENGTH:
|
case tle::FinishReason::kLENGTH:
|
||||||
return finish_reason_t::kLENGTH;
|
return finish_reason_t::kLENGTH;
|
||||||
|
case tle::FinishReason::kTIMED_OUT:
|
||||||
|
return finish_reason_t::kTIMED_OUT;
|
||||||
|
case tle::FinishReason::kCANCELLED:
|
||||||
|
return finish_reason_t::kCANCELLED;
|
||||||
default:
|
default:
|
||||||
std::unreachable();
|
std::unreachable();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static auto as_generation_step = [](const tle::Response &r) {
|
static auto as_generation_step = [](const tle::Response &r, const std::chrono::time_point<std::chrono::steady_clock> created) {
|
||||||
const auto reqId = r.getRequestId();
|
const auto reqId = r.getRequestId();
|
||||||
if (!r.hasError()) [[likely]] {
|
if (!r.hasError()) [[likely]] {
|
||||||
const auto result = r.getResult();
|
const auto result = r.getResult();
|
||||||
const auto logits = result.logProbs.value()[0];
|
std::optional<uint32_t> token_id = std::nullopt;
|
||||||
|
if (!result.outputTokenIds.empty() && !result.outputTokenIds[0].empty()) {
|
||||||
|
token_id = static_cast<uint32_t>(result.outputTokenIds[0][0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<float> log_prob = std::nullopt;
|
||||||
|
if (result.logProbs && !result.logProbs->empty() && !result.logProbs.value()[0].empty()) {
|
||||||
|
log_prob = result.logProbs.value()[0].back();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<int64_t> first_scheduled_time_ns = std::nullopt;
|
||||||
|
if (result.requestPerfMetrics) {
|
||||||
|
const auto &t = result.requestPerfMetrics->timingMetrics;
|
||||||
|
const auto ns = std::chrono::duration_cast<std::chrono::nanoseconds>(t.firstScheduledTime - created).count();
|
||||||
|
first_scheduled_time_ns = static_cast<int64_t>(ns);
|
||||||
|
}
|
||||||
|
|
||||||
return generation_step_t{
|
return generation_step_t{
|
||||||
reqId,
|
reqId,
|
||||||
static_cast<uint32_t>(result.outputTokenIds[0][0]),
|
token_id.value_or(0),
|
||||||
logits.back(),
|
log_prob.value_or(0.0),
|
||||||
|
first_scheduled_time_ns.value_or(0),
|
||||||
result.isFinal,
|
result.isFinal,
|
||||||
as_finish_reason_t(result.finishReasons[0]),
|
as_finish_reason_t(result.finishReasons[0]),
|
||||||
|
token_id.has_value(),
|
||||||
|
log_prob.has_value(),
|
||||||
|
first_scheduled_time_ns.has_value(),
|
||||||
false,
|
false,
|
||||||
std::string()
|
std::string()
|
||||||
};
|
};
|
||||||
@ -66,8 +92,12 @@ namespace huggingface::tgi::backends::trtllm {
|
|||||||
reqId,
|
reqId,
|
||||||
0,
|
0,
|
||||||
0.0,
|
0.0,
|
||||||
|
0,
|
||||||
true,
|
true,
|
||||||
finish_reason_t::kNOT_FINISHED,
|
finish_reason_t::kNOT_FINISHED,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
true,
|
true,
|
||||||
std::move(r.getErrorMsg())
|
std::move(r.getErrorMsg())
|
||||||
};
|
};
|
||||||
@ -79,9 +109,16 @@ namespace huggingface::tgi::backends::trtllm {
|
|||||||
private:
|
private:
|
||||||
backend_t inner_;
|
backend_t inner_;
|
||||||
|
|
||||||
|
// m_created_time is a reference point to convert time from c++ time_point
|
||||||
|
// to rust Instant.
|
||||||
|
std::chrono::time_point<std::chrono::steady_clock> m_created_time;
|
||||||
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path)
|
tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path, const std::chrono::time_point<std::chrono::steady_clock>& created_time)
|
||||||
: inner_(engine_folder, executor_worker_path) {}
|
: inner_(engine_folder, executor_worker_path),
|
||||||
|
m_created_time {created_time}
|
||||||
|
{}
|
||||||
|
|
||||||
size_t num_tokens_ready() const noexcept { return inner_.num_tokens_ready(); }
|
size_t num_tokens_ready() const noexcept { return inner_.num_tokens_ready(); }
|
||||||
|
|
||||||
@ -121,13 +158,16 @@ namespace huggingface::tgi::backends::trtllm {
|
|||||||
|
|
||||||
SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size());
|
SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size());
|
||||||
|
|
||||||
|
auto f = [this](const tle::Response &r){
|
||||||
|
return as_generation_step(r, m_created_time);
|
||||||
|
};
|
||||||
// Transform tle::Response to generation_step_t
|
// Transform tle::Response to generation_step_t
|
||||||
#ifdef __cpp_lib_ranges_to_container
|
#ifdef __cpp_lib_ranges_to_container
|
||||||
auto steps = responses | std::views::transform(as_generation_step) | std::ranges::to<std::vector>();
|
auto steps = responses | std::views::transform(f) | std::ranges::to<std::vector>();
|
||||||
#else
|
#else
|
||||||
auto steps = std::vector<generation_step_t>();
|
auto steps = std::vector<generation_step_t>();
|
||||||
steps.reserve(responses.size());
|
steps.reserve(responses.size());
|
||||||
std::transform(responses.begin(), responses.end(), std::back_inserter(steps), as_generation_step);
|
std::transform(responses.begin(), responses.end(), std::back_inserter(steps), f);
|
||||||
#endif
|
#endif
|
||||||
return std::make_unique<std::vector<generation_step_t>>(steps);
|
return std::make_unique<std::vector<generation_step_t>>(steps);
|
||||||
|
|
||||||
@ -179,12 +219,14 @@ namespace huggingface::tgi::backends::trtllm {
|
|||||||
|
|
||||||
std::unique_ptr<tensorrt_llm_backend_t>
|
std::unique_ptr<tensorrt_llm_backend_t>
|
||||||
create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) {
|
create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) {
|
||||||
|
const auto created_time = std::chrono::steady_clock::now();
|
||||||
std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend);
|
std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend);
|
||||||
return std::make_unique<tensorrt_llm_backend_t>(
|
return std::make_unique<tensorrt_llm_backend_t>(
|
||||||
std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()),
|
std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()),
|
||||||
std::filesystem::path::format::auto_format),
|
std::filesystem::path::format::auto_format),
|
||||||
std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()),
|
std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()),
|
||||||
std::filesystem::path::format::auto_format)
|
std::filesystem::path::format::auto_format),
|
||||||
|
created_time
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -19,4 +19,8 @@ pub enum TensorRtLlmBackendError {
|
|||||||
WebServer(#[from] server::WebServerError),
|
WebServer(#[from] server::WebServerError),
|
||||||
#[error("Tokio runtime failed to start: {0}")]
|
#[error("Tokio runtime failed to start: {0}")]
|
||||||
Tokio(#[from] std::io::Error),
|
Tokio(#[from] std::io::Error),
|
||||||
|
#[error("config.json doesn't exist in engine folder {0}")]
|
||||||
|
ConfigNotFound(PathBuf),
|
||||||
|
#[error("generation_config.json doesn't exist in engine folder {0}")]
|
||||||
|
GenerationConfigNotFound(PathBuf),
|
||||||
}
|
}
|
||||||
|
@ -24,6 +24,14 @@ mod ffi {
|
|||||||
/// The request finished because the maximum number of tokens was reached.
|
/// The request finished because the maximum number of tokens was reached.
|
||||||
#[cxx_name = "kLENGTH"]
|
#[cxx_name = "kLENGTH"]
|
||||||
MaxLength = 3u8,
|
MaxLength = 3u8,
|
||||||
|
|
||||||
|
#[cxx_name = "kTIMED_OUT"]
|
||||||
|
/// The request finished because it got timed out (via the mAllotedTime parameter)
|
||||||
|
TimedOut = 4u8,
|
||||||
|
|
||||||
|
#[cxx_name = "kCANCELLED"]
|
||||||
|
/// The request was cancelled by calling cancelRequest.
|
||||||
|
Cancelled = 5u8,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Struct used as shared type between rust and C++ to represent the result
|
/// Struct used as shared type between rust and C++ to represent the result
|
||||||
@ -34,8 +42,14 @@ mod ffi {
|
|||||||
request_id: u64,
|
request_id: u64,
|
||||||
token_id: u32,
|
token_id: u32,
|
||||||
log_prob: f32,
|
log_prob: f32,
|
||||||
|
|
||||||
|
/// The time of first schedule since the creation of the backend
|
||||||
|
first_scheduled_time_ns: i64,
|
||||||
is_final: bool,
|
is_final: bool,
|
||||||
finish_reason: FinishReason,
|
finish_reason: FinishReason,
|
||||||
|
token_id_valid: bool,
|
||||||
|
log_prob_valid: bool,
|
||||||
|
first_scheduled_time_ns_valid: bool,
|
||||||
has_error: bool,
|
has_error: bool,
|
||||||
error_msg: String,
|
error_msg: String,
|
||||||
}
|
}
|
||||||
|
@ -3,12 +3,12 @@ use cxx::UniquePtr;
|
|||||||
use hashbrown::HashMap;
|
use hashbrown::HashMap;
|
||||||
use std::hint;
|
use std::hint;
|
||||||
use std::ops::Deref;
|
use std::ops::Deref;
|
||||||
use std::path::Path;
|
use std::path::{Path, PathBuf};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
||||||
use tokio::sync::TryAcquireError;
|
use tokio::sync::TryAcquireError;
|
||||||
use tokio::task::spawn_blocking;
|
use tokio::task::spawn_blocking;
|
||||||
use tokio::time::Instant;
|
use tokio::time::{Duration, Instant};
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{debug, error, warn};
|
use tracing::{debug, error, warn};
|
||||||
|
|
||||||
@ -35,6 +35,9 @@ struct GenerationContext {
|
|||||||
tokens: Vec<u32>,
|
tokens: Vec<u32>,
|
||||||
start: Option<Instant>,
|
start: Option<Instant>,
|
||||||
queued: Instant,
|
queued: Instant,
|
||||||
|
|
||||||
|
/// output_buffer stores the output for detecting stop sequences
|
||||||
|
output_buffer: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Copy, Clone)]
|
#[derive(Debug, Copy, Clone)]
|
||||||
@ -49,16 +52,28 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
|
|||||||
type Error = InferError;
|
type Error = InferError;
|
||||||
|
|
||||||
fn try_from(step: &'step GenerationStep) -> Result<Self, Self::Error> {
|
fn try_from(step: &'step GenerationStep) -> Result<Self, Self::Error> {
|
||||||
if !step.has_error {
|
if step.has_error {
|
||||||
Ok(Self {
|
return Err(GenerationError(step.error_msg.clone()));
|
||||||
id: step.token_id,
|
|
||||||
log_prob: step.log_prob,
|
|
||||||
is_final: step.is_final,
|
|
||||||
finish_reason: step.finish_reason,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
Err(GenerationError(step.error_msg.clone()))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !step.token_id_valid {
|
||||||
|
return Err(GenerationError(
|
||||||
|
"GenerationStep contains no token_id".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !step.log_prob_valid {
|
||||||
|
return Err(GenerationError(
|
||||||
|
"GenerationStep contains no log_prob".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
id: step.token_id,
|
||||||
|
log_prob: step.log_prob,
|
||||||
|
is_final: step.is_final,
|
||||||
|
finish_reason: step.finish_reason,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -67,6 +82,7 @@ fn executor_status_looper(
|
|||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
|
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
|
||||||
mut backlog: UnboundedReceiver<GenerationContext>,
|
mut backlog: UnboundedReceiver<GenerationContext>,
|
||||||
|
created_time: Instant,
|
||||||
) {
|
) {
|
||||||
// Track the tuple (request_id, stream) for each request
|
// Track the tuple (request_id, stream) for each request
|
||||||
let mut in_flights =
|
let mut in_flights =
|
||||||
@ -74,7 +90,12 @@ fn executor_status_looper(
|
|||||||
|
|
||||||
'scheduler: loop {
|
'scheduler: loop {
|
||||||
// Is there any request pending to be scheduled?
|
// Is there any request pending to be scheduled?
|
||||||
let awaiting_requests = backlog.len();
|
let mut awaiting_requests = backlog.len();
|
||||||
|
if awaiting_requests == 0 && in_flights.is_empty() {
|
||||||
|
// Wait for 1 request if we are not waiting for any response,
|
||||||
|
// so that the loop blocks at receive from backlog.
|
||||||
|
awaiting_requests += 1;
|
||||||
|
}
|
||||||
for _ in 0..awaiting_requests {
|
for _ in 0..awaiting_requests {
|
||||||
// Retrieve all the requests
|
// Retrieve all the requests
|
||||||
if let Some(ctx) = backlog.blocking_recv() {
|
if let Some(ctx) = backlog.blocking_recv() {
|
||||||
@ -83,12 +104,17 @@ fn executor_status_looper(
|
|||||||
let generation_params = &request.parameters;
|
let generation_params = &request.parameters;
|
||||||
let stopping_params = &request.stopping_parameters;
|
let stopping_params = &request.stopping_parameters;
|
||||||
let input_ids = request.input_ids.as_deref();
|
let input_ids = request.input_ids.as_deref();
|
||||||
|
let top_k = if generation_params.do_sample {
|
||||||
|
generation_params.top_k
|
||||||
|
} else {
|
||||||
|
1
|
||||||
|
};
|
||||||
|
|
||||||
// Submit to the TensorRT-LLM executor for scheduling
|
// Submit to the TensorRT-LLM executor for scheduling
|
||||||
match backend.pin_mut().submit(
|
match backend.pin_mut().submit(
|
||||||
&input_ids.unwrap(), // This is checked beforehand in validate()
|
&input_ids.unwrap(), // This is checked beforehand in validate()
|
||||||
stopping_params.max_new_tokens,
|
stopping_params.max_new_tokens,
|
||||||
generation_params.top_k,
|
top_k,
|
||||||
generation_params.top_p,
|
generation_params.top_p,
|
||||||
generation_params.temperature,
|
generation_params.temperature,
|
||||||
generation_params.repetition_penalty,
|
generation_params.repetition_penalty,
|
||||||
@ -124,12 +150,22 @@ fn executor_status_looper(
|
|||||||
for step in responses.deref() {
|
for step in responses.deref() {
|
||||||
if let Some(ctx) = in_flights.get_mut(&step.request_id) {
|
if let Some(ctx) = in_flights.get_mut(&step.request_id) {
|
||||||
// Update the starting timestamp if not set
|
// Update the starting timestamp if not set
|
||||||
// This value might not be the actual real starting time of the request
|
|
||||||
// on the executor side - Need to expose more info from the executor to
|
|
||||||
// retrieve this value
|
|
||||||
// TODO : Expose actual real starting time for a request on FFI layer
|
|
||||||
if ctx.start.is_none() {
|
if ctx.start.is_none() {
|
||||||
ctx.start = Some(Instant::now());
|
if step.first_scheduled_time_ns_valid {
|
||||||
|
if step.first_scheduled_time_ns >= 0 {
|
||||||
|
ctx.start = created_time.checked_add(Duration::from_nanos(
|
||||||
|
step.first_scheduled_time_ns as u64,
|
||||||
|
));
|
||||||
|
} else {
|
||||||
|
ctx.start = created_time.checked_sub(Duration::from_nanos(
|
||||||
|
-step.first_scheduled_time_ns as u64,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.start.is_none() {
|
||||||
|
ctx.start = Some(Instant::now());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to map the generation step to a DecodedToken
|
// Try to map the generation step to a DecodedToken
|
||||||
@ -151,7 +187,16 @@ fn executor_status_looper(
|
|||||||
let _ = in_flights.remove(&step.request_id);
|
let _ = in_flights.remove(&step.request_id);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
warn!("Untracked request {}", step.request_id,);
|
match step.finish_reason {
|
||||||
|
FinishReason::Cancelled => {
|
||||||
|
// The client has canceled the request, so this should not generate a
|
||||||
|
// warning.
|
||||||
|
debug!("Cancelled request {}", step.request_id);
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
warn!("Untracked request {}", step.request_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -170,11 +215,39 @@ fn executor_status_looper(
|
|||||||
fn post_process_decoded_token(
|
fn post_process_decoded_token(
|
||||||
tokenizer: &Tokenizer,
|
tokenizer: &Tokenizer,
|
||||||
ctx: &mut GenerationContext,
|
ctx: &mut GenerationContext,
|
||||||
decoded_token: DecodedToken,
|
mut decoded_token: DecodedToken,
|
||||||
) -> InferResult<InferStreamResponse> {
|
) -> InferResult<InferStreamResponse> {
|
||||||
match tokenizer.decode(&[decoded_token.id], false) {
|
match tokenizer.decode(&[decoded_token.id], false) {
|
||||||
Ok(text) => {
|
Ok(text) => {
|
||||||
let is_special = tokenizer.get_added_vocabulary().is_special_token(&text);
|
let is_special = tokenizer.get_added_vocabulary().is_special_token(&text);
|
||||||
|
|
||||||
|
if let Some(buf) = ctx.output_buffer.as_mut() {
|
||||||
|
if buf.len() + text.len() > buf.capacity() {
|
||||||
|
let mut start = buf.len() + text.len() - buf.capacity();
|
||||||
|
while start <= buf.len() && !buf.is_char_boundary(start) {
|
||||||
|
start += 1;
|
||||||
|
}
|
||||||
|
buf.drain(..start);
|
||||||
|
}
|
||||||
|
buf.push_str(&text);
|
||||||
|
|
||||||
|
for stop_seq in &ctx.request.stopping_parameters.stop_sequences {
|
||||||
|
let start = if 1 + buf.len() > text.len() + stop_seq.len() {
|
||||||
|
let mut start = 1 + buf.len() - text.len() - stop_seq.len();
|
||||||
|
while start > 0 && !buf.is_char_boundary(start) {
|
||||||
|
start -= 1;
|
||||||
|
}
|
||||||
|
start
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
if buf[start..].contains(stop_seq) {
|
||||||
|
decoded_token.is_final = true;
|
||||||
|
decoded_token.finish_reason = FinishReason::StopWords;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let token = Token {
|
let token = Token {
|
||||||
id: decoded_token.id,
|
id: decoded_token.id,
|
||||||
text,
|
text,
|
||||||
@ -231,6 +304,26 @@ fn ensure_paths_exist<P: AsRef<Path>, PP: AsRef<Path>>(
|
|||||||
return Err(err);
|
return Err(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let mut config_path = PathBuf::from(engine_folder);
|
||||||
|
config_path.push("config.json");
|
||||||
|
|
||||||
|
if !config_path.exists() {
|
||||||
|
let err = TensorRtLlmBackendError::ConfigNotFound(engine_folder.to_path_buf());
|
||||||
|
|
||||||
|
error!("Path validation failed: {}", err,);
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut generation_config_path = PathBuf::from(engine_folder);
|
||||||
|
generation_config_path.push("generation_config.json");
|
||||||
|
|
||||||
|
if !generation_config_path.exists() {
|
||||||
|
let err = TensorRtLlmBackendError::GenerationConfigNotFound(engine_folder.to_path_buf());
|
||||||
|
|
||||||
|
error!("Path validation failed: {}", err,);
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure executor worker binary exists
|
// Ensure executor worker binary exists
|
||||||
if !executor_worker_path.exists() {
|
if !executor_worker_path.exists() {
|
||||||
let err = TensorRtLlmBackendError::ExecutorWorkerNotFound(engine_folder.to_path_buf());
|
let err = TensorRtLlmBackendError::ExecutorWorkerNotFound(engine_folder.to_path_buf());
|
||||||
@ -271,13 +364,23 @@ impl TensorRtLlmBackendV2 {
|
|||||||
// Allocate the IPC layer to communicate with the backend
|
// Allocate the IPC layer to communicate with the backend
|
||||||
let (executor_sender, executor_receiver) = unbounded_channel();
|
let (executor_sender, executor_receiver) = unbounded_channel();
|
||||||
|
|
||||||
|
// This is a reference point to convert time from c++ time_point
|
||||||
|
// to rust Instant.
|
||||||
|
let created_time = Instant::now();
|
||||||
|
|
||||||
// Create the FFI backend
|
// Create the FFI backend
|
||||||
let backend = create_backend_from_engine_folder(&engine_folder, &executor_worker_path)
|
let backend = create_backend_from_engine_folder(&engine_folder, &executor_worker_path)
|
||||||
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
|
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
|
||||||
|
|
||||||
// Executor looper is responsible for scheduling and pulling requests state at regular interval
|
// Executor looper is responsible for scheduling and pulling requests state at regular interval
|
||||||
spawn_blocking(move || {
|
spawn_blocking(move || {
|
||||||
executor_status_looper(max_inflight_requests, tokenizer, backend, executor_receiver)
|
executor_status_looper(
|
||||||
|
max_inflight_requests,
|
||||||
|
tokenizer,
|
||||||
|
backend,
|
||||||
|
executor_receiver,
|
||||||
|
created_time,
|
||||||
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
Ok(TensorRtLlmBackendV2(executor_sender))
|
Ok(TensorRtLlmBackendV2(executor_sender))
|
||||||
@ -323,12 +426,20 @@ impl Backend for TensorRtLlmBackendV2 {
|
|||||||
|
|
||||||
// Send the context to the executor for scheduling
|
// Send the context to the executor for scheduling
|
||||||
let queued = Instant::now();
|
let queued = Instant::now();
|
||||||
|
let output_buffer = request
|
||||||
|
.stopping_parameters
|
||||||
|
.stop_sequences
|
||||||
|
.iter()
|
||||||
|
.map(|x| x.len())
|
||||||
|
.max()
|
||||||
|
.map(|m| String::with_capacity(m + 32)); // TODO: is this number enough?
|
||||||
match self.0.send(GenerationContext {
|
match self.0.send(GenerationContext {
|
||||||
request,
|
request,
|
||||||
streamer,
|
streamer,
|
||||||
tokens: Vec::with_capacity(256),
|
tokens: Vec::with_capacity(256),
|
||||||
start: None,
|
start: None,
|
||||||
queued,
|
queued,
|
||||||
|
output_buffer,
|
||||||
}) {
|
}) {
|
||||||
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
|
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
|
||||||
Err(_) => Err(GenerationError(
|
Err(_) => Err(GenerationError(
|
||||||
|
@ -37,7 +37,7 @@ struct Args {
|
|||||||
hostname: String,
|
hostname: String,
|
||||||
#[clap(default_value = "3000", long, short, env)]
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
port: u16,
|
port: u16,
|
||||||
#[clap(default_value = "9000", long, short, env)]
|
#[clap(default_value = "9000", long, env)]
|
||||||
prometheus_port: u16,
|
prometheus_port: u16,
|
||||||
#[clap(long, env, required = true)]
|
#[clap(long, env, required = true)]
|
||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
|
@ -36,7 +36,7 @@ struct Args {
|
|||||||
hostname: String,
|
hostname: String,
|
||||||
#[clap(default_value = "3000", long, short, env)]
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
port: u16,
|
port: u16,
|
||||||
#[clap(default_value = "9000", long, short, env)]
|
#[clap(default_value = "9000", long, env)]
|
||||||
prometheus_port: u16,
|
prometheus_port: u16,
|
||||||
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||||
master_shard_uds_path: String,
|
master_shard_uds_path: String,
|
||||||
|
@ -36,7 +36,7 @@ struct Args {
|
|||||||
hostname: String,
|
hostname: String,
|
||||||
#[clap(default_value = "3000", long, short, env)]
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
port: u16,
|
port: u16,
|
||||||
#[clap(default_value = "9000", long, short, env)]
|
#[clap(default_value = "9000", long, env)]
|
||||||
prometheus_port: u16,
|
prometheus_port: u16,
|
||||||
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||||
master_shard_uds_path: String,
|
master_shard_uds_path: String,
|
||||||
|
@ -254,7 +254,7 @@ Options:
|
|||||||
```
|
```
|
||||||
## PROMETHEUS_PORT
|
## PROMETHEUS_PORT
|
||||||
```shell
|
```shell
|
||||||
-p, --prometheus-port <PROMETHEUS_PORT>
|
--prometheus-port <PROMETHEUS_PORT>
|
||||||
The Prometheus port to listen on
|
The Prometheus port to listen on
|
||||||
|
|
||||||
[env: PROMETHEUS_PORT=]
|
[env: PROMETHEUS_PORT=]
|
||||||
|
@ -774,7 +774,7 @@ struct Args {
|
|||||||
port: u16,
|
port: u16,
|
||||||
|
|
||||||
/// The Prometheus port to listen on.
|
/// The Prometheus port to listen on.
|
||||||
#[clap(default_value = "9000", long, short, env)]
|
#[clap(default_value = "9000", long, env)]
|
||||||
prometheus_port: u16,
|
prometheus_port: u16,
|
||||||
|
|
||||||
/// The name of the socket for gRPC communication between the webserver
|
/// The name of the socket for gRPC communication between the webserver
|
||||||
|
Loading…
Reference in New Issue
Block a user