Port over block allocator interface (with token ids)

This commit is contained in:
Daniël de Kok 2024-08-01 13:41:07 +00:00
parent 4562c16048
commit 5d482d4da2
3 changed files with 123 additions and 51 deletions

View File

@ -1,16 +1,18 @@
use std::cmp::min; use std::{cmp::min, sync::Arc};
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct BlockAllocation { pub(crate) struct BlockAllocation {
pub blocks: Vec<u32>, pub blocks: Vec<u32>,
pub slots: Vec<u32>, pub slots: Vec<u32>,
pub allocation_id: u64,
block_allocator: BlockAllocator, block_allocator: BlockAllocator,
} }
impl Drop for BlockAllocation { impl Drop for BlockAllocation {
fn drop(&mut self) { fn drop(&mut self) {
self.block_allocator.free(self.blocks.clone()) self.block_allocator
.free(self.blocks.clone(), self.allocation_id)
} }
} }
@ -42,11 +44,16 @@ impl BlockAllocator {
} }
} }
pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> { pub(crate) async fn allocate(
&self,
tokens: u32,
prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> {
let (response_sender, response_receiver) = oneshot::channel(); let (response_sender, response_receiver) = oneshot::channel();
self.block_allocator self.block_allocator
.send(BlockAllocatorCommand::Allocate { .send(BlockAllocatorCommand::Allocate {
tokens, tokens,
prefill_tokens,
response_sender, response_sender,
}) })
.unwrap(); .unwrap();
@ -54,16 +61,20 @@ impl BlockAllocator {
response_receiver response_receiver
.await .await
.unwrap() .unwrap()
.map(|(blocks, slots)| BlockAllocation { .map(|(blocks, slots, allocation_id)| BlockAllocation {
blocks, blocks,
slots, slots,
allocation_id,
block_allocator: self.clone(), block_allocator: self.clone(),
}) })
} }
pub(crate) fn free(&self, blocks: Vec<u32>) { pub(crate) fn free(&self, blocks: Vec<u32>, allocation_id: u64) {
self.block_allocator self.block_allocator
.send(BlockAllocatorCommand::Free { blocks }) .send(BlockAllocatorCommand::Free {
allocation_id,
blocks,
})
.unwrap(); .unwrap();
} }
} }
@ -74,18 +85,63 @@ async fn block_allocator_task(
window_size: Option<u32>, window_size: Option<u32>,
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>, mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
) { ) {
// Block 0 is reserved for health checks let mut allocator = SimpleAllocator::new(blocks, block_size, window_size);
let mut free_blocks: Vec<u32> = (1..blocks).collect();
while let Some(cmd) = receiver.recv().await { while let Some(cmd) = receiver.recv().await {
match cmd { match cmd {
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks), BlockAllocatorCommand::Free {
blocks,
allocation_id,
} => allocator.free(blocks, allocation_id),
BlockAllocatorCommand::Allocate { BlockAllocatorCommand::Allocate {
tokens, tokens,
prefill_tokens,
response_sender, response_sender,
} => { } => {
let prefill_tokens_slice = prefill_tokens.as_ref().map(|p| p.as_slice());
response_sender
.send(allocator.allocate(tokens, prefill_tokens_slice))
.unwrap();
}
}
}
}
pub trait Allocator {
fn allocate(
&mut self,
tokens: u32,
prefill_tokens: Option<&[u32]>,
) -> Option<(Vec<u32>, Vec<u32>, u64)>;
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
}
pub struct SimpleAllocator {
free_blocks: Vec<u32>,
block_size: u32,
window_size: Option<u32>,
}
impl SimpleAllocator {
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
SimpleAllocator {
block_size,
// Block 0 is reserved for health checks
free_blocks: (1..blocks).collect(),
window_size,
}
}
}
impl Allocator for SimpleAllocator {
fn allocate(
&mut self,
tokens: u32,
_prefill_tokens: Option<&[u32]>,
) -> Option<(Vec<u32>, Vec<u32>, u64)> {
// Apply window size // Apply window size
let (required_blocks, repeats) = { let (required_blocks, repeats) = {
let (tokens, repeats) = match window_size { let (tokens, repeats) = match self.window_size {
None => (tokens, 1), None => (tokens, 1),
Some(window_size) => { Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size; let repeats = (tokens + window_size - 1) / window_size;
@ -94,33 +150,34 @@ async fn block_allocator_task(
} }
}; };
// Pad to a multiple of block size // Pad to a multiple of block size
let required_blocks = (tokens + block_size - 1) / block_size; let required_blocks = (tokens + self.block_size - 1) / self.block_size;
(required_blocks, repeats) (required_blocks, repeats)
}; };
let tokens = tokens as usize; let tokens = tokens as usize;
let allocation = if required_blocks > free_blocks.len() as u32 { if required_blocks > self.free_blocks.len() as u32 {
None None
} else { } else {
let blocks = let blocks = self
free_blocks.split_off(free_blocks.len() - required_blocks as usize); .free_blocks
let mut slots = Vec::with_capacity( .split_off(self.free_blocks.len() - required_blocks as usize);
(required_blocks * block_size * repeats as u32) as usize, let mut slots =
); Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
'slots: for block_id in blocks.repeat(repeats).iter() { 'slots: for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * block_size)..((block_id + 1) * block_size) { for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
slots.push(s); slots.push(s);
if slots.len() == tokens { if slots.len() == tokens {
break 'slots; break 'slots;
} }
} }
} }
Some((blocks, slots)) Some((blocks, slots, 0))
};
response_sender.send(allocation).unwrap();
} }
} }
fn free(&mut self, blocks: Vec<u32>, _allocation_id: u64) {
self.free_blocks.extend(blocks)
} }
} }
@ -128,9 +185,11 @@ async fn block_allocator_task(
enum BlockAllocatorCommand { enum BlockAllocatorCommand {
Free { Free {
blocks: Vec<u32>, blocks: Vec<u32>,
allocation_id: u64,
}, },
Allocate { Allocate {
tokens: u32, tokens: u32,
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>, prefill_tokens: Option<Arc<Vec<u32>>>,
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>, u64)>>,
}, },
} }

