mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
wip
This commit is contained in:
parent
4e63d9cb28
commit
9476170dda
@ -3,14 +3,16 @@ use crate::validation::{Validation, ValidationError};
|
|||||||
use crate::{Entry, Queue, Token};
|
use crate::{Entry, Queue, Token};
|
||||||
use crate::{GenerateRequest, PrefillToken};
|
use crate::{GenerateRequest, PrefillToken};
|
||||||
use flume::r#async::RecvStream;
|
use flume::r#async::RecvStream;
|
||||||
|
use flume::SendError;
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
use futures::stream::StreamExt;
|
use futures::stream::StreamExt;
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use flume::SendError;
|
use text_generation_client::{
|
||||||
use text_generation_client::{Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient};
|
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||||
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::{Notify, Semaphore, TryAcquireError};
|
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::{info_span, instrument, Instrument, Span};
|
use tracing::{info_span, instrument, Instrument, Span};
|
||||||
|
|
||||||
@ -72,9 +74,14 @@ impl Infer {
|
|||||||
pub(crate) async fn generate_stream(
|
pub(crate) async fn generate_stream(
|
||||||
&self,
|
&self,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> Result<RecvStream<Result<InferStreamResponse, InferError>>, InferError> {
|
) -> Result<
|
||||||
|
(
|
||||||
|
OwnedSemaphorePermit,
|
||||||
|
RecvStream<Result<InferStreamResponse, InferError>>,
|
||||||
|
),
|
||||||
|
InferError,
|
||||||
|
> {
|
||||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||||
// This permit will live as long as Entry
|
|
||||||
let permit = self
|
let permit = self
|
||||||
.clone()
|
.clone()
|
||||||
.limit_concurrent_requests
|
.limit_concurrent_requests
|
||||||
@ -103,7 +110,6 @@ impl Infer {
|
|||||||
temp_span: None,
|
temp_span: None,
|
||||||
queue_time: Instant::now(),
|
queue_time: Instant::now(),
|
||||||
batch_time: None,
|
batch_time: None,
|
||||||
_permit: permit,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Notify the background task that we have a new entry in the queue that needs
|
// Notify the background task that we have a new entry in the queue that needs
|
||||||
@ -111,7 +117,7 @@ impl Infer {
|
|||||||
self.shared.batching_task.notify_one();
|
self.shared.batching_task.notify_one();
|
||||||
|
|
||||||
// Return stream
|
// Return stream
|
||||||
Ok(response_rx.into_stream())
|
Ok((permit, response_rx.into_stream()))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a new request to the queue and return a InferResponse
|
/// Add a new request to the queue and return a InferResponse
|
||||||
@ -120,8 +126,8 @@ impl Infer {
|
|||||||
&self,
|
&self,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> Result<InferResponse, InferError> {
|
) -> Result<InferResponse, InferError> {
|
||||||
// Create stream
|
// Create stream and keep semaphore permit as long as generate lives
|
||||||
let mut stream = self.generate_stream(request).await?;
|
let (_permit, mut stream) = self.generate_stream(request).await?;
|
||||||
|
|
||||||
// Return values
|
// Return values
|
||||||
let mut result_prefill = Vec::new();
|
let mut result_prefill = Vec::new();
|
||||||
@ -275,12 +281,10 @@ async fn batching_task(
|
|||||||
.next_batch(min_size, max_batch_size - batch_size as usize)
|
.next_batch(min_size, max_batch_size - batch_size as usize)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
let new_batch_size = new_batch.size;
|
|
||||||
entries.iter_mut().for_each(|(_, entry)| {
|
entries.iter_mut().for_each(|(_, entry)| {
|
||||||
// Create a new span to add the info that this entry is waiting
|
// Create a new span to add the info that this entry is waiting
|
||||||
// because a new batch is being computed
|
// because a new batch is being computed
|
||||||
let entry_waiting_span =
|
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
|
||||||
info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size);
|
|
||||||
// Add relationships
|
// Add relationships
|
||||||
span.follows_from(&entry_waiting_span);
|
span.follows_from(&entry_waiting_span);
|
||||||
entry_waiting_span.follows_from(&span);
|
entry_waiting_span.follows_from(&span);
|
||||||
@ -307,8 +311,7 @@ async fn batching_task(
|
|||||||
info_span!(parent: None, "batch", batch_size = next_batch_size);
|
info_span!(parent: None, "batch", batch_size = next_batch_size);
|
||||||
entries.iter_mut().for_each(|(_, entry)| {
|
entries.iter_mut().for_each(|(_, entry)| {
|
||||||
// Create a new span to link the batch back to this entry
|
// Create a new span to link the batch back to this entry
|
||||||
let entry_batch_span =
|
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||||
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
|
|
||||||
// Add relationships
|
// Add relationships
|
||||||
next_batch_span.follows_from(&entry_batch_span);
|
next_batch_span.follows_from(&entry_batch_span);
|
||||||
entry_batch_span.follows_from(&next_batch_span);
|
entry_batch_span.follows_from(&next_batch_span);
|
||||||
@ -340,17 +343,19 @@ async fn prefill(
|
|||||||
Ok((generations, next_batch)) => {
|
Ok((generations, next_batch)) => {
|
||||||
filter_send_generations(generations, entries);
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
let next_batch = {
|
// Filter next batch and remove requests that were stopped
|
||||||
let mut batch = next_batch.expect("next_batch is None. This is a bug.");
|
let next_batch = match next_batch {
|
||||||
|
None => None,
|
||||||
batch.requests = batch.requests.into_iter().filter(|r| { entries.contains_key(&r.id) }).collect();
|
Some(batch) => {
|
||||||
let size = batch.requests.len();
|
let id = batch.id;
|
||||||
if size == 0 {
|
let next_batch = filter_batch(batch, entries);
|
||||||
let _ = client.clear_cache(Some(batch.id)).await;
|
// Next batch is now empty
|
||||||
return None;
|
// Clear it from the Python shards cache
|
||||||
|
if next_batch.is_none() {
|
||||||
|
let _ = client.clear_cache(Some(id)).await;
|
||||||
|
}
|
||||||
|
next_batch
|
||||||
}
|
}
|
||||||
batch.size = size as u32;
|
|
||||||
Some(batch)
|
|
||||||
};
|
};
|
||||||
|
|
||||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
|
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
|
||||||
@ -381,17 +386,19 @@ async fn decode(
|
|||||||
Ok((generations, next_batch)) => {
|
Ok((generations, next_batch)) => {
|
||||||
filter_send_generations(generations, entries);
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
let next_batch = {
|
// Filter next batch and remove requests that were stopped
|
||||||
let mut batch = next_batch.expect("next_batch is None. This is a bug.");
|
let next_batch = match next_batch {
|
||||||
|
None => None,
|
||||||
batch.requests = batch.requests.into_iter().filter(|r| { entries.contains_key(&r.id) }).collect();
|
Some(batch) => {
|
||||||
let size = batch.requests.len();
|
let id = batch.id;
|
||||||
if size == 0 {
|
let next_batch = filter_batch(batch, entries);
|
||||||
let _ = client.clear_cache(Some(batch.id)).await;
|
// Next batch is now empty
|
||||||
return None;
|
// Clear it from the Python shards cache
|
||||||
|
if next_batch.is_none() {
|
||||||
|
let _ = client.clear_cache(Some(id)).await;
|
||||||
|
}
|
||||||
|
next_batch
|
||||||
}
|
}
|
||||||
batch.size = size as u32;
|
|
||||||
Some(batch)
|
|
||||||
};
|
};
|
||||||
|
|
||||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
|
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
|
||||||
@ -410,22 +417,16 @@ async fn decode(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send errors to Infer for all `entries`
|
/// Filter a `batch` and remove all requests not present in `entries`
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
fn filter_batch(mut batch: Batch, entries: &IntMap<u64, Entry>) -> Option<Batch> {
|
||||||
entries.drain().for_each(|(_, entry)| {
|
batch.requests.retain(|r| entries.contains_key(&r.id));
|
||||||
// Create and enter a span to link this function back to the entry
|
let size = batch.requests.len();
|
||||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
if size == 0 {
|
||||||
let err = InferError::GenerationError(error.to_string());
|
return None;
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
|
}
|
||||||
tracing::error!("{err}");
|
batch.size = size as u32;
|
||||||
|
Some(batch)
|
||||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
|
||||||
entry
|
|
||||||
.response_tx
|
|
||||||
.send(Err(err))
|
|
||||||
.unwrap_or(());
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||||
@ -442,22 +443,27 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||||||
|
|
||||||
// Create and enter a span to link this function back to the entry
|
// Create and enter a span to link this function back to the entry
|
||||||
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||||
// Send generation back to infer task
|
// Send generation responses back to the infer task
|
||||||
// If the receive an error from the Flume channel, we need to stop generating for this
|
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||||
// request hence why we unwrap_or(true)
|
// request and we need to stop generating hence why we unwrap_or(true)
|
||||||
let stopped = send_generation(generation, entry).unwrap_or(true);
|
let stopped = send_responses(generation, entry).unwrap_or(true);
|
||||||
if stopped {
|
if stopped {
|
||||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
fn send_generation(generation: Generation, entry: &Entry) -> Result<bool, SendError<Result<InferStreamResponse, InferError>>> {
|
/// Send responses through the `entry` response channel
|
||||||
|
fn send_responses(
|
||||||
|
generation: Generation,
|
||||||
|
entry: &Entry,
|
||||||
|
) -> Result<bool, SendError<Result<InferStreamResponse, InferError>>> {
|
||||||
let mut stopped = false;
|
let mut stopped = false;
|
||||||
|
|
||||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||||
// Send message
|
// Send message
|
||||||
entry.response_tx
|
entry
|
||||||
|
.response_tx
|
||||||
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -473,22 +479,39 @@ fn send_generation(generation: Generation, entry: &Entry) -> Result<bool, SendEr
|
|||||||
// Generation has ended
|
// Generation has ended
|
||||||
stopped = true;
|
stopped = true;
|
||||||
// Send message
|
// Send message
|
||||||
entry.response_tx
|
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||||
.send(Ok(InferStreamResponse::End {
|
token,
|
||||||
token,
|
generated_text,
|
||||||
generated_text,
|
queued: entry.queue_time,
|
||||||
queued: entry.queue_time,
|
start: entry.batch_time.unwrap(),
|
||||||
start: entry.batch_time.unwrap(),
|
}))?;
|
||||||
}))?;
|
|
||||||
} else {
|
} else {
|
||||||
// Send message
|
// Send message
|
||||||
entry.response_tx
|
entry
|
||||||
.send(Ok(InferStreamResponse::Token(token)))
|
.response_tx
|
||||||
?;
|
.send(Ok(InferStreamResponse::Token(token)))?;
|
||||||
}
|
}
|
||||||
Ok(stopped)
|
Ok(stopped)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Send errors to Infer for all `entries`
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||||
|
entries.drain().for_each(|(_, entry)| {
|
||||||
|
// Create and enter a span to link this function back to the entry
|
||||||
|
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||||
|
let err = InferError::GenerationError(error.to_string());
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
|
||||||
|
tracing::error!("{err}");
|
||||||
|
|
||||||
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
|
entry
|
||||||
|
.response_tx
|
||||||
|
.send(Err(err))
|
||||||
|
.unwrap_or(());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) enum InferStreamResponse {
|
pub(crate) enum InferStreamResponse {
|
||||||
// Optional first message
|
// Optional first message
|
||||||
|
@ -4,7 +4,7 @@ use crate::validation::ValidGenerateRequest;
|
|||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
use text_generation_client::{Batch, Request};
|
use text_generation_client::{Batch, Request};
|
||||||
use tokio::sync::{oneshot, OwnedSemaphorePermit};
|
use tokio::sync::oneshot;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::{info_span, instrument, Span};
|
use tracing::{info_span, instrument, Span};
|
||||||
|
|
||||||
@ -23,8 +23,6 @@ pub(crate) struct Entry {
|
|||||||
pub queue_time: Instant,
|
pub queue_time: Instant,
|
||||||
/// Instant when this entry was added to a batch
|
/// Instant when this entry was added to a batch
|
||||||
pub batch_time: Option<Instant>,
|
pub batch_time: Option<Instant>,
|
||||||
/// Permit
|
|
||||||
pub _permit: OwnedSemaphorePermit,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Request Queue
|
/// Request Queue
|
||||||
@ -147,46 +145,53 @@ impl State {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let next_batch_size = min(self.entries.len(), max_size);
|
let max_batch_size = min(self.entries.len(), max_size);
|
||||||
|
|
||||||
// Create span for this batch to add context to inference calls
|
// Create span for this batch to add context to inference calls
|
||||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = next_batch_size);
|
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||||
next_batch_span.follows_from(&Span::current());
|
next_batch_span.follows_from(&Span::current());
|
||||||
|
|
||||||
let mut batch_requests = Vec::with_capacity(next_batch_size);
|
let mut batch_requests = Vec::with_capacity(max_batch_size);
|
||||||
let mut batch_entries =
|
let mut batch_entries =
|
||||||
IntMap::with_capacity_and_hasher(next_batch_size, BuildNoHashHasher::default());
|
IntMap::with_capacity_and_hasher(max_batch_size, BuildNoHashHasher::default());
|
||||||
|
|
||||||
// Drain next_batch_size entries
|
// Drain next_batch_size entries
|
||||||
self.entries
|
for (id, mut entry) in self.entries.drain(..max_batch_size) {
|
||||||
.drain(..next_batch_size)
|
// Filter entries where the response receiver was dropped (== entries where the request
|
||||||
.for_each(|(id, mut entry)| {
|
// was dropped by the client)
|
||||||
// Create a new span to link the batch back to this entry
|
if entry.response_tx.is_disconnected() {
|
||||||
let entry_batch_span =
|
continue;
|
||||||
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
|
}
|
||||||
// Add relationships
|
|
||||||
next_batch_span.follows_from(&entry_batch_span);
|
|
||||||
entry_batch_span.follows_from(&next_batch_span);
|
|
||||||
// Update entry
|
|
||||||
entry.temp_span = Some(entry_batch_span);
|
|
||||||
|
|
||||||
batch_requests.push(Request {
|
// Create a new span to link the batch back to this entry
|
||||||
id,
|
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||||
inputs: entry.request.inputs.clone(),
|
// Add relationships
|
||||||
truncate: entry.request.truncate,
|
next_batch_span.follows_from(&entry_batch_span);
|
||||||
parameters: Some(entry.request.parameters.clone()),
|
entry_batch_span.follows_from(&next_batch_span);
|
||||||
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
// Update entry
|
||||||
});
|
entry.temp_span = Some(entry_batch_span);
|
||||||
// Set batch_time
|
|
||||||
entry.batch_time = Some(Instant::now());
|
batch_requests.push(Request {
|
||||||
// Insert in batch_entries IntMap
|
id,
|
||||||
batch_entries.insert(id, entry);
|
inputs: entry.request.inputs.clone(),
|
||||||
|
truncate: entry.request.truncate,
|
||||||
|
parameters: Some(entry.request.parameters.clone()),
|
||||||
|
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
||||||
});
|
});
|
||||||
|
// Set batch_time
|
||||||
|
entry.batch_time = Some(Instant::now());
|
||||||
|
// Insert in batch_entries IntMap
|
||||||
|
batch_entries.insert(id, entry);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final batch size once we dropped entries
|
||||||
|
let size = batch_requests.len() as u32;
|
||||||
|
next_batch_span.record("batch_size", size);
|
||||||
|
|
||||||
let batch = Batch {
|
let batch = Batch {
|
||||||
id: self.next_batch_id,
|
id: self.next_batch_id,
|
||||||
requests: batch_requests,
|
requests: batch_requests,
|
||||||
size: next_batch_size as u32,
|
size,
|
||||||
};
|
};
|
||||||
// Increment batch id
|
// Increment batch id
|
||||||
self.next_batch_id += 1;
|
self.next_batch_id += 1;
|
||||||
@ -219,9 +224,7 @@ mod tests {
|
|||||||
use tracing::info_span;
|
use tracing::info_span;
|
||||||
|
|
||||||
fn default_entry() -> Entry {
|
fn default_entry() -> Entry {
|
||||||
let semaphore = Arc::new(Semaphore::new(1));
|
|
||||||
let (response_tx, _) = flume::unbounded();
|
let (response_tx, _) = flume::unbounded();
|
||||||
let permit = semaphore.try_acquire_owned().unwrap();
|
|
||||||
|
|
||||||
Entry {
|
Entry {
|
||||||
request: ValidGenerateRequest {
|
request: ValidGenerateRequest {
|
||||||
@ -248,7 +251,6 @@ mod tests {
|
|||||||
temp_span: None,
|
temp_span: None,
|
||||||
queue_time: Instant::now(),
|
queue_time: Instant::now(),
|
||||||
batch_time: None,
|
batch_time: None,
|
||||||
_permit: permit,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -367,7 +367,8 @@ async fn generate_stream(
|
|||||||
let best_of = req.0.parameters.best_of.unwrap_or(1);
|
let best_of = req.0.parameters.best_of.unwrap_or(1);
|
||||||
if best_of == 1 {
|
if best_of == 1 {
|
||||||
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
|
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||||
Ok(mut response_stream) => {
|
// Keep permit as long as generate_stream lives
|
||||||
|
Ok((_permit, mut response_stream)) => {
|
||||||
// Server-Sent Event stream
|
// Server-Sent Event stream
|
||||||
while let Some(response) = response_stream.next().await {
|
while let Some(response) = response_stream.next().await {
|
||||||
match response {
|
match response {
|
||||||
|
Loading…
Reference in New Issue
Block a user