mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
wip
This commit is contained in:
parent
4e63d9cb28
commit
9476170dda
@ -3,14 +3,16 @@ use crate::validation::{Validation, ValidationError};
|
||||
use crate::{Entry, Queue, Token};
|
||||
use crate::{GenerateRequest, PrefillToken};
|
||||
use flume::r#async::RecvStream;
|
||||
use flume::SendError;
|
||||
use futures::future::try_join_all;
|
||||
use futures::stream::StreamExt;
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::Arc;
|
||||
use flume::SendError;
|
||||
use text_generation_client::{Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient};
|
||||
use text_generation_client::{
|
||||
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{Notify, Semaphore, TryAcquireError};
|
||||
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
|
||||
@ -72,9 +74,14 @@ impl Infer {
|
||||
pub(crate) async fn generate_stream(
|
||||
&self,
|
||||
request: GenerateRequest,
|
||||
) -> Result<RecvStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||
) -> Result<
|
||||
(
|
||||
OwnedSemaphorePermit,
|
||||
RecvStream<Result<InferStreamResponse, InferError>>,
|
||||
),
|
||||
InferError,
|
||||
> {
|
||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||
// This permit will live as long as Entry
|
||||
let permit = self
|
||||
.clone()
|
||||
.limit_concurrent_requests
|
||||
@ -103,7 +110,6 @@ impl Infer {
|
||||
temp_span: None,
|
||||
queue_time: Instant::now(),
|
||||
batch_time: None,
|
||||
_permit: permit,
|
||||
});
|
||||
|
||||
// Notify the background task that we have a new entry in the queue that needs
|
||||
@ -111,7 +117,7 @@ impl Infer {
|
||||
self.shared.batching_task.notify_one();
|
||||
|
||||
// Return stream
|
||||
Ok(response_rx.into_stream())
|
||||
Ok((permit, response_rx.into_stream()))
|
||||
}
|
||||
|
||||
/// Add a new request to the queue and return a InferResponse
|
||||
@ -120,8 +126,8 @@ impl Infer {
|
||||
&self,
|
||||
request: GenerateRequest,
|
||||
) -> Result<InferResponse, InferError> {
|
||||
// Create stream
|
||||
let mut stream = self.generate_stream(request).await?;
|
||||
// Create stream and keep semaphore permit as long as generate lives
|
||||
let (_permit, mut stream) = self.generate_stream(request).await?;
|
||||
|
||||
// Return values
|
||||
let mut result_prefill = Vec::new();
|
||||
@ -275,12 +281,10 @@ async fn batching_task(
|
||||
.next_batch(min_size, max_batch_size - batch_size as usize)
|
||||
.await
|
||||
{
|
||||
let new_batch_size = new_batch.size;
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to add the info that this entry is waiting
|
||||
// because a new batch is being computed
|
||||
let entry_waiting_span =
|
||||
info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size);
|
||||
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
|
||||
// Add relationships
|
||||
span.follows_from(&entry_waiting_span);
|
||||
entry_waiting_span.follows_from(&span);
|
||||
@ -307,8 +311,7 @@ async fn batching_task(
|
||||
info_span!(parent: None, "batch", batch_size = next_batch_size);
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span =
|
||||
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
// Add relationships
|
||||
next_batch_span.follows_from(&entry_batch_span);
|
||||
entry_batch_span.follows_from(&next_batch_span);
|
||||
@ -340,17 +343,19 @@ async fn prefill(
|
||||
Ok((generations, next_batch)) => {
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
let next_batch = {
|
||||
let mut batch = next_batch.expect("next_batch is None. This is a bug.");
|
||||
|
||||
batch.requests = batch.requests.into_iter().filter(|r| { entries.contains_key(&r.id) }).collect();
|
||||
let size = batch.requests.len();
|
||||
if size == 0 {
|
||||
let _ = client.clear_cache(Some(batch.id)).await;
|
||||
return None;
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = match next_batch {
|
||||
None => None,
|
||||
Some(batch) => {
|
||||
let id = batch.id;
|
||||
let next_batch = filter_batch(batch, entries);
|
||||
// Next batch is now empty
|
||||
// Clear it from the Python shards cache
|
||||
if next_batch.is_none() {
|
||||
let _ = client.clear_cache(Some(id)).await;
|
||||
}
|
||||
next_batch
|
||||
}
|
||||
batch.size = size as u32;
|
||||
Some(batch)
|
||||
};
|
||||
|
||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
|
||||
@ -381,17 +386,19 @@ async fn decode(
|
||||
Ok((generations, next_batch)) => {
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
let next_batch = {
|
||||
let mut batch = next_batch.expect("next_batch is None. This is a bug.");
|
||||
|
||||
batch.requests = batch.requests.into_iter().filter(|r| { entries.contains_key(&r.id) }).collect();
|
||||
let size = batch.requests.len();
|
||||
if size == 0 {
|
||||
let _ = client.clear_cache(Some(batch.id)).await;
|
||||
return None;
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = match next_batch {
|
||||
None => None,
|
||||
Some(batch) => {
|
||||
let id = batch.id;
|
||||
let next_batch = filter_batch(batch, entries);
|
||||
// Next batch is now empty
|
||||
// Clear it from the Python shards cache
|
||||
if next_batch.is_none() {
|
||||
let _ = client.clear_cache(Some(id)).await;
|
||||
}
|
||||
next_batch
|
||||
}
|
||||
batch.size = size as u32;
|
||||
Some(batch)
|
||||
};
|
||||
|
||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
|
||||
@ -410,22 +417,16 @@ async fn decode(
|
||||
}
|
||||
}
|
||||
|
||||
/// Send errors to Infer for all `entries`
|
||||
/// Filter a `batch` and remove all requests not present in `entries`
|
||||
#[instrument(skip_all)]
|
||||
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||
entries.drain().for_each(|(_, entry)| {
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||
let err = InferError::GenerationError(error.to_string());
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
|
||||
tracing::error!("{err}");
|
||||
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
entry
|
||||
.response_tx
|
||||
.send(Err(err))
|
||||
.unwrap_or(());
|
||||
});
|
||||
fn filter_batch(mut batch: Batch, entries: &IntMap<u64, Entry>) -> Option<Batch> {
|
||||
batch.requests.retain(|r| entries.contains_key(&r.id));
|
||||
let size = batch.requests.len();
|
||||
if size == 0 {
|
||||
return None;
|
||||
}
|
||||
batch.size = size as u32;
|
||||
Some(batch)
|
||||
}
|
||||
|
||||
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||
@ -442,22 +443,27 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
||||
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||
// Send generation back to infer task
|
||||
// If the receive an error from the Flume channel, we need to stop generating for this
|
||||
// request hence why we unwrap_or(true)
|
||||
let stopped = send_generation(generation, entry).unwrap_or(true);
|
||||
// Send generation responses back to the infer task
|
||||
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||
// request and we need to stop generating hence why we unwrap_or(true)
|
||||
let stopped = send_responses(generation, entry).unwrap_or(true);
|
||||
if stopped {
|
||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn send_generation(generation: Generation, entry: &Entry) -> Result<bool, SendError<Result<InferStreamResponse, InferError>>> {
|
||||
/// Send responses through the `entry` response channel
|
||||
fn send_responses(
|
||||
generation: Generation,
|
||||
entry: &Entry,
|
||||
) -> Result<bool, SendError<Result<InferStreamResponse, InferError>>> {
|
||||
let mut stopped = false;
|
||||
|
||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||
// Send message
|
||||
entry.response_tx
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
||||
}
|
||||
|
||||
@ -473,22 +479,39 @@ fn send_generation(generation: Generation, entry: &Entry) -> Result<bool, SendEr
|
||||
// Generation has ended
|
||||
stopped = true;
|
||||
// Send message
|
||||
entry.response_tx
|
||||
.send(Ok(InferStreamResponse::End {
|
||||
token,
|
||||
generated_text,
|
||||
queued: entry.queue_time,
|
||||
start: entry.batch_time.unwrap(),
|
||||
}))?;
|
||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||
token,
|
||||
generated_text,
|
||||
queued: entry.queue_time,
|
||||
start: entry.batch_time.unwrap(),
|
||||
}))?;
|
||||
} else {
|
||||
// Send message
|
||||
entry.response_tx
|
||||
.send(Ok(InferStreamResponse::Token(token)))
|
||||
?;
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Token(token)))?;
|
||||
}
|
||||
Ok(stopped)
|
||||
}
|
||||
|
||||
/// Send errors to Infer for all `entries`
|
||||
#[instrument(skip_all)]
|
||||
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||
entries.drain().for_each(|(_, entry)| {
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||
let err = InferError::GenerationError(error.to_string());
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
|
||||
tracing::error!("{err}");
|
||||
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
entry
|
||||
.response_tx
|
||||
.send(Err(err))
|
||||
.unwrap_or(());
|
||||
});
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum InferStreamResponse {
|
||||
// Optional first message
|
||||
|
@ -4,7 +4,7 @@ use crate::validation::ValidGenerateRequest;
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::min;
|
||||
use text_generation_client::{Batch, Request};
|
||||
use tokio::sync::{oneshot, OwnedSemaphorePermit};
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::Instant;
|
||||
use tracing::{info_span, instrument, Span};
|
||||
|
||||
@ -23,8 +23,6 @@ pub(crate) struct Entry {
|
||||
pub queue_time: Instant,
|
||||
/// Instant when this entry was added to a batch
|
||||
pub batch_time: Option<Instant>,
|
||||
/// Permit
|
||||
pub _permit: OwnedSemaphorePermit,
|
||||
}
|
||||
|
||||
/// Request Queue
|
||||
@ -147,46 +145,53 @@ impl State {
|
||||
}
|
||||
}
|
||||
|
||||
let next_batch_size = min(self.entries.len(), max_size);
|
||||
let max_batch_size = min(self.entries.len(), max_size);
|
||||
|
||||
// Create span for this batch to add context to inference calls
|
||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = next_batch_size);
|
||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||
next_batch_span.follows_from(&Span::current());
|
||||
|
||||
let mut batch_requests = Vec::with_capacity(next_batch_size);
|
||||
let mut batch_requests = Vec::with_capacity(max_batch_size);
|
||||
let mut batch_entries =
|
||||
IntMap::with_capacity_and_hasher(next_batch_size, BuildNoHashHasher::default());
|
||||
IntMap::with_capacity_and_hasher(max_batch_size, BuildNoHashHasher::default());
|
||||
|
||||
// Drain next_batch_size entries
|
||||
self.entries
|
||||
.drain(..next_batch_size)
|
||||
.for_each(|(id, mut entry)| {
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span =
|
||||
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
|
||||
// Add relationships
|
||||
next_batch_span.follows_from(&entry_batch_span);
|
||||
entry_batch_span.follows_from(&next_batch_span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
for (id, mut entry) in self.entries.drain(..max_batch_size) {
|
||||
// Filter entries where the response receiver was dropped (== entries where the request
|
||||
// was dropped by the client)
|
||||
if entry.response_tx.is_disconnected() {
|
||||
continue;
|
||||
}
|
||||
|
||||
batch_requests.push(Request {
|
||||
id,
|
||||
inputs: entry.request.inputs.clone(),
|
||||
truncate: entry.request.truncate,
|
||||
parameters: Some(entry.request.parameters.clone()),
|
||||
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
||||
});
|
||||
// Set batch_time
|
||||
entry.batch_time = Some(Instant::now());
|
||||
// Insert in batch_entries IntMap
|
||||
batch_entries.insert(id, entry);
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
// Add relationships
|
||||
next_batch_span.follows_from(&entry_batch_span);
|
||||
entry_batch_span.follows_from(&next_batch_span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
|
||||
batch_requests.push(Request {
|
||||
id,
|
||||
inputs: entry.request.inputs.clone(),
|
||||
truncate: entry.request.truncate,
|
||||
parameters: Some(entry.request.parameters.clone()),
|
||||
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
||||
});
|
||||
// Set batch_time
|
||||
entry.batch_time = Some(Instant::now());
|
||||
// Insert in batch_entries IntMap
|
||||
batch_entries.insert(id, entry);
|
||||
}
|
||||
|
||||
// Final batch size once we dropped entries
|
||||
let size = batch_requests.len() as u32;
|
||||
next_batch_span.record("batch_size", size);
|
||||
|
||||
let batch = Batch {
|
||||
id: self.next_batch_id,
|
||||
requests: batch_requests,
|
||||
size: next_batch_size as u32,
|
||||
size,
|
||||
};
|
||||
// Increment batch id
|
||||
self.next_batch_id += 1;
|
||||
@ -219,9 +224,7 @@ mod tests {
|
||||
use tracing::info_span;
|
||||
|
||||
fn default_entry() -> Entry {
|
||||
let semaphore = Arc::new(Semaphore::new(1));
|
||||
let (response_tx, _) = flume::unbounded();
|
||||
let permit = semaphore.try_acquire_owned().unwrap();
|
||||
|
||||
Entry {
|
||||
request: ValidGenerateRequest {
|
||||
@ -248,7 +251,6 @@ mod tests {
|
||||
temp_span: None,
|
||||
queue_time: Instant::now(),
|
||||
batch_time: None,
|
||||
_permit: permit,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -367,7 +367,8 @@ async fn generate_stream(
|
||||
let best_of = req.0.parameters.best_of.unwrap_or(1);
|
||||
if best_of == 1 {
|
||||
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
Ok(mut response_stream) => {
|
||||
// Keep permit as long as generate_stream lives
|
||||
Ok((_permit, mut response_stream)) => {
|
||||
// Server-Sent Event stream
|
||||
while let Some(response) = response_stream.next().await {
|
||||
match response {
|
||||
|
Loading…
Reference in New Issue
Block a user