From 5d482d4da24f229b7e335bff98eeb92384b1df56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Thu, 1 Aug 2024 13:41:07 +0000 Subject: [PATCH] Port over block allocator interface (with token ids) --- backends/v3/src/block_allocator.rs | 151 ++++++++++++++++++++--------- backends/v3/src/queue.rs | 5 +- router/src/validation.rs | 18 +++- 3 files changed, 123 insertions(+), 51 deletions(-) diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index 7467fd85..3a3d53bb 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -1,16 +1,18 @@ -use std::cmp::min; +use std::{cmp::min, sync::Arc}; use tokio::sync::{mpsc, oneshot}; #[derive(Debug, Clone)] pub(crate) struct BlockAllocation { pub blocks: Vec, pub slots: Vec, + pub allocation_id: u64, block_allocator: BlockAllocator, } impl Drop for BlockAllocation { fn drop(&mut self) { - self.block_allocator.free(self.blocks.clone()) + self.block_allocator + .free(self.blocks.clone(), self.allocation_id) } } @@ -42,11 +44,16 @@ impl BlockAllocator { } } - pub(crate) async fn allocate(&self, tokens: u32) -> Option { + pub(crate) async fn allocate( + &self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option { let (response_sender, response_receiver) = oneshot::channel(); self.block_allocator .send(BlockAllocatorCommand::Allocate { tokens, + prefill_tokens, response_sender, }) .unwrap(); @@ -54,16 +61,20 @@ impl BlockAllocator { response_receiver .await .unwrap() - .map(|(blocks, slots)| BlockAllocation { + .map(|(blocks, slots, allocation_id)| BlockAllocation { blocks, slots, + allocation_id, block_allocator: self.clone(), }) } - pub(crate) fn free(&self, blocks: Vec) { + pub(crate) fn free(&self, blocks: Vec, allocation_id: u64) { self.block_allocator - .send(BlockAllocatorCommand::Free { blocks }) + .send(BlockAllocatorCommand::Free { + allocation_id, + blocks, + }) .unwrap(); } } @@ -74,63 +85,111 @@ async fn block_allocator_task( window_size: Option, mut receiver: mpsc::UnboundedReceiver, ) { - // Block 0 is reserved for health checks - let mut free_blocks: Vec = (1..blocks).collect(); + let mut allocator = SimpleAllocator::new(blocks, block_size, window_size); while let Some(cmd) = receiver.recv().await { match cmd { - BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks), + BlockAllocatorCommand::Free { + blocks, + allocation_id, + } => allocator.free(blocks, allocation_id), BlockAllocatorCommand::Allocate { tokens, + prefill_tokens, response_sender, } => { - // Apply window size - let (required_blocks, repeats) = { - let (tokens, repeats) = match window_size { - None => (tokens, 1), - Some(window_size) => { - let repeats = (tokens + window_size - 1) / window_size; - let tokens = min(tokens, window_size); - (tokens, repeats as usize) - } - }; - // Pad to a multiple of block size - let required_blocks = (tokens + block_size - 1) / block_size; - (required_blocks, repeats) - }; - - let tokens = tokens as usize; - let allocation = if required_blocks > free_blocks.len() as u32 { - None - } else { - let blocks = - free_blocks.split_off(free_blocks.len() - required_blocks as usize); - let mut slots = Vec::with_capacity( - (required_blocks * block_size * repeats as u32) as usize, - ); - - 'slots: for block_id in blocks.repeat(repeats).iter() { - for s in (block_id * block_size)..((block_id + 1) * block_size) { - slots.push(s); - if slots.len() == tokens { - break 'slots; - } - } - } - Some((blocks, slots)) - }; - response_sender.send(allocation).unwrap(); + let prefill_tokens_slice = prefill_tokens.as_ref().map(|p| p.as_slice()); + response_sender + .send(allocator.allocate(tokens, prefill_tokens_slice)) + .unwrap(); } } } } +pub trait Allocator { + fn allocate( + &mut self, + tokens: u32, + prefill_tokens: Option<&[u32]>, + ) -> Option<(Vec, Vec, u64)>; + + fn free(&mut self, blocks: Vec, allocation_id: u64); +} + +pub struct SimpleAllocator { + free_blocks: Vec, + block_size: u32, + window_size: Option, +} + +impl SimpleAllocator { + fn new(blocks: u32, block_size: u32, window_size: Option) -> Self { + SimpleAllocator { + block_size, + // Block 0 is reserved for health checks + free_blocks: (1..blocks).collect(), + window_size, + } + } +} + +impl Allocator for SimpleAllocator { + fn allocate( + &mut self, + tokens: u32, + _prefill_tokens: Option<&[u32]>, + ) -> Option<(Vec, Vec, u64)> { + // Apply window size + let (required_blocks, repeats) = { + let (tokens, repeats) = match self.window_size { + None => (tokens, 1), + Some(window_size) => { + let repeats = (tokens + window_size - 1) / window_size; + let tokens = min(tokens, window_size); + (tokens, repeats as usize) + } + }; + // Pad to a multiple of block size + let required_blocks = (tokens + self.block_size - 1) / self.block_size; + (required_blocks, repeats) + }; + + let tokens = tokens as usize; + if required_blocks > self.free_blocks.len() as u32 { + None + } else { + let blocks = self + .free_blocks + .split_off(self.free_blocks.len() - required_blocks as usize); + let mut slots = + Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize); + + 'slots: for block_id in blocks.repeat(repeats).iter() { + for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { + slots.push(s); + if slots.len() == tokens { + break 'slots; + } + } + } + Some((blocks, slots, 0)) + } + } + + fn free(&mut self, blocks: Vec, _allocation_id: u64) { + self.free_blocks.extend(blocks) + } +} + #[derive(Debug)] enum BlockAllocatorCommand { Free { blocks: Vec, + allocation_id: u64, }, Allocate { tokens: u32, - response_sender: oneshot::Sender, Vec)>>, + prefill_tokens: Option>>, + response_sender: oneshot::Sender, Vec, u64)>>, }, } diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 9427bd60..235d02b2 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -298,7 +298,10 @@ impl State { + self.speculate - 1; - match block_allocator.allocate(tokens).await { + match block_allocator + .allocate(tokens, entry.request.input_ids.clone()) + .await + { None => { // Entry is over budget // Add it back to the front diff --git a/router/src/validation.rs b/router/src/validation.rs index 3d1a4103..72b6431f 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -11,6 +11,7 @@ use rand::{thread_rng, Rng}; use serde_json::Value; use std::io::Cursor; use std::iter; +use std::sync::Arc; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; use tokio::sync::mpsc; @@ -121,7 +122,7 @@ impl Validation { inputs: String, truncate: Option, max_new_tokens: Option, - ) -> Result<(Vec, usize, u32), ValidationError> { + ) -> Result<(Vec, Option>, usize, u32), ValidationError> { // If we have a fast tokenizer if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { // Create response channel @@ -156,8 +157,10 @@ impl Validation { )); } + let input_ids = encoding.get_ids()[..input_length].to_owned(); + metrics::histogram!("tgi_request_input_length").record(input_length as f64); - Ok((inputs, input_length, max_new_tokens)) + Ok((inputs, Some(input_ids), input_length, max_new_tokens)) } // Return inputs without validation else { @@ -180,7 +183,12 @@ impl Validation { input_length = input_length.saturating_sub(max_new_tokens as usize); } - Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens)) + Ok(( + vec![Chunk::Text(inputs)], + None, + input_length, + max_new_tokens, + )) } } @@ -314,7 +322,7 @@ impl Validation { .unwrap_or(Ok(None))?; // Validate inputs - let (inputs, input_length, max_new_tokens) = self + let (inputs, input_ids, input_length, max_new_tokens) = self .validate_input(request.inputs, truncate, max_new_tokens) .await?; @@ -391,6 +399,7 @@ impl Validation { Ok(ValidGenerateRequest { inputs, + input_ids: input_ids.map(Arc::new), decoder_input_details, input_length: input_length as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32, @@ -707,6 +716,7 @@ pub struct ValidStoppingParameters { #[derive(Debug, Clone)] pub struct ValidGenerateRequest { pub inputs: Vec, + pub input_ids: Option>>, pub input_length: u32, pub truncate: u32, pub decoder_input_details: bool,