text-generation-inference/router/src/db.rs

183 lines
5.3 KiB
Rust
Raw Normal View History

2022-10-08 10:30:12 +00:00
/// This code is massively inspired by Tokio mini-redis
2022-10-18 13:19:03 +00:00
use crate::{GenerateParameters, GenerateRequest};
2022-10-08 10:30:12 +00:00
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
2022-10-18 13:19:03 +00:00
use parking_lot::Mutex;
2022-10-08 10:30:12 +00:00
use std::collections::BTreeMap;
use std::sync::Arc;
2022-10-18 13:19:03 +00:00
use std::time::Duration;
2022-10-08 10:30:12 +00:00
use tokio::sync::oneshot::Sender;
2022-10-18 13:19:03 +00:00
use tokio::time::Instant;
2022-10-08 10:30:12 +00:00
2022-10-18 13:19:03 +00:00
/// Database entry
2022-10-17 12:59:00 +00:00
#[derive(Debug)]
pub(crate) struct Entry {
2022-10-18 13:19:03 +00:00
/// Request
2022-10-17 12:59:00 +00:00
pub request: GenerateRequest,
2022-10-18 13:19:03 +00:00
/// Response sender to communicate between the Batcher and the batching_task
2022-10-17 12:59:00 +00:00
pub response_tx: Sender<Result<String, ClientError>>,
2022-10-18 13:19:03 +00:00
/// Number of tokens in the input
2022-10-17 12:59:00 +00:00
pub input_length: usize,
2022-10-18 13:19:03 +00:00
/// Instant when this entry was created
pub time: Instant,
2022-10-17 12:59:00 +00:00
}
2022-10-18 13:19:03 +00:00
/// Request Database
2022-10-08 10:30:12 +00:00
#[derive(Debug, Clone)]
pub(crate) struct Db {
pub shared: Arc<Shared>,
}
2022-10-18 13:19:03 +00:00
/// Shared state
2022-10-08 10:30:12 +00:00
#[derive(Debug)]
pub struct Shared {
2022-10-18 13:19:03 +00:00
state: Mutex<State>,
2022-10-08 10:30:12 +00:00
}
2022-10-18 13:19:03 +00:00
/// Database State
2022-10-08 10:30:12 +00:00
#[derive(Debug)]
struct State {
2022-10-18 13:19:03 +00:00
/// Database entries organized in a BTreeMap to be able to iterate over them in order
2022-10-17 12:59:00 +00:00
entries: BTreeMap<u64, Entry>,
2022-10-08 10:30:12 +00:00
2022-10-18 13:19:03 +00:00
/// Id of the next entry
2022-10-08 10:30:12 +00:00
next_id: u64,
2022-10-18 13:19:03 +00:00
/// Id of the next batch
2022-10-08 10:30:12 +00:00
next_batch_id: u64,
2022-10-18 13:19:03 +00:00
/// Start ID of the next batch. Used to iterate inside the entries BTreeMap
2022-10-08 10:30:12 +00:00
next_batch_start_id: u64,
}
2022-10-18 13:19:03 +00:00
impl State {
/// Get the next requests
fn next_requests(
&self,
max_size: usize,
min_waiting_time: Option<Duration>,
) -> Option<(Vec<u64>, Vec<Request>)> {
// Iterates for max_size over the BTreemap starting from next_batch_start_id
let mut requests = Vec::new();
let mut ids = Vec::new();
for (id, entry) in self
.entries
// Start from next_batch_start_id
.range(self.next_batch_start_id..)
// Take max_size
.take(max_size)
{
if let Some(min_waiting_time) = min_waiting_time {
// Only take entries that waited for at least min_waiting_time
if entry.time.elapsed() < min_waiting_time {
// Since entries are ordered, we already know that all following entries won't
// satisfy the condition
break;
}
}
requests.push(Request {
id: *id,
inputs: entry.request.inputs.clone(),
input_length: entry.input_length as u32,
parameters: Some(LogitsWarperParameters::from(
entry.request.parameters.clone(),
)),
max_new_tokens: entry.request.parameters.max_new_tokens,
});
ids.push(*id);
}
if requests.is_empty() {
None
} else {
Some((ids, requests))
}
}
}
2022-10-08 10:30:12 +00:00
impl Db {
pub(crate) fn new() -> Self {
2022-10-18 13:19:03 +00:00
// Shared state
2022-10-08 10:30:12 +00:00
let shared = Arc::new(Shared {
2022-10-18 13:19:03 +00:00
state: Mutex::new(State {
2022-10-08 10:30:12 +00:00
entries: BTreeMap::new(),
next_id: 0,
next_batch_id: 0,
next_batch_start_id: 0,
}),
});
Self { shared }
}
2022-10-18 13:19:03 +00:00
/// Append an entry to the database
2022-10-17 12:59:00 +00:00
pub(crate) fn append(&self, entry: Entry) {
2022-10-18 13:19:03 +00:00
// Acquire lock
let mut state = self.shared.state.lock();
2022-10-08 10:30:12 +00:00
2022-10-18 13:19:03 +00:00
// Insert entry
2022-10-08 10:30:12 +00:00
let id = state.next_id;
state.next_id += 1;
2022-10-17 12:59:00 +00:00
state.entries.insert(id, entry);
2022-10-08 10:30:12 +00:00
}
2022-10-18 13:19:03 +00:00
/// Remove an entry from the database if it exists
2022-10-17 12:59:00 +00:00
pub(crate) fn remove(&self, id: &u64) -> Option<Entry> {
2022-10-18 13:19:03 +00:00
let mut state = self.shared.state.lock();
2022-10-08 10:30:12 +00:00
state.entries.remove(id)
}
2022-10-18 13:19:03 +00:00
// Get the next batch
pub(crate) fn next_batch(
&self,
min_size: Option<usize>,
max_size: usize,
min_waiting_time: Option<Duration>,
) -> Option<(Vec<u64>, Batch)> {
// Acquire lock
let mut state = self.shared.state.lock();
// Get requests from the database
if let Some((ids, requests)) = state.next_requests(max_size, min_waiting_time) {
if let Some(min_size) = min_size {
// If min_size is set, only return a batch if there are enough requests
if requests.len() < min_size {
return None;
}
}
2022-10-08 10:30:12 +00:00
2022-10-18 13:19:03 +00:00
// Batch size
let size = requests.len();
2022-10-18 13:19:03 +00:00
// Longest input length for all requests in batch size
// Used for padding inside the inference server
let max_sequence_length = requests.iter().map(|r| r.input_length).max().unwrap();
2022-10-08 10:30:12 +00:00
let batch = Batch {
id: state.next_batch_id,
requests,
size: size as u32,
max_sequence_length,
2022-10-08 10:30:12 +00:00
};
2022-10-18 13:19:03 +00:00
// Update next_batch_start_id to the last id in the batch + 1
state.next_batch_start_id = ids.last().unwrap() + 1;
// Increment batch id
2022-10-08 10:30:12 +00:00
state.next_batch_id += 1;
2022-10-18 13:19:03 +00:00
return Some((ids, batch));
2022-10-08 10:30:12 +00:00
}
None
}
2022-10-18 13:19:03 +00:00
}
2022-10-08 10:30:12 +00:00
2022-10-18 13:19:03 +00:00
impl From<GenerateParameters> for LogitsWarperParameters {
fn from(parameters: GenerateParameters) -> Self {
Self {
temperature: parameters.temperature,
top_k: parameters.top_k as u32,
top_p: parameters.top_p,
do_sample: parameters.do_sample,
2022-10-08 10:30:12 +00:00
}
}
}