diff --git a/router/src/infer.rs b/router/src/infer.rs index 9618264d..1eb37e6a 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -3,14 +3,16 @@ use crate::validation::{Validation, ValidationError}; use crate::{Entry, Queue, Token}; use crate::{GenerateRequest, PrefillToken}; use flume::r#async::RecvStream; +use flume::SendError; use futures::future::try_join_all; use futures::stream::StreamExt; use nohash_hasher::IntMap; use std::sync::Arc; -use flume::SendError; -use text_generation_client::{Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient}; +use text_generation_client::{ + Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, +}; use thiserror::Error; -use tokio::sync::{Notify, Semaphore, TryAcquireError}; +use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::time::Instant; use tracing::{info_span, instrument, Instrument, Span}; @@ -72,9 +74,14 @@ impl Infer { pub(crate) async fn generate_stream( &self, request: GenerateRequest, - ) -> Result>, InferError> { + ) -> Result< + ( + OwnedSemaphorePermit, + RecvStream>, + ), + InferError, + > { // Limit concurrent requests by acquiring a permit from the semaphore - // This permit will live as long as Entry let permit = self .clone() .limit_concurrent_requests @@ -103,7 +110,6 @@ impl Infer { temp_span: None, queue_time: Instant::now(), batch_time: None, - _permit: permit, }); // Notify the background task that we have a new entry in the queue that needs @@ -111,7 +117,7 @@ impl Infer { self.shared.batching_task.notify_one(); // Return stream - Ok(response_rx.into_stream()) + Ok((permit, response_rx.into_stream())) } /// Add a new request to the queue and return a InferResponse @@ -120,8 +126,8 @@ impl Infer { &self, request: GenerateRequest, ) -> Result { - // Create stream - let mut stream = self.generate_stream(request).await?; + // Create stream and keep semaphore permit as long as generate lives + let (_permit, mut stream) = self.generate_stream(request).await?; // Return values let mut result_prefill = Vec::new(); @@ -275,12 +281,10 @@ async fn batching_task( .next_batch(min_size, max_batch_size - batch_size as usize) .await { - let new_batch_size = new_batch.size; entries.iter_mut().for_each(|(_, entry)| { // Create a new span to add the info that this entry is waiting // because a new batch is being computed - let entry_waiting_span = - info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size); + let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); // Add relationships span.follows_from(&entry_waiting_span); entry_waiting_span.follows_from(&span); @@ -307,8 +311,7 @@ async fn batching_task( info_span!(parent: None, "batch", batch_size = next_batch_size); entries.iter_mut().for_each(|(_, entry)| { // Create a new span to link the batch back to this entry - let entry_batch_span = - info_span!(parent: &entry.span, "infer", batch_size = next_batch_size); + let entry_batch_span = info_span!(parent: &entry.span, "infer"); // Add relationships next_batch_span.follows_from(&entry_batch_span); entry_batch_span.follows_from(&next_batch_span); @@ -340,17 +343,19 @@ async fn prefill( Ok((generations, next_batch)) => { filter_send_generations(generations, entries); - let next_batch = { - let mut batch = next_batch.expect("next_batch is None. This is a bug."); - - batch.requests = batch.requests.into_iter().filter(|r| { entries.contains_key(&r.id) }).collect(); - let size = batch.requests.len(); - if size == 0 { - let _ = client.clear_cache(Some(batch.id)).await; - return None; + // Filter next batch and remove requests that were stopped + let next_batch = match next_batch { + None => None, + Some(batch) => { + let id = batch.id; + let next_batch = filter_batch(batch, entries); + // Next batch is now empty + // Clear it from the Python shards cache + if next_batch.is_none() { + let _ = client.clear_cache(Some(id)).await; + } + next_batch } - batch.size = size as u32; - Some(batch) }; metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); @@ -381,17 +386,19 @@ async fn decode( Ok((generations, next_batch)) => { filter_send_generations(generations, entries); - let next_batch = { - let mut batch = next_batch.expect("next_batch is None. This is a bug."); - - batch.requests = batch.requests.into_iter().filter(|r| { entries.contains_key(&r.id) }).collect(); - let size = batch.requests.len(); - if size == 0 { - let _ = client.clear_cache(Some(batch.id)).await; - return None; + // Filter next batch and remove requests that were stopped + let next_batch = match next_batch { + None => None, + Some(batch) => { + let id = batch.id; + let next_batch = filter_batch(batch, entries); + // Next batch is now empty + // Clear it from the Python shards cache + if next_batch.is_none() { + let _ = client.clear_cache(Some(id)).await; + } + next_batch } - batch.size = size as u32; - Some(batch) }; metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); @@ -410,22 +417,16 @@ async fn decode( } } -/// Send errors to Infer for all `entries` +/// Filter a `batch` and remove all requests not present in `entries` #[instrument(skip_all)] -fn send_errors(error: ClientError, entries: &mut IntMap) { - entries.drain().for_each(|(_, entry)| { - // Create and enter a span to link this function back to the entry - let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); - let err = InferError::GenerationError(error.to_string()); - metrics::increment_counter!("tgi_request_failure", "err" => "generation"); - tracing::error!("{err}"); - - // unwrap_or is valid here as we don't care if the receiver is gone. - entry - .response_tx - .send(Err(err)) - .unwrap_or(()); - }); +fn filter_batch(mut batch: Batch, entries: &IntMap) -> Option { + batch.requests.retain(|r| entries.contains_key(&r.id)); + let size = batch.requests.len(); + if size == 0 { + return None; + } + batch.size = size as u32; + Some(batch) } /// Send one or multiple `InferStreamResponse` to Infer for all `entries` @@ -442,22 +443,27 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap Result>> { +/// Send responses through the `entry` response channel +fn send_responses( + generation: Generation, + entry: &Entry, +) -> Result>> { let mut stopped = false; if let Some(prefill_tokens) = generation.prefill_tokens { // Send message - entry.response_tx + entry + .response_tx .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; } @@ -473,22 +479,39 @@ fn send_generation(generation: Generation, entry: &Entry) -> Result) { + entries.drain().for_each(|(_, entry)| { + // Create and enter a span to link this function back to the entry + let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); + let err = InferError::GenerationError(error.to_string()); + metrics::increment_counter!("tgi_request_failure", "err" => "generation"); + tracing::error!("{err}"); + + // unwrap_or is valid here as we don't care if the receiver is gone. + entry + .response_tx + .send(Err(err)) + .unwrap_or(()); + }); +} + #[derive(Debug)] pub(crate) enum InferStreamResponse { // Optional first message diff --git a/router/src/queue.rs b/router/src/queue.rs index 11eb7f59..93855827 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -4,7 +4,7 @@ use crate::validation::ValidGenerateRequest; use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::min; use text_generation_client::{Batch, Request}; -use tokio::sync::{oneshot, OwnedSemaphorePermit}; +use tokio::sync::oneshot; use tokio::time::Instant; use tracing::{info_span, instrument, Span}; @@ -23,8 +23,6 @@ pub(crate) struct Entry { pub queue_time: Instant, /// Instant when this entry was added to a batch pub batch_time: Option, - /// Permit - pub _permit: OwnedSemaphorePermit, } /// Request Queue @@ -147,46 +145,53 @@ impl State { } } - let next_batch_size = min(self.entries.len(), max_size); + let max_batch_size = min(self.entries.len(), max_size); // Create span for this batch to add context to inference calls - let next_batch_span = info_span!(parent: None, "batch", batch_size = next_batch_size); + let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); next_batch_span.follows_from(&Span::current()); - let mut batch_requests = Vec::with_capacity(next_batch_size); + let mut batch_requests = Vec::with_capacity(max_batch_size); let mut batch_entries = - IntMap::with_capacity_and_hasher(next_batch_size, BuildNoHashHasher::default()); + IntMap::with_capacity_and_hasher(max_batch_size, BuildNoHashHasher::default()); // Drain next_batch_size entries - self.entries - .drain(..next_batch_size) - .for_each(|(id, mut entry)| { - // Create a new span to link the batch back to this entry - let entry_batch_span = - info_span!(parent: &entry.span, "infer", batch_size = next_batch_size); - // Add relationships - next_batch_span.follows_from(&entry_batch_span); - entry_batch_span.follows_from(&next_batch_span); - // Update entry - entry.temp_span = Some(entry_batch_span); + for (id, mut entry) in self.entries.drain(..max_batch_size) { + // Filter entries where the response receiver was dropped (== entries where the request + // was dropped by the client) + if entry.response_tx.is_disconnected() { + continue; + } - batch_requests.push(Request { - id, - inputs: entry.request.inputs.clone(), - truncate: entry.request.truncate, - parameters: Some(entry.request.parameters.clone()), - stopping_parameters: Some(entry.request.stopping_parameters.clone()), - }); - // Set batch_time - entry.batch_time = Some(Instant::now()); - // Insert in batch_entries IntMap - batch_entries.insert(id, entry); + // Create a new span to link the batch back to this entry + let entry_batch_span = info_span!(parent: &entry.span, "infer"); + // Add relationships + next_batch_span.follows_from(&entry_batch_span); + entry_batch_span.follows_from(&next_batch_span); + // Update entry + entry.temp_span = Some(entry_batch_span); + + batch_requests.push(Request { + id, + inputs: entry.request.inputs.clone(), + truncate: entry.request.truncate, + parameters: Some(entry.request.parameters.clone()), + stopping_parameters: Some(entry.request.stopping_parameters.clone()), }); + // Set batch_time + entry.batch_time = Some(Instant::now()); + // Insert in batch_entries IntMap + batch_entries.insert(id, entry); + } + + // Final batch size once we dropped entries + let size = batch_requests.len() as u32; + next_batch_span.record("batch_size", size); let batch = Batch { id: self.next_batch_id, requests: batch_requests, - size: next_batch_size as u32, + size, }; // Increment batch id self.next_batch_id += 1; @@ -219,9 +224,7 @@ mod tests { use tracing::info_span; fn default_entry() -> Entry { - let semaphore = Arc::new(Semaphore::new(1)); let (response_tx, _) = flume::unbounded(); - let permit = semaphore.try_acquire_owned().unwrap(); Entry { request: ValidGenerateRequest { @@ -248,7 +251,6 @@ mod tests { temp_span: None, queue_time: Instant::now(), batch_time: None, - _permit: permit, } } diff --git a/router/src/server.rs b/router/src/server.rs index ce301399..fee748e6 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -367,7 +367,8 @@ async fn generate_stream( let best_of = req.0.parameters.best_of.unwrap_or(1); if best_of == 1 { match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await { - Ok(mut response_stream) => { + // Keep permit as long as generate_stream lives + Ok((_permit, mut response_stream)) => { // Server-Sent Event stream while let Some(response) = response_stream.next().await { match response {