(backend) implement the post_processor background thread

This commit is contained in:
Morgan Funtowicz 2024-08-05 13:27:18 +00:00 committed by Morgan Funtowicz
parent 0dca168bcb
commit c2e21d8725
2 changed files with 129 additions and 19 deletions

View File

@ -6,9 +6,9 @@ mod utils;
#[cxx::bridge(namespace = "huggingface::tgi::backends")] #[cxx::bridge(namespace = "huggingface::tgi::backends")]
mod ffi { mod ffi {
/// Struct used as shared type between rust and C++ to represent the result /// Struct used as shared type between rust and C++ to represent the result
/// of a single decoding iteration /// of a single decoding iteration
#[derive(Debug, Clone)]
pub struct GenerationStep { pub struct GenerationStep {
request_id: u64, request_id: u64,
token_id: u32, token_id: u32,

View File

@ -6,19 +6,22 @@ use std::sync::OnceLock;
use async_trait::async_trait; use async_trait::async_trait;
use cxx::UniquePtr; use cxx::UniquePtr;
use hashbrown::HashMap; use hashbrown::HashMap;
use log::warn;
use tokenizers::{Encoding, Tokenizer}; use tokenizers::{Encoding, Tokenizer};
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::task::JoinHandle; use tokio::task::{JoinHandle, spawn_blocking};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{error, info, Level, span}; use tracing::{error, info, Level, span};
use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; use text_generation_router::{FinishReason, Token};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::infer::InferError::GenerationError; use text_generation_router::infer::InferError::GenerationError;
use text_generation_router::validation::{Chunk, ValidationError, ValidGenerateRequest}; use text_generation_router::validation::{Chunk, ValidationError, ValidGenerateRequest};
use text_generation_router::validation::ValidationError::UnsupportedModality; use text_generation_router::validation::ValidationError::UnsupportedModality;
use crate::errors::TensorRtLlmBackendError; use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl}; use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
use crate::utils::first_line; use crate::utils::first_line;
// Value used to poll the state of the generation stream // Value used to poll the state of the generation stream
@ -34,15 +37,21 @@ struct ValidGenerateRequestWithTokens {
inner: ValidGenerateRequest, inner: ValidGenerateRequest,
} }
struct DecodedTokenContext {
tokens: Vec<GenerationStep>,
ctx: UnboundedSender<InferResult<InferStreamResponse>>,
}
fn executor_status_poller( fn executor_status_poller(
mut backend: UniquePtr<TensorRtLlmBackendImpl>, mut backend: UniquePtr<TensorRtLlmBackendImpl>,
mut waiting_requests: UnboundedReceiver<GenerationContext>, mut waiting_requests: UnboundedReceiver<GenerationContext>,
mut post_processor_sender: UnboundedSender<DecodedTokenContext>,
) { ) {
// Track the tuple (request_id, stream) for each request // Track the tuple (request_id, stream) for each request
let mut in_flights = HashMap::<u64, GenerationContext>::with_capacity(128); let mut in_flights = HashMap::<u64, GenerationContext>::with_capacity(128);
// TODO: Does it need a spin-loop? // TODO: Does it need a spin-loop?
loop { 'executor: loop {
span!(Level::DEBUG, "[in-flight][submit]").in_scope(|| { span!(Level::DEBUG, "[in-flight][submit]").in_scope(|| {
// Is there any request pending to be scheduled? // Is there any request pending to be scheduled?
let awaiting_requests = waiting_requests.len(); let awaiting_requests = waiting_requests.len();
@ -84,18 +93,40 @@ fn executor_status_poller(
} }
}); });
span!(Level::DEBUG, "[in-flight][poll]").in_scope(|| { if let Err(e) = span!(Level::DEBUG, "[in-flight][poll]").in_scope(|| {
if backend.num_responses_ready() > 0 { if backend.num_responses_ready() > 0 {
match backend.pin_mut().pull_tokens() { match backend.pin_mut().pull_tokens() {
Ok(responses) => { Ok(responses) => {
// worse case scenario is one token for each response: with_capacity(responses.len())
// grouper will group decoded tokens per request to decode multiple tokens
let mut grouper: HashMap<u64, DecodedTokenContext> =
HashMap::with_capacity(responses.len());
// Iterate through all the decoded token
for step in responses.deref() { for step in responses.deref() {
let request_id = step.request_id; let request_id = step.request_id;
match in_flights.get(&request_id) { match in_flights.get(&request_id) {
Some(ctx) => { Some(ctx) => {
info!("New token for {} -> {}", request_id, step.token_id); info!("New token for {} -> {}", request_id, step.token_id);
if step.is_final { if !step.has_error {
let _ = in_flights.remove(&step.request_id); let req_group = grouper.entry_ref(&request_id).or_insert(
DecodedTokenContext {
tokens: vec![],
ctx: ctx.streamer.clone(), // Arc::clone() = cheap
},
);
req_group.tokens.push(step.clone()); // Should be ultra cheap
if step.is_final {
let _ = in_flights.remove(&step.request_id);
}
} else {
warn!(
"Error for request: {} -> {}",
request_id, &step.error_msg
);
} }
} }
None => { None => {
@ -103,19 +134,87 @@ fn executor_status_poller(
} }
} }
} }
grouper
.into_values()
.map(|ctx| post_processor_sender.send(ctx))
.collect()?;
} }
Err(err) => { Err(err) => {
error!("Failed to retrieve tokens from the executor: {}", err); error!("Failed to retrieve tokens from the executor: {}", err);
} }
} }
} }
});
Ok(())
}) {
error!(
"Caught an fatal error in the executor's loop, about to exit. {}",
e
);
break 'executor;
}
// Hint the CPU we are spin-locking // Hint the CPU we are spin-locking
hint::spin_loop(); hint::spin_loop();
} }
} }
fn post_processor_looper(
tokenizer: Tokenizer,
mut decoded_tokens: UnboundedReceiver<DecodedTokenContext>,
) {
'post_processor: loop {
if decoded_tokens.is_closed() {
warn!("Post processor IPC is closed, loop will exit now.");
break 'post_processor;
}
if let Some(ctx) = decoded_tokens.blocking_recv() {
ctx.tokens.iter().for_each(|step| {
let out = match tokenizer.decode(&[step.token_id], true) {
Ok(text) => {
let is_special = tokenizer.get_added_vocabulary().is_special_token(&text);
let token = Token {
id: step.token_id,
text,
logprob: step.log_prob,
special: is_special,
};
let response = if !step.is_final {
InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
}
} else {
InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text: GeneratedText {
text: String::from(""),
generated_tokens: 0,
finish_reason: FinishReason::Length,
seed: None,
},
start: Instant::now(), // Handle start time
queued: Instant::now(), // Handle queued time
}
};
Ok(response)
}
Err(e) => Err(GenerationError(e.to_string())),
};
if let Err(e) = ctx.ctx.send(out) {
warn!("Failed to send back the decoded tokens: {}", e);
};
});
}
}
}
struct GenerationContext { struct GenerationContext {
request: ValidGenerateRequestWithTokens, request: ValidGenerateRequestWithTokens,
streamer: UnboundedSender<InferResult<InferStreamResponse>>, streamer: UnboundedSender<InferResult<InferStreamResponse>>,
@ -123,8 +222,9 @@ struct GenerationContext {
pub struct TensorRtLlmBackendV2 { pub struct TensorRtLlmBackendV2 {
tokenizer: Tokenizer, tokenizer: Tokenizer,
looper: JoinHandle<()>, executor_looper: JoinHandle<()>,
queue: UnboundedSender<GenerationContext>, post_processor_looper: JoinHandle<()>,
executor: UnboundedSender<GenerationContext>,
} }
impl TensorRtLlmBackendV2 { impl TensorRtLlmBackendV2 {
@ -150,20 +250,28 @@ impl TensorRtLlmBackendV2 {
); );
// Allocate the IPC layer to communicate with the backend // Allocate the IPC layer to communicate with the backend
let (requests_sender, requests_receiver) = unbounded_channel::<GenerationContext>(); let (executor_sender, executor_receiver) = unbounded_channel();
let (post_processor_sender, post_processor_receiver) = unbounded_channel();
// Create the FFI backend // Create the FFI backend
let backend = create_tensorrt_llm_backend(&engine_folder, &executor_worker_path) let backend = create_tensorrt_llm_backend(&engine_folder, &executor_worker_path)
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?; .map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
// Looper is responsible for scheduling and pulling requests state at regular interval // Executor looper is responsible for scheduling and pulling requests state at regular interval
let looper = let executor_looper = spawn_blocking(move || {
tokio::task::spawn_blocking(move || executor_status_poller(backend, requests_receiver)); executor_status_poller(backend, executor_receiver, post_processor_sender)
});
// Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user
let tokenizer_ = tokenizer.clone();
let post_processor_looper =
spawn_blocking(move || post_processor_looper(tokenizer_, post_processor_receiver));
Ok(TensorRtLlmBackendV2 { Ok(TensorRtLlmBackendV2 {
tokenizer, tokenizer,
looper, executor_looper,
queue: requests_sender, post_processor_looper,
executor: executor_sender,
}) })
} }
@ -212,7 +320,7 @@ impl Backend for TensorRtLlmBackendV2 {
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>(); let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
// Send the context to the executor for scheduling // Send the context to the executor for scheduling
match self.queue.send(GenerationContext { request, streamer }) { match self.executor.send(GenerationContext { request, streamer }) {
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)), Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
Err(_) => Err(GenerationError( Err(_) => Err(GenerationError(
"Failed to submit request to the backend".into(), "Failed to submit request to the backend".into(),
@ -221,6 +329,8 @@ impl Backend for TensorRtLlmBackendV2 {
} }
async fn health(&self, current_health: bool) -> bool { async fn health(&self, current_health: bool) -> bool {
current_health & !self.looper.is_finished() current_health
& !self.executor_looper.is_finished()
& !self.post_processor_looper.is_finished()
} }
} }