This commit is contained in:
OlivierDehaene 2023-04-16 18:52:47 +02:00
parent 4e63d9cb28
commit 9476170dda
3 changed files with 125 additions and 99 deletions

View File

@ -3,14 +3,16 @@ use crate::validation::{Validation, ValidationError};
use crate::{Entry, Queue, Token};
use crate::{GenerateRequest, PrefillToken};
use flume::r#async::RecvStream;
use flume::SendError;
use futures::future::try_join_all;
use futures::stream::StreamExt;
use nohash_hasher::IntMap;
use std::sync::Arc;
use flume::SendError;
use text_generation_client::{Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient};
use text_generation_client::{
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
};
use thiserror::Error;
use tokio::sync::{Notify, Semaphore, TryAcquireError};
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::time::Instant;
use tracing::{info_span, instrument, Instrument, Span};
@ -72,9 +74,14 @@ impl Infer {
pub(crate) async fn generate_stream(
&self,
request: GenerateRequest,
) -> Result<RecvStream<Result<InferStreamResponse, InferError>>, InferError> {
) -> Result<
(
OwnedSemaphorePermit,
RecvStream<Result<InferStreamResponse, InferError>>,
),
InferError,
> {
// Limit concurrent requests by acquiring a permit from the semaphore
// This permit will live as long as Entry
let permit = self
.clone()
.limit_concurrent_requests
@ -103,7 +110,6 @@ impl Infer {
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
_permit: permit,
});
// Notify the background task that we have a new entry in the queue that needs
@ -111,7 +117,7 @@ impl Infer {
self.shared.batching_task.notify_one();
// Return stream
Ok(response_rx.into_stream())
Ok((permit, response_rx.into_stream()))
}
/// Add a new request to the queue and return a InferResponse
@ -120,8 +126,8 @@ impl Infer {
&self,
request: GenerateRequest,
) -> Result<InferResponse, InferError> {
// Create stream
let mut stream = self.generate_stream(request).await?;
// Create stream and keep semaphore permit as long as generate lives
let (_permit, mut stream) = self.generate_stream(request).await?;
// Return values
let mut result_prefill = Vec::new();
@ -275,12 +281,10 @@ async fn batching_task(
.next_batch(min_size, max_batch_size - batch_size as usize)
.await
{
let new_batch_size = new_batch.size;
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let entry_waiting_span =
info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size);
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
// Add relationships
span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span);
@ -307,8 +311,7 @@ async fn batching_task(
info_span!(parent: None, "batch", batch_size = next_batch_size);
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to link the batch back to this entry
let entry_batch_span =
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
let entry_batch_span = info_span!(parent: &entry.span, "infer");
// Add relationships
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span);
@ -340,17 +343,19 @@ async fn prefill(
Ok((generations, next_batch)) => {
filter_send_generations(generations, entries);
let next_batch = {
let mut batch = next_batch.expect("next_batch is None. This is a bug.");
batch.requests = batch.requests.into_iter().filter(|r| { entries.contains_key(&r.id) }).collect();
let size = batch.requests.len();
if size == 0 {
let _ = client.clear_cache(Some(batch.id)).await;
return None;
// Filter next batch and remove requests that were stopped
let next_batch = match next_batch {
None => None,
Some(batch) => {
let id = batch.id;
let next_batch = filter_batch(batch, entries);
// Next batch is now empty
// Clear it from the Python shards cache
if next_batch.is_none() {
let _ = client.clear_cache(Some(id)).await;
}
next_batch
}
batch.size = size as u32;
Some(batch)
};
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
@ -381,17 +386,19 @@ async fn decode(
Ok((generations, next_batch)) => {
filter_send_generations(generations, entries);
let next_batch = {
let mut batch = next_batch.expect("next_batch is None. This is a bug.");
batch.requests = batch.requests.into_iter().filter(|r| { entries.contains_key(&r.id) }).collect();
let size = batch.requests.len();
if size == 0 {
let _ = client.clear_cache(Some(batch.id)).await;
return None;
// Filter next batch and remove requests that were stopped
let next_batch = match next_batch {
None => None,
Some(batch) => {
let id = batch.id;
let next_batch = filter_batch(batch, entries);
// Next batch is now empty
// Clear it from the Python shards cache
if next_batch.is_none() {
let _ = client.clear_cache(Some(id)).await;
}
next_batch
}
batch.size = size as u32;
Some(batch)
};
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
@ -410,22 +417,16 @@ async fn decode(
}
}
/// Send errors to Infer for all `entries`
/// Filter a `batch` and remove all requests not present in `entries`
#[instrument(skip_all)]
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
entries.drain().for_each(|(_, entry)| {
// Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string());
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Err(err))
.unwrap_or(());
});
fn filter_batch(mut batch: Batch, entries: &IntMap<u64, Entry>) -> Option<Batch> {
batch.requests.retain(|r| entries.contains_key(&r.id));
let size = batch.requests.len();
if size == 0 {
return None;
}
batch.size = size as u32;
Some(batch)
}
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
@ -442,22 +443,27 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
// Create and enter a span to link this function back to the entry
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
// Send generation back to infer task
// If the receive an error from the Flume channel, we need to stop generating for this
// request hence why we unwrap_or(true)
let stopped = send_generation(generation, entry).unwrap_or(true);
// Send generation responses back to the infer task
// If the receive an error from the Flume channel, it means that the client dropped the
// request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).unwrap_or(true);
if stopped {
entries.remove(&id).expect("ID not found in entries. This is a bug.");
}
});
}
fn send_generation(generation: Generation, entry: &Entry) -> Result<bool, SendError<Result<InferStreamResponse, InferError>>> {
/// Send responses through the `entry` response channel
fn send_responses(
generation: Generation,
entry: &Entry,
) -> Result<bool, SendError<Result<InferStreamResponse, InferError>>> {
let mut stopped = false;
if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message
entry.response_tx
entry
.response_tx
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
}
@ -473,22 +479,39 @@ fn send_generation(generation: Generation, entry: &Entry) -> Result<bool, SendEr
// Generation has ended
stopped = true;
// Send message
entry.response_tx
.send(Ok(InferStreamResponse::End {
token,
generated_text,
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
entry.response_tx.send(Ok(InferStreamResponse::End {
token,
generated_text,
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
} else {
// Send message
entry.response_tx
.send(Ok(InferStreamResponse::Token(token)))
?;
entry
.response_tx
.send(Ok(InferStreamResponse::Token(token)))?;
}
Ok(stopped)
}
/// Send errors to Infer for all `entries`
#[instrument(skip_all)]
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
entries.drain().for_each(|(_, entry)| {
// Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string());
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Err(err))
.unwrap_or(());
});
}
#[derive(Debug)]
pub(crate) enum InferStreamResponse {
// Optional first message

View File

@ -4,7 +4,7 @@ use crate::validation::ValidGenerateRequest;
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min;
use text_generation_client::{Batch, Request};
use tokio::sync::{oneshot, OwnedSemaphorePermit};
use tokio::sync::oneshot;
use tokio::time::Instant;
use tracing::{info_span, instrument, Span};
@ -23,8 +23,6 @@ pub(crate) struct Entry {
pub queue_time: Instant,
/// Instant when this entry was added to a batch
pub batch_time: Option<Instant>,
/// Permit
pub _permit: OwnedSemaphorePermit,
}
/// Request Queue
@ -147,46 +145,53 @@ impl State {
}
}
let next_batch_size = min(self.entries.len(), max_size);
let max_batch_size = min(self.entries.len(), max_size);
// Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = next_batch_size);
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
next_batch_span.follows_from(&Span::current());
let mut batch_requests = Vec::with_capacity(next_batch_size);
let mut batch_requests = Vec::with_capacity(max_batch_size);
let mut batch_entries =
IntMap::with_capacity_and_hasher(next_batch_size, BuildNoHashHasher::default());
IntMap::with_capacity_and_hasher(max_batch_size, BuildNoHashHasher::default());
// Drain next_batch_size entries
self.entries
.drain(..next_batch_size)
.for_each(|(id, mut entry)| {
// Create a new span to link the batch back to this entry
let entry_batch_span =
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
// Add relationships
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.temp_span = Some(entry_batch_span);
for (id, mut entry) in self.entries.drain(..max_batch_size) {
// Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client)
if entry.response_tx.is_disconnected() {
continue;
}
batch_requests.push(Request {
id,
inputs: entry.request.inputs.clone(),
truncate: entry.request.truncate,
parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
});
// Set batch_time
entry.batch_time = Some(Instant::now());
// Insert in batch_entries IntMap
batch_entries.insert(id, entry);
// Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer");
// Add relationships
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.temp_span = Some(entry_batch_span);
batch_requests.push(Request {
id,
inputs: entry.request.inputs.clone(),
truncate: entry.request.truncate,
parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
});
// Set batch_time
entry.batch_time = Some(Instant::now());
// Insert in batch_entries IntMap
batch_entries.insert(id, entry);
}
// Final batch size once we dropped entries
let size = batch_requests.len() as u32;
next_batch_span.record("batch_size", size);
let batch = Batch {
id: self.next_batch_id,
requests: batch_requests,
size: next_batch_size as u32,
size,
};
// Increment batch id
self.next_batch_id += 1;
@ -219,9 +224,7 @@ mod tests {
use tracing::info_span;
fn default_entry() -> Entry {
let semaphore = Arc::new(Semaphore::new(1));
let (response_tx, _) = flume::unbounded();
let permit = semaphore.try_acquire_owned().unwrap();
Entry {
request: ValidGenerateRequest {
@ -248,7 +251,6 @@ mod tests {
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
_permit: permit,
}
}

View File

@ -367,7 +367,8 @@ async fn generate_stream(
let best_of = req.0.parameters.best_of.unwrap_or(1);
if best_of == 1 {
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
Ok(mut response_stream) => {
// Keep permit as long as generate_stream lives
Ok((_permit, mut response_stream)) => {
// Server-Sent Event stream
while let Some(response) = response_stream.next().await {
match response {