feat(backend): make sure we can easily cancel request on the executor

This commit is contained in:
Morgan Funtowicz 2024-12-05 13:54:56 +01:00
parent 300f6c6f94
commit b3cd5ea076

View File

@ -1,10 +1,9 @@
use std::hint;
use std::ops::Deref;
use std::path::Path;
use async_trait::async_trait; use async_trait::async_trait;
use cxx::UniquePtr; use cxx::UniquePtr;
use hashbrown::HashMap; use hashbrown::HashMap;
use std::hint;
use std::ops::Deref;
use std::path::Path;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::TryAcquireError; use tokio::sync::TryAcquireError;
@ -30,9 +29,10 @@ type InferResult<T> = Result<T, InferError>;
/// Wrap the requests along with the channel used to stream back to the client the decoded tokens /// Wrap the requests along with the channel used to stream back to the client the decoded tokens
struct GenerationContext { struct GenerationContext {
request: ValidGenerateRequest, request: ValidGenerateRequest,
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
tokens: Vec<u32>,
start: Option<Instant>, start: Option<Instant>,
queued: Instant, queued: Instant,
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
} }
#[derive(Debug, Copy, Clone)] #[derive(Debug, Copy, Clone)]
@ -58,31 +58,22 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
} }
} }
/// Wraps the decoded token with the channel used to stream back to the client the decoded tokens
struct DecodedTokenContext {
token: DecodedToken,
start: Option<Instant>,
queued: Instant,
channel: UnboundedSender<InferResult<InferStreamResponse>>,
}
fn executor_status_looper( fn executor_status_looper(
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
max_inflight_requests: usize, max_inflight_requests: usize,
mut waiting_requests: UnboundedReceiver<GenerationContext>, tokenizer: Tokenizer,
post_processor_sender: UnboundedSender<(u64, InferResult<DecodedTokenContext>)>, mut backend: UniquePtr<TensorRtLlmBackendImpl>,
mut backlog: UnboundedReceiver<GenerationContext>,
) { ) {
// Track the tuple (request_id, stream) for each request // Track the tuple (request_id, stream) for each request
let mut in_flights = let mut in_flights =
HashMap::<u64, GenerationContext>::with_capacity(max_inflight_requests * 2); HashMap::<u64, GenerationContext>::with_capacity(max_inflight_requests * 2);
// TODO: Does it need a spin-loop?
'scheduler: loop { 'scheduler: loop {
// Is there any request pending to be scheduled? // Is there any request pending to be scheduled?
let awaiting_requests = waiting_requests.len(); let awaiting_requests = backlog.len();
for _ in 0..awaiting_requests { for _ in 0..awaiting_requests {
// Retrieve all the requests // Retrieve all the requests
if let Some(mut ctx) = waiting_requests.blocking_recv() { if let Some(ctx) = backlog.blocking_recv() {
// Submit all the request to the executor and move the context to the in-flight tracker // Submit all the request to the executor and move the context to the in-flight tracker
let request = &ctx.request; let request = &ctx.request;
let generation_params = &request.parameters; let generation_params = &request.parameters;
@ -103,7 +94,6 @@ fn executor_status_looper(
Ok(request_id) => { Ok(request_id) => {
// Insert the context linked to the generated request id in the tracker // Insert the context linked to the generated request id in the tracker
debug!("[in-flight] Added {}", request_id); debug!("[in-flight] Added {}", request_id);
ctx.start = Some(Instant::now());
in_flights.insert(request_id, ctx); in_flights.insert(request_id, ctx);
} }
Err(e) => { Err(e) => {
@ -117,6 +107,8 @@ fn executor_status_looper(
} }
} }
}; };
} else {
break 'scheduler;
} }
} }
@ -125,21 +117,28 @@ fn executor_status_looper(
Ok(responses) => { Ok(responses) => {
// Iterate through all the decoded token // Iterate through all the decoded token
for step in responses.deref() { for step in responses.deref() {
if let Some(ctx) = in_flights.get(&step.request_id) { if let Some(ctx) = in_flights.get_mut(&step.request_id) {
// Remove from tracked requests // Update the starting timestamp if not set
let parcel = // This value might not be the actual real starting time of the request
DecodedToken::try_from(step).map(|dt| DecodedTokenContext { // on the executor side - Need to expose more info from the executor to
token: dt, // retrieve this value
start: ctx.start, // TODO : Expose actual real starting time for a request on FFI layer
queued: ctx.queued, if ctx.start.is_none() {
channel: ctx.streamer.clone(), ctx.start = Some(Instant::now());
}); }
// Submit the work to p:the post_processor // Try to map the generation step to a DecodedToken
let posted = post_processor_sender.send((step.request_id, parcel)); let response = match DecodedToken::try_from(step) {
Ok(decoded_token) => {
post_process_decoded_token(&tokenizer, ctx, decoded_token)
}
Err(err) => Err(err)
};
if posted.is_err() || step.is_final { // Attempt to send back the response to the client
debug!("Removing {}", step.request_id); if let Err(_) = ctx.streamer.send(response) {
// Client has dropped, remove from tracked requests
debug!("Client dropped - removing request {} from tracked requests", step.request_id);
backend.pin_mut().cancel(step.request_id); backend.pin_mut().cancel(step.request_id);
let _ = in_flights.remove(&step.request_id); let _ = in_flights.remove(&step.request_id);
} }
@ -160,54 +159,33 @@ fn executor_status_looper(
} }
} }
fn post_processor_looper<const MAX_NUM_TOKENS: usize>( fn post_process_decoded_token(tokenizer: &Tokenizer, ctx: &mut GenerationContext, decoded_token: DecodedToken) -> InferResult<InferStreamResponse> {
tokenizer: Tokenizer, match tokenizer.decode(&[decoded_token.id], false) {
max_inflight_requests: usize,
mut decoded_tokens: UnboundedReceiver<(u64, InferResult<DecodedTokenContext>)>,
) {
let mut states: HashMap<u64, Vec<u32>> = HashMap::with_capacity(max_inflight_requests * 2);
'post_processor: loop {
if decoded_tokens.is_closed() {
warn!("Post processor IPC is closed, loop will exit now.");
break 'post_processor;
}
if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() {
match decoded {
Ok(ctx) => {
states
.entry(request_id)
.and_modify(|s| s.push(*&ctx.token.id))
.or_insert_with(|| {
let mut state = Vec::with_capacity(MAX_NUM_TOKENS);
state.push(*&ctx.token.id);
state
});
let out = match tokenizer.decode(&[ctx.token.id], false) {
Ok(text) => { Ok(text) => {
let is_special = let is_special =
tokenizer.get_added_vocabulary().is_special_token(&text); tokenizer.get_added_vocabulary().is_special_token(&text);
let token = Token { let token = Token {
id: ctx.token.id, id: decoded_token.id,
text, text,
logprob: ctx.token.log_prob, logprob: decoded_token.log_prob,
special: is_special, special: is_special,
}; };
let out = if !ctx.token.is_final { // Append the token to the tracked generated tokens
ctx.tokens.push(token.id);
// Map the correct response depending on the step is final or not
let out = if !decoded_token.is_final {
InferStreamResponse::Intermediate { InferStreamResponse::Intermediate {
token, token,
top_tokens: vec![], top_tokens: vec![],
} }
} else { } else {
let tokens = states.remove(&request_id).unwrap(); let text = tokenizer.decode(&ctx.tokens, true);
let text = tokenizer.decode(&tokens, true);
let generated_text = GeneratedText { let generated_text = GeneratedText {
text: text.unwrap(), text: text.unwrap(),
generated_tokens: tokens.len() as u32, generated_tokens: ctx.tokens.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken, finish_reason: FinishReason::EndOfSequenceToken, // TODO : Map FinishReason
seed: None, seed: None,
}; };
@ -223,17 +201,6 @@ fn post_processor_looper<const MAX_NUM_TOKENS: usize>(
Ok(out) Ok(out)
} }
Err(err) => Err(GenerationError(err.to_string())), Err(err) => Err(GenerationError(err.to_string())),
};
if let Err(_) = ctx.channel.send(out) {
warn!("Failed to send decoded token back to the user")
}
}
Err(_err) => {
todo!("what do we do?")
}
}
}
} }
} }
@ -278,11 +245,8 @@ fn ensure_paths_exist<P: AsRef<Path>, PP: AsRef<Path>>(
unsafe impl Send for TensorRtLlmBackendImpl {} unsafe impl Send for TensorRtLlmBackendImpl {}
pub struct TensorRtLlmBackendV2 { pub struct TensorRtLlmBackendV2(UnboundedSender<GenerationContext>);
executor_looper: JoinHandle<()>,
post_processor_looper: JoinHandle<()>,
executor: UnboundedSender<GenerationContext>,
}
impl TensorRtLlmBackendV2 { impl TensorRtLlmBackendV2 {
pub fn new<P: AsRef<Path> + Send, PP: AsRef<Path> + Send>( pub fn new<P: AsRef<Path> + Send, PP: AsRef<Path> + Send>(
@ -296,32 +260,22 @@ impl TensorRtLlmBackendV2 {
// Allocate the IPC layer to communicate with the backend // Allocate the IPC layer to communicate with the backend
let (executor_sender, executor_receiver) = unbounded_channel(); let (executor_sender, executor_receiver) = unbounded_channel();
let (post_processor_sender, post_processor_receiver) = unbounded_channel();
// Create the FFI backend // Create the FFI backend
let backend = create_backend_from_engine_folder(&engine_folder, &executor_worker_path) let backend = create_backend_from_engine_folder(&engine_folder, &executor_worker_path)
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?; .map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
// Executor looper is responsible for scheduling and pulling requests state at regular interval // Executor looper is responsible for scheduling and pulling requests state at regular interval
let executor_looper = spawn_blocking(move || { spawn_blocking(move || {
executor_status_looper( executor_status_looper(
backend,
max_inflight_requests, max_inflight_requests,
tokenizer,
backend,
executor_receiver, executor_receiver,
post_processor_sender,
) )
}); });
// Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user Ok(TensorRtLlmBackendV2(executor_sender))
let post_processor_looper = spawn_blocking(move || {
post_processor_looper::<256>(tokenizer, max_inflight_requests, post_processor_receiver)
});
Ok(TensorRtLlmBackendV2 {
executor_looper,
post_processor_looper,
executor: executor_sender,
})
} }
fn validate(request: &ValidGenerateRequest) -> InferResult<()> { fn validate(request: &ValidGenerateRequest) -> InferResult<()> {
@ -355,20 +309,21 @@ impl TensorRtLlmBackendV2 {
impl Backend for TensorRtLlmBackendV2 { impl Backend for TensorRtLlmBackendV2 {
fn schedule( fn schedule(
&self, &self,
inner: ValidGenerateRequest, request: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> { ) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
Self::validate(&inner)?; Self::validate(&request)?;
// Open-up the stream to send tokens // Open-up the stream to send tokens
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>(); let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
// Send the context to the executor for scheduling // Send the context to the executor for scheduling
let queued = Instant::now(); let queued = Instant::now();
match self.executor.send(GenerationContext { match self.send(GenerationContext {
request: inner, request,
streamer,
tokens: Vec::with_capacity(256),
start: None, start: None,
queued, queued,
streamer,
}) { }) {
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)), Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
Err(_) => Err(GenerationError( Err(_) => Err(GenerationError(
@ -378,6 +333,6 @@ impl Backend for TensorRtLlmBackendV2 {
} }
async fn health(&self, _: bool) -> bool { async fn health(&self, _: bool) -> bool {
!self.executor_looper.is_finished() & !self.post_processor_looper.is_finished() true
} }
} }