From dd2d6cfe401863d0580fa00915f9f8361a8c49a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Tue, 16 Jul 2024 11:40:35 +0000 Subject: [PATCH] Proper support for two allocations with overlapping prefixes --- router/src/infer/v3/block_allocator.rs | 83 ++++++++++++++++++++------ 1 file changed, 64 insertions(+), 19 deletions(-) diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index f9bc6764..834e5934 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -174,7 +174,13 @@ enum BlockAllocatorCommand { pub struct BlockAllocationWithCache { pub blocks: Vec, pub slots: Vec, - pub prefix_hashes: Vec, + pub allocation: u64, +} + +#[derive(Debug)] +struct Allocation { + cache_prefixes: Vec, + new_prefixes: Vec, } #[derive(Debug)] @@ -192,6 +198,8 @@ struct PrefixBlockState { #[derive(Debug)] pub struct PrefixCacheAllocator { + allocations: HashMap, + /// Size of a paged attention block. block_size: usize, @@ -218,6 +226,7 @@ impl PrefixCacheAllocator { } PrefixCacheAllocator { + allocations: HashMap::new(), block_size, cache_blocks: HashMap::new(), free_blocks: (1..n_blocks as u32).collect(), @@ -234,6 +243,7 @@ impl PrefixCacheAllocator { let mut hasher = DefaultHasher::new(); let mut prefix_cache_blocks = Vec::new(); + // Find hashes for all block_sized prefill chunks. let mut prefix_hashes = Vec::new(); for prefill_chunk in prefill_tokens.chunks(self.block_size) { if prefill_chunk.len() < self.block_size { @@ -259,14 +269,17 @@ impl PrefixCacheAllocator { n_from_cache += 1; } + let new_prefixes = prefix_hashes.split_off(n_from_cache); + let cache_prefixes = prefix_hashes; + // Get tokens for the remaining prefill and decode. let blocks = match self.alloc_or_reclaim(n_tokens - (n_from_cache * self.block_size)) { Some(blocks) => blocks, None => { // If the allocation fails, we have relinquish our use of the // prefix cache blocks. Maybe we can do this using `Drop`? - for prefix_hash in &prefix_hashes[..n_from_cache] { - self.decref_prefix(*prefix_hash); + for prefix_hash in cache_prefixes { + self.decref_prefix(prefix_hash); } return None; @@ -287,10 +300,19 @@ impl PrefixCacheAllocator { } } + let allocation = Allocation { + cache_prefixes, + new_prefixes, + }; + + let allocation_id = self.time; + self.time += 1; + self.allocations.insert(allocation_id, allocation); + Some(BlockAllocationWithCache { blocks: prefix_cache_blocks, slots, - prefix_hashes, + allocation: allocation_id, }) } @@ -347,15 +369,40 @@ impl PrefixCacheAllocator { ) } - pub fn free(&mut self, blocks: &[u32], prefix_hashes: &[u64]) { + pub fn free(&mut self, blocks: &[u32], allocation: u64) { + let allocation = match self.allocations.remove(&allocation) { + Some(allocation) => allocation, + None => unreachable!("Tried to free an unknown allocation."), + }; + let mut predecessor = None; - for (&prefix_hash, &block_id) in prefix_hashes.iter().zip(blocks.iter()) { + // Bookkeeping for prefill blocks that were retrieved from the cache. + for &prefix_hash in &allocation.cache_prefixes { + let state = match self.cache_blocks.get_mut(&prefix_hash) { + Some(state) => state, + None => unreachable!("Tried to free an unknown prefix block."), + }; + + state.last_accessed = self.time; + self.decref_prefix(prefix_hash); + predecessor = Some(prefix_hash); + } + + // Bookkeeping for prefill blocks that were new. + for (&block_id, &prefix_hash) in blocks[allocation.cache_prefixes.len()..] + .iter() + .zip(&allocation.new_prefixes) + { + // We can't simply cache the block. There may have been a concurrent + // allocation for the same prefix. match self.cache_blocks.entry(prefix_hash) { Entry::Occupied(mut entry) => { - let value = entry.get_mut(); - value.last_accessed = self.time; - self.decref_prefix(prefix_hash); + // A concurrent request added the prefix to the cache. We'll + // only update the last accessed time. + let state = entry.get_mut(); + // TODO: also update the time in the LRU set. + state.last_accessed = self.time; } Entry::Vacant(entry) => { entry.insert(PrefixBlockState { @@ -364,19 +411,17 @@ impl PrefixCacheAllocator { predecessor, ref_count: 0, }); + if let Some(predecessor) = predecessor { self.incref_prefix(predecessor); } self.leaves.insert((self.time, prefix_hash)); } - }; - + } predecessor = Some(prefix_hash); } - self.time += 1; - - for block in &blocks[prefix_hashes.len()..] { + for block in &blocks[allocation.cache_prefixes.len() + allocation.new_prefixes.len()..] { self.free_blocks.push(*block); } } @@ -394,7 +439,7 @@ mod tests { let allocation = cache.alloc(8, &[0, 1, 2, 3]).unwrap(); assert_eq!(allocation.blocks, vec![1, 2]); assert_eq!(allocation.slots, (4..12).collect::>()); - cache.free(&allocation.blocks, &allocation.prefix_hashes); + cache.free(&allocation.blocks, allocation.allocation); let allocation = cache.alloc(8, &[0, 1, 2, 3]).unwrap(); assert_eq!(allocation.blocks, vec![1, 2]); @@ -412,8 +457,8 @@ mod tests { assert_eq!(allocation2.blocks, vec![1]); assert_eq!(allocation2.slots, (2..4).collect::>()); - cache.free(&allocation1.blocks, &allocation1.prefix_hashes); - cache.free(&allocation2.blocks, &allocation2.prefix_hashes); + cache.free(&allocation1.blocks, allocation1.allocation); + cache.free(&allocation2.blocks, allocation2.allocation); // We should get the blocks of the first allocation, since they are more recent. let allocation3 = cache.alloc(4, &[6, 7, 8, 9]).unwrap(); @@ -427,7 +472,7 @@ mod tests { let allocation1 = cache.alloc(4, &[0, 1, 2, 3]).unwrap(); let allocation2 = cache.alloc(4, &[0, 1, 2, 3]).unwrap(); - cache.free(&allocation2.blocks, &allocation2.prefix_hashes); - cache.free(&allocation1.blocks, &allocation1.prefix_hashes); + cache.free(&allocation2.blocks, allocation2.allocation); + cache.free(&allocation1.blocks, allocation1.allocation); } }