(scheduler) rework submit/pull logic

This commit is contained in:
Morgan Funtowicz 2024-08-26 13:39:20 +00:00 committed by Morgan Funtowicz
parent 42ccf4e77c
commit fa63db0d07
2 changed files with 89 additions and 102 deletions

View File

@ -4,17 +4,14 @@ pub mod errors;
mod looper; mod looper;
mod utils; mod utils;
pub(crate) type RequestId = u64;
pub(crate) type TokenId = u32;
#[cxx::bridge(namespace = "huggingface::tgi::backends")] #[cxx::bridge(namespace = "huggingface::tgi::backends")]
mod ffi { mod ffi {
/// Struct used as shared type between rust and C++ to represent the result /// Struct used as shared type between rust and C++ to represent the result
/// of a single decoding iteration /// of a single decoding iteration
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct GenerationStep { pub struct GenerationStep {
request_id: RequestId, request_id: u64,
token_id: TokenId, token_id: u32,
log_prob: f32, log_prob: f32,
is_final: bool, is_final: bool,
has_error: bool, has_error: bool,
@ -53,7 +50,7 @@ mod ffi {
#[rust_name = "submit"] #[rust_name = "submit"]
fn Submit( fn Submit(
self: Pin<&mut TensorRtLlmBackendImpl>, self: Pin<&mut TensorRtLlmBackendImpl>,
tokens: &[TokenId], tokens: &[u32],
max_new_tokens: u32, max_new_tokens: u32,
top_k: i32, top_k: i32,
top_p: f32, top_p: f32,
@ -68,4 +65,5 @@ mod ffi {
self: Pin<&mut TensorRtLlmBackendImpl>, self: Pin<&mut TensorRtLlmBackendImpl>,
) -> Result<UniquePtr<CxxVector<GenerationStep>>>; ) -> Result<UniquePtr<CxxVector<GenerationStep>>>;
} }
} }

View File

@ -3,46 +3,34 @@ use std::ops::Deref;
use std::path::Path; use std::path::Path;
use async_trait::async_trait; use async_trait::async_trait;
use cxx::UniquePtr; use cxx::{UniquePtr};
use hashbrown::{HashMap, HashSet}; use hashbrown::{HashMap};
use log::warn; use log::warn;
use tokenizers::{Encoding, Tokenizer}; use tokenizers::{Encoding, Tokenizer};
use tokio::sync::mpsc::error::SendError;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::task::{spawn_blocking, JoinHandle}; use tokio::task::{spawn_blocking, JoinHandle};
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, debug_span, error, info, info_span, span, Level}; use tracing::{debug, error};
use text_generation_router::infer::InferError::{GenerationError, ValidationError}; use text_generation_router::infer::InferError::{GenerationError, ValidationError};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
use text_generation_router::validation::ValidationError::{ use text_generation_router::validation::ValidationError::{
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality, EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
}; };
use text_generation_router::validation::{Chunk, ValidGenerateRequest}; use text_generation_router::validation::{Chunk, ValidGenerateRequest};
use text_generation_router::{FinishReason, Token};
use crate::errors::TensorRtLlmBackendError; use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl}; use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
use crate::utils::first_line; use crate::utils::first_line;
use crate::RequestId;
type InferResult<T> = Result<T, InferError>; type InferResult<T> = Result<T, InferError>;
struct IdentifiableRequest<T> { struct IdentifiableRequest<T> {
request_id: RequestId, request_id: u64,
inner: T, inner: T,
} }
macro_rules! identifiable {
($id: expr, $inner: expr) => {
IdentifiableRequest {
id: $id,
inner: $inner,
}
};
}
/// Wrap the TGI server forwarded ValidGenerateRequest with the tokenized view of the prompt /// Wrap the TGI server forwarded ValidGenerateRequest with the tokenized view of the prompt
struct ValidGenerateRequestWithTokens { struct ValidGenerateRequestWithTokens {
encoding: Encoding, encoding: Encoding,
@ -52,8 +40,8 @@ struct ValidGenerateRequestWithTokens {
/// Wrap the requests along with the channel used to stream back to the client the decoded tokens /// Wrap the requests along with the channel used to stream back to the client the decoded tokens
struct GenerationContext { struct GenerationContext {
request: ValidGenerateRequestWithTokens, request: ValidGenerateRequestWithTokens,
start: Instant, start: Option<Instant>,
queued: Option<Instant>, queued: Instant,
streamer: UnboundedSender<InferResult<InferStreamResponse>>, streamer: UnboundedSender<InferResult<InferStreamResponse>>,
} }
@ -64,10 +52,10 @@ struct DecodedToken {
is_final: bool, is_final: bool,
} }
impl TryFrom<GenerationStep> for DecodedToken { impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
type Error = InferError; type Error = InferError;
fn try_from(step: GenerationStep) -> Result<Self, Self::Error> { fn try_from(step: &'step GenerationStep) -> Result<Self, Self::Error> {
if !step.has_error { if !step.has_error {
Ok(Self { Ok(Self {
id: step.token_id, id: step.token_id,
@ -75,7 +63,7 @@ impl TryFrom<GenerationStep> for DecodedToken {
is_final: step.is_final, is_final: step.is_final,
}) })
} else { } else {
Err(GenerationError(step.error_msg)) Err(GenerationError(step.error_msg.clone()))
} }
} }
} }
@ -89,86 +77,84 @@ struct DecodedTokenContext {
fn executor_status_looper( fn executor_status_looper(
mut backend: UniquePtr<TensorRtLlmBackendImpl>, mut backend: UniquePtr<TensorRtLlmBackendImpl>,
mut waiting_requests: UnboundedReceiver<GenerationContext>, mut waiting_requests: UnboundedReceiver<GenerationContext>,
mut post_processor_sender: UnboundedSender<DecodedTokenContextWithRequestId>, post_processor_sender: UnboundedSender<(u64, InferResult<DecodedTokenContext>)>,
) { ) {
// Track the tuple (request_id, stream) for each request // Track the tuple (request_id, stream) for each request
let mut in_flights = HashMap::<RequestId, GenerationContext>::with_capacity(128); let mut in_flights = HashMap::<u64, GenerationContext>::with_capacity(128);
// TODO: Does it need a spin-loop? // TODO: Does it need a spin-loop?
'executor: loop { 'scheduler: loop {
span!(Level::DEBUG, "[in-flight][submit]").in_scope(|| { // Is there any request pending to be scheduled?
// Is there any request pending to be scheduled? let awaiting_requests = waiting_requests.len();
let awaiting_requests = waiting_requests.len(); for _ in 0..awaiting_requests {
for _ in 0..awaiting_requests { // Retrieve all the requests
// Retrieve all the requests if let Some(mut ctx) = waiting_requests.blocking_recv() {
if let Some(mut ctx) = waiting_requests.blocking_recv() { // Submit all the request to the executor and move the context to the in-flight tracker
// Submit all the request to the executor and move the context to the in-flight tracker let request = &ctx.request;
let request = &ctx.request; let generation_params = &request.inner.parameters;
let generation_params = &request.inner.parameters; let stopping_params = &request.inner.stopping_parameters;
let stopping_params = &request.inner.stopping_parameters;
// Submit to the TensorRT-LLM executor for scheduling // Submit to the TensorRT-LLM executor for scheduling
match backend.pin_mut().submit( match backend.pin_mut().submit(
request.encoding.get_ids(), request.encoding.get_ids(),
stopping_params.max_new_tokens, stopping_params.max_new_tokens,
generation_params.top_k as i32, generation_params.top_k as i32,
generation_params.top_p, generation_params.top_p,
generation_params.temperature, generation_params.temperature,
generation_params.repetition_penalty, generation_params.repetition_penalty,
generation_params.frequency_penalty, generation_params.frequency_penalty,
generation_params.seed, generation_params.seed,
) { ) {
Ok(request_id) => { Ok(request_id) => {
// Insert the context linked to the generated request id in the tracker // Insert the context linked to the generated request id in the tracker
debug!("[in-flight] Added {}", request_id); debug!("[in-flight] Added {}", request_id);
ctx.queued = Instant::now(); ctx.start = Some(Instant::now());
in_flights.insert(request_id, ctx); in_flights.insert(request_id, ctx);
}
Err(e) => {
// Return to the caller
let what = e.to_string();
error!(error = what.as_str(), "Failed to schedule request");
let err = Err(InferError::SchedulingError(what));
if let Err(_) = ctx.streamer.send(err) {
error!("Failed to send back error to the client");
} }
Err(e) => { }
// Return to the caller };
let what = Err(InferError::SchedulingError(e.to_string())); }
if let Err(ref e) = ctx.streamer.send(what) { }
error!("Failed to send the client", error = e.as_ref());
if backend.num_responses_ready() > 0 {
match backend.pin_mut().pull_tokens() {
Ok(responses) => {
// Iterate through all the decoded token
for step in responses.deref() {
if let Some(ctx) = in_flights.get(&step.request_id) {
// Remove from tracked requests
let parcel = DecodedToken::try_from(step).map(|dt| DecodedTokenContext {
token: dt,
channel: ctx.streamer.clone(),
});
// Submit the work to p:the post_processor
let posted = post_processor_sender.send((step.request_id, parcel));
if posted.is_err() || step.is_final {
debug!("Removing {}", step.request_id);
let _ = in_flights.remove(&step.request_id);
} }
} else {
warn!("Untracked request {}", step.request_id,);
} }
}; };
} }
Err(ref err) => {
error!("Failed to get responses from the executor: {}.", err.what());
break 'scheduler;
}
} }
});
if let Err(ref e) = info_span!("[in-flight][poll]").in_scope(|| {
if backend.num_responses_ready() > 0 {
let responses = backend
.pin_mut()
.pull_tokens()
.map_err(|e| Err(GenerationError(e.what())))?;
// Iterate through all the decoded token
for step in responses.deref() {
if let Some(ctx) = in_flights.get(&step.request_id) {
let parcel = DecodedToken::try_from(step).map(|dt| DecodedTokenContext {
token: dt,
channel: ctx.streamer.clone(),
});
// Submit the work to the post_processor
let delivered = post_processor_sender.send(parcel);
// Remove from tracked requests
if step.is_final {
debug!("Removing {}", step.request_id);
let _ = in_flights.remove(&step.request_id);
}
delivered
} else {
warn!("Untracked request {}", step.request_id,);
}
}?;
}
}) {
error!("Error in the executor's loop, exiting", error = e.as_ref());
break 'executor;
} }
// Hint the CPU we are spin-locking // Hint the CPU we are spin-locking
@ -178,7 +164,7 @@ fn executor_status_looper(
fn post_processor_looper( fn post_processor_looper(
tokenizer: Tokenizer, tokenizer: Tokenizer,
mut decoded_tokens: UnboundedReceiver<DecodedTokenContextWithRequestId>, mut decoded_tokens: UnboundedReceiver<(u64, InferResult<DecodedTokenContext>)>,
) { ) {
'post_processor: loop { 'post_processor: loop {
if decoded_tokens.is_closed() { if decoded_tokens.is_closed() {
@ -186,7 +172,7 @@ fn post_processor_looper(
break 'post_processor; break 'post_processor;
} }
let mut states = HashMap::with_capacity(128); let mut states: HashMap<u64, Vec<u32>> = HashMap::with_capacity(128);
if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() { if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() {
let state = states.entry(request_id).or_insert(vec![]); let state = states.entry(request_id).or_insert(vec![]);
@ -194,6 +180,9 @@ fn post_processor_looper(
} }
} }
unsafe impl Send for crate::ffi::TensorRtLlmBackendImpl {}
pub struct TensorRtLlmBackendV2 { pub struct TensorRtLlmBackendV2 {
tokenizer: Tokenizer, tokenizer: Tokenizer,
executor_looper: JoinHandle<()>, executor_looper: JoinHandle<()>,
@ -292,11 +281,11 @@ impl Backend for TensorRtLlmBackendV2 {
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>(); let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
// Send the context to the executor for scheduling // Send the context to the executor for scheduling
let start = Instant::now(); let queued = Instant::now();
match self.executor.send(GenerationContext { match self.executor.send(GenerationContext {
request, request,
start, start: None,
queued: None, queued,
streamer, streamer,
}) { }) {
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)), Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),