diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index 8049b881..f989738e 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -2,7 +2,7 @@ use itertools::Itertools; use std::{ borrow::BorrowMut, cmp::min, - collections::{hash_map::Entry, HashMap}, + collections::{hash_map::Entry, HashMap, HashSet}, hash::{DefaultHasher, Hash, Hasher}, }; use tokio::sync::{mpsc, oneshot}; @@ -160,6 +160,7 @@ pub struct PrefixCache { block_size: usize, cache_partial: bool, free_blocks: Vec, + leaves: HashSet, cache_blocks: HashMap, // Avoid a system call, use a counter for time. @@ -173,18 +174,18 @@ impl PrefixCache { cache_blocks: HashMap::new(), cache_partial, free_blocks: (1..n_blocks as u32).collect(), + leaves: HashSet::new(), time: 0, } } fn alloc( &mut self, - mut n_tokens: usize, + n_tokens: usize, prefill_tokens: &[u32], ) -> Option { let mut hasher = DefaultHasher::new(); let mut tokens_from_cache = 0; - let mut cache_blocks_hashes = Vec::new(); let mut prefix_cache_blocks = Vec::new(); let mut prefix_hashes = Vec::new(); for prefill_chunk in prefill_tokens.chunks(self.block_size) { @@ -194,15 +195,14 @@ impl PrefixCache { prefill_chunk.hash(&mut hasher); - match self.cache_blocks.get_mut(&hasher.finish()) { - Some(state) => { - // Ensure that we don't evict the prefix blocks. - state.ref_count += 1; - prefix_cache_blocks.push(state.block_id); - prefix_hashes.push(hasher.finish()); - } + let block_id = match self.cache_blocks.get(&hasher.finish()) { + Some(state) => state.block_id, None => break, - } + }; + + self.incref_prefix(hasher.finish()); + prefix_hashes.push(hasher.finish()); + prefix_cache_blocks.push(block_id); tokens_from_cache += prefill_chunk.len(); } @@ -218,19 +218,6 @@ impl PrefixCache { } }; - for hash in cache_blocks_hashes { - match self.cache_blocks.get_mut(&hash) { - Some(info) => { - info.ref_count += 1; - prefix_cache_blocks.push(info.block_id); - } - None => unreachable!(), - } - } - - eprintln!("Allocated blocks: {:?}", blocks); - eprintln!("Prefix blocks: {:?}", prefix_cache_blocks); - prefix_cache_blocks.extend(blocks); let mut slots = Vec::with_capacity(n_tokens); @@ -251,6 +238,40 @@ impl PrefixCache { }) } + fn free_prefix(&mut self, prefix_hash: u64) { + let state = self + .cache_blocks + .remove(&prefix_hash) + .expect("Unknown hash"); + + if let Some(predecessor) = state.predecessor { + self.decref_prefix(predecessor); + } + + self.leaves.remove(&prefix_hash); + } + + fn decref_prefix(&mut self, prefix_hash: u64) { + let state = self + .cache_blocks + .get_mut(&prefix_hash) + .expect("Unknown hash"); + assert!(state.ref_count > 0); + state.ref_count -= 1; + if state.ref_count == 0 { + self.leaves.insert(prefix_hash); + } + } + + fn incref_prefix(&mut self, prefix_hash: u64) { + let state = self + .cache_blocks + .get_mut(&prefix_hash) + .expect("Unknown hash"); + state.ref_count += 1; + self.leaves.remove(&prefix_hash); + } + fn alloc_or_reclaim(&mut self, n_tokens: usize) -> Option> { let n_blocks = (n_tokens + self.block_size - 1) / self.block_size; let n_blocks_needed = if n_tokens > self.free_blocks.len() { @@ -259,28 +280,24 @@ impl PrefixCache { 0 }; - if n_blocks_needed > 0 { - let removable_blocks = self - .cache_blocks - .iter_mut() - // Block must be unused. - .filter(|(_, state)| state.ref_count == 0) - // Remove most recent block first. - // TODO: we are not yet removing a prefix in reverse order. - .sorted_by_key(|(_, state)| state.last_accessed) - // Find enough candidates. - .take(n_blocks_needed) - .map(|(block_hash, block_state)| (*block_hash, block_state.block_id)) - .collect::>(); + while self.free_blocks.len() < n_blocks_needed { + // We have to free one block at a time, because removing the LRU + // prefix block may make available another prefix block that is + // LRU. + // + // TODO: switch to something like a binary heap to avoid sorting + // the set of leaves over and over again. - if removable_blocks.len() < n_blocks_needed { - return None; - } + let (lru_prefix_hash, lru_block_id) = self + .leaves + .iter() + .map(|prefix_hash| (prefix_hash, &self.cache_blocks[prefix_hash])) + .sorted_by_key(|state| state.1.last_accessed) + .map(|(prefix_hash, state)| (*prefix_hash, state.block_id)) + .next()?; - for (block_hash, block_id) in removable_blocks.into_iter() { - self.free_blocks.push(block_id); - self.cache_blocks.remove(&block_hash); - } + self.free_prefix(lru_prefix_hash); + self.free_blocks.push(lru_block_id); } Some(