diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index 54652419..022625dc 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -171,9 +171,6 @@ pub struct PrefixCache { /// on its preceding prefix block. cache_blocks: HashMap, - /// Whether to cache partial blocks. - cache_partial: bool, - /// Blocks that are immediately available for allocation. free_blocks: Vec, @@ -185,11 +182,10 @@ pub struct PrefixCache { } impl PrefixCache { - pub fn new(block_size: usize, n_blocks: usize, cache_partial: bool) -> Self { + pub fn new(block_size: usize, n_blocks: usize) -> Self { PrefixCache { block_size, cache_blocks: HashMap::new(), - cache_partial, free_blocks: (1..n_blocks as u32).collect(), leaves: HashSet::new(), time: 0, @@ -202,11 +198,10 @@ impl PrefixCache { prefill_tokens: &[u32], ) -> Option { let mut hasher = DefaultHasher::new(); - let mut tokens_from_cache = 0; let mut prefix_cache_blocks = Vec::new(); let mut prefix_hashes = Vec::new(); for prefill_chunk in prefill_tokens.chunks(self.block_size) { - if prefill_chunk.len() < self.block_size && !self.cache_partial { + if prefill_chunk.len() < self.block_size { break; } @@ -225,12 +220,11 @@ impl PrefixCache { self.incref_prefix(prefix_hash); prefix_hashes.push(prefix_hash); prefix_cache_blocks.push(block_id); - - tokens_from_cache += prefill_chunk.len(); } // Get tokens for the remaining prefill and decode. - let blocks = match self.alloc_or_reclaim(n_tokens - tokens_from_cache) { + let blocks = match self.alloc_or_reclaim(n_tokens - (prefix_hashes.len() * self.block_size)) + { Some(blocks) => blocks, None => { // If the allocation fails, we have relinquish our use of the @@ -247,7 +241,6 @@ impl PrefixCache { let mut slots = Vec::with_capacity(n_tokens); for block_id in prefix_cache_blocks.iter() { - // TODO: fixme: doesn't work with cache_partial yet. for s in (*block_id * self.block_size as u32)..((*block_id + 1) * self.block_size as u32) { @@ -301,7 +294,7 @@ impl PrefixCache { fn alloc_or_reclaim(&mut self, n_tokens: usize) -> Option> { let n_blocks = (n_tokens + self.block_size - 1) / self.block_size; - let n_blocks_needed = if n_tokens > self.free_blocks.len() { + let n_blocks_needed = if n_blocks > self.free_blocks.len() { n_blocks - self.free_blocks.len() } else { 0 @@ -337,7 +330,7 @@ impl PrefixCache { let mut hasher = DefaultHasher::new(); let mut predecessor = None; for (prefill_chunk, block_id) in prefill_tokens.chunks(self.block_size).zip(blocks.iter()) { - if prefill_chunk.len() < self.block_size && !self.cache_partial { + if prefill_chunk.len() < self.block_size { break; } @@ -357,6 +350,9 @@ impl PrefixCache { predecessor, ref_count: 0, }); + if let Some(predecessor) = predecessor { + self.incref_prefix(predecessor); + } self.leaves.insert(prefix_hash); } }; @@ -381,7 +377,7 @@ mod tests { #[test] fn test_prefix_cache() { - let mut cache = PrefixCache::new(4, 3, false); + let mut cache = PrefixCache::new(4, 3); let allocation = cache.alloc(8, &[0, 1, 2, 3]); assert_eq!( allocation, @@ -392,8 +388,6 @@ mod tests { ); cache.free(&allocation.unwrap().blocks, &[0, 1, 2, 3]); - eprintln!("{:?}", cache); - let allocation = cache.alloc(8, &[0, 1, 2, 3]); assert_eq!( allocation, @@ -403,4 +397,39 @@ mod tests { }) ); } + + #[test] + fn test_older_prefixes_are_collected_first() { + let mut cache = PrefixCache::new(2, 4); + let allocation1 = cache.alloc(4, &[0, 1, 2, 3]); + assert_eq!( + allocation1, + Some(BlockAllocationWithCache { + blocks: vec![2, 3], + slots: (4..8).collect() + }) + ); + + let allocation2 = cache.alloc(2, &[4, 5]); + assert_eq!( + allocation2, + Some(BlockAllocationWithCache { + blocks: vec![1], + slots: (2..4).collect() + }) + ); + + cache.free(&allocation1.unwrap().blocks, &[0, 1, 2, 3]); + cache.free(&allocation2.unwrap().blocks, &[4, 5]); + + // We should get the blocks of the first allocation, since they are more recent. + let allocation3 = cache.alloc(4, &[6, 7, 8, 9]); + assert_eq!( + allocation3, + Some(BlockAllocationWithCache { + blocks: vec![3, 2], + slots: vec![6, 7, 4, 5] + }) + ); + } }