diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index 0d6c7cfa..a084a505 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -1,44 +1,55 @@ -use std::cmp::min; use std::sync::{Arc, Mutex}; use thiserror::Error; #[derive(Debug, Clone)] pub(crate) struct BlockAllocation { - pub blocks: Vec, - pub slots: Vec, - prompt_tokens: u32, - decode_tokens: u32, + allocated_blocks: Vec, + allocated_slots: Vec, + required_blocks: usize, + required_slots: usize, block_allocator: BlockAllocator, } impl BlockAllocation { pub(crate) fn len(&self) -> usize { - self.slots.len() + self.allocated_slots.len() + } + + pub(crate) fn blocks(&self) -> &[u32] { + &self.allocated_blocks + } + + pub(crate) fn slots(&self) -> &[u32] { + &self.allocated_slots } pub(crate) fn extend(&mut self) -> Result<(), AllocationError> { let (block, slots) = self.block_allocator.allocate_block()?; + // Add block and slots to current allocation + self.allocated_blocks.push(block); + self.allocated_slots.extend(slots); - match self.block_allocator.window_size { - None => { - self.blocks.push(block); - self.slots.extend(slots); - } - Some(window_size) => { - if self.len() as u32 > window_size { - let total_tokens = self.prompt_tokens + self.decode_tokens; - - let repeats = (total_tokens + window_size - 1) / window_size; - } + if let Some(window_size) = self.block_allocator.window_size { + // if we have more slots than the window size, + // we will never need to re-allocate and we can just repeat the blocks/slots + let window_size = window_size as usize; + if self.len() > window_size { + let repeats = (self.required_slots + window_size - 1) / window_size; + self.allocated_blocks = self.allocated_blocks.repeat(repeats); + self.allocated_blocks.truncate(self.required_blocks); + self.allocated_slots = self.allocated_slots.repeat(repeats); + self.allocated_slots.truncate(self.required_slots); } } + Ok(()) } } impl Drop for BlockAllocation { fn drop(&mut self) { - self.block_allocator.free(self.blocks.clone()) + let allocated_blocks = std::mem::take(&mut self.allocated_blocks); + self.block_allocator.free(allocated_blocks) } } @@ -82,85 +93,76 @@ impl BlockAllocator { /// For decode tokens, we allocate block by block /// /// If prompt tokens + min(decode_tokens, block_size) > window size, we repeat blocks and slots - fn allocate( - &self, - prompt_tokens: u32, - decode_tokens: u32, - ) -> Result<(Vec, Vec), AllocationError> { - // let decode_tokens = min(decode_tokens, self.block_size); - // let tokens = prompt_tokens + decode_tokens; - - let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size; - // prompt blocks + a single block for decode - let required_blocks = required_prompt_blocks + 1; - - let (required_blocks, repeats) = match self.window_size { - // Nothing to do - None => (required_blocks, 1), - Some(window_size) => { - // Number of blocks needed for this window size - let window_size_required_blocks = (window_size + self.block_size - 1) / self.block_size; - // Number of times we will need to repeat blocks to cover the required allocation - let repeats = (required_blocks + window_size_required_blocks -1) / window_size_required_blocks; - let required_blocks = min(required_blocks, window_size_required_blocks); - - (required_blocks, repeats) - } - }; - - - /// if prompt + decode < window size => do nothing - /// if prompt + decode > window size => do normal until we reach window size then - - // Apply window size - let (required_blocks, repeats) = { - let (tokens, repeats) = match self.window_size { - None => (tokens, 1), - Some(window_size) => { - let repeats = (tokens + window_size - 1) / window_size; - let tokens = min(tokens, window_size); - (tokens, repeats as usize) - } - }; - // Pad to a multiple of block size - let required_blocks = (tokens + self.block_size - 1) / self.block_size; - (required_blocks, repeats) - }; - - let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired"); - - if required_blocks > free_blocks.len() as u32 { - Err(AllocationError::NotEnoughPages) - } else { - let n_free_blocks = free_blocks.len(); - let blocks = - free_blocks.split_off(n_free_blocks - required_blocks as usize); - let mut slots = Vec::with_capacity( - (required_blocks * self.block_size * repeats as u32) as usize, - ); - - for block_id in blocks.repeat(repeats).iter() { - for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { - slots.push(s); - } - } - Ok((blocks, slots)) - } - } - pub(crate) fn block_allocation( &self, prompt_tokens: u32, decode_tokens: u32, ) -> Result { - self.allocate_inner(prompt_tokens, decode_tokens) - .map(|(blocks, slots)| BlockAllocation { - blocks, - slots, - prompt_tokens, - decode_tokens, + let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size; + // prompt blocks + a single block for decode + let required_blocks = required_prompt_blocks + 1; + + let (clipped_required_blocks, repeats) = match self.window_size { + // Nothing to do + None => (required_blocks, 1), + Some(window_size) => { + // Number of blocks for this window size + let window_size_blocks = (window_size + self.block_size - 1) / self.block_size; + + if required_blocks > window_size_blocks { + // Number of times we will need to repeat blocks to cover the required allocation + let repeats = (required_blocks + window_size_blocks - 1) / window_size_blocks; + (window_size_blocks, repeats) + } else { + (required_blocks, 1) + } + } + }; + + let repeats = repeats as usize; + let required_blocks = required_blocks as usize; + let clipped_required_blocks = clipped_required_blocks as usize; + + let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired"); + + if clipped_required_blocks > free_blocks.len() { + Err(AllocationError::NotEnoughPages) + } else { + let n_free_blocks = free_blocks.len(); + let allocated_blocks = + free_blocks.split_off(n_free_blocks - clipped_required_blocks); + + let allocated_blocks = if repeats != 1 { + let mut allocated_blocks = allocated_blocks.repeat(repeats); + allocated_blocks.truncate(required_blocks); + allocated_blocks + } else { + allocated_blocks + }; + + let mut allocated_slots = Vec::with_capacity( + allocated_blocks.len() * self.block_size as usize * repeats, + ); + + let required_slots = (prompt_tokens + decode_tokens) as usize; + + 'slots: for block_id in allocated_blocks.iter() { + for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { + allocated_slots.push(s); + if allocated_slots.len() > required_slots { + break 'slots; + } + } + } + + Ok(BlockAllocation { + allocated_blocks, + allocated_slots, + required_blocks, + required_slots, block_allocator: self.clone(), }) + } } pub(crate) fn free(&self, blocks: Vec) { diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index 14e67fff..db09f9b4 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -283,7 +283,7 @@ impl State { let decode_tokens = entry.request.stopping_parameters.max_new_tokens + self.speculate; match block_allocator - .allocate(entry.request.input_length, decode_tokens) + .block_allocation(entry.request.input_length, decode_tokens) { Err(_) => { // Entry is over budget @@ -294,7 +294,7 @@ impl State { } Ok(block_allocation) => { tracing::debug!("Allocation: {block_allocation:?}"); - max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); + max_blocks = max(max_blocks, block_allocation.blocks().len() as u32); Some(block_allocation) } } @@ -313,8 +313,8 @@ impl State { let (blocks, slots) = match &block_allocation { None => (Vec::new(), Vec::new()), Some(block_allocation) => ( - block_allocation.blocks.clone(), - block_allocation.slots.clone(), + block_allocation.blocks().to_vec(), + block_allocation.slots().to_vec(), ), }; diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index 0913f495..5fb9e11d 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -347,7 +347,7 @@ async fn filter_batch( let (blocks, slots) = entry .block_allocation .as_ref() - .map(|alloc| (alloc.blocks.clone(), alloc.slots.clone())) + .map(|alloc| (alloc.blocks().to_vec(), alloc.slots().to_vec())) .unwrap_or((Vec::new(), Vec::new())); KeptRequest {