mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Proposal: Use bounded queue instead of database
Entries are moved into a hashmap owned by the batching loop at time of request batching. - Less code, less locking - Fewer synchronization primitives - replaces mutex, arc, notifier, semaphore
This commit is contained in:
parent
1f570d181f
commit
d0ccada7c0
7
Cargo.lock
generated
7
Cargo.lock
generated
@ -1087,6 +1087,12 @@ dependencies = [
|
|||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nohash-hasher"
|
||||||
|
version = "0.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "nom"
|
name = "nom"
|
||||||
version = "7.1.1"
|
version = "7.1.1"
|
||||||
@ -1826,6 +1832,7 @@ dependencies = [
|
|||||||
"axum",
|
"axum",
|
||||||
"clap 4.0.22",
|
"clap 4.0.22",
|
||||||
"futures",
|
"futures",
|
||||||
|
"nohash-hasher",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
@ -17,6 +17,7 @@ axum = { version = "0.5.16", features = ["json", "serde_json"] }
|
|||||||
text-generation-client = { path = "client" }
|
text-generation-client = { path = "client" }
|
||||||
clap = { version = "4.0.15", features = ["derive", "env"] }
|
clap = { version = "4.0.15", features = ["derive", "env"] }
|
||||||
futures = "0.3.24"
|
futures = "0.3.24"
|
||||||
|
nohash-hasher = "0.2.0"
|
||||||
parking_lot = "0.12.1"
|
parking_lot = "0.12.1"
|
||||||
serde = "1.0.145"
|
serde = "1.0.145"
|
||||||
serde_json = "1.0.85"
|
serde_json = "1.0.85"
|
||||||
|
@ -1,66 +1,70 @@
|
|||||||
/// Batching and inference logic
|
/// Batching and inference logic
|
||||||
use crate::{Db, Entry};
|
use crate::Entry;
|
||||||
use crate::{ErrorResponse, GenerateRequest};
|
use crate::{ErrorResponse, GenerateRequest};
|
||||||
use axum::http::StatusCode;
|
use axum::http::StatusCode;
|
||||||
use axum::Json;
|
use axum::Json;
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::sync::Arc;
|
use nohash_hasher::IntMap;
|
||||||
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::{oneshot, Notify};
|
use tokio::sync::oneshot;
|
||||||
|
use tokio::sync::mpsc::{channel, Permit, Sender};
|
||||||
|
use tokio::sync::mpsc::error::TrySendError;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
use crate::queue::Queue;
|
||||||
|
|
||||||
/// Batcher
|
/// Batcher
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Batcher {
|
pub struct Batcher {
|
||||||
/// Request database
|
/// Request queue
|
||||||
db: Db,
|
sender: Sender<Entry>,
|
||||||
/// Shared state
|
|
||||||
shared: Arc<Shared>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Batcher shared state
|
|
||||||
struct Shared {
|
|
||||||
/// Batching background Tokio task notifier
|
|
||||||
batching_task: Notify,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Batcher {
|
impl Batcher {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
max_batch_size: usize,
|
max_batch_size: usize,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
|
queue_size: usize,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Batcher shared state
|
// Set up queue
|
||||||
let db = Db::new();
|
let (sender, receiver) = channel(queue_size);
|
||||||
let shared = Arc::new(Shared {
|
|
||||||
batching_task: Notify::new(),
|
|
||||||
});
|
|
||||||
|
|
||||||
// Spawn batching background task that contains all the inference logic
|
// Spawn batching background task that contains all the inference logic
|
||||||
tokio::spawn(batching_task(
|
tokio::spawn(batching_task(
|
||||||
client,
|
client,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
db.clone(),
|
Queue::new(receiver),
|
||||||
shared.clone(),
|
|
||||||
));
|
));
|
||||||
|
|
||||||
Self { db, shared }
|
Self { sender }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a new request to the database and return a future that will generate the text
|
/// Reserve a slot in the queue for sending a request
|
||||||
|
pub(crate) fn reserve_slot(&self) -> Result<RequestSender<'_>, TrySendError<()>> {
|
||||||
|
self.sender.try_reserve().map(|permit| RequestSender { permit })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct RequestSender<'a> {
|
||||||
|
permit: Permit<'a, Entry>
|
||||||
|
}
|
||||||
|
|
||||||
|
impl <'a> RequestSender<'a> {
|
||||||
|
/// Add a new request to the queue and return a future that will generate the text
|
||||||
pub(crate) async fn infer(
|
pub(crate) async fn infer(
|
||||||
&self,
|
self,
|
||||||
input_length: usize,
|
input_length: usize,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> Result<InferResponse, InferError> {
|
) -> Result<InferResponse, InferError> {
|
||||||
// One shot channel to communicate with the background batching task
|
// One shot channel to communicate with the background batching task
|
||||||
let (response_tx, response_rx) = oneshot::channel();
|
let (response_tx, response_rx) = oneshot::channel();
|
||||||
|
|
||||||
// Try to append the request to the database
|
// Try to enqueue the request
|
||||||
self.db.append(Entry {
|
self.permit.send(Entry {
|
||||||
request,
|
request,
|
||||||
response_tx,
|
response_tx,
|
||||||
input_length,
|
input_length,
|
||||||
@ -68,10 +72,6 @@ impl Batcher {
|
|||||||
batch_time: None,
|
batch_time: None,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Notify the background task that we have a new entry in the database that needs
|
|
||||||
// to be batched
|
|
||||||
self.shared.batching_task.notify_one();
|
|
||||||
|
|
||||||
// Await on the response from the background task
|
// Await on the response from the background task
|
||||||
// We can safely unwrap as the background task will never drop the sender
|
// We can safely unwrap as the background task will never drop the sender
|
||||||
response_rx
|
response_rx
|
||||||
@ -85,68 +85,69 @@ impl Batcher {
|
|||||||
/// Will be launched in a background Tokio task
|
/// Will be launched in a background Tokio task
|
||||||
///
|
///
|
||||||
/// Batches requests and sends them to the inference server
|
/// Batches requests and sends them to the inference server
|
||||||
#[instrument(skip(client, db, shared))]
|
#[instrument(skip(client, queue))]
|
||||||
async fn batching_task(
|
async fn batching_task(
|
||||||
mut client: ShardedClient,
|
mut client: ShardedClient,
|
||||||
max_batch_size: usize,
|
max_batch_size: usize,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
db: Db,
|
mut queue: Queue,
|
||||||
shared: Arc<Shared>,
|
|
||||||
) {
|
) {
|
||||||
// Minimum batch size after which we try to add more requests
|
// Minimum batch size after which we try to add more requests
|
||||||
let limit_min_batch_size = (max_batch_size / 2) as u32;
|
let limit_min_batch_size = (max_batch_size / 2) as u32;
|
||||||
|
|
||||||
// Infinite loop
|
// Entries corresponding to all of the in-progress requests
|
||||||
loop {
|
let mut entries = IntMap::default();
|
||||||
// Wait for a notification from the Batcher struct
|
|
||||||
shared.batching_task.notified().await;
|
|
||||||
|
|
||||||
// Get the next batch from the DB
|
// Get the next batch from the queue
|
||||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||||
// waiting in the DB
|
// waiting in the queue
|
||||||
while let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
|
while let Some(batch) = queue.next_batch(max_batch_size, &mut entries).await {
|
||||||
let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
|
let mut cached_batch = wrap_future(
|
||||||
let mut waiting_tokens = 1;
|
client.generate(batch), None, &mut entries
|
||||||
|
).await;
|
||||||
|
let mut waiting_tokens = 1;
|
||||||
|
|
||||||
// We loop until we do not receive any cached batch from the inference server (== until
|
// We loop until we do not receive any cached batch from the inference server (== until
|
||||||
// all requests have met their stopping criteria)
|
// all requests have met their stopping criteria)
|
||||||
while let Some(batch) = cached_batch {
|
while let Some(batch) = cached_batch {
|
||||||
// Get current batch info
|
// Get current batch info
|
||||||
let batch_size = batch.size;
|
let batch_size = batch.size;
|
||||||
let mut request_ids: Vec<u64> = batch.requests.iter().map(|req| req.id).collect();
|
let mut batches = vec![batch];
|
||||||
let mut batches = vec![batch];
|
|
||||||
|
|
||||||
// If the current batch is too small, we try to add more requests to it
|
// If the current batch is too small, we try to add more requests to it
|
||||||
if batch_size <= limit_min_batch_size {
|
if batch_size <= limit_min_batch_size {
|
||||||
let min_size = match waiting_tokens {
|
let min_size = match waiting_tokens {
|
||||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||||
// to add a new batch even though its size might be small
|
// to add a new batch even though its size might be small
|
||||||
_ if waiting_tokens >= max_waiting_tokens => None,
|
_ if waiting_tokens >= max_waiting_tokens => 1,
|
||||||
// Minimum size criteria
|
// Minimum size criteria
|
||||||
_ => Some(limit_min_batch_size as usize),
|
_ => limit_min_batch_size as usize,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Try to get a new batch
|
// Try to get a new batch
|
||||||
if let Some((new_request_ids, new_batch)) =
|
if let Some(new_batch) = queue.try_next_batch(
|
||||||
db.next_batch(min_size, max_batch_size - batch_size as usize)
|
min_size, max_batch_size - batch_size as usize, &mut entries
|
||||||
{
|
) {
|
||||||
// Generate one token for this new batch to have the attention past in cache
|
let first_new_id = new_batch.requests.first()
|
||||||
let new_cached_batch =
|
.expect("batch can't be empty here").id;
|
||||||
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
// Reset waiting counter
|
let new_cached_batch = wrap_future(
|
||||||
waiting_tokens = 1;
|
client.generate(new_batch), Some(first_new_id), &mut entries
|
||||||
// Extend current batch with the new batch
|
).await;
|
||||||
if let Some(new_cached_batch) = new_cached_batch {
|
|
||||||
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
|
// Reset waiting counter
|
||||||
batches.push(new_cached_batch);
|
waiting_tokens = 1;
|
||||||
}
|
// Extend current batch with the new batch
|
||||||
|
if let Some(new_cached_batch) = new_cached_batch {
|
||||||
|
batches.push(new_cached_batch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cached_batch =
|
|
||||||
wrap_future(client.generate_with_cache(batches), request_ids, &db).await;
|
|
||||||
waiting_tokens += 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cached_batch = wrap_future(
|
||||||
|
client.generate_with_cache(batches), None, &mut entries
|
||||||
|
).await;
|
||||||
|
waiting_tokens += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -154,39 +155,45 @@ async fn batching_task(
|
|||||||
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
|
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
|
||||||
async fn wrap_future(
|
async fn wrap_future(
|
||||||
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
|
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
|
||||||
request_ids: Vec<u64>,
|
// First request id in this batch if it doesn't comprise all current entries
|
||||||
db: &Db,
|
start_id: Option<u64>,
|
||||||
|
entries: &mut IntMap<u64, Entry>,
|
||||||
) -> Option<Batch> {
|
) -> Option<Batch> {
|
||||||
match future.await {
|
match future.await {
|
||||||
Ok((generated_texts, next_batch)) => {
|
Ok((generated_texts, next_batch)) => {
|
||||||
send_generated(generated_texts, db);
|
send_generated(generated_texts, entries);
|
||||||
next_batch
|
next_batch
|
||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
send_error(err, request_ids, db);
|
send_error(err, start_id, entries);
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send errors to the Batcher for all `request_ids`
|
/// Send errors to the Batcher for all failed entries
|
||||||
fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
|
fn send_error(error: ClientError, start_id: Option<u64>, entries: &mut IntMap<u64, Entry>) {
|
||||||
request_ids.into_iter().for_each(|id| {
|
let to_keep = entries.drain().filter_map(|(id, entry)| match start_id {
|
||||||
// We can `expect` here as the request id should always be in the DB
|
// Keep entries that weren't in the failed request batch
|
||||||
let entry = db.remove(&id).expect("ID not found in db. This is a bug.");
|
Some(sid) if id < sid => Some((id, entry)),
|
||||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
_ => {
|
||||||
entry.response_tx.send(Err(error.clone())).unwrap_or(());
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
});
|
entry.response_tx.send(Err(error.clone())).unwrap_or(());
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}).collect::<IntMap<u64, Entry>>();
|
||||||
|
// Workaround since drain_filter() is not yet stable. This will be empty when start_id == None.
|
||||||
|
entries.extend(to_keep);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send `generated_text` to the Batcher for all `finished`
|
/// Send `generated_text` to the Batcher for all `finished`
|
||||||
fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
fn send_generated(finished: Vec<GeneratedText>, entries: &mut IntMap<u64, Entry>) {
|
||||||
finished.into_iter().for_each(|output| {
|
finished.into_iter().for_each(|output| {
|
||||||
// We can `expect` here as the request id should always be in the DB
|
// We can `expect` here as the request id should always be in the map
|
||||||
let entry = db
|
let entry = entries
|
||||||
.remove(&output.request.unwrap().id)
|
.remove(&output.request.unwrap().id)
|
||||||
.expect("ID not found in db. This is a bug.");
|
.expect("ID not found. This is a bug.");
|
||||||
|
|
||||||
let response = InferResponse {
|
let response = InferResponse {
|
||||||
output_text: output.output_text,
|
output_text: output.output_text,
|
||||||
|
179
router/src/db.rs
179
router/src/db.rs
@ -1,179 +0,0 @@
|
|||||||
use crate::InferResponse;
|
|
||||||
/// This code is massively inspired by Tokio mini-redis
|
|
||||||
use crate::{GenerateParameters, GenerateRequest};
|
|
||||||
use parking_lot::Mutex;
|
|
||||||
use std::collections::BTreeMap;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use text_generation_client::{
|
|
||||||
Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
|
||||||
};
|
|
||||||
use tokio::sync::oneshot::Sender;
|
|
||||||
use tokio::time::Instant;
|
|
||||||
|
|
||||||
/// Database entry
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub(crate) struct Entry {
|
|
||||||
/// Request
|
|
||||||
pub request: GenerateRequest,
|
|
||||||
/// Response sender to communicate between the Batcher and the batching_task
|
|
||||||
pub response_tx: Sender<Result<InferResponse, ClientError>>,
|
|
||||||
/// Number of tokens in the input
|
|
||||||
pub input_length: usize,
|
|
||||||
/// Instant when this entry was created
|
|
||||||
pub time: Instant,
|
|
||||||
/// Instant when this entry was added to a batch
|
|
||||||
pub batch_time: Option<Instant>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Request Database
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub(crate) struct Db {
|
|
||||||
pub shared: Arc<Shared>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Shared state
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Shared {
|
|
||||||
state: Mutex<State>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Database State
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct State {
|
|
||||||
/// Database entries organized in a BTreeMap to be able to iterate over them in order
|
|
||||||
entries: BTreeMap<u64, Entry>,
|
|
||||||
|
|
||||||
/// Id of the next entry
|
|
||||||
next_id: u64,
|
|
||||||
|
|
||||||
/// Id of the next batch
|
|
||||||
next_batch_id: u64,
|
|
||||||
|
|
||||||
/// Start ID of the next batch. Used to iterate inside the entries BTreeMap
|
|
||||||
next_batch_start_id: u64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl State {
|
|
||||||
/// Get the next requests
|
|
||||||
fn next_requests(&self, max_size: usize) -> Option<(Vec<u64>, Vec<Request>)> {
|
|
||||||
// Iterates for max_size over the BTreemap starting from next_batch_start_id
|
|
||||||
let mut requests = Vec::new();
|
|
||||||
let mut ids = Vec::new();
|
|
||||||
|
|
||||||
for (id, entry) in self
|
|
||||||
.entries
|
|
||||||
// Start from next_batch_start_id
|
|
||||||
.range(self.next_batch_start_id..)
|
|
||||||
// Take max_size
|
|
||||||
.take(max_size)
|
|
||||||
{
|
|
||||||
requests.push(Request {
|
|
||||||
id: *id,
|
|
||||||
inputs: entry.request.inputs.clone(),
|
|
||||||
input_length: entry.input_length as u32,
|
|
||||||
parameters: Some((&entry.request.parameters).into()),
|
|
||||||
stopping_parameters: Some(entry.request.parameters.clone().into()),
|
|
||||||
});
|
|
||||||
|
|
||||||
ids.push(*id);
|
|
||||||
}
|
|
||||||
|
|
||||||
if requests.is_empty() {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some((ids, requests))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Db {
|
|
||||||
pub(crate) fn new() -> Self {
|
|
||||||
// Shared state
|
|
||||||
let shared = Arc::new(Shared {
|
|
||||||
state: Mutex::new(State {
|
|
||||||
entries: BTreeMap::new(),
|
|
||||||
next_id: 0,
|
|
||||||
next_batch_id: 0,
|
|
||||||
next_batch_start_id: 0,
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
|
|
||||||
Self { shared }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Append an entry to the database
|
|
||||||
pub(crate) fn append(&self, entry: Entry) {
|
|
||||||
// Acquire lock
|
|
||||||
let mut state = self.shared.state.lock();
|
|
||||||
|
|
||||||
// Insert entry
|
|
||||||
let id = state.next_id;
|
|
||||||
state.next_id += 1;
|
|
||||||
state.entries.insert(id, entry);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Remove an entry from the database if it exists
|
|
||||||
pub(crate) fn remove(&self, id: &u64) -> Option<Entry> {
|
|
||||||
let mut state = self.shared.state.lock();
|
|
||||||
state.entries.remove(id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the next batch
|
|
||||||
pub(crate) fn next_batch(
|
|
||||||
&self,
|
|
||||||
min_size: Option<usize>,
|
|
||||||
max_size: usize,
|
|
||||||
) -> Option<(Vec<u64>, Batch)> {
|
|
||||||
// Acquire lock
|
|
||||||
let mut state = self.shared.state.lock();
|
|
||||||
|
|
||||||
// Get requests from the database
|
|
||||||
if let Some((ids, requests)) = state.next_requests(max_size) {
|
|
||||||
if let Some(min_size) = min_size {
|
|
||||||
// If min_size is set, only return a batch if there are enough requests
|
|
||||||
if requests.len() < min_size {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ids.iter().for_each(|id| {
|
|
||||||
// Set batch_time for each request
|
|
||||||
state.entries.get_mut(id).unwrap().batch_time = Some(Instant::now());
|
|
||||||
});
|
|
||||||
|
|
||||||
// Batch size
|
|
||||||
let size = requests.len();
|
|
||||||
let batch = Batch {
|
|
||||||
id: state.next_batch_id,
|
|
||||||
requests,
|
|
||||||
size: size as u32,
|
|
||||||
};
|
|
||||||
// Update next_batch_start_id to the last id in the batch + 1
|
|
||||||
state.next_batch_start_id = ids.last().unwrap() + 1;
|
|
||||||
// Increment batch id
|
|
||||||
state.next_batch_id += 1;
|
|
||||||
|
|
||||||
return Some((ids, batch));
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<&GenerateParameters> for NextTokenChooserParameters {
|
|
||||||
fn from(parameters: &GenerateParameters) -> Self {
|
|
||||||
Self {
|
|
||||||
temperature: parameters.temperature,
|
|
||||||
top_k: parameters.top_k as u32,
|
|
||||||
top_p: parameters.top_p,
|
|
||||||
do_sample: parameters.do_sample,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<GenerateParameters> for StoppingCriteriaParameters {
|
|
||||||
fn from(parameters: GenerateParameters) -> Self {
|
|
||||||
Self {
|
|
||||||
stop_sequences: parameters.stop,
|
|
||||||
max_new_tokens: parameters.max_new_tokens,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,11 +1,11 @@
|
|||||||
/// Text Generation Inference Webserver
|
/// Text Generation Inference Webserver
|
||||||
mod batcher;
|
mod batcher;
|
||||||
mod db;
|
mod queue;
|
||||||
pub mod server;
|
pub mod server;
|
||||||
mod validation;
|
mod validation;
|
||||||
|
|
||||||
use batcher::{Batcher, InferResponse};
|
use batcher::{Batcher, InferResponse};
|
||||||
use db::{Db, Entry};
|
use queue::Entry;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use validation::Validation;
|
use validation::Validation;
|
||||||
|
|
||||||
|
137
router/src/queue.rs
Normal file
137
router/src/queue.rs
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
use std::cmp::min;
|
||||||
|
use crate::InferResponse;
|
||||||
|
use crate::{GenerateParameters, GenerateRequest};
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use nohash_hasher::IntMap;
|
||||||
|
use tokio::sync::mpsc::Receiver;
|
||||||
|
use text_generation_client::{
|
||||||
|
Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
|
use tokio::sync::oneshot::Sender;
|
||||||
|
use tokio::time::Instant;
|
||||||
|
|
||||||
|
/// In-flight request record
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct Entry {
|
||||||
|
/// Request
|
||||||
|
pub request: GenerateRequest,
|
||||||
|
/// Response sender to communicate between the Batcher and the batching_task
|
||||||
|
pub response_tx: Sender<Result<InferResponse, ClientError>>,
|
||||||
|
/// Number of tokens in the input
|
||||||
|
pub input_length: usize,
|
||||||
|
/// Instant when this entry was created
|
||||||
|
pub time: Instant,
|
||||||
|
/// Instant when this entry was added to a batch
|
||||||
|
pub batch_time: Option<Instant>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Request Queue
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct Queue {
|
||||||
|
receiver: Receiver<Entry>,
|
||||||
|
buffer: VecDeque<Entry>,
|
||||||
|
/// Id of the next entry
|
||||||
|
next_id: u64,
|
||||||
|
/// Id of the next batch
|
||||||
|
next_batch_id: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
impl Queue {
|
||||||
|
pub(crate) fn new(receiver: Receiver<Entry>) -> Self {
|
||||||
|
Self { receiver, buffer: VecDeque::new(), next_id: 0, next_batch_id: 0 }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the next batch, blocking until available
|
||||||
|
/// Corresponding entries are added to the entries map
|
||||||
|
pub(crate) async fn next_batch(
|
||||||
|
&mut self,
|
||||||
|
max_size: usize,
|
||||||
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
) -> Option<Batch> {
|
||||||
|
loop {
|
||||||
|
if self.buffer.is_empty() {
|
||||||
|
match self.receiver.recv().await {
|
||||||
|
Some(ent) => self.buffer.push_back(ent),
|
||||||
|
None => return None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(batch) = self.try_next_batch(1, max_size, entries) {
|
||||||
|
return Some(batch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the next batch without blocking
|
||||||
|
/// Corresponding entries are added to the entries map
|
||||||
|
pub(crate) fn try_next_batch(
|
||||||
|
&mut self,
|
||||||
|
min_size: usize,
|
||||||
|
max_size: usize,
|
||||||
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
) -> Option<Batch> {
|
||||||
|
while self.buffer.len() < max_size {
|
||||||
|
match self.receiver.try_recv() {
|
||||||
|
Ok(ent) => self.buffer.push_back(ent),
|
||||||
|
_ => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let len = self.buffer.len();
|
||||||
|
if len < min_size || len == 0 {
|
||||||
|
// Can't get minimum
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let now = Some(Instant::now());
|
||||||
|
let requests = self.buffer.drain(..min(len, max_size))
|
||||||
|
.map(|mut entry| {
|
||||||
|
let id = self.next_id;
|
||||||
|
self.next_id += 1;
|
||||||
|
let request = Request {
|
||||||
|
id,
|
||||||
|
inputs: entry.request.inputs.clone(),
|
||||||
|
input_length: entry.input_length as u32,
|
||||||
|
parameters: Some((&entry.request.parameters).into()),
|
||||||
|
stopping_parameters: Some(entry.request.parameters.clone().into()),
|
||||||
|
};
|
||||||
|
entry.batch_time = now;
|
||||||
|
entries.insert(id, entry);
|
||||||
|
request
|
||||||
|
})
|
||||||
|
.collect::<Vec<Request>>();
|
||||||
|
|
||||||
|
// Batch size
|
||||||
|
let size = requests.len();
|
||||||
|
let batch = Batch {
|
||||||
|
id: self.next_batch_id,
|
||||||
|
requests,
|
||||||
|
size: size as u32,
|
||||||
|
};
|
||||||
|
// Increment batch id
|
||||||
|
self.next_batch_id += 1;
|
||||||
|
|
||||||
|
Some(batch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
impl From<&GenerateParameters> for NextTokenChooserParameters {
|
||||||
|
fn from(parameters: &GenerateParameters) -> Self {
|
||||||
|
Self {
|
||||||
|
temperature: parameters.temperature,
|
||||||
|
top_k: parameters.top_k as u32,
|
||||||
|
top_p: parameters.top_p,
|
||||||
|
do_sample: parameters.do_sample,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<GenerateParameters> for StoppingCriteriaParameters {
|
||||||
|
fn from(parameters: GenerateParameters) -> Self {
|
||||||
|
Self {
|
||||||
|
stop_sequences: parameters.stop,
|
||||||
|
max_new_tokens: parameters.max_new_tokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -7,11 +7,9 @@ use axum::response::IntoResponse;
|
|||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{Json, Router};
|
use axum::{Json, Router};
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
|
||||||
use text_generation_client::ShardedClient;
|
use text_generation_client::ShardedClient;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
use tokio::sync::Semaphore;
|
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
@ -20,7 +18,6 @@ use tracing::instrument;
|
|||||||
struct ServerState {
|
struct ServerState {
|
||||||
validation: Validation,
|
validation: Validation,
|
||||||
batcher: Batcher,
|
batcher: Batcher,
|
||||||
limit_concurrent_requests: Arc<Semaphore>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Health check method
|
/// Health check method
|
||||||
@ -30,8 +27,8 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
|
|||||||
// be a bit too slow for a health check.
|
// be a bit too slow for a health check.
|
||||||
// What we should do instead if check if the gRPC channels are still healthy.
|
// What we should do instead if check if the gRPC channels are still healthy.
|
||||||
|
|
||||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
// Limit concurrent requests by reserving a slot in the queue
|
||||||
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
let sender = state.batcher.reserve_slot().map_err(|_| {
|
||||||
(
|
(
|
||||||
StatusCode::TOO_MANY_REQUESTS,
|
StatusCode::TOO_MANY_REQUESTS,
|
||||||
Json(ErrorResponse {
|
Json(ErrorResponse {
|
||||||
@ -41,24 +38,22 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
|
|||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Send a small inference request
|
// Send a small inference request
|
||||||
state
|
sender.infer(
|
||||||
.batcher
|
1,
|
||||||
.infer(
|
GenerateRequest {
|
||||||
1,
|
inputs: "liveness".to_string(),
|
||||||
GenerateRequest {
|
parameters: GenerateParameters {
|
||||||
inputs: "liveness".to_string(),
|
temperature: 1.0,
|
||||||
parameters: GenerateParameters {
|
top_k: 0,
|
||||||
temperature: 1.0,
|
top_p: 1.0,
|
||||||
top_k: 0,
|
do_sample: false,
|
||||||
top_p: 1.0,
|
max_new_tokens: 1,
|
||||||
do_sample: false,
|
stop: vec![],
|
||||||
max_new_tokens: 1,
|
details: false,
|
||||||
stop: vec![],
|
|
||||||
details: false,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
)
|
},
|
||||||
.await?;
|
)
|
||||||
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,8 +73,8 @@ async fn generate(
|
|||||||
req: Json<GenerateRequest>,
|
req: Json<GenerateRequest>,
|
||||||
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
// Limit concurrent requests by reserving a slot in the queue
|
||||||
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
let sender = state.batcher.reserve_slot().map_err(|_| {
|
||||||
tracing::error!("Model is overloaded");
|
tracing::error!("Model is overloaded");
|
||||||
(
|
(
|
||||||
StatusCode::TOO_MANY_REQUESTS,
|
StatusCode::TOO_MANY_REQUESTS,
|
||||||
@ -98,8 +93,7 @@ async fn generate(
|
|||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Inference
|
// Inference
|
||||||
let response = state
|
let response = sender
|
||||||
.batcher
|
|
||||||
.infer(input_length, validated_request)
|
.infer(input_length, validated_request)
|
||||||
.await
|
.await
|
||||||
.map_err(|err| {
|
.map_err(|err| {
|
||||||
@ -185,12 +179,13 @@ pub async fn run(
|
|||||||
addr: SocketAddr,
|
addr: SocketAddr,
|
||||||
) {
|
) {
|
||||||
// Create state
|
// Create state
|
||||||
let batcher = Batcher::new(client, max_batch_size, max_waiting_tokens);
|
let batcher = Batcher::new(
|
||||||
|
client, max_batch_size, max_waiting_tokens, max_concurrent_requests
|
||||||
|
);
|
||||||
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
||||||
let shared_state = ServerState {
|
let shared_state = ServerState {
|
||||||
validation,
|
validation,
|
||||||
batcher,
|
batcher,
|
||||||
limit_concurrent_requests: Arc::new(Semaphore::new(max_concurrent_requests)),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create router
|
// Create router
|
||||||
|
Loading…
Reference in New Issue
Block a user