mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
(backend) implement the post_processor background thread
This commit is contained in:
parent
0dca168bcb
commit
c2e21d8725
@ -6,9 +6,9 @@ mod utils;
|
|||||||
|
|
||||||
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
|
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
|
||||||
mod ffi {
|
mod ffi {
|
||||||
|
|
||||||
/// Struct used as shared type between rust and C++ to represent the result
|
/// Struct used as shared type between rust and C++ to represent the result
|
||||||
/// of a single decoding iteration
|
/// of a single decoding iteration
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct GenerationStep {
|
pub struct GenerationStep {
|
||||||
request_id: u64,
|
request_id: u64,
|
||||||
token_id: u32,
|
token_id: u32,
|
||||||
|
@ -6,19 +6,22 @@ use std::sync::OnceLock;
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use cxx::UniquePtr;
|
use cxx::UniquePtr;
|
||||||
use hashbrown::HashMap;
|
use hashbrown::HashMap;
|
||||||
|
use log::warn;
|
||||||
use tokenizers::{Encoding, Tokenizer};
|
use tokenizers::{Encoding, Tokenizer};
|
||||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
||||||
use tokio::task::JoinHandle;
|
use tokio::task::{JoinHandle, spawn_blocking};
|
||||||
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{error, info, Level, span};
|
use tracing::{error, info, Level, span};
|
||||||
|
|
||||||
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
use text_generation_router::{FinishReason, Token};
|
||||||
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
use text_generation_router::infer::InferError::GenerationError;
|
use text_generation_router::infer::InferError::GenerationError;
|
||||||
use text_generation_router::validation::{Chunk, ValidationError, ValidGenerateRequest};
|
use text_generation_router::validation::{Chunk, ValidationError, ValidGenerateRequest};
|
||||||
use text_generation_router::validation::ValidationError::UnsupportedModality;
|
use text_generation_router::validation::ValidationError::UnsupportedModality;
|
||||||
|
|
||||||
use crate::errors::TensorRtLlmBackendError;
|
use crate::errors::TensorRtLlmBackendError;
|
||||||
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
|
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
|
||||||
use crate::utils::first_line;
|
use crate::utils::first_line;
|
||||||
|
|
||||||
// Value used to poll the state of the generation stream
|
// Value used to poll the state of the generation stream
|
||||||
@ -34,15 +37,21 @@ struct ValidGenerateRequestWithTokens {
|
|||||||
inner: ValidGenerateRequest,
|
inner: ValidGenerateRequest,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct DecodedTokenContext {
|
||||||
|
tokens: Vec<GenerationStep>,
|
||||||
|
ctx: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||||
|
}
|
||||||
|
|
||||||
fn executor_status_poller(
|
fn executor_status_poller(
|
||||||
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
|
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
|
||||||
mut waiting_requests: UnboundedReceiver<GenerationContext>,
|
mut waiting_requests: UnboundedReceiver<GenerationContext>,
|
||||||
|
mut post_processor_sender: UnboundedSender<DecodedTokenContext>,
|
||||||
) {
|
) {
|
||||||
// Track the tuple (request_id, stream) for each request
|
// Track the tuple (request_id, stream) for each request
|
||||||
let mut in_flights = HashMap::<u64, GenerationContext>::with_capacity(128);
|
let mut in_flights = HashMap::<u64, GenerationContext>::with_capacity(128);
|
||||||
|
|
||||||
// TODO: Does it need a spin-loop?
|
// TODO: Does it need a spin-loop?
|
||||||
loop {
|
'executor: loop {
|
||||||
span!(Level::DEBUG, "[in-flight][submit]").in_scope(|| {
|
span!(Level::DEBUG, "[in-flight][submit]").in_scope(|| {
|
||||||
// Is there any request pending to be scheduled?
|
// Is there any request pending to be scheduled?
|
||||||
let awaiting_requests = waiting_requests.len();
|
let awaiting_requests = waiting_requests.len();
|
||||||
@ -84,18 +93,40 @@ fn executor_status_poller(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
span!(Level::DEBUG, "[in-flight][poll]").in_scope(|| {
|
if let Err(e) = span!(Level::DEBUG, "[in-flight][poll]").in_scope(|| {
|
||||||
if backend.num_responses_ready() > 0 {
|
if backend.num_responses_ready() > 0 {
|
||||||
match backend.pin_mut().pull_tokens() {
|
match backend.pin_mut().pull_tokens() {
|
||||||
Ok(responses) => {
|
Ok(responses) => {
|
||||||
|
// worse case scenario is one token for each response: with_capacity(responses.len())
|
||||||
|
// grouper will group decoded tokens per request to decode multiple tokens
|
||||||
|
let mut grouper: HashMap<u64, DecodedTokenContext> =
|
||||||
|
HashMap::with_capacity(responses.len());
|
||||||
|
|
||||||
|
// Iterate through all the decoded token
|
||||||
for step in responses.deref() {
|
for step in responses.deref() {
|
||||||
let request_id = step.request_id;
|
let request_id = step.request_id;
|
||||||
|
|
||||||
match in_flights.get(&request_id) {
|
match in_flights.get(&request_id) {
|
||||||
Some(ctx) => {
|
Some(ctx) => {
|
||||||
info!("New token for {} -> {}", request_id, step.token_id);
|
info!("New token for {} -> {}", request_id, step.token_id);
|
||||||
|
|
||||||
if step.is_final {
|
if !step.has_error {
|
||||||
let _ = in_flights.remove(&step.request_id);
|
let req_group = grouper.entry_ref(&request_id).or_insert(
|
||||||
|
DecodedTokenContext {
|
||||||
|
tokens: vec![],
|
||||||
|
ctx: ctx.streamer.clone(), // Arc::clone() = cheap
|
||||||
|
},
|
||||||
|
);
|
||||||
|
req_group.tokens.push(step.clone()); // Should be ultra cheap
|
||||||
|
|
||||||
|
if step.is_final {
|
||||||
|
let _ = in_flights.remove(&step.request_id);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
warn!(
|
||||||
|
"Error for request: {} -> {}",
|
||||||
|
request_id, &step.error_msg
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
@ -103,19 +134,87 @@ fn executor_status_poller(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
grouper
|
||||||
|
.into_values()
|
||||||
|
.map(|ctx| post_processor_sender.send(ctx))
|
||||||
|
.collect()?;
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
error!("Failed to retrieve tokens from the executor: {}", err);
|
error!("Failed to retrieve tokens from the executor: {}", err);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
|
||||||
|
Ok(())
|
||||||
|
}) {
|
||||||
|
error!(
|
||||||
|
"Caught an fatal error in the executor's loop, about to exit. {}",
|
||||||
|
e
|
||||||
|
);
|
||||||
|
break 'executor;
|
||||||
|
}
|
||||||
|
|
||||||
// Hint the CPU we are spin-locking
|
// Hint the CPU we are spin-locking
|
||||||
hint::spin_loop();
|
hint::spin_loop();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn post_processor_looper(
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
mut decoded_tokens: UnboundedReceiver<DecodedTokenContext>,
|
||||||
|
) {
|
||||||
|
'post_processor: loop {
|
||||||
|
if decoded_tokens.is_closed() {
|
||||||
|
warn!("Post processor IPC is closed, loop will exit now.");
|
||||||
|
break 'post_processor;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ctx) = decoded_tokens.blocking_recv() {
|
||||||
|
ctx.tokens.iter().for_each(|step| {
|
||||||
|
let out = match tokenizer.decode(&[step.token_id], true) {
|
||||||
|
Ok(text) => {
|
||||||
|
let is_special = tokenizer.get_added_vocabulary().is_special_token(&text);
|
||||||
|
let token = Token {
|
||||||
|
id: step.token_id,
|
||||||
|
text,
|
||||||
|
logprob: step.log_prob,
|
||||||
|
special: is_special,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = if !step.is_final {
|
||||||
|
InferStreamResponse::Intermediate {
|
||||||
|
token,
|
||||||
|
top_tokens: vec![],
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
InferStreamResponse::End {
|
||||||
|
token,
|
||||||
|
top_tokens: vec![],
|
||||||
|
generated_text: GeneratedText {
|
||||||
|
text: String::from(""),
|
||||||
|
generated_tokens: 0,
|
||||||
|
finish_reason: FinishReason::Length,
|
||||||
|
seed: None,
|
||||||
|
},
|
||||||
|
start: Instant::now(), // Handle start time
|
||||||
|
queued: Instant::now(), // Handle queued time
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
Err(e) => Err(GenerationError(e.to_string())),
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Err(e) = ctx.ctx.send(out) {
|
||||||
|
warn!("Failed to send back the decoded tokens: {}", e);
|
||||||
|
};
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct GenerationContext {
|
struct GenerationContext {
|
||||||
request: ValidGenerateRequestWithTokens,
|
request: ValidGenerateRequestWithTokens,
|
||||||
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
|
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||||
@ -123,8 +222,9 @@ struct GenerationContext {
|
|||||||
|
|
||||||
pub struct TensorRtLlmBackendV2 {
|
pub struct TensorRtLlmBackendV2 {
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
looper: JoinHandle<()>,
|
executor_looper: JoinHandle<()>,
|
||||||
queue: UnboundedSender<GenerationContext>,
|
post_processor_looper: JoinHandle<()>,
|
||||||
|
executor: UnboundedSender<GenerationContext>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TensorRtLlmBackendV2 {
|
impl TensorRtLlmBackendV2 {
|
||||||
@ -150,20 +250,28 @@ impl TensorRtLlmBackendV2 {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Allocate the IPC layer to communicate with the backend
|
// Allocate the IPC layer to communicate with the backend
|
||||||
let (requests_sender, requests_receiver) = unbounded_channel::<GenerationContext>();
|
let (executor_sender, executor_receiver) = unbounded_channel();
|
||||||
|
let (post_processor_sender, post_processor_receiver) = unbounded_channel();
|
||||||
|
|
||||||
// Create the FFI backend
|
// Create the FFI backend
|
||||||
let backend = create_tensorrt_llm_backend(&engine_folder, &executor_worker_path)
|
let backend = create_tensorrt_llm_backend(&engine_folder, &executor_worker_path)
|
||||||
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
|
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
|
||||||
|
|
||||||
// Looper is responsible for scheduling and pulling requests state at regular interval
|
// Executor looper is responsible for scheduling and pulling requests state at regular interval
|
||||||
let looper =
|
let executor_looper = spawn_blocking(move || {
|
||||||
tokio::task::spawn_blocking(move || executor_status_poller(backend, requests_receiver));
|
executor_status_poller(backend, executor_receiver, post_processor_sender)
|
||||||
|
});
|
||||||
|
|
||||||
|
// Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user
|
||||||
|
let tokenizer_ = tokenizer.clone();
|
||||||
|
let post_processor_looper =
|
||||||
|
spawn_blocking(move || post_processor_looper(tokenizer_, post_processor_receiver));
|
||||||
|
|
||||||
Ok(TensorRtLlmBackendV2 {
|
Ok(TensorRtLlmBackendV2 {
|
||||||
tokenizer,
|
tokenizer,
|
||||||
looper,
|
executor_looper,
|
||||||
queue: requests_sender,
|
post_processor_looper,
|
||||||
|
executor: executor_sender,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -212,7 +320,7 @@ impl Backend for TensorRtLlmBackendV2 {
|
|||||||
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
|
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
|
||||||
|
|
||||||
// Send the context to the executor for scheduling
|
// Send the context to the executor for scheduling
|
||||||
match self.queue.send(GenerationContext { request, streamer }) {
|
match self.executor.send(GenerationContext { request, streamer }) {
|
||||||
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
|
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
|
||||||
Err(_) => Err(GenerationError(
|
Err(_) => Err(GenerationError(
|
||||||
"Failed to submit request to the backend".into(),
|
"Failed to submit request to the backend".into(),
|
||||||
@ -221,6 +329,8 @@ impl Backend for TensorRtLlmBackendV2 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn health(&self, current_health: bool) -> bool {
|
async fn health(&self, current_health: bool) -> bool {
|
||||||
current_health & !self.looper.is_finished()
|
current_health
|
||||||
|
& !self.executor_looper.is_finished()
|
||||||
|
& !self.post_processor_looper.is_finished()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user