mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
(scheduler) rework submit/pull logic
This commit is contained in:
parent
42ccf4e77c
commit
fa63db0d07
@ -4,17 +4,14 @@ pub mod errors;
|
||||
mod looper;
|
||||
mod utils;
|
||||
|
||||
pub(crate) type RequestId = u64;
|
||||
pub(crate) type TokenId = u32;
|
||||
|
||||
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
|
||||
mod ffi {
|
||||
/// Struct used as shared type between rust and C++ to represent the result
|
||||
/// of a single decoding iteration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GenerationStep {
|
||||
request_id: RequestId,
|
||||
token_id: TokenId,
|
||||
request_id: u64,
|
||||
token_id: u32,
|
||||
log_prob: f32,
|
||||
is_final: bool,
|
||||
has_error: bool,
|
||||
@ -53,7 +50,7 @@ mod ffi {
|
||||
#[rust_name = "submit"]
|
||||
fn Submit(
|
||||
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||
tokens: &[TokenId],
|
||||
tokens: &[u32],
|
||||
max_new_tokens: u32,
|
||||
top_k: i32,
|
||||
top_p: f32,
|
||||
@ -68,4 +65,5 @@ mod ffi {
|
||||
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||
) -> Result<UniquePtr<CxxVector<GenerationStep>>>;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -3,46 +3,34 @@ use std::ops::Deref;
|
||||
use std::path::Path;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use cxx::UniquePtr;
|
||||
use hashbrown::{HashMap, HashSet};
|
||||
use cxx::{UniquePtr};
|
||||
use hashbrown::{HashMap};
|
||||
use log::warn;
|
||||
use tokenizers::{Encoding, Tokenizer};
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
||||
use tokio::task::{spawn_blocking, JoinHandle};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{debug, debug_span, error, info, info_span, span, Level};
|
||||
use tracing::{debug, error};
|
||||
|
||||
use text_generation_router::infer::InferError::{GenerationError, ValidationError};
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::ValidationError::{
|
||||
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
|
||||
};
|
||||
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
|
||||
use text_generation_router::{FinishReason, Token};
|
||||
|
||||
use crate::errors::TensorRtLlmBackendError;
|
||||
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
|
||||
use crate::utils::first_line;
|
||||
use crate::RequestId;
|
||||
|
||||
type InferResult<T> = Result<T, InferError>;
|
||||
|
||||
struct IdentifiableRequest<T> {
|
||||
request_id: RequestId,
|
||||
request_id: u64,
|
||||
inner: T,
|
||||
}
|
||||
|
||||
macro_rules! identifiable {
|
||||
($id: expr, $inner: expr) => {
|
||||
IdentifiableRequest {
|
||||
id: $id,
|
||||
inner: $inner,
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Wrap the TGI server forwarded ValidGenerateRequest with the tokenized view of the prompt
|
||||
struct ValidGenerateRequestWithTokens {
|
||||
encoding: Encoding,
|
||||
@ -52,8 +40,8 @@ struct ValidGenerateRequestWithTokens {
|
||||
/// Wrap the requests along with the channel used to stream back to the client the decoded tokens
|
||||
struct GenerationContext {
|
||||
request: ValidGenerateRequestWithTokens,
|
||||
start: Instant,
|
||||
queued: Option<Instant>,
|
||||
start: Option<Instant>,
|
||||
queued: Instant,
|
||||
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||
}
|
||||
|
||||
@ -64,10 +52,10 @@ struct DecodedToken {
|
||||
is_final: bool,
|
||||
}
|
||||
|
||||
impl TryFrom<GenerationStep> for DecodedToken {
|
||||
impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
|
||||
type Error = InferError;
|
||||
|
||||
fn try_from(step: GenerationStep) -> Result<Self, Self::Error> {
|
||||
fn try_from(step: &'step GenerationStep) -> Result<Self, Self::Error> {
|
||||
if !step.has_error {
|
||||
Ok(Self {
|
||||
id: step.token_id,
|
||||
@ -75,7 +63,7 @@ impl TryFrom<GenerationStep> for DecodedToken {
|
||||
is_final: step.is_final,
|
||||
})
|
||||
} else {
|
||||
Err(GenerationError(step.error_msg))
|
||||
Err(GenerationError(step.error_msg.clone()))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -89,14 +77,13 @@ struct DecodedTokenContext {
|
||||
fn executor_status_looper(
|
||||
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
|
||||
mut waiting_requests: UnboundedReceiver<GenerationContext>,
|
||||
mut post_processor_sender: UnboundedSender<DecodedTokenContextWithRequestId>,
|
||||
post_processor_sender: UnboundedSender<(u64, InferResult<DecodedTokenContext>)>,
|
||||
) {
|
||||
// Track the tuple (request_id, stream) for each request
|
||||
let mut in_flights = HashMap::<RequestId, GenerationContext>::with_capacity(128);
|
||||
let mut in_flights = HashMap::<u64, GenerationContext>::with_capacity(128);
|
||||
|
||||
// TODO: Does it need a spin-loop?
|
||||
'executor: loop {
|
||||
span!(Level::DEBUG, "[in-flight][submit]").in_scope(|| {
|
||||
'scheduler: loop {
|
||||
// Is there any request pending to be scheduled?
|
||||
let awaiting_requests = waiting_requests.len();
|
||||
for _ in 0..awaiting_requests {
|
||||
@ -121,54 +108,53 @@ fn executor_status_looper(
|
||||
Ok(request_id) => {
|
||||
// Insert the context linked to the generated request id in the tracker
|
||||
debug!("[in-flight] Added {}", request_id);
|
||||
ctx.queued = Instant::now();
|
||||
ctx.start = Some(Instant::now());
|
||||
in_flights.insert(request_id, ctx);
|
||||
}
|
||||
Err(e) => {
|
||||
// Return to the caller
|
||||
let what = Err(InferError::SchedulingError(e.to_string()));
|
||||
if let Err(ref e) = ctx.streamer.send(what) {
|
||||
error!("Failed to send the client", error = e.as_ref());
|
||||
let what = e.to_string();
|
||||
error!(error = what.as_str(), "Failed to schedule request");
|
||||
|
||||
let err = Err(InferError::SchedulingError(what));
|
||||
if let Err(_) = ctx.streamer.send(err) {
|
||||
error!("Failed to send back error to the client");
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if let Err(ref e) = info_span!("[in-flight][poll]").in_scope(|| {
|
||||
if backend.num_responses_ready() > 0 {
|
||||
let responses = backend
|
||||
.pin_mut()
|
||||
.pull_tokens()
|
||||
.map_err(|e| Err(GenerationError(e.what())))?;
|
||||
|
||||
match backend.pin_mut().pull_tokens() {
|
||||
Ok(responses) => {
|
||||
// Iterate through all the decoded token
|
||||
for step in responses.deref() {
|
||||
if let Some(ctx) = in_flights.get(&step.request_id) {
|
||||
|
||||
// Remove from tracked requests
|
||||
let parcel = DecodedToken::try_from(step).map(|dt| DecodedTokenContext {
|
||||
token: dt,
|
||||
channel: ctx.streamer.clone(),
|
||||
});
|
||||
|
||||
// Submit the work to the post_processor
|
||||
let delivered = post_processor_sender.send(parcel);
|
||||
// Submit the work to p:the post_processor
|
||||
let posted = post_processor_sender.send((step.request_id, parcel));
|
||||
|
||||
// Remove from tracked requests
|
||||
if step.is_final {
|
||||
if posted.is_err() || step.is_final {
|
||||
debug!("Removing {}", step.request_id);
|
||||
let _ = in_flights.remove(&step.request_id);
|
||||
}
|
||||
|
||||
delivered
|
||||
} else {
|
||||
warn!("Untracked request {}", step.request_id,);
|
||||
}
|
||||
}?;
|
||||
};
|
||||
}
|
||||
Err(ref err) => {
|
||||
error!("Failed to get responses from the executor: {}.", err.what());
|
||||
break 'scheduler;
|
||||
}
|
||||
}
|
||||
}) {
|
||||
error!("Error in the executor's loop, exiting", error = e.as_ref());
|
||||
break 'executor;
|
||||
}
|
||||
|
||||
// Hint the CPU we are spin-locking
|
||||
@ -178,7 +164,7 @@ fn executor_status_looper(
|
||||
|
||||
fn post_processor_looper(
|
||||
tokenizer: Tokenizer,
|
||||
mut decoded_tokens: UnboundedReceiver<DecodedTokenContextWithRequestId>,
|
||||
mut decoded_tokens: UnboundedReceiver<(u64, InferResult<DecodedTokenContext>)>,
|
||||
) {
|
||||
'post_processor: loop {
|
||||
if decoded_tokens.is_closed() {
|
||||
@ -186,7 +172,7 @@ fn post_processor_looper(
|
||||
break 'post_processor;
|
||||
}
|
||||
|
||||
let mut states = HashMap::with_capacity(128);
|
||||
let mut states: HashMap<u64, Vec<u32>> = HashMap::with_capacity(128);
|
||||
|
||||
if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() {
|
||||
let state = states.entry(request_id).or_insert(vec![]);
|
||||
@ -194,6 +180,9 @@ fn post_processor_looper(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
unsafe impl Send for crate::ffi::TensorRtLlmBackendImpl {}
|
||||
|
||||
pub struct TensorRtLlmBackendV2 {
|
||||
tokenizer: Tokenizer,
|
||||
executor_looper: JoinHandle<()>,
|
||||
@ -292,11 +281,11 @@ impl Backend for TensorRtLlmBackendV2 {
|
||||
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
|
||||
|
||||
// Send the context to the executor for scheduling
|
||||
let start = Instant::now();
|
||||
let queued = Instant::now();
|
||||
match self.executor.send(GenerationContext {
|
||||
request,
|
||||
start,
|
||||
queued: None,
|
||||
start: None,
|
||||
queued,
|
||||
streamer,
|
||||
}) {
|
||||
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
|
||||
|
Loading…
Reference in New Issue
Block a user