diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index 272a04c0..63acdfa4 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -1,10 +1,9 @@ -use std::hint; -use std::ops::Deref; -use std::path::Path; - use async_trait::async_trait; use cxx::UniquePtr; use hashbrown::HashMap; +use std::hint; +use std::ops::Deref; +use std::path::Path; use tokenizers::Tokenizer; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::TryAcquireError; @@ -30,9 +29,10 @@ type InferResult = Result; /// Wrap the requests along with the channel used to stream back to the client the decoded tokens struct GenerationContext { request: ValidGenerateRequest, + streamer: UnboundedSender>, + tokens: Vec, start: Option, queued: Instant, - streamer: UnboundedSender>, } #[derive(Debug, Copy, Clone)] @@ -58,31 +58,22 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken { } } -/// Wraps the decoded token with the channel used to stream back to the client the decoded tokens -struct DecodedTokenContext { - token: DecodedToken, - start: Option, - queued: Instant, - channel: UnboundedSender>, -} - fn executor_status_looper( - mut backend: UniquePtr, max_inflight_requests: usize, - mut waiting_requests: UnboundedReceiver, - post_processor_sender: UnboundedSender<(u64, InferResult)>, + tokenizer: Tokenizer, + mut backend: UniquePtr, + mut backlog: UnboundedReceiver, ) { // Track the tuple (request_id, stream) for each request let mut in_flights = HashMap::::with_capacity(max_inflight_requests * 2); - // TODO: Does it need a spin-loop? 'scheduler: loop { // Is there any request pending to be scheduled? - let awaiting_requests = waiting_requests.len(); + let awaiting_requests = backlog.len(); for _ in 0..awaiting_requests { // Retrieve all the requests - if let Some(mut ctx) = waiting_requests.blocking_recv() { + if let Some(ctx) = backlog.blocking_recv() { // Submit all the request to the executor and move the context to the in-flight tracker let request = &ctx.request; let generation_params = &request.parameters; @@ -103,7 +94,6 @@ fn executor_status_looper( Ok(request_id) => { // Insert the context linked to the generated request id in the tracker debug!("[in-flight] Added {}", request_id); - ctx.start = Some(Instant::now()); in_flights.insert(request_id, ctx); } Err(e) => { @@ -117,6 +107,8 @@ fn executor_status_looper( } } }; + } else { + break 'scheduler; } } @@ -125,21 +117,28 @@ fn executor_status_looper( Ok(responses) => { // Iterate through all the decoded token for step in responses.deref() { - if let Some(ctx) = in_flights.get(&step.request_id) { - // Remove from tracked requests - let parcel = - DecodedToken::try_from(step).map(|dt| DecodedTokenContext { - token: dt, - start: ctx.start, - queued: ctx.queued, - channel: ctx.streamer.clone(), - }); + if let Some(ctx) = in_flights.get_mut(&step.request_id) { + // Update the starting timestamp if not set + // This value might not be the actual real starting time of the request + // on the executor side - Need to expose more info from the executor to + // retrieve this value + // TODO : Expose actual real starting time for a request on FFI layer + if ctx.start.is_none() { + ctx.start = Some(Instant::now()); + } - // Submit the work to p:the post_processor - let posted = post_processor_sender.send((step.request_id, parcel)); + // Try to map the generation step to a DecodedToken + let response = match DecodedToken::try_from(step) { + Ok(decoded_token) => { + post_process_decoded_token(&tokenizer, ctx, decoded_token) + } + Err(err) => Err(err) + }; - if posted.is_err() || step.is_final { - debug!("Removing {}", step.request_id); + // Attempt to send back the response to the client + if let Err(_) = ctx.streamer.send(response) { + // Client has dropped, remove from tracked requests + debug!("Client dropped - removing request {} from tracked requests", step.request_id); backend.pin_mut().cancel(step.request_id); let _ = in_flights.remove(&step.request_id); } @@ -160,80 +159,48 @@ fn executor_status_looper( } } -fn post_processor_looper( - tokenizer: Tokenizer, - max_inflight_requests: usize, - mut decoded_tokens: UnboundedReceiver<(u64, InferResult)>, -) { - let mut states: HashMap> = HashMap::with_capacity(max_inflight_requests * 2); +fn post_process_decoded_token(tokenizer: &Tokenizer, ctx: &mut GenerationContext, decoded_token: DecodedToken) -> InferResult { + match tokenizer.decode(&[decoded_token.id], false) { + Ok(text) => { + let is_special = + tokenizer.get_added_vocabulary().is_special_token(&text); + let token = Token { + id: decoded_token.id, + text, + logprob: decoded_token.log_prob, + special: is_special, + }; - 'post_processor: loop { - if decoded_tokens.is_closed() { - warn!("Post processor IPC is closed, loop will exit now."); - break 'post_processor; - } + // Append the token to the tracked generated tokens + ctx.tokens.push(token.id); - if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() { - match decoded { - Ok(ctx) => { - states - .entry(request_id) - .and_modify(|s| s.push(*&ctx.token.id)) - .or_insert_with(|| { - let mut state = Vec::with_capacity(MAX_NUM_TOKENS); - state.push(*&ctx.token.id); - state - }); - - let out = match tokenizer.decode(&[ctx.token.id], false) { - Ok(text) => { - let is_special = - tokenizer.get_added_vocabulary().is_special_token(&text); - let token = Token { - id: ctx.token.id, - text, - logprob: ctx.token.log_prob, - special: is_special, - }; - - let out = if !ctx.token.is_final { - InferStreamResponse::Intermediate { - token, - top_tokens: vec![], - } - } else { - let tokens = states.remove(&request_id).unwrap(); - let text = tokenizer.decode(&tokens, true); - let generated_text = GeneratedText { - text: text.unwrap(), - generated_tokens: tokens.len() as u32, - finish_reason: FinishReason::EndOfSequenceToken, - seed: None, - }; - - InferStreamResponse::End { - token, - top_tokens: vec![], - generated_text, - start: ctx.start.unwrap(), - queued: ctx.queued, - } - }; - - Ok(out) - } - Err(err) => Err(GenerationError(err.to_string())), - }; - - if let Err(_) = ctx.channel.send(out) { - warn!("Failed to send decoded token back to the user") - } + // Map the correct response depending on the step is final or not + let out = if !decoded_token.is_final { + InferStreamResponse::Intermediate { + token, + top_tokens: vec![], } - Err(_err) => { - todo!("what do we do?") + } else { + let text = tokenizer.decode(&ctx.tokens, true); + let generated_text = GeneratedText { + text: text.unwrap(), + generated_tokens: ctx.tokens.len() as u32, + finish_reason: FinishReason::EndOfSequenceToken, // TODO : Map FinishReason + seed: None, + }; + + InferStreamResponse::End { + token, + top_tokens: vec![], + generated_text, + start: ctx.start.unwrap(), + queued: ctx.queued, } - } + }; + + Ok(out) } + Err(err) => Err(GenerationError(err.to_string())), } } @@ -278,11 +245,8 @@ fn ensure_paths_exist, PP: AsRef>( unsafe impl Send for TensorRtLlmBackendImpl {} -pub struct TensorRtLlmBackendV2 { - executor_looper: JoinHandle<()>, - post_processor_looper: JoinHandle<()>, - executor: UnboundedSender, -} +pub struct TensorRtLlmBackendV2(UnboundedSender); + impl TensorRtLlmBackendV2 { pub fn new + Send, PP: AsRef + Send>( @@ -296,32 +260,22 @@ impl TensorRtLlmBackendV2 { // Allocate the IPC layer to communicate with the backend let (executor_sender, executor_receiver) = unbounded_channel(); - let (post_processor_sender, post_processor_receiver) = unbounded_channel(); // Create the FFI backend let backend = create_backend_from_engine_folder(&engine_folder, &executor_worker_path) .map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?; // Executor looper is responsible for scheduling and pulling requests state at regular interval - let executor_looper = spawn_blocking(move || { + spawn_blocking(move || { executor_status_looper( - backend, max_inflight_requests, + tokenizer, + backend, executor_receiver, - post_processor_sender, ) }); - // Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user - let post_processor_looper = spawn_blocking(move || { - post_processor_looper::<256>(tokenizer, max_inflight_requests, post_processor_receiver) - }); - - Ok(TensorRtLlmBackendV2 { - executor_looper, - post_processor_looper, - executor: executor_sender, - }) + Ok(TensorRtLlmBackendV2(executor_sender)) } fn validate(request: &ValidGenerateRequest) -> InferResult<()> { @@ -355,20 +309,21 @@ impl TensorRtLlmBackendV2 { impl Backend for TensorRtLlmBackendV2 { fn schedule( &self, - inner: ValidGenerateRequest, + request: ValidGenerateRequest, ) -> Result>, InferError> { - Self::validate(&inner)?; + Self::validate(&request)?; // Open-up the stream to send tokens let (streamer, receiver) = unbounded_channel::>(); // Send the context to the executor for scheduling let queued = Instant::now(); - match self.executor.send(GenerationContext { - request: inner, + match self.send(GenerationContext { + request, + streamer, + tokens: Vec::with_capacity(256), start: None, queued, - streamer, }) { Ok(_) => Ok(UnboundedReceiverStream::new(receiver)), Err(_) => Err(GenerationError( @@ -378,6 +333,6 @@ impl Backend for TensorRtLlmBackendV2 { } async fn health(&self, _: bool) -> bool { - !self.executor_looper.is_finished() & !self.post_processor_looper.is_finished() + true } }