2022-10-11 14:50:54 +00:00
|
|
|
use crate::server::GenerateRequest;
|
2022-10-17 12:59:00 +00:00
|
|
|
use crate::{Db, Entry};
|
|
|
|
use axum::http::StatusCode;
|
2022-10-11 14:50:54 +00:00
|
|
|
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
|
|
|
use std::future::Future;
|
2022-10-08 10:30:12 +00:00
|
|
|
use std::sync::Arc;
|
2022-10-17 12:59:00 +00:00
|
|
|
use thiserror::Error;
|
2022-10-11 14:50:54 +00:00
|
|
|
use tokio::sync::{oneshot, Notify};
|
2022-10-08 10:30:12 +00:00
|
|
|
|
|
|
|
const MAX_LENGTH: usize = 128;
|
|
|
|
|
2022-10-17 12:59:00 +00:00
|
|
|
#[derive(Debug, Error)]
|
|
|
|
pub enum InferError {
|
|
|
|
#[error("Request failed during generation: {0}")]
|
|
|
|
GenerationError(String),
|
|
|
|
#[error("Model is overloaded")]
|
|
|
|
Overloaded,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl From<InferError> for (StatusCode, String) {
|
|
|
|
fn from(err: InferError) -> Self {
|
|
|
|
match err {
|
|
|
|
InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
|
|
|
|
InferError::Overloaded => (StatusCode::TOO_MANY_REQUESTS, err.to_string()),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2022-10-08 10:30:12 +00:00
|
|
|
|
|
|
|
#[derive(Clone)]
|
2022-10-17 16:27:33 +00:00
|
|
|
pub struct Batcher {
|
2022-10-08 10:30:12 +00:00
|
|
|
db: Db,
|
|
|
|
shared: Arc<Shared>,
|
|
|
|
}
|
|
|
|
|
|
|
|
struct Shared {
|
|
|
|
batching_task: Notify,
|
|
|
|
}
|
|
|
|
|
2022-10-11 08:36:51 +00:00
|
|
|
impl Batcher {
|
2022-10-17 16:27:33 +00:00
|
|
|
pub(crate) fn new(client: ShardedClient, max_batch_size: usize) -> Self {
|
2022-10-08 10:30:12 +00:00
|
|
|
let db = Db::new();
|
|
|
|
let shared = Arc::new(Shared {
|
|
|
|
batching_task: Notify::new(),
|
|
|
|
});
|
|
|
|
|
2022-10-17 16:27:33 +00:00
|
|
|
tokio::spawn(batching_task(max_batch_size, client, db.clone(), shared.clone()));
|
2022-10-08 10:30:12 +00:00
|
|
|
|
|
|
|
Self { db, shared }
|
|
|
|
}
|
|
|
|
|
2022-10-11 14:50:54 +00:00
|
|
|
pub(crate) async fn infer(
|
|
|
|
&self,
|
|
|
|
input_length: usize,
|
|
|
|
request: GenerateRequest,
|
|
|
|
) -> Result<String, InferError> {
|
2022-10-08 10:30:12 +00:00
|
|
|
if self.db.len() > MAX_LENGTH {
|
2022-10-17 12:59:00 +00:00
|
|
|
return Err(InferError::Overloaded);
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
let (request_tx, request_rx) = oneshot::channel();
|
2022-10-17 12:59:00 +00:00
|
|
|
self.db.append(Entry {
|
|
|
|
request,
|
|
|
|
response_tx: request_tx,
|
|
|
|
input_length,
|
|
|
|
});
|
2022-10-08 10:30:12 +00:00
|
|
|
self.shared.batching_task.notify_waiters();
|
|
|
|
match request_rx.await.unwrap() {
|
|
|
|
Ok(output) => Ok(output),
|
2022-10-17 12:59:00 +00:00
|
|
|
Err(err) => Err(InferError::GenerationError(err.to_string())),
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-10-17 16:27:33 +00:00
|
|
|
async fn batching_task(max_batch_size: usize,
|
|
|
|
client: ShardedClient,
|
|
|
|
db: Db,
|
|
|
|
shared: Arc<Shared>) {
|
|
|
|
let limit_min_batch_size = (max_batch_size / 2) as u32;
|
|
|
|
|
2022-10-08 10:30:12 +00:00
|
|
|
loop {
|
|
|
|
shared.batching_task.notified().await;
|
|
|
|
|
2022-10-17 16:27:33 +00:00
|
|
|
if let Some(batch) = db.next_batch(max_batch_size) {
|
2022-10-11 14:50:54 +00:00
|
|
|
let request_ids = batch.requests.iter().map(|req| req.id).collect();
|
|
|
|
let mut cached_batch = match batch.size {
|
2022-10-17 16:27:33 +00:00
|
|
|
size if size > limit_min_batch_size => {
|
2022-10-11 14:50:54 +00:00
|
|
|
wrap_future(client.generate_until_finished(batch), request_ids, &db).await
|
|
|
|
}
|
|
|
|
_ => wrap_future(client.generate(batch), request_ids, &db).await,
|
|
|
|
};
|
|
|
|
|
|
|
|
while let Some(batch) = cached_batch {
|
2022-10-17 16:27:33 +00:00
|
|
|
let mut current_batch_size = batch.size;
|
2022-10-11 14:50:54 +00:00
|
|
|
let mut request_ids: Vec<u64> = batch.requests.iter().map(|req| req.id).collect();
|
|
|
|
let mut batches = vec![batch];
|
|
|
|
|
2022-10-17 16:27:33 +00:00
|
|
|
if current_batch_size <= limit_min_batch_size {
|
|
|
|
if let Some(new_batch) = db.next_batch_minimum_size(limit_min_batch_size as usize, max_batch_size) {
|
2022-10-11 14:50:54 +00:00
|
|
|
let new_batch_request_ids =
|
|
|
|
new_batch.requests.iter().map(|req| req.id).collect();
|
|
|
|
let new_cached_batch =
|
|
|
|
wrap_future(client.generate(new_batch), new_batch_request_ids, &db)
|
|
|
|
.await;
|
|
|
|
if let Some(new_cached_batch) = new_cached_batch {
|
2022-10-17 16:27:33 +00:00
|
|
|
current_batch_size += new_cached_batch.size;
|
2022-10-11 14:50:54 +00:00
|
|
|
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
|
|
|
|
batches.push(new_cached_batch);
|
|
|
|
}
|
|
|
|
}
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
2022-10-11 14:50:54 +00:00
|
|
|
|
2022-10-17 16:27:33 +00:00
|
|
|
cached_batch = match current_batch_size {
|
|
|
|
size if size > limit_min_batch_size => {
|
2022-10-14 13:56:21 +00:00
|
|
|
wrap_future(
|
|
|
|
client.generate_until_finished_with_cache(batches),
|
|
|
|
request_ids,
|
|
|
|
&db,
|
|
|
|
)
|
|
|
|
.await
|
2022-10-11 14:50:54 +00:00
|
|
|
}
|
|
|
|
_ => wrap_future(client.generate_with_cache(batches), request_ids, &db).await,
|
|
|
|
};
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-10-11 14:50:54 +00:00
|
|
|
async fn wrap_future(
|
|
|
|
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
|
|
|
|
request_ids: Vec<u64>,
|
2022-10-11 08:36:51 +00:00
|
|
|
db: &Db,
|
2022-10-11 14:50:54 +00:00
|
|
|
) -> Option<Batch> {
|
|
|
|
match future.await {
|
|
|
|
Ok((generated_texts, next_batch)) => {
|
|
|
|
send_generated(generated_texts, db);
|
|
|
|
next_batch
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
Err(err) => {
|
2022-10-11 14:50:54 +00:00
|
|
|
send_error(err, request_ids, db);
|
2022-10-08 10:30:12 +00:00
|
|
|
None
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
|
|
|
|
request_ids.into_iter().for_each(|id| {
|
2022-10-17 12:59:00 +00:00
|
|
|
let entry = db.remove(&id).expect("ID not found in db. This is a bug.");
|
|
|
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
|
|
|
entry.response_tx.send(Err(error.clone())).unwrap_or(());
|
2022-10-08 10:30:12 +00:00
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2022-10-11 14:50:54 +00:00
|
|
|
fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
2022-10-08 10:30:12 +00:00
|
|
|
finished.into_iter().for_each(|output| {
|
2022-10-17 12:59:00 +00:00
|
|
|
let entry = db
|
|
|
|
.remove(&output.request.unwrap().id)
|
|
|
|
.expect("ID not found in db. This is a bug.");
|
|
|
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
|
|
|
entry.response_tx.send(Ok(output.output)).unwrap_or(());
|
2022-10-08 10:30:12 +00:00
|
|
|
});
|
|
|
|
}
|