From 6d0094e5d4865f00cfcc592016aaadbf779e34d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Wed, 10 Jul 2024 14:06:23 +0200 Subject: [PATCH] docs/cleanups --- router/src/infer/v3/block_allocator.rs | 54 +++++++++++++++++++------- 1 file changed, 41 insertions(+), 13 deletions(-) diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index f989738e..54652419 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -148,20 +148,37 @@ pub(crate) struct BlockAllocationWithCache { } #[derive(Debug)] -struct BlockState { +struct PrefixBlockState { block_id: u32, + + /// Last prefix block use. last_accessed: u64, + + /// Prefix predecessor (parent in the prefix trie). predecessor: Option, + ref_count: usize, } #[derive(Debug)] pub struct PrefixCache { + /// Size of a paged attention block. block_size: usize, + + /// Blocks that cache a prefix with the given hash. + /// + /// The blocks form a Merkle tree, because a prefix block is dependent + /// on its preceding prefix block. + cache_blocks: HashMap, + + /// Whether to cache partial blocks. cache_partial: bool, + + /// Blocks that are immediately available for allocation. free_blocks: Vec, + + /// Prefix blocks with a reference count of zero. leaves: HashSet, - cache_blocks: HashMap, // Avoid a system call, use a counter for time. time: u64, @@ -195,24 +212,32 @@ impl PrefixCache { prefill_chunk.hash(&mut hasher); - let block_id = match self.cache_blocks.get(&hasher.finish()) { + let prefix_hash = hasher.finish(); + + let block_id = match self.cache_blocks.get(&prefix_hash) { Some(state) => state.block_id, None => break, }; - self.incref_prefix(hasher.finish()); - prefix_hashes.push(hasher.finish()); + // We have to acquire the prefixes blocks, even if the allocation fails + // later, otherwise the allocation below could garbage collect the + // prefix blocks. + self.incref_prefix(prefix_hash); + prefix_hashes.push(prefix_hash); prefix_cache_blocks.push(block_id); tokens_from_cache += prefill_chunk.len(); } + // Get tokens for the remaining prefill and decode. let blocks = match self.alloc_or_reclaim(n_tokens - tokens_from_cache) { Some(blocks) => blocks, None => { - prefix_hashes.into_iter().for_each(|hash| { - self.cache_blocks.get_mut(&hash).unwrap().ref_count -= 1; - }); + // If the allocation fails, we have relinquish our use of the + // prefix cache blocks. Maybe we can do this using `Drop`? + for prefix_hash in prefix_hashes { + self.decref_prefix(prefix_hash); + } return None; } @@ -222,6 +247,7 @@ impl PrefixCache { let mut slots = Vec::with_capacity(n_tokens); for block_id in prefix_cache_blocks.iter() { + // TODO: fixme: doesn't work with cache_partial yet. for s in (*block_id * self.block_size as u32)..((*block_id + 1) * self.block_size as u32) { @@ -244,6 +270,7 @@ impl PrefixCache { .remove(&prefix_hash) .expect("Unknown hash"); + // Parent has one user less. if let Some(predecessor) = state.predecessor { self.decref_prefix(predecessor); } @@ -315,25 +342,26 @@ impl PrefixCache { } prefill_chunk.hash(&mut hasher); + let prefix_hash = hasher.finish(); - match self.cache_blocks.entry(hasher.finish()) { + match self.cache_blocks.entry(prefix_hash) { Entry::Occupied(mut entry) => { let value = entry.get_mut(); value.last_accessed = self.time; - assert!(value.ref_count > 0); - value.ref_count -= 1; + self.decref_prefix(prefix_hash); } Entry::Vacant(entry) => { - entry.insert(BlockState { + entry.insert(PrefixBlockState { block_id: *block_id, last_accessed: self.time, predecessor, ref_count: 0, }); + self.leaves.insert(prefix_hash); } }; - predecessor = Some(hasher.finish()); + predecessor = Some(prefix_hash); } self.time += 1;