Port over block allocator interface (with token ids)

This commit is contained in:
Daniël de Kok 2024-08-01 13:41:07 +00:00
parent 4562c16048
commit 5d482d4da2
3 changed files with 123 additions and 51 deletions

View File

@ -1,16 +1,18 @@
use std::cmp::min;
use std::{cmp::min, sync::Arc};
use tokio::sync::{mpsc, oneshot};
#[derive(Debug, Clone)]
pub(crate) struct BlockAllocation {
pub blocks: Vec<u32>,
pub slots: Vec<u32>,
pub allocation_id: u64,
block_allocator: BlockAllocator,
}
impl Drop for BlockAllocation {
fn drop(&mut self) {
self.block_allocator.free(self.blocks.clone())
self.block_allocator
.free(self.blocks.clone(), self.allocation_id)
}
}
@ -42,11 +44,16 @@ impl BlockAllocator {
}
}
pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> {
pub(crate) async fn allocate(
&self,
tokens: u32,
prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> {
let (response_sender, response_receiver) = oneshot::channel();
self.block_allocator
.send(BlockAllocatorCommand::Allocate {
tokens,
prefill_tokens,
response_sender,
})
.unwrap();
@ -54,16 +61,20 @@ impl BlockAllocator {
response_receiver
.await
.unwrap()
.map(|(blocks, slots)| BlockAllocation {
.map(|(blocks, slots, allocation_id)| BlockAllocation {
blocks,
slots,
allocation_id,
block_allocator: self.clone(),
})
}
pub(crate) fn free(&self, blocks: Vec<u32>) {
pub(crate) fn free(&self, blocks: Vec<u32>, allocation_id: u64) {
self.block_allocator
.send(BlockAllocatorCommand::Free { blocks })
.send(BlockAllocatorCommand::Free {
allocation_id,
blocks,
})
.unwrap();
}
}
@ -74,63 +85,111 @@ async fn block_allocator_task(
window_size: Option<u32>,
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
) {
// Block 0 is reserved for health checks
let mut free_blocks: Vec<u32> = (1..blocks).collect();
let mut allocator = SimpleAllocator::new(blocks, block_size, window_size);
while let Some(cmd) = receiver.recv().await {
match cmd {
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks),
BlockAllocatorCommand::Free {
blocks,
allocation_id,
} => allocator.free(blocks, allocation_id),
BlockAllocatorCommand::Allocate {
tokens,
prefill_tokens,
response_sender,
} => {
// Apply window size
let (required_blocks, repeats) = {
let (tokens, repeats) = match window_size {
None => (tokens, 1),
Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size;
let tokens = min(tokens, window_size);
(tokens, repeats as usize)
}
};
// Pad to a multiple of block size
let required_blocks = (tokens + block_size - 1) / block_size;
(required_blocks, repeats)
};
let tokens = tokens as usize;
let allocation = if required_blocks > free_blocks.len() as u32 {
None
} else {
let blocks =
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
let mut slots = Vec::with_capacity(
(required_blocks * block_size * repeats as u32) as usize,
);
'slots: for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * block_size)..((block_id + 1) * block_size) {
slots.push(s);
if slots.len() == tokens {
break 'slots;
}
}
}
Some((blocks, slots))
};
response_sender.send(allocation).unwrap();
let prefill_tokens_slice = prefill_tokens.as_ref().map(|p| p.as_slice());
response_sender
.send(allocator.allocate(tokens, prefill_tokens_slice))
.unwrap();
}
}
}
}
pub trait Allocator {
fn allocate(
&mut self,
tokens: u32,
prefill_tokens: Option<&[u32]>,
) -> Option<(Vec<u32>, Vec<u32>, u64)>;
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
}
pub struct SimpleAllocator {
free_blocks: Vec<u32>,
block_size: u32,
window_size: Option<u32>,
}
impl SimpleAllocator {
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
SimpleAllocator {
block_size,
// Block 0 is reserved for health checks
free_blocks: (1..blocks).collect(),
window_size,
}
}
}
impl Allocator for SimpleAllocator {
fn allocate(
&mut self,
tokens: u32,
_prefill_tokens: Option<&[u32]>,
) -> Option<(Vec<u32>, Vec<u32>, u64)> {
// Apply window size
let (required_blocks, repeats) = {
let (tokens, repeats) = match self.window_size {
None => (tokens, 1),
Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size;
let tokens = min(tokens, window_size);
(tokens, repeats as usize)
}
};
// Pad to a multiple of block size
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
(required_blocks, repeats)
};
let tokens = tokens as usize;
if required_blocks > self.free_blocks.len() as u32 {
None
} else {
let blocks = self
.free_blocks
.split_off(self.free_blocks.len() - required_blocks as usize);
let mut slots =
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
'slots: for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
slots.push(s);
if slots.len() == tokens {
break 'slots;
}
}
}
Some((blocks, slots, 0))
}
}
fn free(&mut self, blocks: Vec<u32>, _allocation_id: u64) {
self.free_blocks.extend(blocks)
}
}
#[derive(Debug)]
enum BlockAllocatorCommand {
Free {
blocks: Vec<u32>,
allocation_id: u64,
},
Allocate {
tokens: u32,
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>,
prefill_tokens: Option<Arc<Vec<u32>>>,
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>, u64)>>,
},
}

View File

@ -298,7 +298,10 @@ impl State {
+ self.speculate
- 1;
match block_allocator.allocate(tokens).await {
match block_allocator
.allocate(tokens, entry.request.input_ids.clone())
.await
{
None => {
// Entry is over budget
// Add it back to the front

View File

@ -11,6 +11,7 @@ use rand::{thread_rng, Rng};
use serde_json::Value;
use std::io::Cursor;
use std::iter;
use std::sync::Arc;
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
use tokio::sync::mpsc;
@ -121,7 +122,7 @@ impl Validation {
inputs: String,
truncate: Option<usize>,
max_new_tokens: Option<u32>,
) -> Result<(Vec<Chunk>, usize, u32), ValidationError> {
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
// If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
// Create response channel
@ -156,8 +157,10 @@ impl Validation {
));
}
let input_ids = encoding.get_ids()[..input_length].to_owned();
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
Ok((inputs, input_length, max_new_tokens))
Ok((inputs, Some(input_ids), input_length, max_new_tokens))
}
// Return inputs without validation
else {
@ -180,7 +183,12 @@ impl Validation {
input_length = input_length.saturating_sub(max_new_tokens as usize);
}
Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens))
Ok((
vec![Chunk::Text(inputs)],
None,
input_length,
max_new_tokens,
))
}
}
@ -314,7 +322,7 @@ impl Validation {
.unwrap_or(Ok(None))?;
// Validate inputs
let (inputs, input_length, max_new_tokens) = self
let (inputs, input_ids, input_length, max_new_tokens) = self
.validate_input(request.inputs, truncate, max_new_tokens)
.await?;
@ -391,6 +399,7 @@ impl Validation {
Ok(ValidGenerateRequest {
inputs,
input_ids: input_ids.map(Arc::new),
decoder_input_details,
input_length: input_length as u32,
truncate: truncate.unwrap_or(self.max_input_length) as u32,
@ -707,6 +716,7 @@ pub struct ValidStoppingParameters {
#[derive(Debug, Clone)]
pub struct ValidGenerateRequest {
pub inputs: Vec<Chunk>,
pub input_ids: Option<Arc<Vec<u32>>>,
pub input_length: u32,
pub truncate: u32,
pub decoder_input_details: bool,