diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index e6e97c03..edd8caff 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -6,9 +6,9 @@ mod utils; #[cxx::bridge(namespace = "huggingface::tgi::backends")] mod ffi { - /// Struct used as shared type between rust and C++ to represent the result /// of a single decoding iteration + #[derive(Debug, Clone)] pub struct GenerationStep { request_id: u64, token_id: u32, diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index c7225062..7d805863 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -6,19 +6,22 @@ use std::sync::OnceLock; use async_trait::async_trait; use cxx::UniquePtr; use hashbrown::HashMap; +use log::warn; use tokenizers::{Encoding, Tokenizer}; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; -use tokio::task::JoinHandle; +use tokio::task::{JoinHandle, spawn_blocking}; +use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{error, info, Level, span}; -use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; +use text_generation_router::{FinishReason, Token}; +use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::infer::InferError::GenerationError; use text_generation_router::validation::{Chunk, ValidationError, ValidGenerateRequest}; use text_generation_router::validation::ValidationError::UnsupportedModality; use crate::errors::TensorRtLlmBackendError; -use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl}; +use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl}; use crate::utils::first_line; // Value used to poll the state of the generation stream @@ -34,15 +37,21 @@ struct ValidGenerateRequestWithTokens { inner: ValidGenerateRequest, } +struct DecodedTokenContext { + tokens: Vec, + ctx: UnboundedSender>, +} + fn executor_status_poller( mut backend: UniquePtr, mut waiting_requests: UnboundedReceiver, + mut post_processor_sender: UnboundedSender, ) { // Track the tuple (request_id, stream) for each request let mut in_flights = HashMap::::with_capacity(128); // TODO: Does it need a spin-loop? - loop { + 'executor: loop { span!(Level::DEBUG, "[in-flight][submit]").in_scope(|| { // Is there any request pending to be scheduled? let awaiting_requests = waiting_requests.len(); @@ -84,18 +93,40 @@ fn executor_status_poller( } }); - span!(Level::DEBUG, "[in-flight][poll]").in_scope(|| { + if let Err(e) = span!(Level::DEBUG, "[in-flight][poll]").in_scope(|| { if backend.num_responses_ready() > 0 { match backend.pin_mut().pull_tokens() { Ok(responses) => { + // worse case scenario is one token for each response: with_capacity(responses.len()) + // grouper will group decoded tokens per request to decode multiple tokens + let mut grouper: HashMap = + HashMap::with_capacity(responses.len()); + + // Iterate through all the decoded token for step in responses.deref() { let request_id = step.request_id; + match in_flights.get(&request_id) { Some(ctx) => { info!("New token for {} -> {}", request_id, step.token_id); - if step.is_final { - let _ = in_flights.remove(&step.request_id); + if !step.has_error { + let req_group = grouper.entry_ref(&request_id).or_insert( + DecodedTokenContext { + tokens: vec![], + ctx: ctx.streamer.clone(), // Arc::clone() = cheap + }, + ); + req_group.tokens.push(step.clone()); // Should be ultra cheap + + if step.is_final { + let _ = in_flights.remove(&step.request_id); + } + } else { + warn!( + "Error for request: {} -> {}", + request_id, &step.error_msg + ); } } None => { @@ -103,19 +134,87 @@ fn executor_status_poller( } } } + + grouper + .into_values() + .map(|ctx| post_processor_sender.send(ctx)) + .collect()?; } Err(err) => { error!("Failed to retrieve tokens from the executor: {}", err); } } } - }); + + Ok(()) + }) { + error!( + "Caught an fatal error in the executor's loop, about to exit. {}", + e + ); + break 'executor; + } // Hint the CPU we are spin-locking hint::spin_loop(); } } +fn post_processor_looper( + tokenizer: Tokenizer, + mut decoded_tokens: UnboundedReceiver, +) { + 'post_processor: loop { + if decoded_tokens.is_closed() { + warn!("Post processor IPC is closed, loop will exit now."); + break 'post_processor; + } + + if let Some(ctx) = decoded_tokens.blocking_recv() { + ctx.tokens.iter().for_each(|step| { + let out = match tokenizer.decode(&[step.token_id], true) { + Ok(text) => { + let is_special = tokenizer.get_added_vocabulary().is_special_token(&text); + let token = Token { + id: step.token_id, + text, + logprob: step.log_prob, + special: is_special, + }; + + let response = if !step.is_final { + InferStreamResponse::Intermediate { + token, + top_tokens: vec![], + } + } else { + InferStreamResponse::End { + token, + top_tokens: vec![], + generated_text: GeneratedText { + text: String::from(""), + generated_tokens: 0, + finish_reason: FinishReason::Length, + seed: None, + }, + start: Instant::now(), // Handle start time + queued: Instant::now(), // Handle queued time + } + }; + + Ok(response) + } + Err(e) => Err(GenerationError(e.to_string())), + }; + + if let Err(e) = ctx.ctx.send(out) { + warn!("Failed to send back the decoded tokens: {}", e); + }; + }); + } + } +} + struct GenerationContext { request: ValidGenerateRequestWithTokens, streamer: UnboundedSender>, @@ -123,8 +222,9 @@ struct GenerationContext { pub struct TensorRtLlmBackendV2 { tokenizer: Tokenizer, - looper: JoinHandle<()>, - queue: UnboundedSender, + executor_looper: JoinHandle<()>, + post_processor_looper: JoinHandle<()>, + executor: UnboundedSender, } impl TensorRtLlmBackendV2 { @@ -150,20 +250,28 @@ impl TensorRtLlmBackendV2 { ); // Allocate the IPC layer to communicate with the backend - let (requests_sender, requests_receiver) = unbounded_channel::(); + let (executor_sender, executor_receiver) = unbounded_channel(); + let (post_processor_sender, post_processor_receiver) = unbounded_channel(); // Create the FFI backend let backend = create_tensorrt_llm_backend(&engine_folder, &executor_worker_path) .map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?; - // Looper is responsible for scheduling and pulling requests state at regular interval - let looper = - tokio::task::spawn_blocking(move || executor_status_poller(backend, requests_receiver)); + // Executor looper is responsible for scheduling and pulling requests state at regular interval + let executor_looper = spawn_blocking(move || { + executor_status_poller(backend, executor_receiver, post_processor_sender) + }); + + // Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user + let tokenizer_ = tokenizer.clone(); + let post_processor_looper = + spawn_blocking(move || post_processor_looper(tokenizer_, post_processor_receiver)); Ok(TensorRtLlmBackendV2 { tokenizer, - looper, - queue: requests_sender, + executor_looper, + post_processor_looper, + executor: executor_sender, }) } @@ -212,7 +320,7 @@ impl Backend for TensorRtLlmBackendV2 { let (streamer, receiver) = unbounded_channel::>(); // Send the context to the executor for scheduling - match self.queue.send(GenerationContext { request, streamer }) { + match self.executor.send(GenerationContext { request, streamer }) { Ok(_) => Ok(UnboundedReceiverStream::new(receiver)), Err(_) => Err(GenerationError( "Failed to submit request to the backend".into(), @@ -221,6 +329,8 @@ impl Backend for TensorRtLlmBackendV2 { } async fn health(&self, current_health: bool) -> bool { - current_health & !self.looper.is_finished() + current_health + & !self.executor_looper.is_finished() + & !self.post_processor_looper.is_finished() } }