feat(router): drop permit after batching

This commit is contained in:
OlivierDehaene 2024-07-23 14:45:30 +02:00
parent e7e3aa6cac
commit 344427b6ab
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
6 changed files with 18 additions and 15 deletions

View File

@ -155,7 +155,7 @@ impl Infer {
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
// Create stream and keep semaphore permit as long as generate lives // Create stream and keep semaphore permit as long as generate lives
let (_permit, _input_length, mut stream) = self.generate_stream(request).await?; let (_input_length, mut stream) = self.generate_stream(request).await?;
// Return values // Return values
let mut result_prefill = Vec::new(); let mut result_prefill = Vec::new();
@ -462,7 +462,6 @@ impl ToolGrammar {
/// Type alias for generation responses /// Type alias for generation responses
pub(crate) type GenerateStreamResponse = ( pub(crate) type GenerateStreamResponse = (
OwnedSemaphorePermit,
u32, // input_length u32, // input_length
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
); );

View File

@ -9,7 +9,7 @@ use text_generation_client::v2::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
}; };
use text_generation_client::ChunksToString; use text_generation_client::ChunksToString;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit};
use tokio::time::Instant; use tokio::time::Instant;
use tracing::{info_span, instrument, Span}; use tracing::{info_span, instrument, Span};
@ -18,6 +18,8 @@ use tracing::{info_span, instrument, Span};
pub(crate) struct Entry { pub(crate) struct Entry {
/// Request /// Request
pub request: ValidGenerateRequest, pub request: ValidGenerateRequest,
/// Permit
pub permit: Option<OwnedSemaphorePermit>,
/// Response sender to communicate between the Infer struct and the batching_task /// Response sender to communicate between the Infer struct and the batching_task
pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>, pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,
/// Span that will live as long as entry /// Span that will live as long as entry
@ -269,6 +271,9 @@ impl State {
break; break;
} }
// Drop permit
entry.permit = None;
tracing::debug!("Accepting entry"); tracing::debug!("Accepting entry");
// Create a new span to link the batch back to this entry // Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer"); let entry_batch_span = info_span!(parent: &entry.span, "infer");

View File

@ -84,6 +84,7 @@ impl Scheduler for SchedulerV2 {
self.queue.append(Entry { self.queue.append(Entry {
request, request,
response_tx, response_tx,
permit: Some(permit),
span: Span::current(), span: Span::current(),
temp_span: None, temp_span: None,
queue_time: Instant::now(), queue_time: Instant::now(),
@ -95,11 +96,7 @@ impl Scheduler for SchedulerV2 {
self.batching_task_notifier.notify_one(); self.batching_task_notifier.notify_one();
// Return stream // Return stream
Ok(( Ok((input_length, UnboundedReceiverStream::new(response_rx)))
permit,
input_length,
UnboundedReceiverStream::new(response_rx),
))
} }
} }

View File

@ -12,7 +12,7 @@ use text_generation_client::v3::{
}; };
use text_generation_client::ChunksToString; use text_generation_client::ChunksToString;
use text_generation_client::Input; use text_generation_client::Input;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit};
use tokio::time::Instant; use tokio::time::Instant;
use tracing::{info_span, instrument, Instrument, Span}; use tracing::{info_span, instrument, Instrument, Span};
@ -21,6 +21,8 @@ use tracing::{info_span, instrument, Instrument, Span};
pub(crate) struct Entry { pub(crate) struct Entry {
/// Request /// Request
pub request: ValidGenerateRequest, pub request: ValidGenerateRequest,
/// Permit
pub permit: Option<OwnedSemaphorePermit>,
/// Response sender to communicate between the Infer struct and the batching_task /// Response sender to communicate between the Infer struct and the batching_task
pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>, pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,
/// Span that will live as long as entry /// Span that will live as long as entry
@ -315,6 +317,9 @@ impl State {
} }
}; };
// Drop permit
entry.permit = None;
tracing::debug!("Accepting entry"); tracing::debug!("Accepting entry");
// Create a new span to link the batch back to this entry // Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer"); let entry_batch_span = info_span!(parent: &entry.span, "infer");

View File

@ -89,6 +89,7 @@ impl Scheduler for SchedulerV3 {
self.queue.append(Entry { self.queue.append(Entry {
request, request,
response_tx, response_tx,
permit: Some(permit),
span: Span::current(), span: Span::current(),
temp_span: None, temp_span: None,
queue_time: Instant::now(), queue_time: Instant::now(),
@ -101,11 +102,7 @@ impl Scheduler for SchedulerV3 {
self.batching_task_notifier.notify_one(); self.batching_task_notifier.notify_one();
// Return stream // Return stream
Ok(( Ok((input_length, UnboundedReceiverStream::new(response_rx)))
permit,
input_length,
UnboundedReceiverStream::new(response_rx),
))
} }
} }

View File

@ -429,7 +429,7 @@ async fn generate_stream_internal(
} else { } else {
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives // Keep permit as long as generate_stream lives
Ok((_permit, _input_length, mut response_stream)) => { Ok((_input_length, mut response_stream)) => {
let mut index = 0; let mut index = 0;
// Server-Sent Event stream // Server-Sent Event stream
while let Some(response) = response_stream.next().await { while let Some(response) = response_stream.next().await {