impl RwLock scenario for TensorRtLllmBackend

This commit is contained in:
Morgan Funtowicz 2024-07-16 20:08:10 +00:00
parent 31d9f4d5dc
commit 7784a21d48
8 changed files with 352 additions and 173 deletions

View File

@ -15,6 +15,10 @@ tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot"
tokio-stream = "0.1.14" tokio-stream = "0.1.14"
clap = { version = "4.5.4", features = ["derive"] } clap = { version = "4.5.4", features = ["derive"] }
thiserror = "1.0.61" thiserror = "1.0.61"
tracing = "0.1"
tracing-opentelemetry = "0.24"
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
log = { version = "0.4.21", features = [] }
[build-dependencies] [build-dependencies]
cmake = "0.1" cmake = "0.1"

View File

@ -33,6 +33,8 @@ fn main() {
"debug" => format!("{}d", dependency), "debug" => format!("{}d", dependency),
_ => String::from(dependency), _ => String::from(dependency),
}; };
let dep_path = deps_folder.join(format!("{}-build", dependency));
println!("cargo:rustc-link-search={}", dep_path.display());
println!("cargo:rustc-link-lib=static={}", dep_name); println!("cargo:rustc-link-lib=static={}", dep_name);
} }

View File

@ -17,14 +17,11 @@ else ()
set(FAST_BUILD OFF) set(FAST_BUILD OFF)
endif () endif ()
# This line turn off DEBUG in TRTLLM logger which is quite spammy
add_compile_definitions(NDEBUG OFF)
fetchcontent_declare( fetchcontent_declare(
trtllm trtllm
GIT_REPOSITORY https://github.com/nvidia/tensorrt-llm.git GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git
GIT_TAG a96cccafcf6365c128f004f779160951f8c0801c GIT_TAG 9691e12bce7ae1c126c435a049eb516eb119486c
GIT_SHALLOW TRUE GIT_SHALLOW FALSE
) )
fetchcontent_makeavailable(trtllm) fetchcontent_makeavailable(trtllm)
message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}") message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}")

View File

