mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
feat(backend): make sure we can easily cancel request on the executor
This commit is contained in:
parent
300f6c6f94
commit
b3cd5ea076
@ -1,10 +1,9 @@
|
|||||||
use std::hint;
|
|
||||||
use std::ops::Deref;
|
|
||||||
use std::path::Path;
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use cxx::UniquePtr;
|
use cxx::UniquePtr;
|
||||||
use hashbrown::HashMap;
|
use hashbrown::HashMap;
|
||||||
|
use std::hint;
|
||||||
|
use std::ops::Deref;
|
||||||
|
use std::path::Path;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
||||||
use tokio::sync::TryAcquireError;
|
use tokio::sync::TryAcquireError;
|
||||||
@ -30,9 +29,10 @@ type InferResult<T> = Result<T, InferError>;
|
|||||||
/// Wrap the requests along with the channel used to stream back to the client the decoded tokens
|
/// Wrap the requests along with the channel used to stream back to the client the decoded tokens
|
||||||
struct GenerationContext {
|
struct GenerationContext {
|
||||||
request: ValidGenerateRequest,
|
request: ValidGenerateRequest,
|
||||||
|
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||||
|
tokens: Vec<u32>,
|
||||||
start: Option<Instant>,
|
start: Option<Instant>,
|
||||||
queued: Instant,
|
queued: Instant,
|
||||||
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Copy, Clone)]
|
#[derive(Debug, Copy, Clone)]
|
||||||
@ -58,31 +58,22 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Wraps the decoded token with the channel used to stream back to the client the decoded tokens
|
|
||||||
struct DecodedTokenContext {
|
|
||||||
token: DecodedToken,
|
|
||||||
start: Option<Instant>,
|
|
||||||
queued: Instant,
|
|
||||||
channel: UnboundedSender<InferResult<InferStreamResponse>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn executor_status_looper(
|
fn executor_status_looper(
|
||||||
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
|
|
||||||
max_inflight_requests: usize,
|
max_inflight_requests: usize,
|
||||||
mut waiting_requests: UnboundedReceiver<GenerationContext>,
|
tokenizer: Tokenizer,
|
||||||
post_processor_sender: UnboundedSender<(u64, InferResult<DecodedTokenContext>)>,
|
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
|
||||||
|
mut backlog: UnboundedReceiver<GenerationContext>,
|
||||||
) {
|
) {
|
||||||
// Track the tuple (request_id, stream) for each request
|
// Track the tuple (request_id, stream) for each request
|
||||||
let mut in_flights =
|
let mut in_flights =
|
||||||
HashMap::<u64, GenerationContext>::with_capacity(max_inflight_requests * 2);
|
HashMap::<u64, GenerationContext>::with_capacity(max_inflight_requests * 2);
|
||||||
|
|
||||||
// TODO: Does it need a spin-loop?
|
|
||||||
'scheduler: loop {
|
'scheduler: loop {
|
||||||
// Is there any request pending to be scheduled?
|
// Is there any request pending to be scheduled?
|
||||||
let awaiting_requests = waiting_requests.len();
|
let awaiting_requests = backlog.len();
|
||||||
for _ in 0..awaiting_requests {
|
for _ in 0..awaiting_requests {
|
||||||
// Retrieve all the requests
|
// Retrieve all the requests
|
||||||
if let Some(mut ctx) = waiting_requests.blocking_recv() {
|
if let Some(ctx) = backlog.blocking_recv() {
|
||||||
// Submit all the request to the executor and move the context to the in-flight tracker
|
// Submit all the request to the executor and move the context to the in-flight tracker
|
||||||
let request = &ctx.request;
|
let request = &ctx.request;
|
||||||
let generation_params = &request.parameters;
|
let generation_params = &request.parameters;
|
||||||
@ -103,7 +94,6 @@ fn executor_status_looper(
|
|||||||
Ok(request_id) => {
|
Ok(request_id) => {
|
||||||
// Insert the context linked to the generated request id in the tracker
|
// Insert the context linked to the generated request id in the tracker
|
||||||
debug!("[in-flight] Added {}", request_id);
|
debug!("[in-flight] Added {}", request_id);
|
||||||
ctx.start = Some(Instant::now());
|
|
||||||
in_flights.insert(request_id, ctx);
|
in_flights.insert(request_id, ctx);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@ -117,6 +107,8 @@ fn executor_status_looper(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
} else {
|
||||||
|
break 'scheduler;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -125,21 +117,28 @@ fn executor_status_looper(
|
|||||||
Ok(responses) => {
|
Ok(responses) => {
|
||||||
// Iterate through all the decoded token
|
// Iterate through all the decoded token
|
||||||
for step in responses.deref() {
|
for step in responses.deref() {
|
||||||
if let Some(ctx) = in_flights.get(&step.request_id) {
|
if let Some(ctx) = in_flights.get_mut(&step.request_id) {
|
||||||
// Remove from tracked requests
|
// Update the starting timestamp if not set
|
||||||
let parcel =
|
// This value might not be the actual real starting time of the request
|
||||||
DecodedToken::try_from(step).map(|dt| DecodedTokenContext {
|
// on the executor side - Need to expose more info from the executor to
|
||||||
token: dt,
|
// retrieve this value
|
||||||
start: ctx.start,
|
// TODO : Expose actual real starting time for a request on FFI layer
|
||||||
queued: ctx.queued,
|
if ctx.start.is_none() {
|
||||||
channel: ctx.streamer.clone(),
|
ctx.start = Some(Instant::now());
|
||||||
});
|
}
|
||||||
|
|
||||||
// Submit the work to p:the post_processor
|
// Try to map the generation step to a DecodedToken
|
||||||
let posted = post_processor_sender.send((step.request_id, parcel));
|
let response = match DecodedToken::try_from(step) {
|
||||||
|
Ok(decoded_token) => {
|
||||||
|
post_process_decoded_token(&tokenizer, ctx, decoded_token)
|
||||||
|
}
|
||||||
|
Err(err) => Err(err)
|
||||||
|
};
|
||||||
|
|
||||||
if posted.is_err() || step.is_final {
|
// Attempt to send back the response to the client
|
||||||
debug!("Removing {}", step.request_id);
|
if let Err(_) = ctx.streamer.send(response) {
|
||||||
|
// Client has dropped, remove from tracked requests
|
||||||
|
debug!("Client dropped - removing request {} from tracked requests", step.request_id);
|
||||||
backend.pin_mut().cancel(step.request_id);
|
backend.pin_mut().cancel(step.request_id);
|
||||||
let _ = in_flights.remove(&step.request_id);
|
let _ = in_flights.remove(&step.request_id);
|
||||||
}
|
}
|
||||||
@ -160,54 +159,33 @@ fn executor_status_looper(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn post_processor_looper<const MAX_NUM_TOKENS: usize>(
|
fn post_process_decoded_token(tokenizer: &Tokenizer, ctx: &mut GenerationContext, decoded_token: DecodedToken) -> InferResult<InferStreamResponse> {
|
||||||
tokenizer: Tokenizer,
|
match tokenizer.decode(&[decoded_token.id], false) {
|
||||||
max_inflight_requests: usize,
|
|
||||||
mut decoded_tokens: UnboundedReceiver<(u64, InferResult<DecodedTokenContext>)>,
|
|
||||||
) {
|
|
||||||
let mut states: HashMap<u64, Vec<u32>> = HashMap::with_capacity(max_inflight_requests * 2);
|
|
||||||
|
|
||||||
'post_processor: loop {
|
|
||||||
if decoded_tokens.is_closed() {
|
|
||||||
warn!("Post processor IPC is closed, loop will exit now.");
|
|
||||||
break 'post_processor;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() {
|
|
||||||
match decoded {
|
|
||||||
Ok(ctx) => {
|
|
||||||
states
|
|
||||||
.entry(request_id)
|
|
||||||
.and_modify(|s| s.push(*&ctx.token.id))
|
|
||||||
.or_insert_with(|| {
|
|
||||||
let mut state = Vec::with_capacity(MAX_NUM_TOKENS);
|
|
||||||
state.push(*&ctx.token.id);
|
|
||||||
state
|
|
||||||
});
|
|
||||||
|
|
||||||
let out = match tokenizer.decode(&[ctx.token.id], false) {
|
|
||||||
Ok(text) => {
|
Ok(text) => {
|
||||||
let is_special =
|
let is_special =
|
||||||
tokenizer.get_added_vocabulary().is_special_token(&text);
|
tokenizer.get_added_vocabulary().is_special_token(&text);
|
||||||
let token = Token {
|
let token = Token {
|
||||||
id: ctx.token.id,
|
id: decoded_token.id,
|
||||||
text,
|
text,
|
||||||
logprob: ctx.token.log_prob,
|
logprob: decoded_token.log_prob,
|
||||||
special: is_special,
|
special: is_special,
|
||||||
};
|
};
|
||||||
|
|
||||||
let out = if !ctx.token.is_final {
|
// Append the token to the tracked generated tokens
|
||||||
|
ctx.tokens.push(token.id);
|
||||||
|
|
||||||
|
// Map the correct response depending on the step is final or not
|
||||||
|
let out = if !decoded_token.is_final {
|
||||||
InferStreamResponse::Intermediate {
|
InferStreamResponse::Intermediate {
|
||||||
token,
|
token,
|
||||||
top_tokens: vec![],
|
top_tokens: vec![],
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let tokens = states.remove(&request_id).unwrap();
|
let text = tokenizer.decode(&ctx.tokens, true);
|
||||||
let text = tokenizer.decode(&tokens, true);
|
|
||||||
let generated_text = GeneratedText {
|
let generated_text = GeneratedText {
|
||||||
text: text.unwrap(),
|
text: text.unwrap(),
|
||||||
generated_tokens: tokens.len() as u32,
|
generated_tokens: ctx.tokens.len() as u32,
|
||||||
finish_reason: FinishReason::EndOfSequenceToken,
|
finish_reason: FinishReason::EndOfSequenceToken, // TODO : Map FinishReason
|
||||||
seed: None,
|
seed: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -223,17 +201,6 @@ fn post_processor_looper<const MAX_NUM_TOKENS: usize>(
|
|||||||
Ok(out)
|
Ok(out)
|
||||||
}
|
}
|
||||||
Err(err) => Err(GenerationError(err.to_string())),
|
Err(err) => Err(GenerationError(err.to_string())),
|
||||||
};
|
|
||||||
|
|
||||||
if let Err(_) = ctx.channel.send(out) {
|
|
||||||
warn!("Failed to send decoded token back to the user")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(_err) => {
|
|
||||||
todo!("what do we do?")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -278,11 +245,8 @@ fn ensure_paths_exist<P: AsRef<Path>, PP: AsRef<Path>>(
|
|||||||
|
|
||||||
unsafe impl Send for TensorRtLlmBackendImpl {}
|
unsafe impl Send for TensorRtLlmBackendImpl {}
|
||||||
|
|
||||||
pub struct TensorRtLlmBackendV2 {
|
pub struct TensorRtLlmBackendV2(UnboundedSender<GenerationContext>);
|
||||||
executor_looper: JoinHandle<()>,
|
|
||||||
post_processor_looper: JoinHandle<()>,
|
|
||||||
executor: UnboundedSender<GenerationContext>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TensorRtLlmBackendV2 {
|
impl TensorRtLlmBackendV2 {
|
||||||
pub fn new<P: AsRef<Path> + Send, PP: AsRef<Path> + Send>(
|
pub fn new<P: AsRef<Path> + Send, PP: AsRef<Path> + Send>(
|
||||||
@ -296,32 +260,22 @@ impl TensorRtLlmBackendV2 {
|
|||||||
|
|
||||||
// Allocate the IPC layer to communicate with the backend
|
// Allocate the IPC layer to communicate with the backend
|
||||||
let (executor_sender, executor_receiver) = unbounded_channel();
|
let (executor_sender, executor_receiver) = unbounded_channel();
|
||||||
let (post_processor_sender, post_processor_receiver) = unbounded_channel();
|
|
||||||
|
|
||||||
// Create the FFI backend
|
// Create the FFI backend
|
||||||
let backend = create_backend_from_engine_folder(&engine_folder, &executor_worker_path)
|
let backend = create_backend_from_engine_folder(&engine_folder, &executor_worker_path)
|
||||||
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
|
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
|
||||||
|
|
||||||
// Executor looper is responsible for scheduling and pulling requests state at regular interval
|
// Executor looper is responsible for scheduling and pulling requests state at regular interval
|
||||||
let executor_looper = spawn_blocking(move || {
|
spawn_blocking(move || {
|
||||||
executor_status_looper(
|
executor_status_looper(
|
||||||
backend,
|
|
||||||
max_inflight_requests,
|
max_inflight_requests,
|
||||||
|
tokenizer,
|
||||||
|
backend,
|
||||||
executor_receiver,
|
executor_receiver,
|
||||||
post_processor_sender,
|
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
// Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user
|
Ok(TensorRtLlmBackendV2(executor_sender))
|
||||||
let post_processor_looper = spawn_blocking(move || {
|
|
||||||
post_processor_looper::<256>(tokenizer, max_inflight_requests, post_processor_receiver)
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(TensorRtLlmBackendV2 {
|
|
||||||
executor_looper,
|
|
||||||
post_processor_looper,
|
|
||||||
executor: executor_sender,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn validate(request: &ValidGenerateRequest) -> InferResult<()> {
|
fn validate(request: &ValidGenerateRequest) -> InferResult<()> {
|
||||||
@ -355,20 +309,21 @@ impl TensorRtLlmBackendV2 {
|
|||||||
impl Backend for TensorRtLlmBackendV2 {
|
impl Backend for TensorRtLlmBackendV2 {
|
||||||
fn schedule(
|
fn schedule(
|
||||||
&self,
|
&self,
|
||||||
inner: ValidGenerateRequest,
|
request: ValidGenerateRequest,
|
||||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||||
Self::validate(&inner)?;
|
Self::validate(&request)?;
|
||||||
|
|
||||||
// Open-up the stream to send tokens
|
// Open-up the stream to send tokens
|
||||||
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
|
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
|
||||||
|
|
||||||
// Send the context to the executor for scheduling
|
// Send the context to the executor for scheduling
|
||||||
let queued = Instant::now();
|
let queued = Instant::now();
|
||||||
match self.executor.send(GenerationContext {
|
match self.send(GenerationContext {
|
||||||
request: inner,
|
request,
|
||||||
|
streamer,
|
||||||
|
tokens: Vec::with_capacity(256),
|
||||||
start: None,
|
start: None,
|
||||||
queued,
|
queued,
|
||||||
streamer,
|
|
||||||
}) {
|
}) {
|
||||||
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
|
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
|
||||||
Err(_) => Err(GenerationError(
|
Err(_) => Err(GenerationError(
|
||||||
@ -378,6 +333,6 @@ impl Backend for TensorRtLlmBackendV2 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn health(&self, _: bool) -> bool {
|
async fn health(&self, _: bool) -> bool {
|
||||||
!self.executor_looper.is_finished() & !self.post_processor_looper.is_finished()
|
true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user