View File

@ -298,7 +298,10 @@ impl State {
+ self.speculate + self.speculate
- 1; - 1;
match block_allocator.allocate(tokens).await { match block_allocator
.allocate(tokens, entry.request.input_ids.clone())
.await
{
None => { None => {
// Entry is over budget // Entry is over budget
// Add it back to the front // Add it back to the front

View File

@ -11,6 +11,7 @@ use rand::{thread_rng, Rng};
use serde_json::Value; use serde_json::Value;
use std::io::Cursor; use std::io::Cursor;
use std::iter; use std::iter;
use std::sync::Arc;
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokio::sync::mpsc; use tokio::sync::mpsc;
@ -121,7 +122,7 @@ impl Validation {
inputs: String, inputs: String,
truncate: Option<usize>, truncate: Option<usize>,
max_new_tokens: Option<u32>, max_new_tokens: Option<u32>,
) -> Result<(Vec<Chunk>, usize, u32), ValidationError> { ) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
// If we have a fast tokenizer // If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
// Create response channel // Create response channel
@ -156,8 +157,10 @@ impl Validation {
)); ));
} }
let input_ids = encoding.get_ids()[..input_length].to_owned();
metrics::histogram!("tgi_request_input_length").record(input_length as f64); metrics::histogram!("tgi_request_input_length").record(input_length as f64);
Ok((inputs, input_length, max_new_tokens)) Ok((inputs, Some(input_ids), input_length, max_new_tokens))
} }
// Return inputs without validation // Return inputs without validation
else { else {
@ -180,7 +183,12 @@ impl Validation {
input_length = input_length.saturating_sub(max_new_tokens as usize); input_length = input_length.saturating_sub(max_new_tokens as usize);
} }
Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens)) Ok((
vec![Chunk::Text(inputs)],
None,
input_length,
max_new_tokens,
))
} }
} }
@ -314,7 +322,7 @@ impl Validation {
.unwrap_or(Ok(None))?; .unwrap_or(Ok(None))?;
// Validate inputs // Validate inputs
let (inputs, input_length, max_new_tokens) = self let (inputs, input_ids, input_length, max_new_tokens) = self
.validate_input(request.inputs, truncate, max_new_tokens) .validate_input(request.inputs, truncate, max_new_tokens)
.await?; .await?;
@ -391,6 +399,7 @@ impl Validation {
Ok(ValidGenerateRequest { Ok(ValidGenerateRequest {
inputs, inputs,
input_ids: input_ids.map(Arc::new),
decoder_input_details, decoder_input_details,
input_length: input_length as u32, input_length: input_length as u32,
truncate: truncate.unwrap_or(self.max_input_length) as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32,
@ -707,6 +716,7 @@ pub struct ValidStoppingParameters {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ValidGenerateRequest { pub struct ValidGenerateRequest {
pub inputs: Vec<Chunk>, pub inputs: Vec<Chunk>,
pub input_ids: Option<Arc<Vec<u32>>>,
pub input_length: u32, pub input_length: u32,
pub truncate: u32, pub truncate: u32,
pub decoder_input_details: bool, pub decoder_input_details: bool,