mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 03:14:53 +00:00
impl RwLock scenario for TensorRtLllmBackend
This commit is contained in:
parent
31d9f4d5dc
commit
7784a21d48
@ -15,6 +15,10 @@ tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot"
|
|||||||
tokio-stream = "0.1.14"
|
tokio-stream = "0.1.14"
|
||||||
clap = { version = "4.5.4", features = ["derive"] }
|
clap = { version = "4.5.4", features = ["derive"] }
|
||||||
thiserror = "1.0.61"
|
thiserror = "1.0.61"
|
||||||
|
tracing = "0.1"
|
||||||
|
tracing-opentelemetry = "0.24"
|
||||||
|
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
|
||||||
|
log = { version = "0.4.21", features = [] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
cmake = "0.1"
|
cmake = "0.1"
|
||||||
|
@ -33,6 +33,8 @@ fn main() {
|
|||||||
"debug" => format!("{}d", dependency),
|
"debug" => format!("{}d", dependency),
|
||||||
_ => String::from(dependency),
|
_ => String::from(dependency),
|
||||||
};
|
};
|
||||||
|
let dep_path = deps_folder.join(format!("{}-build", dependency));
|
||||||
|
println!("cargo:rustc-link-search={}", dep_path.display());
|
||||||
println!("cargo:rustc-link-lib=static={}", dep_name);
|
println!("cargo:rustc-link-lib=static={}", dep_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,14 +17,11 @@ else ()
|
|||||||
set(FAST_BUILD OFF)
|
set(FAST_BUILD OFF)
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
# This line turn off DEBUG in TRTLLM logger which is quite spammy
|
|
||||||
add_compile_definitions(NDEBUG OFF)
|
|
||||||
|
|
||||||
fetchcontent_declare(
|
fetchcontent_declare(
|
||||||
trtllm
|
trtllm
|
||||||
GIT_REPOSITORY https://github.com/nvidia/tensorrt-llm.git
|
GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git
|
||||||
GIT_TAG a96cccafcf6365c128f004f779160951f8c0801c
|
GIT_TAG 9691e12bce7ae1c126c435a049eb516eb119486c
|
||||||
GIT_SHALLOW TRUE
|
GIT_SHALLOW FALSE
|
||||||
)
|
)
|
||||||
fetchcontent_makeavailable(trtllm)
|
fetchcontent_makeavailable(trtllm)
|
||||||
message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}")
|
message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}")
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
#ifndef TGI_TRTLLM_BACKEND_FFI_H
|
#ifndef TGI_TRTLLM_BACKEND_FFI_H
|
||||||
#define TGI_TRTLLM_BACKEND_FFI_H
|
#define TGI_TRTLLM_BACKEND_FFI_H
|
||||||
|
|
||||||
//#include "rust/cxx.h"
|
#include <cstddef>
|
||||||
#include "backend.h"
|
#include "backend.h"
|
||||||
|
|
||||||
namespace huggingface::tgi::backends {
|
namespace huggingface::tgi::backends {
|
||||||
@ -17,9 +17,9 @@ namespace huggingface::tgi::backends {
|
|||||||
|
|
||||||
namespace huggingface::tgi::backends {
|
namespace huggingface::tgi::backends {
|
||||||
|
|
||||||
struct GenerationContext;
|
// struct GenerationContext;
|
||||||
|
|
||||||
class TensorRtLlmBackendImpl : TensorRtLlmBackend {
|
class TensorRtLlmBackendImpl : public TensorRtLlmBackend {
|
||||||
public:
|
public:
|
||||||
/***
|
/***
|
||||||
*
|
*
|
||||||
@ -37,7 +37,6 @@ namespace huggingface::tgi::backends {
|
|||||||
/***
|
/***
|
||||||
*
|
*
|
||||||
* @param tokens
|
* @param tokens
|
||||||
* @param maxNewTokens
|
|
||||||
* @param topK
|
* @param topK
|
||||||
* @param topP
|
* @param topP
|
||||||
* @param temperature
|
* @param temperature
|
||||||
@ -45,17 +44,20 @@ namespace huggingface::tgi::backends {
|
|||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
|
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
|
||||||
uint64_t Submit(rust::Slice<const uint32_t> tokens, int32_t maxNewTokens, int32_t topK, float_t topP, float_t temperature, uint64_t seed);
|
uint64_t
|
||||||
|
Submit(rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, uint64_t seed);
|
||||||
|
|
||||||
/***
|
/***
|
||||||
*
|
*
|
||||||
* @param requestId
|
* @param requestId
|
||||||
* @param handler
|
* @param ctx
|
||||||
|
* @param callback
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
uint32_t Stream(rust::Box <GenerationContext> ctx,
|
size_t StreamTokens(
|
||||||
uint64_t requestId,
|
const RequestId requestId,
|
||||||
rust::Fn<void(rust::Box<GenerationContext>, uint32_t, uint32_t, bool)> handler);
|
rust::Box<huggingface::tgi::backends::GenerationContext> ctx,
|
||||||
|
rust::Fn<void(rust::Box<huggingface::tgi::backends::GenerationContext>, uint32_t, float_t, bool)> callback);
|
||||||
};
|
};
|
||||||
|
|
||||||
/***
|
/***
|
||||||
|
@ -1,160 +1,311 @@
|
|||||||
use std::cell::RefCell;
|
use std::future::Future;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
use std::pin::{pin, Pin};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use cxx::UniquePtr;
|
use cxx::UniquePtr;
|
||||||
|
use log::{info, warn};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
||||||
use tokio::time::Instant;
|
use tokio::sync::RwLock;
|
||||||
|
use tokio::time::{Instant, sleep};
|
||||||
|
use tokio_stream::{Stream, StreamExt};
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
|
use tracing::{instrument, Level, span};
|
||||||
|
|
||||||
use text_generation_router::{FinishReason, Token};
|
use text_generation_router::{FinishReason, Token};
|
||||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidParameters};
|
use text_generation_router::validation::ValidGenerateRequest;
|
||||||
|
|
||||||
use crate::errors::TensorRtLlmBackendError;
|
use crate::errors::TensorRtLlmBackendError;
|
||||||
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
|
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
|
||||||
|
|
||||||
|
// macro_rules! propagate {
|
||||||
|
// ($ctx: expr, $res: expr) => {
|
||||||
|
// $ctx.sender
|
||||||
|
// .send($res)
|
||||||
|
// .expect("Failed to propagate error back to the transport layer")
|
||||||
|
// };
|
||||||
|
// }
|
||||||
|
|
||||||
type InferResult<T> = Result<T, InferError>;
|
type InferResult<T> = Result<T, InferError>;
|
||||||
|
|
||||||
pub struct GenerationContext(mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>);
|
/// Holds the user provided input to be executed along with a channel allowing
|
||||||
|
/// to bubble up all the generated tokens for that tokens the to end stream.
|
||||||
|
// pub struct InferenceContext {
|
||||||
|
// /// User provided request
|
||||||
|
// request: ValidGenerateRequest,
|
||||||
|
//
|
||||||
|
// /// Inter-process communication handler moving token from the executor thread to the HTTP server
|
||||||
|
// sender: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||||
|
//
|
||||||
|
// /// Pin the instant this inference context was submitted
|
||||||
|
// when: Instant,
|
||||||
|
//
|
||||||
|
// /// Span that will live as long as entry
|
||||||
|
// span: Span,
|
||||||
|
// }
|
||||||
|
|
||||||
pub struct TrtLLmBackend {
|
pub(crate) struct Generation {
|
||||||
tokenizer: Tokenizer,
|
executor: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
|
||||||
inner: RefCell<UniquePtr<TensorRtLlmBackendImpl>>,
|
done: Arc<AtomicBool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl Sync for TrtLLmBackend {}
|
pub struct GenerationContext(
|
||||||
unsafe impl Send for TrtLLmBackend {}
|
UnboundedSender<InferResult<InferStreamResponse>>,
|
||||||
|
Arc<AtomicBool>,
|
||||||
impl TrtLLmBackend {
|
|
||||||
pub fn new<P: AsRef<Path>>(
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
engine_folder: P,
|
|
||||||
) -> Result<Self, TensorRtLlmBackendError> {
|
|
||||||
let engine_folder = engine_folder.as_ref();
|
|
||||||
let inner = create_tensorrt_llm_backend(engine_folder.to_str().unwrap(), "");
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
tokenizer,
|
|
||||||
inner: RefCell::new(inner),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn infer_text(
|
|
||||||
&self,
|
|
||||||
ctx: GenerationContext,
|
|
||||||
text: &str,
|
|
||||||
params: ValidParameters,
|
|
||||||
) -> InferResult<()> {
|
|
||||||
// Keep track of processing time
|
|
||||||
let start = Instant::now();
|
|
||||||
|
|
||||||
// Encode the input
|
|
||||||
let ctx = Box::new(ctx);
|
|
||||||
let encoding = self
|
|
||||||
.tokenizer
|
|
||||||
.encode(text, true)
|
|
||||||
.map_err(|e| InferError::ToolError(e.to_string()))?;
|
|
||||||
|
|
||||||
// Submit the request to the backend and retrieve the handle to query its status
|
|
||||||
let request_id = self
|
|
||||||
.inner
|
|
||||||
.borrow_mut()
|
|
||||||
.as_mut()
|
|
||||||
.expect("Failed to retrieve pointer to TRTLLM backend")
|
|
||||||
.submit(
|
|
||||||
encoding.get_ids(),
|
|
||||||
128,
|
|
||||||
params.top_k as i32,
|
|
||||||
params.top_p,
|
|
||||||
params.temperature,
|
|
||||||
params.seed,
|
|
||||||
);
|
);
|
||||||
|
|
||||||
// Stream generated tokens
|
impl Stream for Generation {
|
||||||
// spawn_blocking(move || {
|
type Item = usize;
|
||||||
let num_generated_tokens = self
|
|
||||||
.inner
|
|
||||||
.borrow_mut()
|
|
||||||
.as_mut()
|
|
||||||
.expect("Failed to retrieve pointer to TRTLLM backend")
|
|
||||||
.stream(ctx, request_id, |ctx, token, step, is_final| {
|
|
||||||
// self.tokenizer.decode(&*[token], true).unwrap();
|
|
||||||
let sender = ctx.0;
|
|
||||||
let token = Token {
|
|
||||||
id: token,
|
|
||||||
text: String::from(""),
|
|
||||||
logprob: 1.0f32,
|
|
||||||
special: false,
|
|
||||||
};
|
|
||||||
|
|
||||||
sender
|
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
.send(Ok(InferStreamResponse::Intermediate {
|
if self.done.load(Ordering::Relaxed) {
|
||||||
token,
|
Poll::Ready(None)
|
||||||
top_tokens: vec![],
|
} else {
|
||||||
}))
|
let pinned = pin!(self.executor.read());
|
||||||
.unwrap()
|
match pinned.poll(ctx) {
|
||||||
|
Poll::Ready(executor_r) => {
|
||||||
|
let ready = executor_r.num_responses_ready();
|
||||||
|
if ready == 0 {
|
||||||
|
let waker = ctx.waker().clone();
|
||||||
|
tokio::spawn(async {
|
||||||
|
sleep(Duration::from_millis(10)).await;
|
||||||
|
waker.wake();
|
||||||
});
|
});
|
||||||
|
Poll::Pending
|
||||||
|
} else {
|
||||||
|
info!("Ready: {}", ready);
|
||||||
|
let waker = ctx.waker().clone();
|
||||||
|
tokio::spawn(async {
|
||||||
|
sleep(Duration::from_millis(100)).await;
|
||||||
|
waker.wake();
|
||||||
|
});
|
||||||
|
Poll::Ready(Some(ready))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Poll::Pending => {
|
||||||
|
let waker = ctx.waker().clone();
|
||||||
|
tokio::spawn(async {
|
||||||
|
sleep(Duration::from_millis(100)).await;
|
||||||
|
waker.wake();
|
||||||
|
});
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Notify the end
|
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||||
let _ = ctx.0.send(Ok(InferStreamResponse::End {
|
(1, None)
|
||||||
token: Token {
|
}
|
||||||
id: 0,
|
}
|
||||||
text: String::from(""),
|
|
||||||
logprob: 1.0f32,
|
|
||||||
special: false,
|
|
||||||
},
|
|
||||||
top_tokens: vec![],
|
|
||||||
generated_text: GeneratedText {
|
|
||||||
text: String::from(""),
|
|
||||||
generated_tokens: num_generated_tokens,
|
|
||||||
finish_reason: FinishReason::EndOfSequenceToken,
|
|
||||||
seed: Some(params.seed),
|
|
||||||
},
|
|
||||||
start,
|
|
||||||
queued: Instant::now(),
|
|
||||||
}));
|
|
||||||
// });
|
|
||||||
|
|
||||||
Ok(())
|
unsafe impl Send for TensorRtLlmBackendImpl {}
|
||||||
|
unsafe impl Sync for TensorRtLlmBackendImpl {}
|
||||||
|
|
||||||
|
/// Implements the logic to execute generation with TensorRT-LLM executor API in background
|
||||||
|
pub struct TensorRtLlmBackend {
|
||||||
|
// Allowing sending user requests to the TensorRT-LLM executor thread
|
||||||
|
// batcher: UnboundedSender<InferenceContext>,
|
||||||
|
backend: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TensorRtLlmBackend {
|
||||||
|
pub fn new<P: AsRef<Path> + Send + 'static, PP: AsRef<Path> + Send + 'static>(
|
||||||
|
_tokenizer: Tokenizer,
|
||||||
|
engine_folder: P,
|
||||||
|
_executor_worker_path: Option<PP>,
|
||||||
|
) -> Result<Self, TensorRtLlmBackendError> {
|
||||||
|
Ok(TensorRtLlmBackend {
|
||||||
|
backend: Arc::new(RwLock::new(create_tensorrt_llm_backend(
|
||||||
|
engine_folder.as_ref().to_str().unwrap(),
|
||||||
|
"",
|
||||||
|
))),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl Backend for TrtLLmBackend {
|
impl Backend for TensorRtLlmBackend {
|
||||||
|
#[instrument(skip_all)]
|
||||||
fn schedule(
|
fn schedule(
|
||||||
&self,
|
&self,
|
||||||
request: ValidGenerateRequest,
|
_request: ValidGenerateRequest,
|
||||||
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
|
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
|
||||||
let (sender, receiver) = mpsc::unbounded_channel();
|
// Channel to stream the generated token as they come from the worker thread back to the transport layer
|
||||||
let ctx = GenerationContext(sender);
|
let (sender, receiver) = unbounded_channel();
|
||||||
|
|
||||||
// Unpack parameters
|
let executor = self.backend.clone();
|
||||||
let params = request.parameters;
|
tokio::spawn(async move {
|
||||||
|
// Submit the request to the batcher
|
||||||
|
let request_id = span!(Level::DEBUG, "[EXECUTOR][SUBMIT]")
|
||||||
|
.in_scope(|| async {
|
||||||
|
info!("Acquiring lock for submit");
|
||||||
|
let mut handle = executor.write().await;
|
||||||
|
let request_id = handle.pin_mut().submit(
|
||||||
|
&vec![2, 2926, 1503, 603, 20189],
|
||||||
|
50,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
2014,
|
||||||
|
);
|
||||||
|
|
||||||
// Ensure we are running in the right conditions for the input (i.e. single textual chunk)
|
info!("Releasing lock for submit");
|
||||||
let input = match request.inputs.len() {
|
return request_id;
|
||||||
0 => Err(InferError::GenerationError("No input provided".into())),
|
})
|
||||||
1 => Ok(request.inputs.first().unwrap()),
|
.await;
|
||||||
_ => Err(InferError::GenerationError(format!(
|
|
||||||
"Unsupported multi-chunks ({}) inference.",
|
|
||||||
request.inputs.len()
|
|
||||||
))),
|
|
||||||
}?;
|
|
||||||
|
|
||||||
// Currently we handle single chunk of text
|
let mut generation = Generation {
|
||||||
match input {
|
executor: executor.clone(),
|
||||||
Chunk::Text(text) => {
|
done: Arc::new(AtomicBool::new(false)),
|
||||||
self.infer_text(ctx, &**text, params)?;
|
|
||||||
}
|
|
||||||
Chunk::Image(_) => panic!("Unsupported"),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
while let Some(num_tokens_ready) = generation.next().await {
|
||||||
|
span!(
|
||||||
|
Level::DEBUG,
|
||||||
|
"[EXECUTOR][GENERATE]",
|
||||||
|
request_id = request_id,
|
||||||
|
num_tokens_ready = num_tokens_ready
|
||||||
|
)
|
||||||
|
.in_scope(|| async {
|
||||||
|
let ctx = Box::new(GenerationContext(
|
||||||
|
sender.clone(),
|
||||||
|
Arc::clone(&generation.done),
|
||||||
|
));
|
||||||
|
let mut executor_w = executor.write().await;
|
||||||
|
|
||||||
|
info!("Acquired write lock stream");
|
||||||
|
executor_w.pin_mut().stream_tokens(
|
||||||
|
request_id,
|
||||||
|
ctx,
|
||||||
|
|ctx: Box<GenerationContext>, token: u32, logprob: f32, is_final: bool| {
|
||||||
|
info!("Sending token: {} (final: {})", token, is_final);
|
||||||
|
let out = if is_final {
|
||||||
|
ctx.1.store(true, Ordering::Relaxed);
|
||||||
|
InferStreamResponse::End {
|
||||||
|
token: Token {
|
||||||
|
id: token,
|
||||||
|
text: "".into(),
|
||||||
|
logprob,
|
||||||
|
special: false,
|
||||||
|
},
|
||||||
|
top_tokens: vec![],
|
||||||
|
generated_text: GeneratedText {
|
||||||
|
text: "".into(),
|
||||||
|
generated_tokens: u32::MAX,
|
||||||
|
finish_reason: FinishReason::EndOfSequenceToken,
|
||||||
|
seed: None,
|
||||||
|
},
|
||||||
|
start: Instant::now(),
|
||||||
|
queued: Instant::now(),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
InferStreamResponse::Intermediate {
|
||||||
|
token: Token {
|
||||||
|
id: token,
|
||||||
|
text: "".into(),
|
||||||
|
logprob,
|
||||||
|
special: false,
|
||||||
|
},
|
||||||
|
top_tokens: vec![],
|
||||||
|
}
|
||||||
|
};
|
||||||
|
ctx.0
|
||||||
|
.send(Ok(out))
|
||||||
|
.expect("Failed to send back generated token");
|
||||||
|
},
|
||||||
|
);
|
||||||
|
info!("Releasing write lock stream")
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
Ok(UnboundedReceiverStream::new(receiver))
|
Ok(UnboundedReceiverStream::new(receiver))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health(&self, _current_health: bool) -> bool {
|
async fn health(&self, _current_health: bool) -> bool {
|
||||||
self.inner.borrow_mut().is_ready()
|
true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// async fn background_looper<P: AsRef<Path>, PP: AsRef<Path>>(
|
||||||
|
// engine_folder: P,
|
||||||
|
// _executor_worker: Option<PP>,
|
||||||
|
// tokenizer: Tokenizer,
|
||||||
|
// mut receiver: UnboundedReceiver<InferenceContext>,
|
||||||
|
// ) {
|
||||||
|
// let mut backend = create_tensorrt_llm_backend(engine_folder.as_ref().to_str().unwrap(), "");
|
||||||
|
//
|
||||||
|
// while !(receiver.is_closed()) {
|
||||||
|
// // Receive the incoming request
|
||||||
|
// if let Some(ctx) = receiver.recv().await {
|
||||||
|
// debug!("Processing new incoming request");
|
||||||
|
//
|
||||||
|
// // We only support single, textual chunk
|
||||||
|
// if ctx.request.inputs.len() != 1 {
|
||||||
|
// propagate!(
|
||||||
|
// ctx,
|
||||||
|
// Err(InferError::GenerationError(format!(
|
||||||
|
// "Unsupported multi-chunk ({}) input",
|
||||||
|
// ctx.request.inputs.len()
|
||||||
|
// )))
|
||||||
|
// );
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// let input = ctx
|
||||||
|
// .request
|
||||||
|
// .inputs
|
||||||
|
// .first()
|
||||||
|
// .expect("Single chunk checked above");
|
||||||
|
// let params = ctx.request.parameters;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// Receive the incoming request
|
||||||
|
// if let Some(ctx) = receiver.recv().await {
|
||||||
|
// debug!("Processing new incoming request");
|
||||||
|
|
||||||
|
// // We only support single, textual chunk
|
||||||
|
// if ctx.request.inputs.len() != 1 {
|
||||||
|
// propagate!(
|
||||||
|
// ctx,
|
||||||
|
// Err(InferError::GenerationError(format!(
|
||||||
|
// "Unsupported multi-chunk ({}) input",
|
||||||
|
// ctx.request.inputs.len()
|
||||||
|
// )))
|
||||||
|
// );
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // Unpack parameters
|
||||||
|
// let inputs = ctx.request.inputs;
|
||||||
|
// let params = ctx.request.parameters;
|
||||||
|
//
|
||||||
|
// match inputs.first().unwrap() {
|
||||||
|
// Chunk::Text(text) => match tokenizer.encode(text.as_str(), true) {
|
||||||
|
// Err(err) => {
|
||||||
|
// propagate!(ctx, Err(InferError::GenerationError(err.to_string())))
|
||||||
|
// }
|
||||||
|
// Ok(encoding) => {
|
||||||
|
// // spawn_blocking(|| {
|
||||||
|
// // info!("Submitting request to TensorRT-LLM executor");
|
||||||
|
// // let mut executor = backend.blocking_write();
|
||||||
|
// // })
|
||||||
|
// // .await
|
||||||
|
// // .expect("");
|
||||||
|
// }
|
||||||
|
// },
|
||||||
|
// Chunk::Image(_) => propagate!(
|
||||||
|
// ctx,
|
||||||
|
// Err(InferError::GenerationError(
|
||||||
|
// "Image input is not supported yet.".into(),
|
||||||
|
// ))
|
||||||
|
// ),
|
||||||
|
// }
|
||||||
|
// };
|
||||||
|
// }
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include <spdlog/spdlog.h>
|
||||||
#include "backends/trtllm/include/ffi.h"
|
#include "backends/trtllm/include/ffi.h"
|
||||||
|
|
||||||
|
|
||||||
@ -21,42 +22,43 @@ bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
|
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
|
||||||
rust::Slice<const uint32_t> tokens,
|
rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, uint64_t seed) {
|
||||||
int32_t maxNewTokens, int32_t topK, float_t topP,
|
|
||||||
float_t temperature, uint64_t seed) {
|
|
||||||
|
|
||||||
// This will copy all the items from the initial slice
|
// This will copy all the items from the initial slice
|
||||||
std::vector<int32_t> tokens_(tokens.size());
|
std::vector<int32_t> tokens_(tokens.size());
|
||||||
tokens_.assign(tokens.begin(), tokens.end());
|
tokens_.assign(tokens.begin(), tokens.end());
|
||||||
|
|
||||||
return TensorRtLlmBackend::Submit(std::move(tokens_), maxNewTokens, topK, topP, temperature, seed);
|
return TensorRtLlmBackend::Submit(std::move(tokens_), topK, topP, temperature, seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Stream(
|
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(const uint64_t requestId,
|
||||||
rust::Box<huggingface::tgi::backends::GenerationContext> ctx,
|
rust::Box<huggingface::tgi::backends::GenerationContext> ctx,
|
||||||
uint64_t requestId,
|
rust::Fn<void(rust::Box<huggingface::tgi::backends::GenerationContext>, uint32_t, float_t, bool)> callback) {
|
||||||
rust::Fn<void(rust::Box<huggingface::tgi::backends::GenerationContext>, uint32_t, uint32_t, bool)> handler) {
|
|
||||||
bool isDone = false;
|
|
||||||
uint32_t numGeneratedTokens = 0;
|
|
||||||
|
|
||||||
do {
|
SPDLOG_INFO("Entering StreamTokens");
|
||||||
const auto responses = Poll(requestId);
|
for (const auto &item: Poll(requestId)) {
|
||||||
for (const auto &response: responses) {
|
if (!item.hasError()) {
|
||||||
if (response.hasError()) {
|
SPDLOG_INFO("\tStreamTokens -> Decoding token...");
|
||||||
isDone = true;
|
const auto decoded = item.getResult();
|
||||||
// TODO : bubble up the error to rust
|
SPDLOG_INFO("\tStreamTokens -> Successfully read decoded token ({})", decoded.outputTokenIds[0].size());
|
||||||
|
|
||||||
|
const auto token = decoded.outputTokenIds[0][0];
|
||||||
|
const auto isFinal = decoded.isFinal;
|
||||||
|
// const auto logProb = decoded.logProbs.value()[0][0];
|
||||||
|
const auto logProb = 0.0;
|
||||||
|
|
||||||
|
SPDLOG_INFO(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
|
||||||
|
callback(std::move(ctx), token, logProb, isFinal);
|
||||||
|
SPDLOG_INFO("\tStreamTokens -> Post callback");
|
||||||
} else {
|
} else {
|
||||||
const auto generation = response.getResult();
|
// TODO : Return rest::Result with error
|
||||||
const auto token = generation.outputTokenIds[0][0];
|
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", item.getErrorMsg());
|
||||||
isDone = generation.isFinal;
|
callback(std::move(ctx), 0, 0.0, true);
|
||||||
|
|
||||||
// Propagate through the handler
|
|
||||||
handler(std::move(ctx), token, numGeneratedTokens, isDone);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} while (!isDone);
|
|
||||||
|
|
||||||
return numGeneratedTokens;
|
SPDLOG_INFO("Exiting StreamTokens");
|
||||||
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
|
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
|
||||||
|
@ -17,7 +17,7 @@ mod ffi {
|
|||||||
/// Represent an instance of the underlying TensorRT-LLM backend
|
/// Represent an instance of the underlying TensorRT-LLM backend
|
||||||
type TensorRtLlmBackendImpl;
|
type TensorRtLlmBackendImpl;
|
||||||
|
|
||||||
/// Create an instance backed behind an std::unique_ptr to manage the lifespan of the backend
|
/// Create an instance backed behind a std::unique_ptr to manage the lifespan of the backend
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
@ -37,29 +37,31 @@ mod ffi {
|
|||||||
executor_worker: &str,
|
executor_worker: &str,
|
||||||
) -> UniquePtr<TensorRtLlmBackendImpl>;
|
) -> UniquePtr<TensorRtLlmBackendImpl>;
|
||||||
|
|
||||||
#[rust_name = "is_ready"]
|
// #[rust_name = "is_ready"]
|
||||||
fn IsReady(self: &TensorRtLlmBackendImpl) -> bool;
|
// fn IsReady(self: &TensorRtLlmBackendImpl) -> bool;
|
||||||
|
|
||||||
|
#[rust_name = "num_responses_ready"]
|
||||||
|
fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
|
||||||
|
|
||||||
#[rust_name = "submit"]
|
#[rust_name = "submit"]
|
||||||
fn Submit(
|
fn Submit(
|
||||||
self: Pin<&mut TensorRtLlmBackendImpl>,
|
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||||
tokens: &[u32],
|
tokens: &[u32],
|
||||||
max_new_tokens: i32,
|
|
||||||
top_k: i32,
|
top_k: i32,
|
||||||
top_p: f32,
|
top_p: f32,
|
||||||
temperature: f32,
|
temperature: f32,
|
||||||
seed: u64,
|
seed: u64,
|
||||||
) -> u64;
|
) -> u64;
|
||||||
|
|
||||||
#[rust_name = "stream"]
|
#[rust_name = "stream_tokens"]
|
||||||
fn Stream(
|
fn StreamTokens(
|
||||||
self: Pin<&mut TensorRtLlmBackendImpl>,
|
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||||
ctx: Box<GenerationContext>,
|
|
||||||
request_id: u64,
|
request_id: u64,
|
||||||
callback: fn(Box<GenerationContext>, u32, u32, bool),
|
ctx: Box<GenerationContext>,
|
||||||
) -> u32;
|
cb: fn(Box<GenerationContext>, u32, f32, bool),
|
||||||
|
) -> usize;
|
||||||
|
|
||||||
#[rust_name = "shutdown"]
|
// #[rust_name = "shutdown"]
|
||||||
fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>);
|
// fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||||
|
|
||||||
use text_generation_backends_trtllm::{errors::TensorRtLlmBackendError, TrtLLmBackend};
|
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
|
||||||
|
use text_generation_backends_trtllm::TensorRtLlmBackend;
|
||||||
use text_generation_router::server;
|
use text_generation_router::server;
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
@ -53,7 +55,13 @@ struct Args {
|
|||||||
#[clap(default_value = "4", long, env)]
|
#[clap(default_value = "4", long, env)]
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
auth_token: Option<String>
|
auth_token: Option<String>,
|
||||||
|
#[clap(
|
||||||
|
long,
|
||||||
|
env,
|
||||||
|
help = "Path to the TensorRT-LLM Orchestrator Worker binary"
|
||||||
|
)]
|
||||||
|
executor_worker: Option<PathBuf>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@ -83,7 +91,8 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||||||
cors_allow_origin,
|
cors_allow_origin,
|
||||||
messages_api_enabled,
|
messages_api_enabled,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
auth_token
|
auth_token,
|
||||||
|
executor_worker,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
@ -114,6 +123,15 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(ref executor_worker) = executor_worker {
|
||||||
|
if !executor_worker.exists() {
|
||||||
|
return Err(TensorRtLlmBackendError::ArgumentValidation(format!(
|
||||||
|
"`executor_work` specified path doesn't exists: {}",
|
||||||
|
executor_worker.display()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
let tokenizer = Tokenizer::from_pretrained(
|
let tokenizer = Tokenizer::from_pretrained(
|
||||||
tokenizer_name.clone(),
|
tokenizer_name.clone(),
|
||||||
@ -122,9 +140,10 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||||||
user_agent: HashMap::new(),
|
user_agent: HashMap::new(),
|
||||||
auth_token,
|
auth_token,
|
||||||
}),
|
}),
|
||||||
).map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?;
|
)
|
||||||
|
.map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?;
|
||||||
|
|
||||||
let backend = TrtLLmBackend::new(tokenizer, model_id)?;
|
let backend = TensorRtLlmBackend::new(tokenizer, model_id, executor_worker)?;
|
||||||
server::run(
|
server::run(
|
||||||
backend,
|
backend,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
|
Loading…
Reference in New Issue
Block a user