diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index ca4ca024d..6d3297d66 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -4,17 +4,14 @@ pub mod errors; mod looper; mod utils; -pub(crate) type RequestId = u64; -pub(crate) type TokenId = u32; - #[cxx::bridge(namespace = "huggingface::tgi::backends")] mod ffi { /// Struct used as shared type between rust and C++ to represent the result /// of a single decoding iteration #[derive(Debug, Clone)] pub struct GenerationStep { - request_id: RequestId, - token_id: TokenId, + request_id: u64, + token_id: u32, log_prob: f32, is_final: bool, has_error: bool, @@ -53,7 +50,7 @@ mod ffi { #[rust_name = "submit"] fn Submit( self: Pin<&mut TensorRtLlmBackendImpl>, - tokens: &[TokenId], + tokens: &[u32], max_new_tokens: u32, top_k: i32, top_p: f32, @@ -68,4 +65,5 @@ mod ffi { self: Pin<&mut TensorRtLlmBackendImpl>, ) -> Result>>; } + } diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index 99d75b81c..4247f338d 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -3,46 +3,34 @@ use std::ops::Deref; use std::path::Path; use async_trait::async_trait; -use cxx::UniquePtr; -use hashbrown::{HashMap, HashSet}; +use cxx::{UniquePtr}; +use hashbrown::{HashMap}; use log::warn; use tokenizers::{Encoding, Tokenizer}; -use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::task::{spawn_blocking, JoinHandle}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; -use tracing::{debug, debug_span, error, info, info_span, span, Level}; +use tracing::{debug, error}; use text_generation_router::infer::InferError::{GenerationError, ValidationError}; -use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; +use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; use text_generation_router::validation::ValidationError::{ EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality, }; use text_generation_router::validation::{Chunk, ValidGenerateRequest}; -use text_generation_router::{FinishReason, Token}; use crate::errors::TensorRtLlmBackendError; use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl}; use crate::utils::first_line; -use crate::RequestId; type InferResult = Result; struct IdentifiableRequest { - request_id: RequestId, + request_id: u64, inner: T, } -macro_rules! identifiable { - ($id: expr, $inner: expr) => { - IdentifiableRequest { - id: $id, - inner: $inner, - } - }; -} - /// Wrap the TGI server forwarded ValidGenerateRequest with the tokenized view of the prompt struct ValidGenerateRequestWithTokens { encoding: Encoding, @@ -52,8 +40,8 @@ struct ValidGenerateRequestWithTokens { /// Wrap the requests along with the channel used to stream back to the client the decoded tokens struct GenerationContext { request: ValidGenerateRequestWithTokens, - start: Instant, - queued: Option, + start: Option, + queued: Instant, streamer: UnboundedSender>, } @@ -64,10 +52,10 @@ struct DecodedToken { is_final: bool, } -impl TryFrom for DecodedToken { +impl<'step> TryFrom<&'step GenerationStep> for DecodedToken { type Error = InferError; - fn try_from(step: GenerationStep) -> Result { + fn try_from(step: &'step GenerationStep) -> Result { if !step.has_error { Ok(Self { id: step.token_id, @@ -75,7 +63,7 @@ impl TryFrom for DecodedToken { is_final: step.is_final, }) } else { - Err(GenerationError(step.error_msg)) + Err(GenerationError(step.error_msg.clone())) } } } @@ -89,86 +77,84 @@ struct DecodedTokenContext { fn executor_status_looper( mut backend: UniquePtr, mut waiting_requests: UnboundedReceiver, - mut post_processor_sender: UnboundedSender, + post_processor_sender: UnboundedSender<(u64, InferResult)>, ) { // Track the tuple (request_id, stream) for each request - let mut in_flights = HashMap::::with_capacity(128); + let mut in_flights = HashMap::::with_capacity(128); // TODO: Does it need a spin-loop? - 'executor: loop { - span!(Level::DEBUG, "[in-flight][submit]").in_scope(|| { - // Is there any request pending to be scheduled? - let awaiting_requests = waiting_requests.len(); - for _ in 0..awaiting_requests { - // Retrieve all the requests - if let Some(mut ctx) = waiting_requests.blocking_recv() { - // Submit all the request to the executor and move the context to the in-flight tracker - let request = &ctx.request; - let generation_params = &request.inner.parameters; - let stopping_params = &request.inner.stopping_parameters; + 'scheduler: loop { + // Is there any request pending to be scheduled? + let awaiting_requests = waiting_requests.len(); + for _ in 0..awaiting_requests { + // Retrieve all the requests + if let Some(mut ctx) = waiting_requests.blocking_recv() { + // Submit all the request to the executor and move the context to the in-flight tracker + let request = &ctx.request; + let generation_params = &request.inner.parameters; + let stopping_params = &request.inner.stopping_parameters; - // Submit to the TensorRT-LLM executor for scheduling - match backend.pin_mut().submit( - request.encoding.get_ids(), - stopping_params.max_new_tokens, - generation_params.top_k as i32, - generation_params.top_p, - generation_params.temperature, - generation_params.repetition_penalty, - generation_params.frequency_penalty, - generation_params.seed, - ) { - Ok(request_id) => { - // Insert the context linked to the generated request id in the tracker - debug!("[in-flight] Added {}", request_id); - ctx.queued = Instant::now(); - in_flights.insert(request_id, ctx); + // Submit to the TensorRT-LLM executor for scheduling + match backend.pin_mut().submit( + request.encoding.get_ids(), + stopping_params.max_new_tokens, + generation_params.top_k as i32, + generation_params.top_p, + generation_params.temperature, + generation_params.repetition_penalty, + generation_params.frequency_penalty, + generation_params.seed, + ) { + Ok(request_id) => { + // Insert the context linked to the generated request id in the tracker + debug!("[in-flight] Added {}", request_id); + ctx.start = Some(Instant::now()); + in_flights.insert(request_id, ctx); + } + Err(e) => { + // Return to the caller + let what = e.to_string(); + error!(error = what.as_str(), "Failed to schedule request"); + + let err = Err(InferError::SchedulingError(what)); + if let Err(_) = ctx.streamer.send(err) { + error!("Failed to send back error to the client"); } - Err(e) => { - // Return to the caller - let what = Err(InferError::SchedulingError(e.to_string())); - if let Err(ref e) = ctx.streamer.send(what) { - error!("Failed to send the client", error = e.as_ref()); + } + }; + } + } + + if backend.num_responses_ready() > 0 { + match backend.pin_mut().pull_tokens() { + Ok(responses) => { + // Iterate through all the decoded token + for step in responses.deref() { + if let Some(ctx) = in_flights.get(&step.request_id) { + + // Remove from tracked requests + let parcel = DecodedToken::try_from(step).map(|dt| DecodedTokenContext { + token: dt, + channel: ctx.streamer.clone(), + }); + + // Submit the work to p:the post_processor + let posted = post_processor_sender.send((step.request_id, parcel)); + + if posted.is_err() || step.is_final { + debug!("Removing {}", step.request_id); + let _ = in_flights.remove(&step.request_id); } + } else { + warn!("Untracked request {}", step.request_id,); } }; } + Err(ref err) => { + error!("Failed to get responses from the executor: {}.", err.what()); + break 'scheduler; + } } - }); - - if let Err(ref e) = info_span!("[in-flight][poll]").in_scope(|| { - if backend.num_responses_ready() > 0 { - let responses = backend - .pin_mut() - .pull_tokens() - .map_err(|e| Err(GenerationError(e.what())))?; - - // Iterate through all the decoded token - for step in responses.deref() { - if let Some(ctx) = in_flights.get(&step.request_id) { - let parcel = DecodedToken::try_from(step).map(|dt| DecodedTokenContext { - token: dt, - channel: ctx.streamer.clone(), - }); - - // Submit the work to the post_processor - let delivered = post_processor_sender.send(parcel); - - // Remove from tracked requests - if step.is_final { - debug!("Removing {}", step.request_id); - let _ = in_flights.remove(&step.request_id); - } - - delivered - } else { - warn!("Untracked request {}", step.request_id,); - } - }?; - } - }) { - error!("Error in the executor's loop, exiting", error = e.as_ref()); - break 'executor; } // Hint the CPU we are spin-locking @@ -178,7 +164,7 @@ fn executor_status_looper( fn post_processor_looper( tokenizer: Tokenizer, - mut decoded_tokens: UnboundedReceiver, + mut decoded_tokens: UnboundedReceiver<(u64, InferResult)>, ) { 'post_processor: loop { if decoded_tokens.is_closed() { @@ -186,7 +172,7 @@ fn post_processor_looper( break 'post_processor; } - let mut states = HashMap::with_capacity(128); + let mut states: HashMap> = HashMap::with_capacity(128); if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() { let state = states.entry(request_id).or_insert(vec![]); @@ -194,6 +180,9 @@ fn post_processor_looper( } } + +unsafe impl Send for crate::ffi::TensorRtLlmBackendImpl {} + pub struct TensorRtLlmBackendV2 { tokenizer: Tokenizer, executor_looper: JoinHandle<()>, @@ -292,11 +281,11 @@ impl Backend for TensorRtLlmBackendV2 { let (streamer, receiver) = unbounded_channel::>(); // Send the context to the executor for scheduling - let start = Instant::now(); + let queued = Instant::now(); match self.executor.send(GenerationContext { request, - start, - queued: None, + start: None, + queued, streamer, }) { Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),