mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
feat(backend): make sure we can easily cancel request on the executor
This commit is contained in:
parent
300f6c6f94
commit
b3cd5ea076
@ -1,10 +1,9 @@
|
||||
use std::hint;
|
||||
use std::ops::Deref;
|
||||
use std::path::Path;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use cxx::UniquePtr;
|
||||
use hashbrown::HashMap;
|
||||
use std::hint;
|
||||
use std::ops::Deref;
|
||||
use std::path::Path;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
||||
use tokio::sync::TryAcquireError;
|
||||
@ -30,9 +29,10 @@ type InferResult<T> = Result<T, InferError>;
|
||||
/// Wrap the requests along with the channel used to stream back to the client the decoded tokens
|
||||
struct GenerationContext {
|
||||
request: ValidGenerateRequest,
|
||||
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||
tokens: Vec<u32>,
|
||||
start: Option<Instant>,
|
||||
queued: Instant,
|
||||
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
@ -58,31 +58,22 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
|
||||
}
|
||||
}
|
||||
|
||||
/// Wraps the decoded token with the channel used to stream back to the client the decoded tokens
|
||||
struct DecodedTokenContext {
|
||||
token: DecodedToken,
|
||||
start: Option<Instant>,
|
||||
queued: Instant,
|
||||
channel: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||
}
|
||||
|
||||
fn executor_status_looper(
|
||||
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
|
||||
max_inflight_requests: usize,
|
||||
mut waiting_requests: UnboundedReceiver<GenerationContext>,
|
||||
post_processor_sender: UnboundedSender<(u64, InferResult<DecodedTokenContext>)>,
|
||||
tokenizer: Tokenizer,
|
||||
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
|
||||
mut backlog: UnboundedReceiver<GenerationContext>,
|
||||
) {
|
||||
// Track the tuple (request_id, stream) for each request
|
||||
let mut in_flights =
|
||||
HashMap::<u64, GenerationContext>::with_capacity(max_inflight_requests * 2);
|
||||
|
||||
// TODO: Does it need a spin-loop?
|
||||
'scheduler: loop {
|
||||
// Is there any request pending to be scheduled?
|
||||
let awaiting_requests = waiting_requests.len();
|
||||
let awaiting_requests = backlog.len();
|
||||
for _ in 0..awaiting_requests {
|
||||
// Retrieve all the requests
|
||||
if let Some(mut ctx) = waiting_requests.blocking_recv() {
|
||||
if let Some(ctx) = backlog.blocking_recv() {
|
||||
// Submit all the request to the executor and move the context to the in-flight tracker
|
||||
let request = &ctx.request;
|
||||
let generation_params = &request.parameters;
|
||||
@ -103,7 +94,6 @@ fn executor_status_looper(
|
||||
Ok(request_id) => {
|
||||
// Insert the context linked to the generated request id in the tracker
|
||||
debug!("[in-flight] Added {}", request_id);
|
||||
ctx.start = Some(Instant::now());
|
||||
in_flights.insert(request_id, ctx);
|
||||
}
|
||||
Err(e) => {
|
||||
@ -117,6 +107,8 @@ fn executor_status_looper(
|
||||
}
|
||||
}
|
||||
};
|
||||
} else {
|
||||
break 'scheduler;
|
||||
}
|
||||
}
|
||||
|
||||
@ -125,21 +117,28 @@ fn executor_status_looper(
|
||||
Ok(responses) => {
|
||||
// Iterate through all the decoded token
|
||||
for step in responses.deref() {
|
||||
if let Some(ctx) = in_flights.get(&step.request_id) {
|
||||
// Remove from tracked requests
|
||||
let parcel =
|
||||
DecodedToken::try_from(step).map(|dt| DecodedTokenContext {
|
||||
token: dt,
|
||||
start: ctx.start,
|
||||
queued: ctx.queued,
|
||||
channel: ctx.streamer.clone(),
|
||||
});
|
||||
if let Some(ctx) = in_flights.get_mut(&step.request_id) {
|
||||
// Update the starting timestamp if not set
|
||||
// This value might not be the actual real starting time of the request
|
||||
// on the executor side - Need to expose more info from the executor to
|
||||
// retrieve this value
|
||||
// TODO : Expose actual real starting time for a request on FFI layer
|
||||
if ctx.start.is_none() {
|
||||
ctx.start = Some(Instant::now());
|
||||
}
|
||||
|
||||
// Submit the work to p:the post_processor
|
||||
let posted = post_processor_sender.send((step.request_id, parcel));
|
||||
// Try to map the generation step to a DecodedToken
|
||||
let response = match DecodedToken::try_from(step) {
|
||||
Ok(decoded_token) => {
|
||||
post_process_decoded_token(&tokenizer, ctx, decoded_token)
|
||||
}
|
||||
Err(err) => Err(err)
|
||||
};
|
||||
|
||||
if posted.is_err() || step.is_final {
|
||||
debug!("Removing {}", step.request_id);
|
||||
// Attempt to send back the response to the client
|
||||
if let Err(_) = ctx.streamer.send(response) {
|
||||
// Client has dropped, remove from tracked requests
|
||||
debug!("Client dropped - removing request {} from tracked requests", step.request_id);
|
||||
backend.pin_mut().cancel(step.request_id);
|
||||
let _ = in_flights.remove(&step.request_id);
|
||||
}
|
||||
@ -160,80 +159,48 @@ fn executor_status_looper(
|
||||
}
|
||||
}
|
||||
|
||||
fn post_processor_looper<const MAX_NUM_TOKENS: usize>(
|
||||
tokenizer: Tokenizer,
|
||||
max_inflight_requests: usize,
|
||||
mut decoded_tokens: UnboundedReceiver<(u64, InferResult<DecodedTokenContext>)>,
|
||||
) {
|
||||
let mut states: HashMap<u64, Vec<u32>> = HashMap::with_capacity(max_inflight_requests * 2);
|
||||
fn post_process_decoded_token(tokenizer: &Tokenizer, ctx: &mut GenerationContext, decoded_token: DecodedToken) -> InferResult<InferStreamResponse> {
|
||||
match tokenizer.decode(&[decoded_token.id], false) {
|
||||
Ok(text) => {
|
||||
let is_special =
|
||||
tokenizer.get_added_vocabulary().is_special_token(&text);
|
||||
let token = Token {
|
||||
id: decoded_token.id,
|
||||
text,
|
||||
logprob: decoded_token.log_prob,
|
||||
special: is_special,
|
||||
};
|
||||
|
||||
'post_processor: loop {
|
||||
if decoded_tokens.is_closed() {
|
||||
warn!("Post processor IPC is closed, loop will exit now.");
|
||||
break 'post_processor;
|
||||
}
|
||||
// Append the token to the tracked generated tokens
|
||||
ctx.tokens.push(token.id);
|
||||
|
||||
if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() {
|
||||
match decoded {
|
||||
Ok(ctx) => {
|
||||
states
|
||||
.entry(request_id)
|
||||
.and_modify(|s| s.push(*&ctx.token.id))
|
||||
.or_insert_with(|| {
|
||||
let mut state = Vec::with_capacity(MAX_NUM_TOKENS);
|
||||
state.push(*&ctx.token.id);
|
||||
state
|
||||
});
|
||||
|
||||
let out = match tokenizer.decode(&[ctx.token.id], false) {
|
||||
Ok(text) => {
|
||||
let is_special =
|
||||
tokenizer.get_added_vocabulary().is_special_token(&text);
|
||||
let token = Token {
|
||||
id: ctx.token.id,
|
||||
text,
|
||||
logprob: ctx.token.log_prob,
|
||||
special: is_special,
|
||||
};
|
||||
|
||||
let out = if !ctx.token.is_final {
|
||||
InferStreamResponse::Intermediate {
|
||||
token,
|
||||
top_tokens: vec![],
|
||||
}
|
||||
} else {
|
||||
let tokens = states.remove(&request_id).unwrap();
|
||||
let text = tokenizer.decode(&tokens, true);
|
||||
let generated_text = GeneratedText {
|
||||
text: text.unwrap(),
|
||||
generated_tokens: tokens.len() as u32,
|
||||
finish_reason: FinishReason::EndOfSequenceToken,
|
||||
seed: None,
|
||||
};
|
||||
|
||||
InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens: vec![],
|
||||
generated_text,
|
||||
start: ctx.start.unwrap(),
|
||||
queued: ctx.queued,
|
||||
}
|
||||
};
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
Err(err) => Err(GenerationError(err.to_string())),
|
||||
};
|
||||
|
||||
if let Err(_) = ctx.channel.send(out) {
|
||||
warn!("Failed to send decoded token back to the user")
|
||||
}
|
||||
// Map the correct response depending on the step is final or not
|
||||
let out = if !decoded_token.is_final {
|
||||
InferStreamResponse::Intermediate {
|
||||
token,
|
||||
top_tokens: vec![],
|
||||
}
|
||||
Err(_err) => {
|
||||
todo!("what do we do?")
|
||||
} else {
|
||||
let text = tokenizer.decode(&ctx.tokens, true);
|
||||
let generated_text = GeneratedText {
|
||||
text: text.unwrap(),
|
||||
generated_tokens: ctx.tokens.len() as u32,
|
||||
finish_reason: FinishReason::EndOfSequenceToken, // TODO : Map FinishReason
|
||||
seed: None,
|
||||
};
|
||||
|
||||
InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens: vec![],
|
||||
generated_text,
|
||||
start: ctx.start.unwrap(),
|
||||
queued: ctx.queued,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
Err(err) => Err(GenerationError(err.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
@ -278,11 +245,8 @@ fn ensure_paths_exist<P: AsRef<Path>, PP: AsRef<Path>>(
|
||||
|
||||
unsafe impl Send for TensorRtLlmBackendImpl {}
|
||||
|
||||
pub struct TensorRtLlmBackendV2 {
|
||||
executor_looper: JoinHandle<()>,
|
||||
post_processor_looper: JoinHandle<()>,
|
||||
executor: UnboundedSender<GenerationContext>,
|
||||
}
|
||||
pub struct TensorRtLlmBackendV2(UnboundedSender<GenerationContext>);
|
||||
|
||||
|
||||
impl TensorRtLlmBackendV2 {
|
||||
pub fn new<P: AsRef<Path> + Send, PP: AsRef<Path> + Send>(
|
||||
@ -296,32 +260,22 @@ impl TensorRtLlmBackendV2 {
|
||||
|
||||
// Allocate the IPC layer to communicate with the backend
|
||||
let (executor_sender, executor_receiver) = unbounded_channel();
|
||||
let (post_processor_sender, post_processor_receiver) = unbounded_channel();
|
||||
|
||||
// Create the FFI backend
|
||||
let backend = create_backend_from_engine_folder(&engine_folder, &executor_worker_path)
|
||||
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
|
||||
|
||||
// Executor looper is responsible for scheduling and pulling requests state at regular interval
|
||||
let executor_looper = spawn_blocking(move || {
|
||||
spawn_blocking(move || {
|
||||
executor_status_looper(
|
||||
backend,
|
||||
max_inflight_requests,
|
||||
tokenizer,
|
||||
backend,
|
||||
executor_receiver,
|
||||
post_processor_sender,
|
||||
)
|
||||
});
|
||||
|
||||
// Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user
|
||||
let post_processor_looper = spawn_blocking(move || {
|
||||
post_processor_looper::<256>(tokenizer, max_inflight_requests, post_processor_receiver)
|
||||
});
|
||||
|
||||
Ok(TensorRtLlmBackendV2 {
|
||||
executor_looper,
|
||||
post_processor_looper,
|
||||
executor: executor_sender,
|
||||
})
|
||||
Ok(TensorRtLlmBackendV2(executor_sender))
|
||||
}
|
||||
|
||||
fn validate(request: &ValidGenerateRequest) -> InferResult<()> {
|
||||
@ -355,20 +309,21 @@ impl TensorRtLlmBackendV2 {
|
||||
impl Backend for TensorRtLlmBackendV2 {
|
||||
fn schedule(
|
||||
&self,
|
||||
inner: ValidGenerateRequest,
|
||||
request: ValidGenerateRequest,
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||
Self::validate(&inner)?;
|
||||
Self::validate(&request)?;
|
||||
|
||||
// Open-up the stream to send tokens
|
||||
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
|
||||
|
||||
// Send the context to the executor for scheduling
|
||||
let queued = Instant::now();
|
||||
match self.executor.send(GenerationContext {
|
||||
request: inner,
|
||||
match self.send(GenerationContext {
|
||||
request,
|
||||
streamer,
|
||||
tokens: Vec::with_capacity(256),
|
||||
start: None,
|
||||
queued,
|
||||
streamer,
|
||||
}) {
|
||||
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
|
||||
Err(_) => Err(GenerationError(
|
||||
@ -378,6 +333,6 @@ impl Backend for TensorRtLlmBackendV2 {
|
||||
}
|
||||
|
||||
async fn health(&self, _: bool) -> bool {
|
||||
!self.executor_looper.is_finished() & !self.post_processor_looper.is_finished()
|
||||
true
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user