2022-10-08 10:30:12 +00:00
|
|
|
/// This code is massively inspired by Tokio mini-redis
|
2023-01-31 10:49:43 +00:00
|
|
|
use crate::infer::InferError;
|
|
|
|
use crate::infer::InferStreamResponse;
|
2022-10-18 13:19:03 +00:00
|
|
|
use crate::{GenerateParameters, GenerateRequest};
|
2023-01-26 15:29:13 +00:00
|
|
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
2022-10-18 13:19:03 +00:00
|
|
|
use parking_lot::Mutex;
|
2022-10-08 10:30:12 +00:00
|
|
|
use std::collections::BTreeMap;
|
|
|
|
use std::sync::Arc;
|
2022-12-12 17:25:22 +00:00
|
|
|
use text_generation_client::{
|
2023-01-31 10:49:43 +00:00
|
|
|
Batch, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
2022-12-12 17:25:22 +00:00
|
|
|
};
|
2023-01-31 10:49:43 +00:00
|
|
|
use tokio::sync::mpsc::UnboundedSender;
|
|
|
|
use tokio::sync::OwnedSemaphorePermit;
|
2022-10-18 13:19:03 +00:00
|
|
|
use tokio::time::Instant;
|
2022-10-08 10:30:12 +00:00
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Database entry
|
2022-10-17 12:59:00 +00:00
|
|
|
#[derive(Debug)]
|
|
|
|
pub(crate) struct Entry {
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Request
|
2022-10-17 12:59:00 +00:00
|
|
|
pub request: GenerateRequest,
|
2023-01-31 10:49:43 +00:00
|
|
|
/// Response sender to communicate between the Infer struct and the batching_task
|
|
|
|
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Number of tokens in the input
|
2022-10-17 12:59:00 +00:00
|
|
|
pub input_length: usize,
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Instant when this entry was created
|
|
|
|
pub time: Instant,
|
2022-10-21 14:40:05 +00:00
|
|
|
/// Instant when this entry was added to a batch
|
|
|
|
pub batch_time: Option<Instant>,
|
2023-01-31 10:49:43 +00:00
|
|
|
/// Permit
|
|
|
|
pub _permit: OwnedSemaphorePermit,
|
2022-10-17 12:59:00 +00:00
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Request Database
|
2022-10-08 10:30:12 +00:00
|
|
|
#[derive(Debug, Clone)]
|
|
|
|
pub(crate) struct Db {
|
|
|
|
pub shared: Arc<Shared>,
|
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Shared state
|
2022-10-08 10:30:12 +00:00
|
|
|
#[derive(Debug)]
|
|
|
|
pub struct Shared {
|
2022-10-18 13:19:03 +00:00
|
|
|
state: Mutex<State>,
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Database State
|
2022-10-08 10:30:12 +00:00
|
|
|
#[derive(Debug)]
|
|
|
|
struct State {
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Database entries organized in a BTreeMap to be able to iterate over them in order
|
2022-10-17 12:59:00 +00:00
|
|
|
entries: BTreeMap<u64, Entry>,
|
2022-10-08 10:30:12 +00:00
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Id of the next entry
|
2022-10-08 10:30:12 +00:00
|
|
|
next_id: u64,
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Id of the next batch
|
2022-10-08 10:30:12 +00:00
|
|
|
next_batch_id: u64,
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Start ID of the next batch. Used to iterate inside the entries BTreeMap
|
2022-10-08 10:30:12 +00:00
|
|
|
next_batch_start_id: u64,
|
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
impl State {
|
|
|
|
/// Get the next requests
|
2022-10-21 14:40:05 +00:00
|
|
|
fn next_requests(&self, max_size: usize) -> Option<(Vec<u64>, Vec<Request>)> {
|
2022-10-18 13:19:03 +00:00
|
|
|
// Iterates for max_size over the BTreemap starting from next_batch_start_id
|
|
|
|
let mut requests = Vec::new();
|
|
|
|
let mut ids = Vec::new();
|
|
|
|
|
|
|
|
for (id, entry) in self
|
|
|
|
.entries
|
|
|
|
// Start from next_batch_start_id
|
|
|
|
.range(self.next_batch_start_id..)
|
|
|
|
// Take max_size
|
|
|
|
.take(max_size)
|
|
|
|
{
|
|
|
|
requests.push(Request {
|
|
|
|
id: *id,
|
|
|
|
inputs: entry.request.inputs.clone(),
|
|
|
|
input_length: entry.input_length as u32,
|
2023-01-03 09:41:22 +00:00
|
|
|
parameters: Some((&entry.request.parameters).into()),
|
|
|
|
stopping_parameters: Some(entry.request.parameters.clone().into()),
|
2022-10-18 13:19:03 +00:00
|
|
|
});
|
|
|
|
|
|
|
|
ids.push(*id);
|
|
|
|
}
|
|
|
|
|
|
|
|
if requests.is_empty() {
|
|
|
|
None
|
|
|
|
} else {
|
|
|
|
Some((ids, requests))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-10-08 10:30:12 +00:00
|
|
|
impl Db {
|
|
|
|
pub(crate) fn new() -> Self {
|
2022-10-18 13:19:03 +00:00
|
|
|
// Shared state
|
2022-10-08 10:30:12 +00:00
|
|
|
let shared = Arc::new(Shared {
|
2022-10-18 13:19:03 +00:00
|
|
|
state: Mutex::new(State {
|
2022-10-08 10:30:12 +00:00
|
|
|
entries: BTreeMap::new(),
|
|
|
|
next_id: 0,
|
|
|
|
next_batch_id: 0,
|
|
|
|
next_batch_start_id: 0,
|
|
|
|
}),
|
|
|
|
});
|
|
|
|
|
|
|
|
Self { shared }
|
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
/// Append an entry to the database
|
2022-10-17 12:59:00 +00:00
|
|
|
pub(crate) fn append(&self, entry: Entry) {
|
2022-10-18 13:19:03 +00:00
|
|
|
// Acquire lock
|
|
|
|
let mut state = self.shared.state.lock();
|
2022-10-08 10:30:12 +00:00
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
// Insert entry
|
2022-10-08 10:30:12 +00:00
|
|
|
let id = state.next_id;
|
|
|
|
state.next_id += 1;
|
2022-10-17 12:59:00 +00:00
|
|
|
state.entries.insert(id, entry);
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
|
2022-10-18 13:19:03 +00:00
|
|
|
// Get the next batch
|
|
|
|
pub(crate) fn next_batch(
|
|
|
|
&self,
|
|
|
|
min_size: Option<usize>,
|
|
|
|
max_size: usize,
|
2023-01-26 15:29:13 +00:00
|
|
|
) -> Option<(IntMap<u64, Entry>, Batch)> {
|
2022-10-18 13:19:03 +00:00
|
|
|
// Acquire lock
|
|
|
|
let mut state = self.shared.state.lock();
|
|
|
|
|
|
|
|
// Get requests from the database
|
2022-10-21 14:40:05 +00:00
|
|
|
if let Some((ids, requests)) = state.next_requests(max_size) {
|
2022-10-18 13:19:03 +00:00
|
|
|
if let Some(min_size) = min_size {
|
|
|
|
// If min_size is set, only return a batch if there are enough requests
|
|
|
|
if requests.len() < min_size {
|
|
|
|
return None;
|
|
|
|
}
|
|
|
|
}
|
2023-01-26 15:29:13 +00:00
|
|
|
// Batch size
|
|
|
|
let size = requests.len();
|
|
|
|
|
|
|
|
let mut entries = IntMap::with_capacity_and_hasher(size, BuildNoHashHasher::default());
|
2022-10-21 14:40:05 +00:00
|
|
|
ids.iter().for_each(|id| {
|
2023-01-26 15:29:13 +00:00
|
|
|
// Remove entry from db
|
|
|
|
let mut entry = state.entries.remove(id).unwrap();
|
|
|
|
// Set batch_time
|
|
|
|
entry.batch_time = Some(Instant::now());
|
|
|
|
// Insert in entries IntMap
|
|
|
|
entries.insert(*id, entry);
|
2022-10-21 14:40:05 +00:00
|
|
|
});
|
2022-10-08 10:30:12 +00:00
|
|
|
|
|
|
|
let batch = Batch {
|
|
|
|
id: state.next_batch_id,
|
|
|
|
requests,
|
2022-10-11 14:50:54 +00:00
|
|
|
size: size as u32,
|
2022-10-08 10:30:12 +00:00
|
|
|
};
|
2022-10-18 13:19:03 +00:00
|
|
|
// Update next_batch_start_id to the last id in the batch + 1
|
|
|
|
state.next_batch_start_id = ids.last().unwrap() + 1;
|
|
|
|
// Increment batch id
|
2022-10-08 10:30:12 +00:00
|
|
|
state.next_batch_id += 1;
|
2022-10-18 13:19:03 +00:00
|
|
|
|
2023-01-26 15:29:13 +00:00
|
|
|
return Some((entries, batch));
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
None
|
|
|
|
}
|
2022-10-18 13:19:03 +00:00
|
|
|
}
|
2022-10-08 10:30:12 +00:00
|
|
|
|
2023-01-03 09:41:22 +00:00
|
|
|
impl From<&GenerateParameters> for NextTokenChooserParameters {
|
|
|
|
fn from(parameters: &GenerateParameters) -> Self {
|
2022-10-18 13:19:03 +00:00
|
|
|
Self {
|
|
|
|
temperature: parameters.temperature,
|
|
|
|
top_k: parameters.top_k as u32,
|
|
|
|
top_p: parameters.top_p,
|
|
|
|
do_sample: parameters.do_sample,
|
2023-01-30 14:36:16 +00:00
|
|
|
seed: parameters.seed,
|
2022-10-08 10:30:12 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2022-12-12 17:25:22 +00:00
|
|
|
|
|
|
|
impl From<GenerateParameters> for StoppingCriteriaParameters {
|
|
|
|
fn from(parameters: GenerateParameters) -> Self {
|
|
|
|
Self {
|
|
|
|
stop_sequences: parameters.stop,
|
|
|
|
max_new_tokens: parameters.max_new_tokens,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|