From 0e6ff1293a7dcc089aa9f15bdf080e8fd4b1b932 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Tue, 16 Jul 2024 10:10:10 +0000 Subject: [PATCH] Fixes --- router/src/infer/v3/block_allocator.rs | 86 +++++++++----------------- 1 file changed, 29 insertions(+), 57 deletions(-) diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index 0c3cf288..65e75be8 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -233,6 +233,7 @@ impl PrefixCacheAllocator { ) -> Option { let mut hasher = DefaultHasher::new(); let mut prefix_cache_blocks = Vec::new(); + let mut prefix_hashes = Vec::new(); for prefill_chunk in prefill_tokens.chunks(self.block_size) { if prefill_chunk.len() < self.block_size { @@ -240,10 +241,12 @@ impl PrefixCacheAllocator { } prefill_chunk.hash(&mut hasher); + prefix_hashes.push(hasher.finish()); + } - let prefix_hash = hasher.finish(); - - let block_id = match self.cache_blocks.get(&prefix_hash) { + let mut n_from_cache = 0; + for prefix_hash in prefix_hashes.iter() { + let block_id = match self.cache_blocks.get(prefix_hash) { Some(state) => state.block_id, None => break, }; @@ -251,20 +254,19 @@ impl PrefixCacheAllocator { // We have to acquire the prefixes blocks, even if the allocation fails // later, otherwise the allocation below could garbage collect the // prefix blocks. - self.incref_prefix(prefix_hash); - prefix_hashes.push(prefix_hash); + self.incref_prefix(*prefix_hash); prefix_cache_blocks.push(block_id); + n_from_cache += 1; } // Get tokens for the remaining prefill and decode. - let blocks = match self.alloc_or_reclaim(n_tokens - (prefix_hashes.len() * self.block_size)) - { + let blocks = match self.alloc_or_reclaim(n_tokens - (n_from_cache * self.block_size)) { Some(blocks) => blocks, None => { // If the allocation fails, we have relinquish our use of the // prefix cache blocks. Maybe we can do this using `Drop`? - for prefix_hash in prefix_hashes { - self.decref_prefix(prefix_hash); + for prefix_hash in &prefix_hashes[..n_from_cache] { + self.decref_prefix(*prefix_hash); } return None; @@ -389,63 +391,33 @@ mod tests { #[test] fn test_prefix_cache() { let mut cache = PrefixCacheAllocator::new(4, 3, None); - let allocation = cache.alloc(8, &[0, 1, 2, 3]); - assert_eq!( - allocation, - Some(BlockAllocationWithCache { - blocks: vec![1, 2], - slots: (4..12).collect(), - prefix_hashes: Vec::new(), - }) - ); - cache.free(&allocation.unwrap().blocks, &[0, 1, 2, 3]); + let allocation = cache.alloc(8, &[0, 1, 2, 3]).unwrap(); + assert_eq!(allocation.blocks, vec![1, 2]); + assert_eq!(allocation.slots, (4..12).collect::>()); + cache.free(&allocation.blocks, &allocation.prefix_hashes); - let allocation = cache.alloc(8, &[0, 1, 2, 3]); - assert_eq!( - allocation, - Some(BlockAllocationWithCache { - blocks: vec![1, 2], - slots: (4..12).collect(), - prefix_hashes: Vec::new(), - }) - ); + let allocation = cache.alloc(8, &[0, 1, 2, 3]).unwrap(); + assert_eq!(allocation.blocks, vec![1, 2]); + assert_eq!(allocation.slots, (4..12).collect::>()); } #[test] fn test_older_prefixes_are_collected_first() { let mut cache = PrefixCacheAllocator::new(2, 4, None); - let allocation1 = cache.alloc(4, &[0, 1, 2, 3]); - assert_eq!( - allocation1, - Some(BlockAllocationWithCache { - blocks: vec![2, 3], - slots: (4..8).collect(), - prefix_hashes: Vec::new(), - }) - ); + let allocation1 = cache.alloc(4, &[0, 1, 2, 3]).unwrap(); + assert_eq!(allocation1.blocks, vec![2, 3]); + assert_eq!(allocation1.slots, (4..8).collect::>()); - let allocation2 = cache.alloc(2, &[4, 5]); - assert_eq!( - allocation2, - Some(BlockAllocationWithCache { - blocks: vec![1], - slots: (2..4).collect(), - prefix_hashes: Vec::new(), - }) - ); + let allocation2 = cache.alloc(2, &[4, 5]).unwrap(); + assert_eq!(allocation2.blocks, vec![1]); + assert_eq!(allocation2.slots, (2..4).collect::>()); - cache.free(&allocation1.unwrap().blocks, &[0, 1, 2, 3]); - cache.free(&allocation2.unwrap().blocks, &[4, 5]); + cache.free(&allocation1.blocks, &allocation1.prefix_hashes); + cache.free(&allocation2.blocks, &allocation2.prefix_hashes); // We should get the blocks of the first allocation, since they are more recent. - let allocation3 = cache.alloc(4, &[6, 7, 8, 9]); - assert_eq!( - allocation3, - Some(BlockAllocationWithCache { - blocks: vec![3, 2], - slots: vec![6, 7, 4, 5], - prefix_hashes: Vec::new(), - }) - ); + let allocation3 = cache.alloc(4, &[6, 7, 8, 9]).unwrap(); + assert_eq!(allocation3.blocks, vec![3, 2]); + assert_eq!(allocation3.slots, vec![6, 7, 4, 5]); } }