From b1846fb4e6e963a24405088604142b9ca4600b2d Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Sun, 11 Aug 2024 14:10:28 +0200 Subject: [PATCH] (backend) refactor & cleanup --- backends/trtllm/src/lib.rs | 9 +- backends/trtllm/src/looper.rs | 240 ++++++++++++++-------------------- 2 files changed, 107 insertions(+), 142 deletions(-) diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index edd8caff..ca4ca024 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -4,14 +4,17 @@ pub mod errors; mod looper; mod utils; +pub(crate) type RequestId = u64; +pub(crate) type TokenId = u32; + #[cxx::bridge(namespace = "huggingface::tgi::backends")] mod ffi { /// Struct used as shared type between rust and C++ to represent the result /// of a single decoding iteration #[derive(Debug, Clone)] pub struct GenerationStep { - request_id: u64, - token_id: u32, + request_id: RequestId, + token_id: TokenId, log_prob: f32, is_final: bool, has_error: bool, @@ -50,7 +53,7 @@ mod ffi { #[rust_name = "submit"] fn Submit( self: Pin<&mut TensorRtLlmBackendImpl>, - tokens: &[u32], + tokens: &[TokenId], max_new_tokens: u32, top_k: i32, top_p: f32, diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index f070bad6..99d75b81 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -1,11 +1,10 @@ use std::hint; use std::ops::Deref; use std::path::Path; -use std::sync::OnceLock; use async_trait::async_trait; use cxx::UniquePtr; -use hashbrown::HashMap; +use hashbrown::{HashMap, HashSet}; use log::warn; use tokenizers::{Encoding, Tokenizer}; use tokio::sync::mpsc::error::SendError; @@ -13,7 +12,7 @@ use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::task::{spawn_blocking, JoinHandle}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; -use tracing::{debug, error, info, span, Level}; +use tracing::{debug, debug_span, error, info, info_span, span, Level}; use text_generation_router::infer::InferError::{GenerationError, ValidationError}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; @@ -26,32 +25,74 @@ use text_generation_router::{FinishReason, Token}; use crate::errors::TensorRtLlmBackendError; use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl}; use crate::utils::first_line; - -// Value used to poll the state of the generation stream -static POLLING_INTERVAL_US: OnceLock = OnceLock::new(); - -// It's safe to send the backend between threads -unsafe impl Send for TensorRtLlmBackendImpl {} +use crate::RequestId; type InferResult = Result; +struct IdentifiableRequest { + request_id: RequestId, + inner: T, +} + +macro_rules! identifiable { + ($id: expr, $inner: expr) => { + IdentifiableRequest { + id: $id, + inner: $inner, + } + }; +} + +/// Wrap the TGI server forwarded ValidGenerateRequest with the tokenized view of the prompt struct ValidGenerateRequestWithTokens { encoding: Encoding, inner: ValidGenerateRequest, } -struct DecodedTokenContext { - tokens: Vec, - ctx: UnboundedSender>, +/// Wrap the requests along with the channel used to stream back to the client the decoded tokens +struct GenerationContext { + request: ValidGenerateRequestWithTokens, + start: Instant, + queued: Option, + streamer: UnboundedSender>, } -fn executor_status_poller( +#[derive(Debug, Copy, Clone)] +struct DecodedToken { + id: u32, + log_prob: f32, + is_final: bool, +} + +impl TryFrom for DecodedToken { + type Error = InferError; + + fn try_from(step: GenerationStep) -> Result { + if !step.has_error { + Ok(Self { + id: step.token_id, + log_prob: step.log_prob, + is_final: step.is_final, + }) + } else { + Err(GenerationError(step.error_msg)) + } + } +} + +/// Wraps the decoded token with the channel used to stream back to the client the decoded tokens +struct DecodedTokenContext { + token: DecodedToken, + channel: UnboundedSender>, +} + +fn executor_status_looper( mut backend: UniquePtr, mut waiting_requests: UnboundedReceiver, - mut post_processor_sender: UnboundedSender, + mut post_processor_sender: UnboundedSender, ) { // Track the tuple (request_id, stream) for each request - let mut in_flights = HashMap::::with_capacity(128); + let mut in_flights = HashMap::::with_capacity(128); // TODO: Does it need a spin-loop? 'executor: loop { @@ -60,7 +101,7 @@ fn executor_status_poller( let awaiting_requests = waiting_requests.len(); for _ in 0..awaiting_requests { // Retrieve all the requests - if let Some(ctx) = waiting_requests.blocking_recv() { + if let Some(mut ctx) = waiting_requests.blocking_recv() { // Submit all the request to the executor and move the context to the in-flight tracker let request = &ctx.request; let generation_params = &request.inner.parameters; @@ -79,13 +120,15 @@ fn executor_status_poller( ) { Ok(request_id) => { // Insert the context linked to the generated request id in the tracker + debug!("[in-flight] Added {}", request_id); + ctx.queued = Instant::now(); in_flights.insert(request_id, ctx); } Err(e) => { // Return to the caller let what = Err(InferError::SchedulingError(e.to_string())); - if let Err(e) = ctx.streamer.send(what) { - error!("Failed to send back through the channel: {}", e); + if let Err(ref e) = ctx.streamer.send(what) { + error!("Failed to send the client", error = e.as_ref()); } } }; @@ -93,83 +136,38 @@ fn executor_status_poller( } }); - if let Err(e) = span!(Level::DEBUG, "[in-flight][poll]").in_scope(|| { + if let Err(ref e) = info_span!("[in-flight][poll]").in_scope(|| { if backend.num_responses_ready() > 0 { - match backend.pin_mut().pull_tokens() { - Ok(responses) => { - debug!("Received {} tokens from the executor", responses.len()); + let responses = backend + .pin_mut() + .pull_tokens() + .map_err(|e| Err(GenerationError(e.what())))?; - // worse case scenario is one token for each response: with_capacity(responses.len()) - // grouper will group decoded tokens per request to decode multiple tokens - let mut grouper: HashMap = - HashMap::with_capacity(responses.len()); + // Iterate through all the decoded token + for step in responses.deref() { + if let Some(ctx) = in_flights.get(&step.request_id) { + let parcel = DecodedToken::try_from(step).map(|dt| DecodedTokenContext { + token: dt, + channel: ctx.streamer.clone(), + }); - // Iterate through all the decoded token - for step in responses.deref() { - match in_flights.get(&step.request_id) { - Some(ctx) => { - debug!( - "{} -> (token={}, final={})", - step.request_id, step.token_id, step.is_final - ); + // Submit the work to the post_processor + let delivered = post_processor_sender.send(parcel); - // If no error, let's forward to post-processor - if !step.has_error { - let req_group = grouper.entry(step.request_id).or_insert( - DecodedTokenContext { - tokens: vec![], - ctx: ctx.streamer.clone(), // Arc::clone() = cheap - }, - ); - req_group.tokens.push(step.clone()); // Should be ultra cheap - } else { - warn!( - "Error for request: {} -> {}", - step.request_id, &step.error_msg - ); - - // TODO: Send something back to the postprocessor for the client? - } - - // Remove from tracked requests - if step.is_final { - let _ = in_flights.remove(&step.request_id); - } - } - None => { - if step.has_error { - error!( - "Untracked request {} -> {}", - step.request_id, &step.error_msg - ); - continue; - } else { - error!( - "Got step for untracked request {}", - step.request_id - ); - } - } - } + // Remove from tracked requests + if step.is_final { + debug!("Removing {}", step.request_id); + let _ = in_flights.remove(&step.request_id); } - grouper - .into_values() - .map(|ctx| post_processor_sender.send(ctx)) - .collect::>>()?; + delivered + } else { + warn!("Untracked request {}", step.request_id,); } - Err(err) => { - error!("Failed to retrieve tokens from the executor: {}", err); - } - } + }?; } - - Ok::<(), SendError>(()) }) { - error!( - "Caught an fatal error in the executor's loop, about to exit. {}", - e - ); + error!("Error in the executor's loop, exiting", error = e.as_ref()); break 'executor; } @@ -180,7 +178,7 @@ fn executor_status_poller( fn post_processor_looper( tokenizer: Tokenizer, - mut decoded_tokens: UnboundedReceiver, + mut decoded_tokens: UnboundedReceiver, ) { 'post_processor: loop { if decoded_tokens.is_closed() { @@ -188,56 +186,14 @@ fn post_processor_looper( break 'post_processor; } - if let Some(ctx) = decoded_tokens.blocking_recv() { - ctx.tokens.iter().for_each(|step| { - let out = match tokenizer.decode(&[step.token_id], true) { - Ok(text) => { - let is_special = tokenizer.get_added_vocabulary().is_special_token(&text); - let token = Token { - id: step.token_id, - text, - logprob: step.log_prob, - special: is_special, - }; + let mut states = HashMap::with_capacity(128); - let response = if !step.is_final { - InferStreamResponse::Intermediate { - token, - top_tokens: vec![], - } - } else { - InferStreamResponse::End { - token, - top_tokens: vec![], - generated_text: GeneratedText { - text: String::from(""), - generated_tokens: 0, - finish_reason: FinishReason::Length, - seed: None, - }, - start: Instant::now(), // Handle start time - queued: Instant::now(), // Handle queued time - } - }; - - Ok(response) - } - Err(e) => Err(GenerationError(e.to_string())), - }; - - if let Err(e) = ctx.ctx.send(out) { - warn!("Failed to send back the decoded tokens: {}", e); - }; - }); + if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() { + let state = states.entry(request_id).or_insert(vec![]); } } } -struct GenerationContext { - request: ValidGenerateRequestWithTokens, - streamer: UnboundedSender>, -} - pub struct TensorRtLlmBackendV2 { tokenizer: Tokenizer, executor_looper: JoinHandle<()>, @@ -277,7 +233,7 @@ impl TensorRtLlmBackendV2 { // Executor looper is responsible for scheduling and pulling requests state at regular interval let executor_looper = spawn_blocking(move || { - executor_status_poller(backend, executor_receiver, post_processor_sender) + executor_status_looper(backend, executor_receiver, post_processor_sender) }); // Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user @@ -295,22 +251,22 @@ impl TensorRtLlmBackendV2 { fn validate(request: &ValidGenerateRequest) -> InferResult<&String> { if request.top_n_tokens > 1 { - return Err(InferError::ValidationError(TopNTokensDisabled)); + return Err(ValidationError(TopNTokensDisabled)); } // TODO: Is it really needed? How can it be validated before? if request.parameters.grammar.is_some() { - return Err(InferError::ValidationError(Grammar)); + return Err(ValidationError(Grammar)); } match request.inputs.len() { - 0 => Err(InferError::ValidationError(EmptyInput)), - 2.. => Err(InferError::GenerationError( + 0 => Err(ValidationError(EmptyInput)), + 2.. => Err(GenerationError( "TensorRT-LLM backend don't support multi-chunk".into(), )), 1 => match request.inputs.first().expect("Single item-chunk") { Chunk::Text(text) => Ok(text), - Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))), + Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))), }, } } @@ -336,7 +292,13 @@ impl Backend for TensorRtLlmBackendV2 { let (streamer, receiver) = unbounded_channel::>(); // Send the context to the executor for scheduling - match self.executor.send(GenerationContext { request, streamer }) { + let start = Instant::now(); + match self.executor.send(GenerationContext { + request, + start, + queued: None, + streamer, + }) { Ok(_) => Ok(UnboundedReceiverStream::new(receiver)), Err(_) => Err(GenerationError( "Failed to submit request to the backend".into(),