feat(router): Remove second lock from batcher hot path

This commit is contained in:
OlivierDehaene 2023-01-20 14:06:33 +01:00
parent ce960be0a5
commit 67ee1907fc
3 changed files with 65 additions and 67 deletions

View File

@ -1,3 +1,4 @@
use std::collections::HashMap;
/// Batching and inference logic /// Batching and inference logic
use crate::{Db, Entry}; use crate::{Db, Entry};
use crate::{ErrorResponse, GenerateRequest}; use crate::{ErrorResponse, GenerateRequest};
@ -8,6 +9,7 @@ use std::sync::Arc;
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient}; use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient};
use thiserror::Error; use thiserror::Error;
use tokio::sync::{oneshot, Notify}; use tokio::sync::{oneshot, Notify};
use tokio::sync::{Semaphore, TryAcquireError};
use tokio::time::Instant; use tokio::time::Instant;
use tracing::instrument; use tracing::instrument;
@ -24,6 +26,8 @@ pub struct Batcher {
struct Shared { struct Shared {
/// Batching background Tokio task notifier /// Batching background Tokio task notifier
batching_task: Notify, batching_task: Notify,
/// Inference request limit
limit_concurrent_requests: Semaphore,
} }
impl Batcher { impl Batcher {
@ -31,11 +35,13 @@ impl Batcher {
client: ShardedClient, client: ShardedClient,
max_batch_size: usize, max_batch_size: usize,
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_concurrent_requests: usize,
) -> Self { ) -> Self {
// Batcher shared state // Batcher shared state
let db = Db::new(); let db = Db::new();
let shared = Arc::new(Shared { let shared = Arc::new(Shared {
batching_task: Notify::new(), batching_task: Notify::new(),
limit_concurrent_requests: Semaphore::new(max_concurrent_requests),
}); });
// Spawn batching background task that contains all the inference logic // Spawn batching background task that contains all the inference logic
@ -56,6 +62,9 @@ impl Batcher {
input_length: usize, input_length: usize,
request: GenerateRequest, request: GenerateRequest,
) -> Result<InferResponse, InferError> { ) -> Result<InferResponse, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore
let _permit = self.shared.limit_concurrent_requests.try_acquire()?;
// One shot channel to communicate with the background batching task // One shot channel to communicate with the background batching task
let (response_tx, response_rx) = oneshot::channel(); let (response_tx, response_rx) = oneshot::channel();
@ -104,8 +113,8 @@ async fn batching_task(
// Get the next batch from the DB // Get the next batch from the DB
// This batch might be smaller than the maximum batch size if there are not enough requests // This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the DB // waiting in the DB
while let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) { while let Some((mut entries, batch)) = db.next_batch(None, max_batch_size) {
let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await; let mut cached_batch = wrap_future(client.generate(batch), &mut entries).await;
let mut waiting_tokens = 1; let mut waiting_tokens = 1;
// We loop until we do not receive any cached batch from the inference server (== until // We loop until we do not receive any cached batch from the inference server (== until
@ -113,7 +122,6 @@ async fn batching_task(
while let Some(batch) = cached_batch { while let Some(batch) = cached_batch {
// Get current batch info // Get current batch info
let batch_size = batch.size; let batch_size = batch.size;
let mut request_ids: Vec<u64> = batch.requests.iter().map(|req| req.id).collect();
let mut batches = vec![batch]; let mut batches = vec![batch];
// If the current batch is too small, we try to add more requests to it // If the current batch is too small, we try to add more requests to it
@ -127,24 +135,24 @@ async fn batching_task(
}; };
// Try to get a new batch // Try to get a new batch
if let Some((new_request_ids, new_batch)) = if let Some((mut new_entries, new_batch)) =
db.next_batch(min_size, max_batch_size - batch_size as usize) db.next_batch(min_size, max_batch_size - batch_size as usize)
{ {
// Generate one token for this new batch to have the attention past in cache // Generate one token for this new batch to have the attention past in cache
let new_cached_batch = let new_cached_batch =
wrap_future(client.generate(new_batch), new_request_ids, &db).await; wrap_future(client.generate(new_batch), &mut new_entries).await;
// Reset waiting counter // Reset waiting counter
waiting_tokens = 1; waiting_tokens = 1;
// Extend current batch with the new batch // Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch { if let Some(new_cached_batch) = new_cached_batch {
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id)); entries.extend(new_entries);
batches.push(new_cached_batch); batches.push(new_cached_batch);
} }
} }
} }
cached_batch = cached_batch =
wrap_future(client.generate_with_cache(batches), request_ids, &db).await; wrap_future(client.generate_with_cache(batches), &mut entries).await;
waiting_tokens += 1; waiting_tokens += 1;
} }
} }
@ -154,39 +162,36 @@ async fn batching_task(
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher /// Wrap a future inside a match statement to handle errors and send the response to the Batcher
async fn wrap_future( async fn wrap_future(
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>, future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
request_ids: Vec<u64>, entries: &mut HashMap<u64, Entry>,
db: &Db,
) -> Option<Batch> { ) -> Option<Batch> {
match future.await { match future.await {
Ok((generated_texts, next_batch)) => { Ok((generated_texts, next_batch)) => {
send_generated(generated_texts, db); send_generated(generated_texts, entries);
next_batch next_batch
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
Err(err) => { Err(err) => {
send_error(err, request_ids, db); send_error(err, entries);
None None
} }
} }
} }
/// Send errors to the Batcher for all `request_ids` /// Send errors to the Batcher for all `entries`
fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) { fn send_error(error: ClientError, entries: &mut HashMap<u64, Entry>) {
request_ids.into_iter().for_each(|id| { entries.drain().for_each(|(_, entry)| {
// We can `expect` here as the request id should always be in the DB
let entry = db.remove(&id).expect("ID not found in db. This is a bug.");
// unwrap_or is valid here as we don't care if the receiver is gone. // unwrap_or is valid here as we don't care if the receiver is gone.
entry.response_tx.send(Err(error.clone())).unwrap_or(()); entry.response_tx.send(Err(error.clone())).unwrap_or(());
}); });
} }
/// Send `generated_text` to the Batcher for all `finished` /// Send `generated_text` to the Batcher for all `finished`
fn send_generated(finished: Vec<GeneratedText>, db: &Db) { fn send_generated(finished: Vec<GeneratedText>, entries: &mut HashMap<u64, Entry>) {
finished.into_iter().for_each(|output| { finished.into_iter().for_each(|output| {
// We can `expect` here as the request id should always be in the DB // We can `expect` here as the request id should always be in the entries
let entry = db let entry = entries
.remove(&output.request.unwrap().id) .remove(&output.request.unwrap().id)
.expect("ID not found in db. This is a bug."); .expect("ID not found in entries. This is a bug.");
let response = InferResponse { let response = InferResponse {
output_text: output.output_text, output_text: output.output_text,
@ -221,18 +226,30 @@ pub(crate) struct InferResponse {
pub enum InferError { pub enum InferError {
#[error("Request failed during generation: {0}")] #[error("Request failed during generation: {0}")]
GenerationError(String), GenerationError(String),
#[error("Model is overloaded")]
Overloaded,
}
/// Convert semaphore error
impl From<TryAcquireError> for InferError {
fn from(_: TryAcquireError) -> Self {
InferError::Overloaded
}
} }
/// Convert to Axum supported format /// Convert to Axum supported format
impl From<InferError> for (StatusCode, Json<ErrorResponse>) { impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
fn from(err: InferError) -> Self { fn from(err: InferError) -> Self {
match err { let status_code = match err {
InferError::GenerationError(_) => ( InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY,
StatusCode::FAILED_DEPENDENCY, InferError::Overloaded => StatusCode::TOO_MANY_REQUESTS,
};
(
status_code,
Json(ErrorResponse { Json(ErrorResponse {
error: err.to_string(), error: err.to_string(),
}), }),
), )
}
} }
} }

View File

@ -2,7 +2,7 @@ use crate::InferResponse;
/// This code is massively inspired by Tokio mini-redis /// This code is massively inspired by Tokio mini-redis
use crate::{GenerateParameters, GenerateRequest}; use crate::{GenerateParameters, GenerateRequest};
use parking_lot::Mutex; use parking_lot::Mutex;
use std::collections::BTreeMap; use std::collections::{BTreeMap, HashMap};
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{ use text_generation_client::{
Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
@ -112,18 +112,12 @@ impl Db {
state.entries.insert(id, entry); state.entries.insert(id, entry);
} }
/// Remove an entry from the database if it exists
pub(crate) fn remove(&self, id: &u64) -> Option<Entry> {
let mut state = self.shared.state.lock();
state.entries.remove(id)
}
// Get the next batch // Get the next batch
pub(crate) fn next_batch( pub(crate) fn next_batch(
&self, &self,
min_size: Option<usize>, min_size: Option<usize>,
max_size: usize, max_size: usize,
) -> Option<(Vec<u64>, Batch)> { ) -> Option<(HashMap<u64, Entry>, Batch)> {
// Acquire lock // Acquire lock
let mut state = self.shared.state.lock(); let mut state = self.shared.state.lock();
@ -135,13 +129,19 @@ impl Db {
return None; return None;
} }
} }
ids.iter().for_each(|id| {
// Set batch_time for each request
state.entries.get_mut(id).unwrap().batch_time = Some(Instant::now());
});
// Batch size // Batch size
let size = requests.len(); let size = requests.len();
let mut entries = HashMap::with_capacity(size);
ids.iter().for_each(|id| {
// Remove entry from db
let mut entry = state.entries.remove(id).unwrap();
// Set batch_time
entry.batch_time = Some(Instant::now());
// Insert in entries hashmap
entries.insert(*id, entry);
});
let batch = Batch { let batch = Batch {
id: state.next_batch_id, id: state.next_batch_id,
requests, requests,
@ -152,7 +152,7 @@ impl Db {
// Increment batch id // Increment batch id
state.next_batch_id += 1; state.next_batch_id += 1;
return Some((ids, batch)); return Some((entries, batch));
} }
None None
} }

View File

@ -7,11 +7,9 @@ use axum::response::IntoResponse;
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{Json, Router}; use axum::{Json, Router};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc;
use text_generation_client::ShardedClient; use text_generation_client::ShardedClient;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::signal; use tokio::signal;
use tokio::sync::Semaphore;
use tokio::time::Instant; use tokio::time::Instant;
use tracing::instrument; use tracing::instrument;
@ -20,7 +18,6 @@ use tracing::instrument;
struct ServerState { struct ServerState {
validation: Validation, validation: Validation,
batcher: Batcher, batcher: Batcher,
limit_concurrent_requests: Arc<Semaphore>,
} }
/// Health check method /// Health check method
@ -30,16 +27,6 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
// be a bit too slow for a health check. // be a bit too slow for a health check.
// What we should do instead if check if the gRPC channels are still healthy. // What we should do instead if check if the gRPC channels are still healthy.
// Limit concurrent requests by acquiring a permit from the semaphore
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
(
StatusCode::TOO_MANY_REQUESTS,
Json(ErrorResponse {
error: "Model is overloaded".to_string(),
}),
)
})?;
// Send a small inference request // Send a small inference request
state state
.batcher .batcher
@ -78,16 +65,6 @@ async fn generate(
req: Json<GenerateRequest>, req: Json<GenerateRequest>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> { ) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
let start_time = Instant::now(); let start_time = Instant::now();
// Limit concurrent requests by acquiring a permit from the semaphore
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
tracing::error!("Model is overloaded");
(
StatusCode::TOO_MANY_REQUESTS,
Json(ErrorResponse {
error: "Model is overloaded".to_string(),
}),
)
})?;
// Validate request // Validate request
let details = req.0.parameters.details; let details = req.0.parameters.details;
@ -185,12 +162,16 @@ pub async fn run(
addr: SocketAddr, addr: SocketAddr,
) { ) {
// Create state // Create state
let batcher = Batcher::new(client, max_batch_size, max_waiting_tokens); let batcher = Batcher::new(
client,
max_batch_size,
max_waiting_tokens,
max_concurrent_requests,
);
let validation = Validation::new(validation_workers, tokenizer, max_input_length); let validation = Validation::new(validation_workers, tokenizer, max_input_length);
let shared_state = ServerState { let shared_state = ServerState {
validation, validation,
batcher, batcher,
limit_concurrent_requests: Arc::new(Semaphore::new(max_concurrent_requests)),
}; };
// Create router // Create router