@ -5,7 +5,7 @@
#ifndef TGI_TRTLLM_BACKEND_FFI_H #ifndef TGI_TRTLLM_BACKEND_FFI_H
#define TGI_TRTLLM_BACKEND_FFI_H #define TGI_TRTLLM_BACKEND_FFI_H
//#include "rust/cxx.h" #include <cstddef>
#include "backend.h" #include "backend.h"
namespace huggingface::tgi::backends { namespace huggingface::tgi::backends {
@ -17,9 +17,9 @@ namespace huggingface::tgi::backends {
namespace huggingface::tgi::backends { namespace huggingface::tgi::backends {
struct GenerationContext; // struct GenerationContext;
class TensorRtLlmBackendImpl : TensorRtLlmBackend { class TensorRtLlmBackendImpl : public TensorRtLlmBackend {
public: public:
/*** /***
* *
@ -37,7 +37,6 @@ namespace huggingface::tgi::backends {
/*** /***
* *
* @param tokens * @param tokens
* @param maxNewTokens
* @param topK * @param topK
* @param topP * @param topP
* @param temperature * @param temperature
@ -45,17 +44,20 @@ namespace huggingface::tgi::backends {
* @return * @return
*/ */
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]] [[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
uint64_t Submit(rust::Slice<const uint32_t> tokens, int32_t maxNewTokens, int32_t topK, float_t topP, float_t temperature, uint64_t seed); uint64_t
Submit(rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, uint64_t seed);
/*** /***
* *
* @param requestId * @param requestId
* @param handler * @param ctx
* @param callback
* @return * @return
*/ */
uint32_t Stream(rust::Box <GenerationContext> ctx, size_t StreamTokens(
uint64_t requestId, const RequestId requestId,
rust::Fn<void(rust::Box<GenerationContext>, uint32_t, uint32_t, bool)> handler); rust::Box<huggingface::tgi::backends::GenerationContext> ctx,
rust::Fn<void(rust::Box<huggingface::tgi::backends::GenerationContext>, uint32_t, float_t, bool)> callback);
}; };
/*** /***

View File

@ -1,160 +1,311 @@
use std::cell::RefCell; use std::future::Future;
use std::path::Path; use std::path::Path;
use std::pin::{pin, Pin};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::task::{Context, Poll};
use std::time::Duration;
use async_trait::async_trait; use async_trait::async_trait;
use cxx::UniquePtr; use cxx::UniquePtr;
use log::{info, warn};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::sync::mpsc; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::time::Instant; use tokio::sync::RwLock;
use tokio::time::{Instant, sleep};
use tokio_stream::{Stream, StreamExt};
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{instrument, Level, span};
use text_generation_router::{FinishReason, Token}; use text_generation_router::{FinishReason, Token};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidParameters}; use text_generation_router::validation::ValidGenerateRequest;
use crate::errors::TensorRtLlmBackendError; use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl}; use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
// macro_rules! propagate {
// ($ctx: expr, $res: expr) => {
// $ctx.sender
// .send($res)
// .expect("Failed to propagate error back to the transport layer")
// };
// }
type InferResult<T> = Result<T, InferError>; type InferResult<T> = Result<T, InferError>;
pub struct GenerationContext(mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>); /// Holds the user provided input to be executed along with a channel allowing
/// to bubble up all the generated tokens for that tokens the to end stream.
// pub struct InferenceContext {
// /// User provided request
// request: ValidGenerateRequest,
//
// /// Inter-process communication handler moving token from the executor thread to the HTTP server
// sender: UnboundedSender<InferResult<InferStreamResponse>>,
//
// /// Pin the instant this inference context was submitted
// when: Instant,
//
// /// Span that will live as long as entry
// span: Span,
// }
pub struct TrtLLmBackend { pub(crate) struct Generation {
tokenizer: Tokenizer, executor: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
inner: RefCell<UniquePtr<TensorRtLlmBackendImpl>>, done: Arc<AtomicBool>,
} }
unsafe impl Sync for TrtLLmBackend {} pub struct GenerationContext(
unsafe impl Send for TrtLLmBackend {} UnboundedSender<InferResult<InferStreamResponse>>,
Arc<AtomicBool>,
);
impl TrtLLmBackend { impl Stream for Generation {
pub fn new<P: AsRef<Path>>( type Item = usize;
tokenizer: Tokenizer,
engine_folder: P,
) -> Result<Self, TensorRtLlmBackendError> {
let engine_folder = engine_folder.as_ref();
let inner = create_tensorrt_llm_backend(engine_folder.to_str().unwrap(), "");
Ok(Self { fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
tokenizer, if self.done.load(Ordering::Relaxed) {
inner: RefCell::new(inner), Poll::Ready(None)
}) } else {
let pinned = pin!(self.executor.read());
match pinned.poll(ctx) {
Poll::Ready(executor_r) => {
let ready = executor_r.num_responses_ready();
if ready == 0 {
let waker = ctx.waker().clone();
tokio::spawn(async {
sleep(Duration::from_millis(10)).await;
waker.wake();
});
Poll::Pending
} else {
info!("Ready: {}", ready);
let waker = ctx.waker().clone();
tokio::spawn(async {
sleep(Duration::from_millis(100)).await;
waker.wake();
});
Poll::Ready(Some(ready))
}
}
Poll::Pending => {
let waker = ctx.waker().clone();
tokio::spawn(async {
sleep(Duration::from_millis(100)).await;
waker.wake();
});
Poll::Pending
}
}
}
} }
fn infer_text( fn size_hint(&self) -> (usize, Option<usize>) {
&self, (1, None)
ctx: GenerationContext, }
text: &str, }
params: ValidParameters,
) -> InferResult<()> {
// Keep track of processing time
let start = Instant::now();
// Encode the input unsafe impl Send for TensorRtLlmBackendImpl {}
let ctx = Box::new(ctx); unsafe impl Sync for TensorRtLlmBackendImpl {}
let encoding = self
.tokenizer
.encode(text, true)
.map_err(|e| InferError::ToolError(e.to_string()))?;
// Submit the request to the backend and retrieve the handle to query its status /// Implements the logic to execute generation with TensorRT-LLM executor API in background
let request_id = self pub struct TensorRtLlmBackend {
.inner // Allowing sending user requests to the TensorRT-LLM executor thread
.borrow_mut() // batcher: UnboundedSender<InferenceContext>,
.as_mut() backend: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
.expect("Failed to retrieve pointer to TRTLLM backend") }
.submit(
encoding.get_ids(),
128,
params.top_k as i32,
params.top_p,
params.temperature,
params.seed,
);
// Stream generated tokens impl TensorRtLlmBackend {
// spawn_blocking(move || { pub fn new<P: AsRef<Path> + Send + 'static, PP: AsRef<Path> + Send + 'static>(
let num_generated_tokens = self _tokenizer: Tokenizer,
.inner engine_folder: P,
.borrow_mut() _executor_worker_path: Option<PP>,
.as_mut() ) -> Result<Self, TensorRtLlmBackendError> {
.expect("Failed to retrieve pointer to TRTLLM backend") Ok(TensorRtLlmBackend {
.stream(ctx, request_id, |ctx, token, step, is_final| { backend: Arc::new(RwLock::new(create_tensorrt_llm_backend(
// self.tokenizer.decode(&*[token], true).unwrap(); engine_folder.as_ref().to_str().unwrap(),
let sender = ctx.0; "",
let token = Token { ))),
id: token, })
text: String::from(""),
logprob: 1.0f32,
special: false,
};
sender
.send(Ok(InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
}))
.unwrap()
});
// Notify the end
let _ = ctx.0.send(Ok(InferStreamResponse::End {
token: Token {
id: 0,
text: String::from(""),
logprob: 1.0f32,
special: false,
},
top_tokens: vec![],
generated_text: GeneratedText {
text: String::from(""),
generated_tokens: num_generated_tokens,
finish_reason: FinishReason::EndOfSequenceToken,
seed: Some(params.seed),
},
start,
queued: Instant::now(),
}));
// });
Ok(())
} }
} }
#[async_trait] #[async_trait]
impl Backend for TrtLLmBackend { impl Backend for TensorRtLlmBackend {
#[instrument(skip_all)]
fn schedule( fn schedule(
&self, &self,
request: ValidGenerateRequest, _request: ValidGenerateRequest,
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> { ) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
let (sender, receiver) = mpsc::unbounded_channel(); // Channel to stream the generated token as they come from the worker thread back to the transport layer
let ctx = GenerationContext(sender); let (sender, receiver) = unbounded_channel();
// Unpack parameters let executor = self.backend.clone();
let params = request.parameters; tokio::spawn(async move {
// Submit the request to the batcher
let request_id = span!(Level::DEBUG, "[EXECUTOR][SUBMIT]")
.in_scope(|| async {
info!("Acquiring lock for submit");
let mut handle = executor.write().await;
let request_id = handle.pin_mut().submit(
&vec![2, 2926, 1503, 603, 20189],
50,
1.0,
1.0,
2014,
);
// Ensure we are running in the right conditions for the input (i.e. single textual chunk) info!("Releasing lock for submit");
let input = match request.inputs.len() { return request_id;
0 => Err(InferError::GenerationError("No input provided".into())), })
1 => Ok(request.inputs.first().unwrap()), .await;
_ => Err(InferError::GenerationError(format!(
"Unsupported multi-chunks ({}) inference.",
request.inputs.len()
))),
}?;
// Currently we handle single chunk of text let mut generation = Generation {
match input { executor: executor.clone(),
Chunk::Text(text) => { done: Arc::new(AtomicBool::new(false)),
self.infer_text(ctx, &**text, params)?;
}
Chunk::Image(_) => panic!("Unsupported"),
}; };
while let Some(num_tokens_ready) = generation.next().await {
span!(
Level::DEBUG,
"[EXECUTOR][GENERATE]",
request_id = request_id,
num_tokens_ready = num_tokens_ready
)
.in_scope(|| async {
let ctx = Box::new(GenerationContext(
sender.clone(),
Arc::clone(&generation.done),
));
let mut executor_w = executor.write().await;
info!("Acquired write lock stream");
executor_w.pin_mut().stream_tokens(
request_id,
ctx,
|ctx: Box<GenerationContext>, token: u32, logprob: f32, is_final: bool| {
info!("Sending token: {} (final: {})", token, is_final);
let out = if is_final {
ctx.1.store(true, Ordering::Relaxed);
InferStreamResponse::End {
token: Token {
id: token,
text: "".into(),
logprob,
special: false,
},
top_tokens: vec![],
generated_text: GeneratedText {
text: "".into(),
generated_tokens: u32::MAX,
finish_reason: FinishReason::EndOfSequenceToken,
seed: None,
},
start: Instant::now(),
queued: Instant::now(),
}
} else {
InferStreamResponse::Intermediate {
token: Token {
id: token,
text: "".into(),
logprob,
special: false,
},
top_tokens: vec![],
}
};
ctx.0
.send(Ok(out))
.expect("Failed to send back generated token");
},
);
info!("Releasing write lock stream")
})
.await;
}
});
Ok(UnboundedReceiverStream::new(receiver)) Ok(UnboundedReceiverStream::new(receiver))
} }
async fn health(&self, _current_health: bool) -> bool { async fn health(&self, _current_health: bool) -> bool {
self.inner.borrow_mut().is_ready() true
} }
} }
// async fn background_looper<P: AsRef<Path>, PP: AsRef<Path>>(
// engine_folder: P,
// _executor_worker: Option<PP>,
// tokenizer: Tokenizer,
// mut receiver: UnboundedReceiver<InferenceContext>,
// ) {
// let mut backend = create_tensorrt_llm_backend(engine_folder.as_ref().to_str().unwrap(), "");
//
// while !(receiver.is_closed()) {
// // Receive the incoming request
// if let Some(ctx) = receiver.recv().await {
// debug!("Processing new incoming request");
//
// // We only support single, textual chunk
// if ctx.request.inputs.len() != 1 {
// propagate!(
// ctx,
// Err(InferError::GenerationError(format!(
// "Unsupported multi-chunk ({}) input",
// ctx.request.inputs.len()
// )))
// );
// }
//
// let input = ctx
// .request
// .inputs
// .first()
// .expect("Single chunk checked above");
// let params = ctx.request.parameters;
// }
// }
// Receive the incoming request
// if let Some(ctx) = receiver.recv().await {
// debug!("Processing new incoming request");
// // We only support single, textual chunk
// if ctx.request.inputs.len() != 1 {
// propagate!(
// ctx,
// Err(InferError::GenerationError(format!(
// "Unsupported multi-chunk ({}) input",
// ctx.request.inputs.len()
// )))
// );
// }
//
// // Unpack parameters
// let inputs = ctx.request.inputs;
// let params = ctx.request.parameters;
//
// match inputs.first().unwrap() {
// Chunk::Text(text) => match tokenizer.encode(text.as_str(), true) {
// Err(err) => {
// propagate!(ctx, Err(InferError::GenerationError(err.to_string())))
// }
// Ok(encoding) => {
// // spawn_blocking(|| {
// // info!("Submitting request to TensorRT-LLM executor");
// // let mut executor = backend.blocking_write();
// // })
// // .await
// // .expect("");
// }
// },
// Chunk::Image(_) => propagate!(
// ctx,
// Err(InferError::GenerationError(
// "Image input is not supported yet.".into(),
// ))
// ),
// }
// };
// }

View File

@ -7,6 +7,7 @@
#include <filesystem> #include <filesystem>
#include <vector> #include <vector>
#include <spdlog/spdlog.h>
#include "backends/trtllm/include/ffi.h" #include "backends/trtllm/include/ffi.h"
@ -21,42 +22,43 @@ bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const {
} }
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit( uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
rust::Slice<const uint32_t> tokens, rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, uint64_t seed) {
int32_t maxNewTokens, int32_t topK, float_t topP,
float_t temperature, uint64_t seed) {
// This will copy all the items from the initial slice // This will copy all the items from the initial slice
std::vector<int32_t> tokens_(tokens.size()); std::vector<int32_t> tokens_(tokens.size());
tokens_.assign(tokens.begin(), tokens.end()); tokens_.assign(tokens.begin(), tokens.end());
return TensorRtLlmBackend::Submit(std::move(tokens_), maxNewTokens, topK, topP, temperature, seed); return TensorRtLlmBackend::Submit(std::move(tokens_), topK, topP, temperature, seed);
} }
uint32_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Stream( size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(const uint64_t requestId,
rust::Box<huggingface::tgi::backends::GenerationContext> ctx, rust::Box<huggingface::tgi::backends::GenerationContext> ctx,
uint64_t requestId, rust::Fn<void(rust::Box<huggingface::tgi::backends::GenerationContext>, uint32_t, float_t, bool)> callback) {
rust::Fn<void(rust::Box<huggingface::tgi::backends::GenerationContext>, uint32_t, uint32_t, bool)> handler) {
bool isDone = false;
uint32_t numGeneratedTokens = 0;
do { SPDLOG_INFO("Entering StreamTokens");
const auto responses = Poll(requestId); for (const auto &item: Poll(requestId)) {
for (const auto &response: responses) { if (!item.hasError()) {
if (response.hasError()) { SPDLOG_INFO("\tStreamTokens -> Decoding token...");
isDone = true; const auto decoded = item.getResult();
// TODO : bubble up the error to rust SPDLOG_INFO("\tStreamTokens -> Successfully read decoded token ({})", decoded.outputTokenIds[0].size());
const auto token = decoded.outputTokenIds[0][0];
const auto isFinal = decoded.isFinal;
// const auto logProb = decoded.logProbs.value()[0][0];
const auto logProb = 0.0;
SPDLOG_INFO(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
callback(std::move(ctx), token, logProb, isFinal);
SPDLOG_INFO("\tStreamTokens -> Post callback");
} else { } else {
const auto generation = response.getResult(); // TODO : Return rest::Result with error
const auto token = generation.outputTokenIds[0][0]; SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", item.getErrorMsg());
isDone = generation.isFinal; callback(std::move(ctx), 0, 0.0, true);
// Propagate through the handler
handler(std::move(ctx), token, numGeneratedTokens, isDone);
} }
} }
} while (!isDone);
return numGeneratedTokens; SPDLOG_INFO("Exiting StreamTokens");
return 0;
} }
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl> std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>

View File

@ -17,7 +17,7 @@ mod ffi {
/// Represent an instance of the underlying TensorRT-LLM backend /// Represent an instance of the underlying TensorRT-LLM backend
type TensorRtLlmBackendImpl; type TensorRtLlmBackendImpl;
/// Create an instance backed behind an std::unique_ptr to manage the lifespan of the backend /// Create an instance backed behind a std::unique_ptr to manage the lifespan of the backend
/// ///
/// # Arguments /// # Arguments
/// ///
@ -37,29 +37,31 @@ mod ffi {
executor_worker: &str, executor_worker: &str,
) -> UniquePtr<TensorRtLlmBackendImpl>; ) -> UniquePtr<TensorRtLlmBackendImpl>;
#[rust_name = "is_ready"] // #[rust_name = "is_ready"]
fn IsReady(self: &TensorRtLlmBackendImpl) -> bool; // fn IsReady(self: &TensorRtLlmBackendImpl) -> bool;
#[rust_name = "num_responses_ready"]
fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
#[rust_name = "submit"] #[rust_name = "submit"]
fn Submit( fn Submit(
self: Pin<&mut TensorRtLlmBackendImpl>, self: Pin<&mut TensorRtLlmBackendImpl>,
tokens: &[u32], tokens: &[u32],
max_new_tokens: i32,
top_k: i32, top_k: i32,
top_p: f32, top_p: f32,
temperature: f32, temperature: f32,
seed: u64, seed: u64,
) -> u64; ) -> u64;
#[rust_name = "stream"] #[rust_name = "stream_tokens"]
fn Stream( fn StreamTokens(
self: Pin<&mut TensorRtLlmBackendImpl>, self: Pin<&mut TensorRtLlmBackendImpl>,
ctx: Box<GenerationContext>,
request_id: u64, request_id: u64,
callback: fn(Box<GenerationContext>, u32, u32, bool), ctx: Box<GenerationContext>,
) -> u32; cb: fn(Box<GenerationContext>, u32, f32, bool),
) -> usize;
#[rust_name = "shutdown"] // #[rust_name = "shutdown"]
fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>); // fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>);
} }
} }

View File

@ -1,9 +1,11 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::path::PathBuf;
use clap::Parser; use clap::Parser;
use tokenizers::{FromPretrainedParameters, Tokenizer}; use tokenizers::{FromPretrainedParameters, Tokenizer};
use text_generation_backends_trtllm::{errors::TensorRtLlmBackendError, TrtLLmBackend}; use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
use text_generation_backends_trtllm::TensorRtLlmBackend;
use text_generation_router::server; use text_generation_router::server;
/// App Configuration /// App Configuration
@ -53,7 +55,13 @@ struct Args {
#[clap(default_value = "4", long, env)] #[clap(default_value = "4", long, env)]
max_client_batch_size: usize, max_client_batch_size: usize,
#[clap(long, env)] #[clap(long, env)]
auth_token: Option<String> auth_token: Option<String>,
#[clap(
long,
env,
help = "Path to the TensorRT-LLM Orchestrator Worker binary"
)]
executor_worker: Option<PathBuf>,
} }
#[tokio::main] #[tokio::main]
@ -83,7 +91,8 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
cors_allow_origin, cors_allow_origin,
messages_api_enabled, messages_api_enabled,
max_client_batch_size, max_client_batch_size,
auth_token auth_token,
executor_worker,
} = args; } = args;
// Launch Tokio runtime // Launch Tokio runtime
@ -114,6 +123,15 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
} }
} }
if let Some(ref executor_worker) = executor_worker {
if !executor_worker.exists() {
return Err(TensorRtLlmBackendError::ArgumentValidation(format!(
"`executor_work` specified path doesn't exists: {}",
executor_worker.display()
)));
}
}
// Run server // Run server
let tokenizer = Tokenizer::from_pretrained( let tokenizer = Tokenizer::from_pretrained(
tokenizer_name.clone(), tokenizer_name.clone(),
@ -122,9 +140,10 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
user_agent: HashMap::new(), user_agent: HashMap::new(),
auth_token, auth_token,
}), }),
).map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?; )
.map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?;
let backend = TrtLLmBackend::new(tokenizer, model_id)?; let backend = TensorRtLlmBackend::new(tokenizer, model_id, executor_worker)?;
server::run( server::run(
backend, backend,
max_concurrent_requests, max_concurrent_requests,