This commit is contained in:
OlivierDehaene 2023-04-16 18:52:47 +02:00
parent 4e63d9cb28
commit 9476170dda
3 changed files with 125 additions and 99 deletions

View File

@ -3,14 +3,16 @@ use crate::validation::{Validation, ValidationError};
use crate::{Entry, Queue, Token}; use crate::{Entry, Queue, Token};
use crate::{GenerateRequest, PrefillToken}; use crate::{GenerateRequest, PrefillToken};
use flume::r#async::RecvStream; use flume::r#async::RecvStream;
use flume::SendError;
use futures::future::try_join_all; use futures::future::try_join_all;
use futures::stream::StreamExt; use futures::stream::StreamExt;
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use std::sync::Arc; use std::sync::Arc;
use flume::SendError; use text_generation_client::{
use text_generation_client::{Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient}; Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
};
use thiserror::Error; use thiserror::Error;
use tokio::sync::{Notify, Semaphore, TryAcquireError}; use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::time::Instant; use tokio::time::Instant;
use tracing::{info_span, instrument, Instrument, Span}; use tracing::{info_span, instrument, Instrument, Span};
@ -72,9 +74,14 @@ impl Infer {
pub(crate) async fn generate_stream( pub(crate) async fn generate_stream(
&self, &self,
request: GenerateRequest, request: GenerateRequest,
) -> Result<RecvStream<Result<InferStreamResponse, InferError>>, InferError> { ) -> Result<
(
OwnedSemaphorePermit,
RecvStream<Result<InferStreamResponse, InferError>>,
),
InferError,
> {
// Limit concurrent requests by acquiring a permit from the semaphore // Limit concurrent requests by acquiring a permit from the semaphore
// This permit will live as long as Entry
let permit = self let permit = self
.clone() .clone()
.limit_concurrent_requests .limit_concurrent_requests
@ -103,7 +110,6 @@ impl Infer {
temp_span: None, temp_span: None,
queue_time: Instant::now(), queue_time: Instant::now(),
batch_time: None, batch_time: None,
_permit: permit,
}); });
// Notify the background task that we have a new entry in the queue that needs // Notify the background task that we have a new entry in the queue that needs
@ -111,7 +117,7 @@ impl Infer {
self.shared.batching_task.notify_one(); self.shared.batching_task.notify_one();
// Return stream // Return stream
Ok(response_rx.into_stream()) Ok((permit, response_rx.into_stream()))
} }
/// Add a new request to the queue and return a InferResponse /// Add a new request to the queue and return a InferResponse
@ -120,8 +126,8 @@ impl Infer {
&self, &self,
request: GenerateRequest, request: GenerateRequest,
) -> Result<InferResponse, InferError> { ) -> Result<InferResponse, InferError> {
// Create stream // Create stream and keep semaphore permit as long as generate lives
let mut stream = self.generate_stream(request).await?; let (_permit, mut stream) = self.generate_stream(request).await?;
// Return values // Return values
let mut result_prefill = Vec::new(); let mut result_prefill = Vec::new();
@ -275,12 +281,10 @@ async fn batching_task(
.next_batch(min_size, max_batch_size - batch_size as usize) .next_batch(min_size, max_batch_size - batch_size as usize)
.await .await
{ {
let new_batch_size = new_batch.size;
entries.iter_mut().for_each(|(_, entry)| { entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting // Create a new span to add the info that this entry is waiting
// because a new batch is being computed // because a new batch is being computed
let entry_waiting_span = let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size);
// Add relationships // Add relationships
span.follows_from(&entry_waiting_span); span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span); entry_waiting_span.follows_from(&span);
@ -307,8 +311,7 @@ async fn batching_task(
info_span!(parent: None, "batch", batch_size = next_batch_size); info_span!(parent: None, "batch", batch_size = next_batch_size);
entries.iter_mut().for_each(|(_, entry)| { entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to link the batch back to this entry // Create a new span to link the batch back to this entry
let entry_batch_span = let entry_batch_span = info_span!(parent: &entry.span, "infer");
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
// Add relationships // Add relationships
next_batch_span.follows_from(&entry_batch_span); next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span); entry_batch_span.follows_from(&next_batch_span);
@ -340,17 +343,19 @@ async fn prefill(
Ok((generations, next_batch)) => { Ok((generations, next_batch)) => {
filter_send_generations(generations, entries); filter_send_generations(generations, entries);
let next_batch = { // Filter next batch and remove requests that were stopped
let mut batch = next_batch.expect("next_batch is None. This is a bug."); let next_batch = match next_batch {
None => None,
batch.requests = batch.requests.into_iter().filter(|r| { entries.contains_key(&r.id) }).collect(); Some(batch) => {
let size = batch.requests.len(); let id = batch.id;
if size == 0 { let next_batch = filter_batch(batch, entries);
let _ = client.clear_cache(Some(batch.id)).await; // Next batch is now empty
return None; // Clear it from the Python shards cache
if next_batch.is_none() {
let _ = client.clear_cache(Some(id)).await;
}
next_batch
} }
batch.size = size as u32;
Some(batch)
}; };
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
@ -381,17 +386,19 @@ async fn decode(
Ok((generations, next_batch)) => { Ok((generations, next_batch)) => {
filter_send_generations(generations, entries); filter_send_generations(generations, entries);
let next_batch = { // Filter next batch and remove requests that were stopped
let mut batch = next_batch.expect("next_batch is None. This is a bug."); let next_batch = match next_batch {
None => None,
batch.requests = batch.requests.into_iter().filter(|r| { entries.contains_key(&r.id) }).collect(); Some(batch) => {
let size = batch.requests.len(); let id = batch.id;
if size == 0 { let next_batch = filter_batch(batch, entries);
let _ = client.clear_cache(Some(batch.id)).await; // Next batch is now empty
return None; // Clear it from the Python shards cache
if next_batch.is_none() {
let _ = client.clear_cache(Some(id)).await;
}
next_batch
} }
batch.size = size as u32;
Some(batch)
}; };
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
@ -410,22 +417,16 @@ async fn decode(
} }
} }
/// Send errors to Infer for all `entries` /// Filter a `batch` and remove all requests not present in `entries`
#[instrument(skip_all)] #[instrument(skip_all)]
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) { fn filter_batch(mut batch: Batch, entries: &IntMap<u64, Entry>) -> Option<Batch> {
entries.drain().for_each(|(_, entry)| { batch.requests.retain(|r| entries.contains_key(&r.id));
// Create and enter a span to link this function back to the entry let size = batch.requests.len();
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); if size == 0 {
let err = InferError::GenerationError(error.to_string()); return None;
metrics::increment_counter!("tgi_request_failure", "err" => "generation"); }
tracing::error!("{err}"); batch.size = size as u32;
Some(batch)
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Err(err))
.unwrap_or(());
});
} }
/// Send one or multiple `InferStreamResponse` to Infer for all `entries` /// Send one or multiple `InferStreamResponse` to Infer for all `entries`
@ -442,22 +443,27 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
// Create and enter a span to link this function back to the entry // Create and enter a span to link this function back to the entry
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered(); let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
// Send generation back to infer task // Send generation responses back to the infer task
// If the receive an error from the Flume channel, we need to stop generating for this // If the receive an error from the Flume channel, it means that the client dropped the
// request hence why we unwrap_or(true) // request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_generation(generation, entry).unwrap_or(true); let stopped = send_responses(generation, entry).unwrap_or(true);
if stopped { if stopped {
entries.remove(&id).expect("ID not found in entries. This is a bug."); entries.remove(&id).expect("ID not found in entries. This is a bug.");
} }
}); });
} }
fn send_generation(generation: Generation, entry: &Entry) -> Result<bool, SendError<Result<InferStreamResponse, InferError>>> { /// Send responses through the `entry` response channel
fn send_responses(
generation: Generation,
entry: &Entry,
) -> Result<bool, SendError<Result<InferStreamResponse, InferError>>> {
let mut stopped = false; let mut stopped = false;
if let Some(prefill_tokens) = generation.prefill_tokens { if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message // Send message
entry.response_tx entry
.response_tx
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
} }
@ -473,8 +479,7 @@ fn send_generation(generation: Generation, entry: &Entry) -> Result<bool, SendEr
// Generation has ended // Generation has ended
stopped = true; stopped = true;
// Send message // Send message
entry.response_tx entry.response_tx.send(Ok(InferStreamResponse::End {
.send(Ok(InferStreamResponse::End {
token, token,
generated_text, generated_text,
queued: entry.queue_time, queued: entry.queue_time,
@ -482,13 +487,31 @@ fn send_generation(generation: Generation, entry: &Entry) -> Result<bool, SendEr
}))?; }))?;
} else { } else {
// Send message // Send message
entry.response_tx entry
.send(Ok(InferStreamResponse::Token(token))) .response_tx
?; .send(Ok(InferStreamResponse::Token(token)))?;
} }
Ok(stopped) Ok(stopped)
} }
/// Send errors to Infer for all `entries`
#[instrument(skip_all)]
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
entries.drain().for_each(|(_, entry)| {
// Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string());
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Err(err))
.unwrap_or(());
});
}
#[derive(Debug)] #[derive(Debug)]
pub(crate) enum InferStreamResponse { pub(crate) enum InferStreamResponse {
// Optional first message // Optional first message

View File

@ -4,7 +4,7 @@ use crate::validation::ValidGenerateRequest;
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min; use std::cmp::min;
use text_generation_client::{Batch, Request}; use text_generation_client::{Batch, Request};
use tokio::sync::{oneshot, OwnedSemaphorePermit}; use tokio::sync::oneshot;
use tokio::time::Instant; use tokio::time::Instant;
use tracing::{info_span, instrument, Span}; use tracing::{info_span, instrument, Span};
@ -23,8 +23,6 @@ pub(crate) struct Entry {
pub queue_time: Instant, pub queue_time: Instant,
/// Instant when this entry was added to a batch /// Instant when this entry was added to a batch
pub batch_time: Option<Instant>, pub batch_time: Option<Instant>,
/// Permit
pub _permit: OwnedSemaphorePermit,
} }
/// Request Queue /// Request Queue
@ -147,23 +145,26 @@ impl State {
} }
} }
let next_batch_size = min(self.entries.len(), max_size); let max_batch_size = min(self.entries.len(), max_size);
// Create span for this batch to add context to inference calls // Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = next_batch_size); let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
next_batch_span.follows_from(&Span::current()); next_batch_span.follows_from(&Span::current());
let mut batch_requests = Vec::with_capacity(next_batch_size); let mut batch_requests = Vec::with_capacity(max_batch_size);
let mut batch_entries = let mut batch_entries =
IntMap::with_capacity_and_hasher(next_batch_size, BuildNoHashHasher::default()); IntMap::with_capacity_and_hasher(max_batch_size, BuildNoHashHasher::default());
// Drain next_batch_size entries // Drain next_batch_size entries
self.entries for (id, mut entry) in self.entries.drain(..max_batch_size) {
.drain(..next_batch_size) // Filter entries where the response receiver was dropped (== entries where the request
.for_each(|(id, mut entry)| { // was dropped by the client)
if entry.response_tx.is_disconnected() {
continue;
}
// Create a new span to link the batch back to this entry // Create a new span to link the batch back to this entry
let entry_batch_span = let entry_batch_span = info_span!(parent: &entry.span, "infer");
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
// Add relationships // Add relationships
next_batch_span.follows_from(&entry_batch_span); next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span); entry_batch_span.follows_from(&next_batch_span);
@ -181,12 +182,16 @@ impl State {
entry.batch_time = Some(Instant::now()); entry.batch_time = Some(Instant::now());
// Insert in batch_entries IntMap // Insert in batch_entries IntMap
batch_entries.insert(id, entry); batch_entries.insert(id, entry);
}); }
// Final batch size once we dropped entries
let size = batch_requests.len() as u32;
next_batch_span.record("batch_size", size);
let batch = Batch { let batch = Batch {
id: self.next_batch_id, id: self.next_batch_id,
requests: batch_requests, requests: batch_requests,
size: next_batch_size as u32, size,
}; };
// Increment batch id // Increment batch id
self.next_batch_id += 1; self.next_batch_id += 1;
@ -219,9 +224,7 @@ mod tests {
use tracing::info_span; use tracing::info_span;
fn default_entry() -> Entry { fn default_entry() -> Entry {
let semaphore = Arc::new(Semaphore::new(1));
let (response_tx, _) = flume::unbounded(); let (response_tx, _) = flume::unbounded();
let permit = semaphore.try_acquire_owned().unwrap();
Entry { Entry {
request: ValidGenerateRequest { request: ValidGenerateRequest {
@ -248,7 +251,6 @@ mod tests {
temp_span: None, temp_span: None,
queue_time: Instant::now(), queue_time: Instant::now(),
batch_time: None, batch_time: None,
_permit: permit,
} }
} }

View File

@ -367,7 +367,8 @@ async fn generate_stream(
let best_of = req.0.parameters.best_of.unwrap_or(1); let best_of = req.0.parameters.best_of.unwrap_or(1);
if best_of == 1 { if best_of == 1 {
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await { match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
Ok(mut response_stream) => { // Keep permit as long as generate_stream lives
Ok((_permit, mut response_stream)) => {
// Server-Sent Event stream // Server-Sent Event stream
while let Some(response) = response_stream.next().await { while let Some(response) = response_stream.next().await {
match response { match response {