feat(backend): make sure we can easily cancel request on the executor

This commit is contained in:
Morgan Funtowicz 2024-12-05 13:54:56 +01:00
parent 300f6c6f94
commit b3cd5ea076

View File

@ -1,10 +1,9 @@
use std::hint;
use std::ops::Deref;
use std::path::Path;
use async_trait::async_trait;
use cxx::UniquePtr;
use hashbrown::HashMap;
use std::hint;
use std::ops::Deref;
use std::path::Path;
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::TryAcquireError;
@ -30,9 +29,10 @@ type InferResult<T> = Result<T, InferError>;
/// Wrap the requests along with the channel used to stream back to the client the decoded tokens
struct GenerationContext {
request: ValidGenerateRequest,
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
tokens: Vec<u32>,
start: Option<Instant>,
queued: Instant,
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
}
#[derive(Debug, Copy, Clone)]
@ -58,31 +58,22 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
}
}
/// Wraps the decoded token with the channel used to stream back to the client the decoded tokens
struct DecodedTokenContext {
token: DecodedToken,
start: Option<Instant>,
queued: Instant,
channel: UnboundedSender<InferResult<InferStreamResponse>>,
}
fn executor_status_looper(
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
max_inflight_requests: usize,
mut waiting_requests: UnboundedReceiver<GenerationContext>,
post_processor_sender: UnboundedSender<(u64, InferResult<DecodedTokenContext>)>,
tokenizer: Tokenizer,
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
mut backlog: UnboundedReceiver<GenerationContext>,
) {
// Track the tuple (request_id, stream) for each request
let mut in_flights =
HashMap::<u64, GenerationContext>::with_capacity(max_inflight_requests * 2);
// TODO: Does it need a spin-loop?
'scheduler: loop {
// Is there any request pending to be scheduled?
let awaiting_requests = waiting_requests.len();
let awaiting_requests = backlog.len();
for _ in 0..awaiting_requests {
// Retrieve all the requests
if let Some(mut ctx) = waiting_requests.blocking_recv() {
if let Some(ctx) = backlog.blocking_recv() {
// Submit all the request to the executor and move the context to the in-flight tracker
let request = &ctx.request;
let generation_params = &request.parameters;
@ -103,7 +94,6 @@ fn executor_status_looper(
Ok(request_id) => {
// Insert the context linked to the generated request id in the tracker
debug!("[in-flight] Added {}", request_id);
ctx.start = Some(Instant::now());
in_flights.insert(request_id, ctx);
}
Err(e) => {
@ -117,6 +107,8 @@ fn executor_status_looper(
}
}
};
} else {
break 'scheduler;
}
}
@ -125,21 +117,28 @@ fn executor_status_looper(
Ok(responses) => {
// Iterate through all the decoded token
for step in responses.deref() {
if let Some(ctx) = in_flights.get(&step.request_id) {
// Remove from tracked requests
let parcel =
DecodedToken::try_from(step).map(|dt| DecodedTokenContext {
token: dt,
start: ctx.start,
queued: ctx.queued,
channel: ctx.streamer.clone(),
});
if let Some(ctx) = in_flights.get_mut(&step.request_id) {
// Update the starting timestamp if not set
// This value might not be the actual real starting time of the request
// on the executor side - Need to expose more info from the executor to
// retrieve this value
// TODO : Expose actual real starting time for a request on FFI layer
if ctx.start.is_none() {
ctx.start = Some(Instant::now());
}
// Submit the work to p:the post_processor
let posted = post_processor_sender.send((step.request_id, parcel));
// Try to map the generation step to a DecodedToken
let response = match DecodedToken::try_from(step) {
Ok(decoded_token) => {
post_process_decoded_token(&tokenizer, ctx, decoded_token)
}
Err(err) => Err(err)
};
if posted.is_err() || step.is_final {
debug!("Removing {}", step.request_id);
// Attempt to send back the response to the client
if let Err(_) = ctx.streamer.send(response) {
// Client has dropped, remove from tracked requests
debug!("Client dropped - removing request {} from tracked requests", step.request_id);
backend.pin_mut().cancel(step.request_id);
let _ = in_flights.remove(&step.request_id);
}
@ -160,80 +159,48 @@ fn executor_status_looper(
}
}
fn post_processor_looper<const MAX_NUM_TOKENS: usize>(
tokenizer: Tokenizer,
max_inflight_requests: usize,
mut decoded_tokens: UnboundedReceiver<(u64, InferResult<DecodedTokenContext>)>,
) {
let mut states: HashMap<u64, Vec<u32>> = HashMap::with_capacity(max_inflight_requests * 2);
fn post_process_decoded_token(tokenizer: &Tokenizer, ctx: &mut GenerationContext, decoded_token: DecodedToken) -> InferResult<InferStreamResponse> {
match tokenizer.decode(&[decoded_token.id], false) {
Ok(text) => {
let is_special =
tokenizer.get_added_vocabulary().is_special_token(&text);
let token = Token {
id: decoded_token.id,
text,
logprob: decoded_token.log_prob,
special: is_special,
};
'post_processor: loop {
if decoded_tokens.is_closed() {
warn!("Post processor IPC is closed, loop will exit now.");
break 'post_processor;
}
// Append the token to the tracked generated tokens
ctx.tokens.push(token.id);
if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() {
match decoded {
Ok(ctx) => {
states
.entry(request_id)
.and_modify(|s| s.push(*&ctx.token.id))
.or_insert_with(|| {
let mut state = Vec::with_capacity(MAX_NUM_TOKENS);
state.push(*&ctx.token.id);
state
});
let out = match tokenizer.decode(&[ctx.token.id], false) {
Ok(text) => {
let is_special =
tokenizer.get_added_vocabulary().is_special_token(&text);
let token = Token {
id: ctx.token.id,
text,
logprob: ctx.token.log_prob,
special: is_special,
};
let out = if !ctx.token.is_final {
InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
}
} else {
let tokens = states.remove(&request_id).unwrap();
let text = tokenizer.decode(&tokens, true);
let generated_text = GeneratedText {
text: text.unwrap(),
generated_tokens: tokens.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken,
seed: None,
};
InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text,
start: ctx.start.unwrap(),
queued: ctx.queued,
}
};
Ok(out)
}
Err(err) => Err(GenerationError(err.to_string())),
};
if let Err(_) = ctx.channel.send(out) {
warn!("Failed to send decoded token back to the user")
}
// Map the correct response depending on the step is final or not
let out = if !decoded_token.is_final {
InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
}
Err(_err) => {
todo!("what do we do?")
} else {
let text = tokenizer.decode(&ctx.tokens, true);
let generated_text = GeneratedText {
text: text.unwrap(),
generated_tokens: ctx.tokens.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken, // TODO : Map FinishReason
seed: None,
};
InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text,
start: ctx.start.unwrap(),
queued: ctx.queued,
}
}
};
Ok(out)
}
Err(err) => Err(GenerationError(err.to_string())),
}
}
@ -278,11 +245,8 @@ fn ensure_paths_exist<P: AsRef<Path>, PP: AsRef<Path>>(
unsafe impl Send for TensorRtLlmBackendImpl {}
pub struct TensorRtLlmBackendV2 {
executor_looper: JoinHandle<()>,
post_processor_looper: JoinHandle<()>,
executor: UnboundedSender<GenerationContext>,
}
pub struct TensorRtLlmBackendV2(UnboundedSender<GenerationContext>);
impl TensorRtLlmBackendV2 {
pub fn new<P: AsRef<Path> + Send, PP: AsRef<Path> + Send>(
@ -296,32 +260,22 @@ impl TensorRtLlmBackendV2 {
// Allocate the IPC layer to communicate with the backend
let (executor_sender, executor_receiver) = unbounded_channel();
let (post_processor_sender, post_processor_receiver) = unbounded_channel();
// Create the FFI backend
let backend = create_backend_from_engine_folder(&engine_folder, &executor_worker_path)
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
// Executor looper is responsible for scheduling and pulling requests state at regular interval
let executor_looper = spawn_blocking(move || {
spawn_blocking(move || {
executor_status_looper(
backend,
max_inflight_requests,
tokenizer,
backend,
executor_receiver,
post_processor_sender,
)
});
// Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user
let post_processor_looper = spawn_blocking(move || {
post_processor_looper::<256>(tokenizer, max_inflight_requests, post_processor_receiver)
});
Ok(TensorRtLlmBackendV2 {
executor_looper,
post_processor_looper,
executor: executor_sender,
})
Ok(TensorRtLlmBackendV2(executor_sender))
}
fn validate(request: &ValidGenerateRequest) -> InferResult<()> {
@ -355,20 +309,21 @@ impl TensorRtLlmBackendV2 {
impl Backend for TensorRtLlmBackendV2 {
fn schedule(
&self,
inner: ValidGenerateRequest,
request: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
Self::validate(&inner)?;
Self::validate(&request)?;
// Open-up the stream to send tokens
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
// Send the context to the executor for scheduling
let queued = Instant::now();
match self.executor.send(GenerationContext {
request: inner,
match self.send(GenerationContext {
request,
streamer,
tokens: Vec::with_capacity(256),
start: None,
queued,
streamer,
}) {
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
Err(_) => Err(GenerationError(
@ -378,6 +333,6 @@ impl Backend for TensorRtLlmBackendV2 {
}
async fn health(&self, _: bool) -> bool {
!self.executor_looper.is_finished() & !self.post_processor_looper.is_finished()
true
}
}