mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
feat(router): Remove second lock from batcher hot path
This commit is contained in:
parent
ce960be0a5
commit
67ee1907fc
@ -1,3 +1,4 @@
|
||||
use std::collections::HashMap;
|
||||
/// Batching and inference logic
|
||||
use crate::{Db, Entry};
|
||||
use crate::{ErrorResponse, GenerateRequest};
|
||||
@ -8,6 +9,7 @@ use std::sync::Arc;
|
||||
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{oneshot, Notify};
|
||||
use tokio::sync::{Semaphore, TryAcquireError};
|
||||
use tokio::time::Instant;
|
||||
use tracing::instrument;
|
||||
|
||||
@ -24,6 +26,8 @@ pub struct Batcher {
|
||||
struct Shared {
|
||||
/// Batching background Tokio task notifier
|
||||
batching_task: Notify,
|
||||
/// Inference request limit
|
||||
limit_concurrent_requests: Semaphore,
|
||||
}
|
||||
|
||||
impl Batcher {
|
||||
@ -31,11 +35,13 @@ impl Batcher {
|
||||
client: ShardedClient,
|
||||
max_batch_size: usize,
|
||||
max_waiting_tokens: usize,
|
||||
max_concurrent_requests: usize,
|
||||
) -> Self {
|
||||
// Batcher shared state
|
||||
let db = Db::new();
|
||||
let shared = Arc::new(Shared {
|
||||
batching_task: Notify::new(),
|
||||
limit_concurrent_requests: Semaphore::new(max_concurrent_requests),
|
||||
});
|
||||
|
||||
// Spawn batching background task that contains all the inference logic
|
||||
@ -56,6 +62,9 @@ impl Batcher {
|
||||
input_length: usize,
|
||||
request: GenerateRequest,
|
||||
) -> Result<InferResponse, InferError> {
|
||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||
let _permit = self.shared.limit_concurrent_requests.try_acquire()?;
|
||||
|
||||
// One shot channel to communicate with the background batching task
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
|
||||
@ -104,8 +113,8 @@ async fn batching_task(
|
||||
// Get the next batch from the DB
|
||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||
// waiting in the DB
|
||||
while let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
|
||||
let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
|
||||
while let Some((mut entries, batch)) = db.next_batch(None, max_batch_size) {
|
||||
let mut cached_batch = wrap_future(client.generate(batch), &mut entries).await;
|
||||
let mut waiting_tokens = 1;
|
||||
|
||||
// We loop until we do not receive any cached batch from the inference server (== until
|
||||
@ -113,7 +122,6 @@ async fn batching_task(
|
||||
while let Some(batch) = cached_batch {
|
||||
// Get current batch info
|
||||
let batch_size = batch.size;
|
||||
let mut request_ids: Vec<u64> = batch.requests.iter().map(|req| req.id).collect();
|
||||
let mut batches = vec![batch];
|
||||
|
||||
// If the current batch is too small, we try to add more requests to it
|
||||
@ -127,24 +135,24 @@ async fn batching_task(
|
||||
};
|
||||
|
||||
// Try to get a new batch
|
||||
if let Some((new_request_ids, new_batch)) =
|
||||
if let Some((mut new_entries, new_batch)) =
|
||||
db.next_batch(min_size, max_batch_size - batch_size as usize)
|
||||
{
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
let new_cached_batch =
|
||||
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
|
||||
wrap_future(client.generate(new_batch), &mut new_entries).await;
|
||||
// Reset waiting counter
|
||||
waiting_tokens = 1;
|
||||
// Extend current batch with the new batch
|
||||
if let Some(new_cached_batch) = new_cached_batch {
|
||||
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
|
||||
entries.extend(new_entries);
|
||||
batches.push(new_cached_batch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cached_batch =
|
||||
wrap_future(client.generate_with_cache(batches), request_ids, &db).await;
|
||||
wrap_future(client.generate_with_cache(batches), &mut entries).await;
|
||||
waiting_tokens += 1;
|
||||
}
|
||||
}
|
||||
@ -154,39 +162,36 @@ async fn batching_task(
|
||||
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
|
||||
async fn wrap_future(
|
||||
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
|
||||
request_ids: Vec<u64>,
|
||||
db: &Db,
|
||||
entries: &mut HashMap<u64, Entry>,
|
||||
) -> Option<Batch> {
|
||||
match future.await {
|
||||
Ok((generated_texts, next_batch)) => {
|
||||
send_generated(generated_texts, db);
|
||||
send_generated(generated_texts, entries);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
send_error(err, request_ids, db);
|
||||
send_error(err, entries);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Send errors to the Batcher for all `request_ids`
|
||||
fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
|
||||
request_ids.into_iter().for_each(|id| {
|
||||
// We can `expect` here as the request id should always be in the DB
|
||||
let entry = db.remove(&id).expect("ID not found in db. This is a bug.");
|
||||
/// Send errors to the Batcher for all `entries`
|
||||
fn send_error(error: ClientError, entries: &mut HashMap<u64, Entry>) {
|
||||
entries.drain().for_each(|(_, entry)| {
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
entry.response_tx.send(Err(error.clone())).unwrap_or(());
|
||||
});
|
||||
}
|
||||
|
||||
/// Send `generated_text` to the Batcher for all `finished`
|
||||
fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
||||
fn send_generated(finished: Vec<GeneratedText>, entries: &mut HashMap<u64, Entry>) {
|
||||
finished.into_iter().for_each(|output| {
|
||||
// We can `expect` here as the request id should always be in the DB
|
||||
let entry = db
|
||||
// We can `expect` here as the request id should always be in the entries
|
||||
let entry = entries
|
||||
.remove(&output.request.unwrap().id)
|
||||
.expect("ID not found in db. This is a bug.");
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
|
||||
let response = InferResponse {
|
||||
output_text: output.output_text,
|
||||
@ -221,18 +226,30 @@ pub(crate) struct InferResponse {
|
||||
pub enum InferError {
|
||||
#[error("Request failed during generation: {0}")]
|
||||
GenerationError(String),
|
||||
#[error("Model is overloaded")]
|
||||
Overloaded,
|
||||
}
|
||||
|
||||
/// Convert semaphore error
|
||||
impl From<TryAcquireError> for InferError {
|
||||
fn from(_: TryAcquireError) -> Self {
|
||||
InferError::Overloaded
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to Axum supported format
|
||||
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
||||
fn from(err: InferError) -> Self {
|
||||
match err {
|
||||
InferError::GenerationError(_) => (
|
||||
StatusCode::FAILED_DEPENDENCY,
|
||||
let status_code = match err {
|
||||
InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY,
|
||||
InferError::Overloaded => StatusCode::TOO_MANY_REQUESTS,
|
||||
};
|
||||
|
||||
(
|
||||
status_code,
|
||||
Json(ErrorResponse {
|
||||
error: err.to_string(),
|
||||
}),
|
||||
),
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ use crate::InferResponse;
|
||||
/// This code is massively inspired by Tokio mini-redis
|
||||
use crate::{GenerateParameters, GenerateRequest};
|
||||
use parking_lot::Mutex;
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{
|
||||
Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
@ -112,18 +112,12 @@ impl Db {
|
||||
state.entries.insert(id, entry);
|
||||
}
|
||||
|
||||
/// Remove an entry from the database if it exists
|
||||
pub(crate) fn remove(&self, id: &u64) -> Option<Entry> {
|
||||
let mut state = self.shared.state.lock();
|
||||
state.entries.remove(id)
|
||||
}
|
||||
|
||||
// Get the next batch
|
||||
pub(crate) fn next_batch(
|
||||
&self,
|
||||
min_size: Option<usize>,
|
||||
max_size: usize,
|
||||
) -> Option<(Vec<u64>, Batch)> {
|
||||
) -> Option<(HashMap<u64, Entry>, Batch)> {
|
||||
// Acquire lock
|
||||
let mut state = self.shared.state.lock();
|
||||
|
||||
@ -135,13 +129,19 @@ impl Db {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
ids.iter().for_each(|id| {
|
||||
// Set batch_time for each request
|
||||
state.entries.get_mut(id).unwrap().batch_time = Some(Instant::now());
|
||||
});
|
||||
|
||||
// Batch size
|
||||
let size = requests.len();
|
||||
|
||||
let mut entries = HashMap::with_capacity(size);
|
||||
ids.iter().for_each(|id| {
|
||||
// Remove entry from db
|
||||
let mut entry = state.entries.remove(id).unwrap();
|
||||
// Set batch_time
|
||||
entry.batch_time = Some(Instant::now());
|
||||
// Insert in entries hashmap
|
||||
entries.insert(*id, entry);
|
||||
});
|
||||
|
||||
let batch = Batch {
|
||||
id: state.next_batch_id,
|
||||
requests,
|
||||
@ -152,7 +152,7 @@ impl Db {
|
||||
// Increment batch id
|
||||
state.next_batch_id += 1;
|
||||
|
||||
return Some((ids, batch));
|
||||
return Some((entries, batch));
|
||||
}
|
||||
None
|
||||
}
|
||||
|
@ -7,11 +7,9 @@ use axum::response::IntoResponse;
|
||||
use axum::routing::{get, post};
|
||||
use axum::{Json, Router};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::ShardedClient;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::signal;
|
||||
use tokio::sync::Semaphore;
|
||||
use tokio::time::Instant;
|
||||
use tracing::instrument;
|
||||
|
||||
@ -20,7 +18,6 @@ use tracing::instrument;
|
||||
struct ServerState {
|
||||
validation: Validation,
|
||||
batcher: Batcher,
|
||||
limit_concurrent_requests: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
/// Health check method
|
||||
@ -30,16 +27,6 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
|
||||
// be a bit too slow for a health check.
|
||||
// What we should do instead if check if the gRPC channels are still healthy.
|
||||
|
||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
||||
(
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
Json(ErrorResponse {
|
||||
error: "Model is overloaded".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Send a small inference request
|
||||
state
|
||||
.batcher
|
||||
@ -78,16 +65,6 @@ async fn generate(
|
||||
req: Json<GenerateRequest>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
||||
let start_time = Instant::now();
|
||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
||||
tracing::error!("Model is overloaded");
|
||||
(
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
Json(ErrorResponse {
|
||||
error: "Model is overloaded".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Validate request
|
||||
let details = req.0.parameters.details;
|
||||
@ -185,12 +162,16 @@ pub async fn run(
|
||||
addr: SocketAddr,
|
||||
) {
|
||||
// Create state
|
||||
let batcher = Batcher::new(client, max_batch_size, max_waiting_tokens);
|
||||
let batcher = Batcher::new(
|
||||
client,
|
||||
max_batch_size,
|
||||
max_waiting_tokens,
|
||||
max_concurrent_requests,
|
||||
);
|
||||
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
||||
let shared_state = ServerState {
|
||||
validation,
|
||||
batcher,
|
||||
limit_concurrent_requests: Arc::new(Semaphore::new(max_concurrent_requests)),
|
||||
};
|
||||
|
||||
// Create router
|
||||
|
Loading…
Reference in New Issue
Block a user