diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index 34275df0..f6e45849 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -2,7 +2,7 @@ use itertools::Itertools; use std::{ borrow::BorrowMut, cmp::{min, Reverse}, - collections::{hash_map::Entry, BinaryHeap, HashMap, HashSet}, + collections::{hash_map::Entry, BTreeMap, BTreeSet, BinaryHeap, HashMap, HashSet}, hash::{DefaultHasher, Hash, Hasher}, }; use tokio::sync::{mpsc, oneshot}; @@ -142,7 +142,7 @@ enum BlockAllocatorCommand { } #[derive(Debug, Clone, Eq, PartialEq)] -pub(crate) struct BlockAllocationWithCache { +pub struct BlockAllocationWithCache { pub blocks: Vec, pub slots: Vec, } @@ -174,8 +174,8 @@ pub struct PrefixCache { /// Blocks that are immediately available for allocation. free_blocks: Vec, - /// Prefix blocks with a reference count of zero. - leaves: HashSet, + /// Prefix blocks with a reference count of zero, by staleness. + leaves: BTreeSet<(u64, u64)>, // Avoid a system call, use a counter for time. time: u64, @@ -187,7 +187,7 @@ impl PrefixCache { block_size, cache_blocks: HashMap::new(), free_blocks: (1..n_blocks as u32).collect(), - leaves: HashSet::new(), + leaves: BTreeSet::new(), time: 0, } } @@ -268,7 +268,7 @@ impl PrefixCache { self.decref_prefix(predecessor); } - self.leaves.remove(&prefix_hash); + self.leaves.remove(&(state.last_accessed, prefix_hash)); self.free_blocks.push(state.block_id); } @@ -280,7 +280,7 @@ impl PrefixCache { assert!(state.ref_count > 0); state.ref_count -= 1; if state.ref_count == 0 { - self.leaves.insert(prefix_hash); + self.leaves.insert((state.last_accessed, prefix_hash)); } } @@ -290,39 +290,18 @@ impl PrefixCache { .get_mut(&prefix_hash) .expect("Unknown hash"); state.ref_count += 1; - self.leaves.remove(&prefix_hash); + self.leaves.remove(&(state.last_accessed, prefix_hash)); } fn alloc_or_reclaim(&mut self, n_tokens: usize) -> Option> { let n_blocks_needed = (n_tokens + self.block_size - 1) / self.block_size; - let mut reclaimable_blocks = self - .leaves - .iter() - .map(|prefix_hash| { - let state = &self.cache_blocks[prefix_hash]; - Reverse((state.last_accessed, *prefix_hash, state.predecessor)) - }) - .collect::>(); - while self.free_blocks.len() < n_blocks_needed { // We have to free one block at a time because removing the LRU // prefix block may make available another prefix block that is // LRU. - let (_, lru_prefix_hash, predecessor) = reclaimable_blocks.pop()?.0; + let (_, lru_prefix_hash) = self.leaves.pop_first()?; self.free_prefix_block(lru_prefix_hash); - - // TODO: this is a leaky abstraction, avoid this. - if let Some(predecessor) = predecessor { - let state = &self.cache_blocks[&predecessor]; - if state.ref_count == 0 { - reclaimable_blocks.push(Reverse(( - state.last_accessed, - predecessor, - state.predecessor, - ))); - } - } } Some( @@ -358,7 +337,7 @@ impl PrefixCache { if let Some(predecessor) = predecessor { self.incref_prefix(predecessor); } - self.leaves.insert(prefix_hash); + self.leaves.insert((self.time, prefix_hash)); } };