From 7c046c9190a36938a626cd7b1ab08f5c752ca0ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Mon, 15 Jul 2024 14:14:10 +0000 Subject: [PATCH] First step towards cleaning up Breaks tests, but I want to shuffle around data structures so that we can just pass block ids to free. --- router/src/infer/v3/block_allocator.rs | 35 +++++++++++++------------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index b9a51668..0c3cf288 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -174,6 +174,7 @@ enum BlockAllocatorCommand { pub struct BlockAllocationWithCache { pub blocks: Vec, pub slots: Vec, + pub prefix_hashes: Vec, } #[derive(Debug)] @@ -225,7 +226,7 @@ impl PrefixCacheAllocator { } } - fn alloc( + pub fn alloc( &mut self, n_tokens: usize, prefill_tokens: &[u32], @@ -287,6 +288,7 @@ impl PrefixCacheAllocator { Some(BlockAllocationWithCache { blocks: prefix_cache_blocks, slots, + prefix_hashes, }) } @@ -343,17 +345,10 @@ impl PrefixCacheAllocator { ) } - fn free(&mut self, blocks: &[u32], prefill_tokens: &[u32]) { - let mut hasher = DefaultHasher::new(); + pub fn free(&mut self, blocks: &[u32], prefix_hashes: &[u64]) { let mut predecessor = None; - for (prefill_chunk, block_id) in prefill_tokens.chunks(self.block_size).zip(blocks.iter()) { - if prefill_chunk.len() < self.block_size { - break; - } - - prefill_chunk.hash(&mut hasher); - let prefix_hash = hasher.finish(); + for (&prefix_hash, &block_id) in prefix_hashes.iter().zip(blocks.iter()) { match self.cache_blocks.entry(prefix_hash) { Entry::Occupied(mut entry) => { let value = entry.get_mut(); @@ -362,7 +357,7 @@ impl PrefixCacheAllocator { } Entry::Vacant(entry) => { entry.insert(PrefixBlockState { - block_id: *block_id, + block_id, last_accessed: self.time, predecessor, ref_count: 0, @@ -379,8 +374,7 @@ impl PrefixCacheAllocator { self.time += 1; - let n_prefill_blocks = (prefill_tokens.len() + self.block_size - 1) / self.block_size; - for block in &blocks[n_prefill_blocks..] { + for block in &blocks[prefix_hashes.len()..] { self.free_blocks.push(*block); } } @@ -400,7 +394,8 @@ mod tests { allocation, Some(BlockAllocationWithCache { blocks: vec![1, 2], - slots: (4..12).collect() + slots: (4..12).collect(), + prefix_hashes: Vec::new(), }) ); cache.free(&allocation.unwrap().blocks, &[0, 1, 2, 3]); @@ -410,7 +405,8 @@ mod tests { allocation, Some(BlockAllocationWithCache { blocks: vec![1, 2], - slots: (4..12).collect() + slots: (4..12).collect(), + prefix_hashes: Vec::new(), }) ); } @@ -423,7 +419,8 @@ mod tests { allocation1, Some(BlockAllocationWithCache { blocks: vec![2, 3], - slots: (4..8).collect() + slots: (4..8).collect(), + prefix_hashes: Vec::new(), }) ); @@ -432,7 +429,8 @@ mod tests { allocation2, Some(BlockAllocationWithCache { blocks: vec![1], - slots: (2..4).collect() + slots: (2..4).collect(), + prefix_hashes: Vec::new(), }) ); @@ -445,7 +443,8 @@ mod tests { allocation3, Some(BlockAllocationWithCache { blocks: vec![3, 2], - slots: vec![6, 7, 4, 5] + slots: vec![6, 7, 4, 5], + prefix_hashes: Vec::new(), }) ); }