mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-04 06:50:15 +00:00
(backend) refactor & cleanup
This commit is contained in:
parent
483f172938
commit
b1846fb4e6
@ -4,14 +4,17 @@ pub mod errors;
|
|||||||
mod looper;
|
mod looper;
|
||||||
mod utils;
|
mod utils;
|
||||||
|
|
||||||
|
pub(crate) type RequestId = u64;
|
||||||
|
pub(crate) type TokenId = u32;
|
||||||
|
|
||||||
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
|
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
|
||||||
mod ffi {
|
mod ffi {
|
||||||
/// Struct used as shared type between rust and C++ to represent the result
|
/// Struct used as shared type between rust and C++ to represent the result
|
||||||
/// of a single decoding iteration
|
/// of a single decoding iteration
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct GenerationStep {
|
pub struct GenerationStep {
|
||||||
request_id: u64,
|
request_id: RequestId,
|
||||||
token_id: u32,
|
token_id: TokenId,
|
||||||
log_prob: f32,
|
log_prob: f32,
|
||||||
is_final: bool,
|
is_final: bool,
|
||||||
has_error: bool,
|
has_error: bool,
|
||||||
@ -50,7 +53,7 @@ mod ffi {
|
|||||||
#[rust_name = "submit"]
|
#[rust_name = "submit"]
|
||||||
fn Submit(
|
fn Submit(
|
||||||
self: Pin<&mut TensorRtLlmBackendImpl>,
|
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||||
tokens: &[u32],
|
tokens: &[TokenId],
|
||||||
max_new_tokens: u32,
|
max_new_tokens: u32,
|
||||||
top_k: i32,
|
top_k: i32,
|
||||||
top_p: f32,
|
top_p: f32,
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
use std::hint;
|
use std::hint;
|
||||||
use std::ops::Deref;
|
use std::ops::Deref;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::sync::OnceLock;
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use cxx::UniquePtr;
|
use cxx::UniquePtr;
|
||||||
use hashbrown::HashMap;
|
use hashbrown::{HashMap, HashSet};
|
||||||
use log::warn;
|
use log::warn;
|
||||||
use tokenizers::{Encoding, Tokenizer};
|
use tokenizers::{Encoding, Tokenizer};
|
||||||
use tokio::sync::mpsc::error::SendError;
|
use tokio::sync::mpsc::error::SendError;
|
||||||
@ -13,7 +12,7 @@ use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
|||||||
use tokio::task::{spawn_blocking, JoinHandle};
|
use tokio::task::{spawn_blocking, JoinHandle};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{debug, error, info, span, Level};
|
use tracing::{debug, debug_span, error, info, info_span, span, Level};
|
||||||
|
|
||||||
use text_generation_router::infer::InferError::{GenerationError, ValidationError};
|
use text_generation_router::infer::InferError::{GenerationError, ValidationError};
|
||||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
@ -26,32 +25,74 @@ use text_generation_router::{FinishReason, Token};
|
|||||||
use crate::errors::TensorRtLlmBackendError;
|
use crate::errors::TensorRtLlmBackendError;
|
||||||
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
|
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
|
||||||
use crate::utils::first_line;
|
use crate::utils::first_line;
|
||||||
|
use crate::RequestId;
|
||||||
// Value used to poll the state of the generation stream
|
|
||||||
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
|
|
||||||
|
|
||||||
// It's safe to send the backend between threads
|
|
||||||
unsafe impl Send for TensorRtLlmBackendImpl {}
|
|
||||||
|
|
||||||
type InferResult<T> = Result<T, InferError>;
|
type InferResult<T> = Result<T, InferError>;
|
||||||
|
|
||||||
|
struct IdentifiableRequest<T> {
|
||||||
|
request_id: RequestId,
|
||||||
|
inner: T,
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! identifiable {
|
||||||
|
($id: expr, $inner: expr) => {
|
||||||
|
IdentifiableRequest {
|
||||||
|
id: $id,
|
||||||
|
inner: $inner,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wrap the TGI server forwarded ValidGenerateRequest with the tokenized view of the prompt
|
||||||
struct ValidGenerateRequestWithTokens {
|
struct ValidGenerateRequestWithTokens {
|
||||||
encoding: Encoding,
|
encoding: Encoding,
|
||||||
inner: ValidGenerateRequest,
|
inner: ValidGenerateRequest,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct DecodedTokenContext {
|
/// Wrap the requests along with the channel used to stream back to the client the decoded tokens
|
||||||
tokens: Vec<GenerationStep>,
|
struct GenerationContext {
|
||||||
ctx: UnboundedSender<InferResult<InferStreamResponse>>,
|
request: ValidGenerateRequestWithTokens,
|
||||||
|
start: Instant,
|
||||||
|
queued: Option<Instant>,
|
||||||
|
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn executor_status_poller(
|
#[derive(Debug, Copy, Clone)]
|
||||||
|
struct DecodedToken {
|
||||||
|
id: u32,
|
||||||
|
log_prob: f32,
|
||||||
|
is_final: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<GenerationStep> for DecodedToken {
|
||||||
|
type Error = InferError;
|
||||||
|
|
||||||
|
fn try_from(step: GenerationStep) -> Result<Self, Self::Error> {
|
||||||
|
if !step.has_error {
|
||||||
|
Ok(Self {
|
||||||
|
id: step.token_id,
|
||||||
|
log_prob: step.log_prob,
|
||||||
|
is_final: step.is_final,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Err(GenerationError(step.error_msg))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wraps the decoded token with the channel used to stream back to the client the decoded tokens
|
||||||
|
struct DecodedTokenContext {
|
||||||
|
token: DecodedToken,
|
||||||
|
channel: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn executor_status_looper(
|
||||||
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
|
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
|
||||||
mut waiting_requests: UnboundedReceiver<GenerationContext>,
|
mut waiting_requests: UnboundedReceiver<GenerationContext>,
|
||||||
mut post_processor_sender: UnboundedSender<DecodedTokenContext>,
|
mut post_processor_sender: UnboundedSender<DecodedTokenContextWithRequestId>,
|
||||||
) {
|
) {
|
||||||
// Track the tuple (request_id, stream) for each request
|
// Track the tuple (request_id, stream) for each request
|
||||||
let mut in_flights = HashMap::<u64, GenerationContext>::with_capacity(128);
|
let mut in_flights = HashMap::<RequestId, GenerationContext>::with_capacity(128);
|
||||||
|
|
||||||
// TODO: Does it need a spin-loop?
|
// TODO: Does it need a spin-loop?
|
||||||
'executor: loop {
|
'executor: loop {
|
||||||
@ -60,7 +101,7 @@ fn executor_status_poller(
|
|||||||
let awaiting_requests = waiting_requests.len();
|
let awaiting_requests = waiting_requests.len();
|
||||||
for _ in 0..awaiting_requests {
|
for _ in 0..awaiting_requests {
|
||||||
// Retrieve all the requests
|
// Retrieve all the requests
|
||||||
if let Some(ctx) = waiting_requests.blocking_recv() {
|
if let Some(mut ctx) = waiting_requests.blocking_recv() {
|
||||||
// Submit all the request to the executor and move the context to the in-flight tracker
|
// Submit all the request to the executor and move the context to the in-flight tracker
|
||||||
let request = &ctx.request;
|
let request = &ctx.request;
|
||||||
let generation_params = &request.inner.parameters;
|
let generation_params = &request.inner.parameters;
|
||||||
@ -79,13 +120,15 @@ fn executor_status_poller(
|
|||||||
) {
|
) {
|
||||||
Ok(request_id) => {
|
Ok(request_id) => {
|
||||||
// Insert the context linked to the generated request id in the tracker
|
// Insert the context linked to the generated request id in the tracker
|
||||||
|
debug!("[in-flight] Added {}", request_id);
|
||||||
|
ctx.queued = Instant::now();
|
||||||
in_flights.insert(request_id, ctx);
|
in_flights.insert(request_id, ctx);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
// Return to the caller
|
// Return to the caller
|
||||||
let what = Err(InferError::SchedulingError(e.to_string()));
|
let what = Err(InferError::SchedulingError(e.to_string()));
|
||||||
if let Err(e) = ctx.streamer.send(what) {
|
if let Err(ref e) = ctx.streamer.send(what) {
|
||||||
error!("Failed to send back through the channel: {}", e);
|
error!("Failed to send the client", error = e.as_ref());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -93,83 +136,38 @@ fn executor_status_poller(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
if let Err(e) = span!(Level::DEBUG, "[in-flight][poll]").in_scope(|| {
|
if let Err(ref e) = info_span!("[in-flight][poll]").in_scope(|| {
|
||||||
if backend.num_responses_ready() > 0 {
|
if backend.num_responses_ready() > 0 {
|
||||||
match backend.pin_mut().pull_tokens() {
|
let responses = backend
|
||||||
Ok(responses) => {
|
.pin_mut()
|
||||||
debug!("Received {} tokens from the executor", responses.len());
|
.pull_tokens()
|
||||||
|
.map_err(|e| Err(GenerationError(e.what())))?;
|
||||||
// worse case scenario is one token for each response: with_capacity(responses.len())
|
|
||||||
// grouper will group decoded tokens per request to decode multiple tokens
|
|
||||||
let mut grouper: HashMap<u64, DecodedTokenContext> =
|
|
||||||
HashMap::with_capacity(responses.len());
|
|
||||||
|
|
||||||
// Iterate through all the decoded token
|
// Iterate through all the decoded token
|
||||||
for step in responses.deref() {
|
for step in responses.deref() {
|
||||||
match in_flights.get(&step.request_id) {
|
if let Some(ctx) = in_flights.get(&step.request_id) {
|
||||||
Some(ctx) => {
|
let parcel = DecodedToken::try_from(step).map(|dt| DecodedTokenContext {
|
||||||
debug!(
|
token: dt,
|
||||||
"{} -> (token={}, final={})",
|
channel: ctx.streamer.clone(),
|
||||||
step.request_id, step.token_id, step.is_final
|
});
|
||||||
);
|
|
||||||
|
|
||||||
// If no error, let's forward to post-processor
|
// Submit the work to the post_processor
|
||||||
if !step.has_error {
|
let delivered = post_processor_sender.send(parcel);
|
||||||
let req_group = grouper.entry(step.request_id).or_insert(
|
|
||||||
DecodedTokenContext {
|
|
||||||
tokens: vec![],
|
|
||||||
ctx: ctx.streamer.clone(), // Arc::clone() = cheap
|
|
||||||
},
|
|
||||||
);
|
|
||||||
req_group.tokens.push(step.clone()); // Should be ultra cheap
|
|
||||||
} else {
|
|
||||||
warn!(
|
|
||||||
"Error for request: {} -> {}",
|
|
||||||
step.request_id, &step.error_msg
|
|
||||||
);
|
|
||||||
|
|
||||||
// TODO: Send something back to the postprocessor for the client?
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove from tracked requests
|
// Remove from tracked requests
|
||||||
if step.is_final {
|
if step.is_final {
|
||||||
|
debug!("Removing {}", step.request_id);
|
||||||
let _ = in_flights.remove(&step.request_id);
|
let _ = in_flights.remove(&step.request_id);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
None => {
|
delivered
|
||||||
if step.has_error {
|
|
||||||
error!(
|
|
||||||
"Untracked request {} -> {}",
|
|
||||||
step.request_id, &step.error_msg
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
} else {
|
} else {
|
||||||
error!(
|
warn!("Untracked request {}", step.request_id,);
|
||||||
"Got step for untracked request {}",
|
|
||||||
step.request_id
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
}?;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
grouper
|
|
||||||
.into_values()
|
|
||||||
.map(|ctx| post_processor_sender.send(ctx))
|
|
||||||
.collect::<Result<(), SendError<DecodedTokenContext>>>()?;
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
error!("Failed to retrieve tokens from the executor: {}", err);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok::<(), SendError<DecodedTokenContext>>(())
|
|
||||||
}) {
|
}) {
|
||||||
error!(
|
error!("Error in the executor's loop, exiting", error = e.as_ref());
|
||||||
"Caught an fatal error in the executor's loop, about to exit. {}",
|
|
||||||
e
|
|
||||||
);
|
|
||||||
break 'executor;
|
break 'executor;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -180,7 +178,7 @@ fn executor_status_poller(
|
|||||||
|
|
||||||
fn post_processor_looper(
|
fn post_processor_looper(
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
mut decoded_tokens: UnboundedReceiver<DecodedTokenContext>,
|
mut decoded_tokens: UnboundedReceiver<DecodedTokenContextWithRequestId>,
|
||||||
) {
|
) {
|
||||||
'post_processor: loop {
|
'post_processor: loop {
|
||||||
if decoded_tokens.is_closed() {
|
if decoded_tokens.is_closed() {
|
||||||
@ -188,56 +186,14 @@ fn post_processor_looper(
|
|||||||
break 'post_processor;
|
break 'post_processor;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ctx) = decoded_tokens.blocking_recv() {
|
let mut states = HashMap::with_capacity(128);
|
||||||
ctx.tokens.iter().for_each(|step| {
|
|
||||||
let out = match tokenizer.decode(&[step.token_id], true) {
|
|
||||||
Ok(text) => {
|
|
||||||
let is_special = tokenizer.get_added_vocabulary().is_special_token(&text);
|
|
||||||
let token = Token {
|
|
||||||
id: step.token_id,
|
|
||||||
text,
|
|
||||||
logprob: step.log_prob,
|
|
||||||
special: is_special,
|
|
||||||
};
|
|
||||||
|
|
||||||
let response = if !step.is_final {
|
if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() {
|
||||||
InferStreamResponse::Intermediate {
|
let state = states.entry(request_id).or_insert(vec![]);
|
||||||
token,
|
|
||||||
top_tokens: vec![],
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
InferStreamResponse::End {
|
|
||||||
token,
|
|
||||||
top_tokens: vec![],
|
|
||||||
generated_text: GeneratedText {
|
|
||||||
text: String::from(""),
|
|
||||||
generated_tokens: 0,
|
|
||||||
finish_reason: FinishReason::Length,
|
|
||||||
seed: None,
|
|
||||||
},
|
|
||||||
start: Instant::now(), // Handle start time
|
|
||||||
queued: Instant::now(), // Handle queued time
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(response)
|
|
||||||
}
|
|
||||||
Err(e) => Err(GenerationError(e.to_string())),
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Err(e) = ctx.ctx.send(out) {
|
|
||||||
warn!("Failed to send back the decoded tokens: {}", e);
|
|
||||||
};
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct GenerationContext {
|
|
||||||
request: ValidGenerateRequestWithTokens,
|
|
||||||
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct TensorRtLlmBackendV2 {
|
pub struct TensorRtLlmBackendV2 {
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
executor_looper: JoinHandle<()>,
|
executor_looper: JoinHandle<()>,
|
||||||
@ -277,7 +233,7 @@ impl TensorRtLlmBackendV2 {
|
|||||||
|
|
||||||
// Executor looper is responsible for scheduling and pulling requests state at regular interval
|
// Executor looper is responsible for scheduling and pulling requests state at regular interval
|
||||||
let executor_looper = spawn_blocking(move || {
|
let executor_looper = spawn_blocking(move || {
|
||||||
executor_status_poller(backend, executor_receiver, post_processor_sender)
|
executor_status_looper(backend, executor_receiver, post_processor_sender)
|
||||||
});
|
});
|
||||||
|
|
||||||
// Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user
|
// Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user
|
||||||
@ -295,22 +251,22 @@ impl TensorRtLlmBackendV2 {
|
|||||||
|
|
||||||
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
|
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
|
||||||
if request.top_n_tokens > 1 {
|
if request.top_n_tokens > 1 {
|
||||||
return Err(InferError::ValidationError(TopNTokensDisabled));
|
return Err(ValidationError(TopNTokensDisabled));
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Is it really needed? How can it be validated before?
|
// TODO: Is it really needed? How can it be validated before?
|
||||||
if request.parameters.grammar.is_some() {
|
if request.parameters.grammar.is_some() {
|
||||||
return Err(InferError::ValidationError(Grammar));
|
return Err(ValidationError(Grammar));
|
||||||
}
|
}
|
||||||
|
|
||||||
match request.inputs.len() {
|
match request.inputs.len() {
|
||||||
0 => Err(InferError::ValidationError(EmptyInput)),
|
0 => Err(ValidationError(EmptyInput)),
|
||||||
2.. => Err(InferError::GenerationError(
|
2.. => Err(GenerationError(
|
||||||
"TensorRT-LLM backend don't support multi-chunk".into(),
|
"TensorRT-LLM backend don't support multi-chunk".into(),
|
||||||
)),
|
)),
|
||||||
1 => match request.inputs.first().expect("Single item-chunk") {
|
1 => match request.inputs.first().expect("Single item-chunk") {
|
||||||
Chunk::Text(text) => Ok(text),
|
Chunk::Text(text) => Ok(text),
|
||||||
Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))),
|
Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -336,7 +292,13 @@ impl Backend for TensorRtLlmBackendV2 {
|
|||||||
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
|
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
|
||||||
|
|
||||||
// Send the context to the executor for scheduling
|
// Send the context to the executor for scheduling
|
||||||
match self.executor.send(GenerationContext { request, streamer }) {
|
let start = Instant::now();
|
||||||
|
match self.executor.send(GenerationContext {
|
||||||
|
request,
|
||||||
|
start,
|
||||||
|
queued: None,
|
||||||
|
streamer,
|
||||||
|
}) {
|
||||||
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
|
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
|
||||||
Err(_) => Err(GenerationError(
|
Err(_) => Err(GenerationError(
|
||||||
"Failed to submit request to the backend".into(),
|
"Failed to submit request to the backend".into(),
|
||||||
|
Loading…
Reference in New Issue
Block a user