2022-10-08 10:30:12 +00:00
|
|
|
/// This code is massively inspired by Tokio mini-redis
|
2022-10-18 13:19:03 +00:00
|
|
|
use crate::{GenerateParameters, GenerateRequest};
|
2022-10-08 10:30:12 +00:00
|
|
|
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
|
2022-10-18 13:19:03 +00:00
|
|
|
use parking_lot::Mutex;
|
2022-10-08 10:30:12 +00:00
|
|
|
use std::collections::BTreeMap;
|
|
|
|
use std::sync::Arc;
|
2022-10-18 13:19:03 +00:00
|
|
|
use std::time::Duration;
|
2022-10-08 10:30:12 +00:00
|
|
|
use tokio::sync::oneshot::Sender;
|
2022-10-18 13:19:03 +00:00
|
|
|
use tokio::time::Instant;
|
2022-10-08 10:30:12 +00:00
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Database entry
|
2022-10-17 12:59:00 +00:00
|
|
|
#[derive(Debug)]
|
|
|
|
pub(crate) struct Entry {
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Request
|
2022-10-17 12:59:00 +00:00
|
|
|
pub request: GenerateRequest,
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Response sender to communicate between the Batcher and the batching_task
|
2022-10-17 12:59:00 +00:00
|
|
|
pub response_tx: Sender<Result<String, ClientError>>,
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Number of tokens in the input
|
2022-10-17 12:59:00 +00:00
|
|
|
pub input_length: usize,
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Instant when this entry was created
|
|
|
|
pub time: Instant,
|
2022-10-17 12:59:00 +00:00
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Request Database
|
2022-10-08 10:30:12 +00:00
|
|
|
#[derive(Debug, Clone)]
|
|
|
|
pub(crate) struct Db {
|
|
|
|
pub shared: Arc<Shared>,
|
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Shared state
|
2022-10-08 10:30:12 +00:00
|
|
|
#[derive(Debug)]
|
|
|
|
pub struct Shared {
|
2022-10-18 13:19:03 +00:00
|
|
|
state: Mutex<State>,
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Database State
|
2022-10-08 10:30:12 +00:00
|
|
|
#[derive(Debug)]
|
|
|
|
struct State {
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Database entries organized in a BTreeMap to be able to iterate over them in order
|
2022-10-17 12:59:00 +00:00
|
|
|
entries: BTreeMap<u64, Entry>,
|
2022-10-08 10:30:12 +00:00
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Id of the next entry
|
2022-10-08 10:30:12 +00:00
|
|
|
next_id: u64,
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Id of the next batch
|
2022-10-08 10:30:12 +00:00
|
|
|
next_batch_id: u64,
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Start ID of the next batch. Used to iterate inside the entries BTreeMap
|
2022-10-08 10:30:12 +00:00
|
|
|
next_batch_start_id: u64,
|
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
impl State {
|
|
|
|
/// Get the next requests
|
|
|
|
fn next_requests(
|
|
|
|
&self,
|
|
|
|
max_size: usize,
|
|
|
|
min_waiting_time: Option<Duration>,
|
|
|
|
) -> Option<(Vec<u64>, Vec<Request>)> {
|
|
|
|
// Iterates for max_size over the BTreemap starting from next_batch_start_id
|
|
|
|
let mut requests = Vec::new();
|
|
|
|
let mut ids = Vec::new();
|
|
|
|
|
|
|
|
for (id, entry) in self
|
|
|
|
.entries
|
|
|
|
// Start from next_batch_start_id
|
|
|
|
.range(self.next_batch_start_id..)
|
|
|
|
// Take max_size
|
|
|
|
.take(max_size)
|
|
|
|
{
|
|
|
|
if let Some(min_waiting_time) = min_waiting_time {
|
|
|
|
// Only take entries that waited for at least min_waiting_time
|
|
|
|
if entry.time.elapsed() < min_waiting_time {
|
|
|
|
// Since entries are ordered, we already know that all following entries won't
|
|
|
|
// satisfy the condition
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
requests.push(Request {
|
|
|
|
id: *id,
|
|
|
|
inputs: entry.request.inputs.clone(),
|
|
|
|
input_length: entry.input_length as u32,
|
|
|
|
parameters: Some(LogitsWarperParameters::from(
|
|
|
|
entry.request.parameters.clone(),
|
|
|
|
)),
|
|
|
|
max_new_tokens: entry.request.parameters.max_new_tokens,
|
|
|
|
});
|
|
|
|
|
|
|
|
ids.push(*id);
|
|
|
|
}
|
|
|
|
|
|
|
|
if requests.is_empty() {
|
|
|
|
None
|
|
|
|
} else {
|
|
|
|
Some((ids, requests))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-10-08 10:30:12 +00:00
|
|
|
impl Db {
|
|
|
|
pub(crate) fn new() -> Self {
|
2022-10-18 13:19:03 +00:00
|
|
|
// Shared state
|
2022-10-08 10:30:12 +00:00
|
|
|
let shared = Arc::new(Shared {
|
2022-10-18 13:19:03 +00:00
|
|
|
state: Mutex::new(State {
|
2022-10-08 10:30:12 +00:00
|
|
|
entries: BTreeMap::new(),
|
|
|
|
next_id: 0,
|
|
|
|
next_batch_id: 0,
|
|
|
|
next_batch_start_id: 0,
|
|
|
|
}),
|
|
|
|
});
|
|
|
|
|
|
|
|
Self { shared }
|
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Append an entry to the database
|
2022-10-17 12:59:00 +00:00
|
|
|
pub(crate) fn append(&self, entry: Entry) {
|
2022-10-18 13:19:03 +00:00
|
|
|
// Acquire lock
|
|
|
|
let mut state = self.shared.state.lock();
|
2022-10-08 10:30:12 +00:00
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
// Insert entry
|
2022-10-08 10:30:12 +00:00
|
|
|
let id = state.next_id;
|
|
|
|
state.next_id += 1;
|
2022-10-17 12:59:00 +00:00
|
|
|
state.entries.insert(id, entry);
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Remove an entry from the database if it exists
|
2022-10-17 12:59:00 +00:00
|
|
|
pub(crate) fn remove(&self, id: &u64) -> Option<Entry> {
|
2022-10-18 13:19:03 +00:00
|
|
|
let mut state = self.shared.state.lock();
|
2022-10-08 10:30:12 +00:00
|
|
|
state.entries.remove(id)
|
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
// Get the next batch
|
|
|
|
pub(crate) fn next_batch(
|
|
|
|
&self,
|
|
|
|
min_size: Option<usize>,
|
|
|
|
max_size: usize,
|
|
|
|
min_waiting_time: Option<Duration>,
|
|
|
|
) -> Option<(Vec<u64>, Batch)> {
|
|
|
|
// Acquire lock
|
|
|
|
let mut state = self.shared.state.lock();
|
|
|
|
|
|
|
|
// Get requests from the database
|
|
|
|
if let Some((ids, requests)) = state.next_requests(max_size, min_waiting_time) {
|
|
|
|
if let Some(min_size) = min_size {
|
|
|
|
// If min_size is set, only return a batch if there are enough requests
|
|
|
|
if requests.len() < min_size {
|
|
|
|
return None;
|
|
|
|
}
|
|
|
|
}
|
2022-10-08 10:30:12 +00:00
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
// Batch size
|
2022-10-11 14:50:54 +00:00
|
|
|
let size = requests.len();
|
2022-10-18 13:19:03 +00:00
|
|
|
// Longest input length for all requests in batch size
|
|
|
|
// Used for padding inside the inference server
|
2022-10-11 14:50:54 +00:00
|
|
|
let max_sequence_length = requests.iter().map(|r| r.input_length).max().unwrap();
|
2022-10-08 10:30:12 +00:00
|
|
|
let batch = Batch {
|
|
|
|
id: state.next_batch_id,
|
|
|
|
requests,
|
2022-10-11 14:50:54 +00:00
|
|
|
size: size as u32,
|
|
|
|
max_sequence_length,
|
2022-10-08 10:30:12 +00:00
|
|
|
};
|
2022-10-18 13:19:03 +00:00
|
|
|
// Update next_batch_start_id to the last id in the batch + 1
|
|
|
|
state.next_batch_start_id = ids.last().unwrap() + 1;
|
|
|
|
// Increment batch id
|
2022-10-08 10:30:12 +00:00
|
|
|
state.next_batch_id += 1;
|
2022-10-18 13:19:03 +00:00
|
|
|
|
|
|
|
return Some((ids, batch));
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
None
|
|
|
|
}
|
2022-10-18 13:19:03 +00:00
|
|
|
}
|
2022-10-08 10:30:12 +00:00
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
impl From<GenerateParameters> for LogitsWarperParameters {
|
|
|
|
fn from(parameters: GenerateParameters) -> Self {
|
|
|
|
Self {
|
|
|
|
temperature: parameters.temperature,
|
|
|
|
top_k: parameters.top_k as u32,
|
|
|
|
top_p: parameters.top_p,
|
|
|
|
do_sample: parameters.do_sample,
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|