From 9da64a7b166f9b2393e8ad0a412686215ede9a0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Tue, 9 Jul 2024 16:37:47 +0200 Subject: [PATCH] Basic test passes --- router/src/infer/v3/block_allocator.rs | 42 ++++++++++++++++++++------ 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index 9e27b8a6..8049b881 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -151,9 +151,11 @@ pub(crate) struct BlockAllocationWithCache { struct BlockState { block_id: u32, last_accessed: u64, + predecessor: Option, ref_count: usize, } +#[derive(Debug)] pub struct PrefixCache { block_size: usize, cache_partial: bool, @@ -180,10 +182,11 @@ impl PrefixCache { mut n_tokens: usize, prefill_tokens: &[u32], ) -> Option { - // First try to lookup prefix. let mut hasher = DefaultHasher::new(); let mut tokens_from_cache = 0; let mut cache_blocks_hashes = Vec::new(); + let mut prefix_cache_blocks = Vec::new(); + let mut prefix_hashes = Vec::new(); for prefill_chunk in prefill_tokens.chunks(self.block_size) { if prefill_chunk.len() < self.block_size && !self.cache_partial { break; @@ -193,22 +196,28 @@ impl PrefixCache { match self.cache_blocks.get_mut(&hasher.finish()) { Some(state) => { + // Ensure that we don't evict the prefix blocks. state.ref_count += 1; + prefix_cache_blocks.push(state.block_id); + prefix_hashes.push(hasher.finish()); } - None => todo!(), - } - - if !self.cache_blocks.contains_key(&hasher.finish()) { - break; + None => break, } tokens_from_cache += prefill_chunk.len(); - cache_blocks_hashes.push(hasher.finish()); } - let blocks = self.alloc_or_reclaim(n_tokens - tokens_from_cache)?; + let blocks = match self.alloc_or_reclaim(n_tokens - tokens_from_cache) { + Some(blocks) => blocks, + None => { + prefix_hashes.into_iter().for_each(|hash| { + self.cache_blocks.get_mut(&hash).unwrap().ref_count -= 1; + }); + + return None; + } + }; - let mut prefix_cache_blocks = Vec::new(); for hash in cache_blocks_hashes { match self.cache_blocks.get_mut(&hash) { Some(info) => { @@ -219,6 +228,9 @@ impl PrefixCache { } } + eprintln!("Allocated blocks: {:?}", blocks); + eprintln!("Prefix blocks: {:?}", prefix_cache_blocks); + prefix_cache_blocks.extend(blocks); let mut slots = Vec::with_capacity(n_tokens); @@ -279,7 +291,7 @@ impl PrefixCache { fn free(&mut self, blocks: &[u32], prefill_tokens: &[u32]) { let mut hasher = DefaultHasher::new(); - + let mut predecessor = None; for (prefill_chunk, block_id) in prefill_tokens.chunks(self.block_size).zip(blocks.iter()) { if prefill_chunk.len() < self.block_size && !self.cache_partial { break; @@ -298,13 +310,21 @@ impl PrefixCache { entry.insert(BlockState { block_id: *block_id, last_accessed: self.time, + predecessor, ref_count: 0, }); } }; + + predecessor = Some(hasher.finish()); } self.time += 1; + + let n_prefill_blocks = (prefill_tokens.len() + self.block_size - 1) / self.block_size; + for block in &blocks[n_prefill_blocks..] { + self.free_blocks.push(*block); + } } } @@ -327,6 +347,8 @@ mod tests { ); cache.free(&allocation.unwrap().blocks, &[0, 1, 2, 3]); + eprintln!("{:?}", cache); + let allocation = cache.alloc(8, &[0, 1, 2, 3]); assert_eq!( allocation,