text-generation-inference/router/src/batcher.rs

229 lines
7.8 KiB
Rust
Raw Normal View History

2022-10-18 13:19:03 +00:00
/// Batching and inference logic
2022-10-17 12:59:00 +00:00
use crate::{Db, Entry};
2022-10-27 12:25:29 +00:00
use crate::{ErrorResponse, GenerateRequest};
2022-10-17 12:59:00 +00:00
use axum::http::StatusCode;
2022-10-27 12:25:29 +00:00
use axum::Json;
use std::future::Future;
2022-10-08 10:30:12 +00:00
use std::sync::Arc;
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient};
2022-10-17 12:59:00 +00:00
use thiserror::Error;
use tokio::sync::{oneshot, Notify};
2022-10-18 13:19:03 +00:00
use tokio::time::Instant;
use tracing::instrument;
2022-10-08 10:30:12 +00:00
2022-10-18 13:19:03 +00:00
/// Batcher
2022-10-08 10:30:12 +00:00
#[derive(Clone)]
2022-10-17 16:27:33 +00:00
pub struct Batcher {
2022-10-18 13:19:03 +00:00
/// Request database
2022-10-08 10:30:12 +00:00
db: Db,
2022-10-18 13:19:03 +00:00
/// Shared state
2022-10-08 10:30:12 +00:00
shared: Arc<Shared>,
}
2022-10-18 13:19:03 +00:00
/// Batcher shared state
2022-10-08 10:30:12 +00:00
struct Shared {
2022-10-18 13:19:03 +00:00
/// Batching background Tokio task notifier
2022-10-08 10:30:12 +00:00
batching_task: Notify,
}
2022-10-11 08:36:51 +00:00
impl Batcher {
2022-10-18 13:19:03 +00:00
pub(crate) fn new(
client: ShardedClient,
max_batch_size: usize,
2022-10-21 14:40:05 +00:00
max_waiting_tokens: usize,
2022-10-18 13:19:03 +00:00
) -> Self {
// Batcher shared state
2022-10-08 10:30:12 +00:00
let db = Db::new();
let shared = Arc::new(Shared {
batching_task: Notify::new(),
});
2022-10-18 13:19:03 +00:00
// Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task(
2022-10-22 21:40:05 +00:00
client,
2022-10-18 13:19:03 +00:00
max_batch_size,
2022-10-21 14:40:05 +00:00
max_waiting_tokens,
2022-10-18 13:19:03 +00:00
db.clone(),
shared.clone(),
));
2022-10-08 10:30:12 +00:00
Self { db, shared }
}
2022-10-18 13:19:03 +00:00
/// Add a new request to the database and return a future that will generate the text
pub(crate) async fn infer(
&self,
input_length: usize,
request: GenerateRequest,
2022-10-21 14:40:05 +00:00
) -> Result<InferResponse, InferError> {
2022-10-18 13:19:03 +00:00
// One shot channel to communicate with the background batching task
let (response_tx, response_rx) = oneshot::channel();
// Try to append the request to the database
2022-10-17 12:59:00 +00:00
self.db.append(Entry {
request,
2022-10-18 13:19:03 +00:00
response_tx,
2022-10-17 12:59:00 +00:00
input_length,
2022-10-18 13:19:03 +00:00
time: Instant::now(),
2022-10-21 14:40:05 +00:00
batch_time: None,
2022-10-17 12:59:00 +00:00
});
2022-10-18 13:19:03 +00:00
// Notify the background task that we have a new entry in the database that needs
// to be batched
2022-10-08 10:30:12 +00:00
self.shared.batching_task.notify_waiters();
2022-10-18 13:19:03 +00:00
// Await on the response from the background task
// We can safely unwrap as the background task will never drop the sender
match response_rx.await.unwrap() {
2022-10-08 10:30:12 +00:00
Ok(output) => Ok(output),
2022-10-17 12:59:00 +00:00
Err(err) => Err(InferError::GenerationError(err.to_string())),
2022-10-08 10:30:12 +00:00
}
}
}
2022-10-18 13:19:03 +00:00
/// Batching logic
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
#[instrument(skip(client, db, shared))]
async fn batching_task(
2022-10-22 21:40:05 +00:00
mut client: ShardedClient,
2022-10-18 13:19:03 +00:00
max_batch_size: usize,
2022-10-21 14:40:05 +00:00
max_waiting_tokens: usize,
2022-10-18 13:19:03 +00:00
db: Db,
shared: Arc<Shared>,
) {
// Minimum batch size after which we try to add more requests
2022-10-17 16:27:33 +00:00
let limit_min_batch_size = (max_batch_size / 2) as u32;
2022-10-18 13:19:03 +00:00
// Infinite loop
2022-10-08 10:30:12 +00:00
loop {
2022-10-18 13:19:03 +00:00
// Wait for a notification from the Batcher struct
2022-10-08 10:30:12 +00:00
shared.batching_task.notified().await;
2022-10-18 13:19:03 +00:00
// Get the next batch from the DB
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the DB
2022-10-21 14:40:05 +00:00
let mut waiting_tokens = 0;
if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
2022-10-18 13:19:03 +00:00
let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
2022-10-21 14:40:05 +00:00
waiting_tokens += 1;
2022-10-18 13:19:03 +00:00
// We loop until we do not receive any cached batch from the inference server (== until
// all requests have met their stopping criteria)
while let Some(batch) = cached_batch {
2022-10-18 13:19:03 +00:00
// Get current batch info
let batch_size = batch.size;
let mut request_ids: Vec<u64> = batch.requests.iter().map(|req| req.id).collect();
let mut batches = vec![batch];
2022-10-18 13:19:03 +00:00
// If the current batch is too small, we try to add more requests to it
if batch_size <= limit_min_batch_size {
2022-10-21 14:40:05 +00:00
let min_size = match waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
_ if waiting_tokens >= max_waiting_tokens => None,
// Minimum size criteria
_ => Some(limit_min_batch_size as usize),
};
// Try to get a new batch
2022-10-18 13:19:03 +00:00
if let Some((new_request_ids, new_batch)) =
2022-10-21 14:40:05 +00:00
db.next_batch(min_size, max_batch_size)
2022-10-18 13:19:03 +00:00
{
2022-10-21 14:40:05 +00:00
// Reset waiting counter
waiting_tokens = 0;
2022-10-18 13:19:03 +00:00
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch =
2022-10-18 13:19:03 +00:00
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
batches.push(new_cached_batch);
}
}
2022-10-08 10:30:12 +00:00
}
2022-10-18 13:19:03 +00:00
cached_batch =
wrap_future(client.generate_with_cache(batches), request_ids, &db).await;
2022-10-21 14:40:05 +00:00
waiting_tokens += 1;
2022-10-08 10:30:12 +00:00
}
}
}
}
2022-10-18 13:19:03 +00:00
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
async fn wrap_future(
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
request_ids: Vec<u64>,
2022-10-11 08:36:51 +00:00
db: &Db,
) -> Option<Batch> {
match future.await {
Ok((generated_texts, next_batch)) => {
send_generated(generated_texts, db);
next_batch
2022-10-08 10:30:12 +00:00
}
2022-10-18 13:19:03 +00:00
// If we have an error, we discard the whole batch
2022-10-08 10:30:12 +00:00
Err(err) => {
send_error(err, request_ids, db);
2022-10-08 10:30:12 +00:00
None
}
}
}
2022-10-18 13:19:03 +00:00
/// Send errors to the Batcher for all `request_ids`
2022-10-08 10:30:12 +00:00
fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
request_ids.into_iter().for_each(|id| {
2022-10-18 13:19:03 +00:00
// We can `expect` here as the request id should always be in the DB
2022-10-17 12:59:00 +00:00
let entry = db.remove(&id).expect("ID not found in db. This is a bug.");
// unwrap_or is valid here as we don't care if the receiver is gone.
entry.response_tx.send(Err(error.clone())).unwrap_or(());
2022-10-08 10:30:12 +00:00
});
}
2022-10-18 13:19:03 +00:00
/// Send `generated_text` to the Batcher for all `finished`
fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
2022-10-08 10:30:12 +00:00
finished.into_iter().for_each(|output| {
2022-10-18 13:19:03 +00:00
// We can `expect` here as the request id should always be in the DB
2022-10-17 12:59:00 +00:00
let entry = db
.remove(&output.request.unwrap().id)
.expect("ID not found in db. This is a bug.");
2022-10-21 14:40:05 +00:00
let response = InferResponse {
output: output.output,
queued: entry.time,
start: entry.batch_time.unwrap(), // unwrap is always valid
end: Instant::now(),
};
2022-10-17 12:59:00 +00:00
// unwrap_or is valid here as we don't care if the receiver is gone.
2022-10-21 14:40:05 +00:00
entry.response_tx.send(Ok(response)).unwrap_or(());
2022-10-08 10:30:12 +00:00
});
}
2022-10-18 13:19:03 +00:00
2022-10-21 14:40:05 +00:00
#[derive(Debug)]
pub(crate) struct InferResponse {
pub(crate) output: String,
pub(crate) queued: Instant,
pub(crate) start: Instant,
pub(crate) end: Instant,
}
2022-10-18 13:19:03 +00:00
#[derive(Debug, Error)]
pub enum InferError {
#[error("Request failed during generation: {0}")]
GenerationError(String),
}
/// Convert to Axum supported format
2022-10-27 12:25:29 +00:00
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
2022-10-18 13:19:03 +00:00
fn from(err: InferError) -> Self {
match err {
2022-10-27 12:25:29 +00:00
InferError::GenerationError(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: err.to_string(),
}),
),
2022-10-18 13:19:03 +00:00
}
}
}