(backend) refactor & cleanup

This commit is contained in:
Morgan Funtowicz 2024-08-11 14:10:28 +02:00
parent 483f172938
commit b1846fb4e6
2 changed files with 107 additions and 142 deletions

View File

@ -4,14 +4,17 @@ pub mod errors;
mod looper; mod looper;
mod utils; mod utils;
pub(crate) type RequestId = u64;
pub(crate) type TokenId = u32;
#[cxx::bridge(namespace = "huggingface::tgi::backends")] #[cxx::bridge(namespace = "huggingface::tgi::backends")]
mod ffi { mod ffi {
/// Struct used as shared type between rust and C++ to represent the result /// Struct used as shared type between rust and C++ to represent the result
/// of a single decoding iteration /// of a single decoding iteration
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct GenerationStep { pub struct GenerationStep {
request_id: u64, request_id: RequestId,
token_id: u32, token_id: TokenId,
log_prob: f32, log_prob: f32,
is_final: bool, is_final: bool,
has_error: bool, has_error: bool,
@ -50,7 +53,7 @@ mod ffi {
#[rust_name = "submit"] #[rust_name = "submit"]
fn Submit( fn Submit(
self: Pin<&mut TensorRtLlmBackendImpl>, self: Pin<&mut TensorRtLlmBackendImpl>,
tokens: &[u32], tokens: &[TokenId],
max_new_tokens: u32, max_new_tokens: u32,
top_k: i32, top_k: i32,
top_p: f32, top_p: f32,

View File

@ -1,11 +1,10 @@
use std::hint; use std::hint;
use std::ops::Deref; use std::ops::Deref;
use std::path::Path; use std::path::Path;
use std::sync::OnceLock;
use async_trait::async_trait; use async_trait::async_trait;
use cxx::UniquePtr; use cxx::UniquePtr;
use hashbrown::HashMap; use hashbrown::{HashMap, HashSet};
use log::warn; use log::warn;
use tokenizers::{Encoding, Tokenizer}; use tokenizers::{Encoding, Tokenizer};
use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::error::SendError;
@ -13,7 +12,7 @@ use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::task::{spawn_blocking, JoinHandle}; use tokio::task::{spawn_blocking, JoinHandle};
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, info, span, Level}; use tracing::{debug, debug_span, error, info, info_span, span, Level};
use text_generation_router::infer::InferError::{GenerationError, ValidationError}; use text_generation_router::infer::InferError::{GenerationError, ValidationError};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
@ -26,32 +25,74 @@ use text_generation_router::{FinishReason, Token};
use crate::errors::TensorRtLlmBackendError; use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl}; use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
use crate::utils::first_line; use crate::utils::first_line;
use crate::RequestId;
// Value used to poll the state of the generation stream
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
// It's safe to send the backend between threads
unsafe impl Send for TensorRtLlmBackendImpl {}
type InferResult<T> = Result<T, InferError>; type InferResult<T> = Result<T, InferError>;
struct IdentifiableRequest<T> {
request_id: RequestId,
inner: T,
}
macro_rules! identifiable {
($id: expr, $inner: expr) => {
IdentifiableRequest {
id: $id,
inner: $inner,
}
};
}
/// Wrap the TGI server forwarded ValidGenerateRequest with the tokenized view of the prompt
struct ValidGenerateRequestWithTokens { struct ValidGenerateRequestWithTokens {
encoding: Encoding, encoding: Encoding,
inner: ValidGenerateRequest, inner: ValidGenerateRequest,
} }
struct DecodedTokenContext { /// Wrap the requests along with the channel used to stream back to the client the decoded tokens
tokens: Vec<GenerationStep>, struct GenerationContext {
ctx: UnboundedSender<InferResult<InferStreamResponse>>, request: ValidGenerateRequestWithTokens,
start: Instant,
queued: Option<Instant>,
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
} }
fn executor_status_poller( #[derive(Debug, Copy, Clone)]
struct DecodedToken {
id: u32,
log_prob: f32,
is_final: bool,
}
impl TryFrom<GenerationStep> for DecodedToken {
type Error = InferError;
fn try_from(step: GenerationStep) -> Result<Self, Self::Error> {
if !step.has_error {
Ok(Self {
id: step.token_id,
log_prob: step.log_prob,
is_final: step.is_final,
})
} else {
Err(GenerationError(step.error_msg))
}
}
}
/// Wraps the decoded token with the channel used to stream back to the client the decoded tokens
struct DecodedTokenContext {
token: DecodedToken,
channel: UnboundedSender<InferResult<InferStreamResponse>>,
}
fn executor_status_looper(
mut backend: UniquePtr<TensorRtLlmBackendImpl>, mut backend: UniquePtr<TensorRtLlmBackendImpl>,
mut waiting_requests: UnboundedReceiver<GenerationContext>, mut waiting_requests: UnboundedReceiver<GenerationContext>,
mut post_processor_sender: UnboundedSender<DecodedTokenContext>, mut post_processor_sender: UnboundedSender<DecodedTokenContextWithRequestId>,
) { ) {
// Track the tuple (request_id, stream) for each request // Track the tuple (request_id, stream) for each request
let mut in_flights = HashMap::<u64, GenerationContext>::with_capacity(128); let mut in_flights = HashMap::<RequestId, GenerationContext>::with_capacity(128);
// TODO: Does it need a spin-loop? // TODO: Does it need a spin-loop?
'executor: loop { 'executor: loop {
@ -60,7 +101,7 @@ fn executor_status_poller(
let awaiting_requests = waiting_requests.len(); let awaiting_requests = waiting_requests.len();
for _ in 0..awaiting_requests { for _ in 0..awaiting_requests {
// Retrieve all the requests // Retrieve all the requests
if let Some(ctx) = waiting_requests.blocking_recv() { if let Some(mut ctx) = waiting_requests.blocking_recv() {
// Submit all the request to the executor and move the context to the in-flight tracker // Submit all the request to the executor and move the context to the in-flight tracker
let request = &ctx.request; let request = &ctx.request;
let generation_params = &request.inner.parameters; let generation_params = &request.inner.parameters;
@ -79,13 +120,15 @@ fn executor_status_poller(
) { ) {
Ok(request_id) => { Ok(request_id) => {
// Insert the context linked to the generated request id in the tracker // Insert the context linked to the generated request id in the tracker
debug!("[in-flight] Added {}", request_id);
ctx.queued = Instant::now();
in_flights.insert(request_id, ctx); in_flights.insert(request_id, ctx);
} }
Err(e) => { Err(e) => {
// Return to the caller // Return to the caller
let what = Err(InferError::SchedulingError(e.to_string())); let what = Err(InferError::SchedulingError(e.to_string()));
if let Err(e) = ctx.streamer.send(what) { if let Err(ref e) = ctx.streamer.send(what) {
error!("Failed to send back through the channel: {}", e); error!("Failed to send the client", error = e.as_ref());
} }
} }
}; };
@ -93,83 +136,38 @@ fn executor_status_poller(
} }
}); });
if let Err(e) = span!(Level::DEBUG, "[in-flight][poll]").in_scope(|| { if let Err(ref e) = info_span!("[in-flight][poll]").in_scope(|| {
if backend.num_responses_ready() > 0 { if backend.num_responses_ready() > 0 {
match backend.pin_mut().pull_tokens() { let responses = backend
Ok(responses) => { .pin_mut()
debug!("Received {} tokens from the executor", responses.len()); .pull_tokens()
.map_err(|e| Err(GenerationError(e.what())))?;
// worse case scenario is one token for each response: with_capacity(responses.len()) // Iterate through all the decoded token
// grouper will group decoded tokens per request to decode multiple tokens for step in responses.deref() {
let mut grouper: HashMap<u64, DecodedTokenContext> = if let Some(ctx) = in_flights.get(&step.request_id) {
HashMap::with_capacity(responses.len()); let parcel = DecodedToken::try_from(step).map(|dt| DecodedTokenContext {
token: dt,
channel: ctx.streamer.clone(),
});
// Iterate through all the decoded token // Submit the work to the post_processor
for step in responses.deref() { let delivered = post_processor_sender.send(parcel);
match in_flights.get(&step.request_id) {
Some(ctx) => {
debug!(
"{} -> (token={}, final={})",
step.request_id, step.token_id, step.is_final
);
// If no error, let's forward to post-processor // Remove from tracked requests
if !step.has_error { if step.is_final {
let req_group = grouper.entry(step.request_id).or_insert( debug!("Removing {}", step.request_id);
DecodedTokenContext { let _ = in_flights.remove(&step.request_id);
tokens: vec![],
ctx: ctx.streamer.clone(), // Arc::clone() = cheap
},
);
req_group.tokens.push(step.clone()); // Should be ultra cheap
} else {
warn!(
"Error for request: {} -> {}",
step.request_id, &step.error_msg
);
// TODO: Send something back to the postprocessor for the client?
}
// Remove from tracked requests
if step.is_final {
let _ = in_flights.remove(&step.request_id);
}
}
None => {
if step.has_error {
error!(
"Untracked request {} -> {}",
step.request_id, &step.error_msg
);
continue;
} else {
error!(
"Got step for untracked request {}",
step.request_id
);
}
}
}
} }
grouper delivered
.into_values() } else {
.map(|ctx| post_processor_sender.send(ctx)) warn!("Untracked request {}", step.request_id,);
.collect::<Result<(), SendError<DecodedTokenContext>>>()?;
} }
Err(err) => { }?;
error!("Failed to retrieve tokens from the executor: {}", err);
}
}
} }
Ok::<(), SendError<DecodedTokenContext>>(())
}) { }) {
error!( error!("Error in the executor's loop, exiting", error = e.as_ref());
"Caught an fatal error in the executor's loop, about to exit. {}",
e
);
break 'executor; break 'executor;
} }
@ -180,7 +178,7 @@ fn executor_status_poller(
fn post_processor_looper( fn post_processor_looper(
tokenizer: Tokenizer, tokenizer: Tokenizer,
mut decoded_tokens: UnboundedReceiver<DecodedTokenContext>, mut decoded_tokens: UnboundedReceiver<DecodedTokenContextWithRequestId>,
) { ) {
'post_processor: loop { 'post_processor: loop {
if decoded_tokens.is_closed() { if decoded_tokens.is_closed() {
@ -188,56 +186,14 @@ fn post_processor_looper(
break 'post_processor; break 'post_processor;
} }
if let Some(ctx) = decoded_tokens.blocking_recv() { let mut states = HashMap::with_capacity(128);
ctx.tokens.iter().for_each(|step| {
let out = match tokenizer.decode(&[step.token_id], true) {
Ok(text) => {
let is_special = tokenizer.get_added_vocabulary().is_special_token(&text);
let token = Token {
id: step.token_id,
text,
logprob: step.log_prob,
special: is_special,
};
let response = if !step.is_final { if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() {
InferStreamResponse::Intermediate { let state = states.entry(request_id).or_insert(vec![]);
token,
top_tokens: vec![],
}
} else {
InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text: GeneratedText {
text: String::from(""),
generated_tokens: 0,
finish_reason: FinishReason::Length,
seed: None,
},
start: Instant::now(), // Handle start time
queued: Instant::now(), // Handle queued time
}
};
Ok(response)
}
Err(e) => Err(GenerationError(e.to_string())),
};
if let Err(e) = ctx.ctx.send(out) {
warn!("Failed to send back the decoded tokens: {}", e);
};
});
} }
} }
} }
struct GenerationContext {
request: ValidGenerateRequestWithTokens,
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
}
pub struct TensorRtLlmBackendV2 { pub struct TensorRtLlmBackendV2 {
tokenizer: Tokenizer, tokenizer: Tokenizer,
executor_looper: JoinHandle<()>, executor_looper: JoinHandle<()>,
@ -277,7 +233,7 @@ impl TensorRtLlmBackendV2 {
// Executor looper is responsible for scheduling and pulling requests state at regular interval // Executor looper is responsible for scheduling and pulling requests state at regular interval
let executor_looper = spawn_blocking(move || { let executor_looper = spawn_blocking(move || {
executor_status_poller(backend, executor_receiver, post_processor_sender) executor_status_looper(backend, executor_receiver, post_processor_sender)
}); });
// Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user // Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user
@ -295,22 +251,22 @@ impl TensorRtLlmBackendV2 {
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> { fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
if request.top_n_tokens > 1 { if request.top_n_tokens > 1 {
return Err(InferError::ValidationError(TopNTokensDisabled)); return Err(ValidationError(TopNTokensDisabled));
} }
// TODO: Is it really needed? How can it be validated before? // TODO: Is it really needed? How can it be validated before?
if request.parameters.grammar.is_some() { if request.parameters.grammar.is_some() {
return Err(InferError::ValidationError(Grammar)); return Err(ValidationError(Grammar));
} }
match request.inputs.len() { match request.inputs.len() {
0 => Err(InferError::ValidationError(EmptyInput)), 0 => Err(ValidationError(EmptyInput)),
2.. => Err(InferError::GenerationError( 2.. => Err(GenerationError(
"TensorRT-LLM backend don't support multi-chunk".into(), "TensorRT-LLM backend don't support multi-chunk".into(),
)), )),
1 => match request.inputs.first().expect("Single item-chunk") { 1 => match request.inputs.first().expect("Single item-chunk") {
Chunk::Text(text) => Ok(text), Chunk::Text(text) => Ok(text),
Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))), Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
}, },
} }
} }
@ -336,7 +292,13 @@ impl Backend for TensorRtLlmBackendV2 {
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>(); let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
// Send the context to the executor for scheduling // Send the context to the executor for scheduling
match self.executor.send(GenerationContext { request, streamer }) { let start = Instant::now();
match self.executor.send(GenerationContext {
request,
start,
queued: None,
streamer,
}) {
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)), Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
Err(_) => Err(GenerationError( Err(_) => Err(GenerationError(
"Failed to submit request to the backend".into(), "Failed to submit request to the backend".into(),