diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index 022625dc..34275df0 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -1,8 +1,8 @@ use itertools::Itertools; use std::{ borrow::BorrowMut, - cmp::min, - collections::{hash_map::Entry, HashMap, HashSet}, + cmp::{min, Reverse}, + collections::{hash_map::Entry, BinaryHeap, HashMap, HashSet}, hash::{DefaultHasher, Hash, Hasher}, }; use tokio::sync::{mpsc, oneshot}; @@ -257,7 +257,7 @@ impl PrefixCache { }) } - fn free_prefix(&mut self, prefix_hash: u64) { + fn free_prefix_block(&mut self, prefix_hash: u64) { let state = self .cache_blocks .remove(&prefix_hash) @@ -269,6 +269,7 @@ impl PrefixCache { } self.leaves.remove(&prefix_hash); + self.free_blocks.push(state.block_id); } fn decref_prefix(&mut self, prefix_hash: u64) { @@ -293,36 +294,40 @@ impl PrefixCache { } fn alloc_or_reclaim(&mut self, n_tokens: usize) -> Option> { - let n_blocks = (n_tokens + self.block_size - 1) / self.block_size; - let n_blocks_needed = if n_blocks > self.free_blocks.len() { - n_blocks - self.free_blocks.len() - } else { - 0 - }; + let n_blocks_needed = (n_tokens + self.block_size - 1) / self.block_size; + + let mut reclaimable_blocks = self + .leaves + .iter() + .map(|prefix_hash| { + let state = &self.cache_blocks[prefix_hash]; + Reverse((state.last_accessed, *prefix_hash, state.predecessor)) + }) + .collect::>(); while self.free_blocks.len() < n_blocks_needed { - // We have to free one block at a time, because removing the LRU + // We have to free one block at a time because removing the LRU // prefix block may make available another prefix block that is // LRU. - // - // TODO: switch to something like a binary heap to avoid sorting - // the set of leaves over and over again. + let (_, lru_prefix_hash, predecessor) = reclaimable_blocks.pop()?.0; + self.free_prefix_block(lru_prefix_hash); - let (lru_prefix_hash, lru_block_id) = self - .leaves - .iter() - .map(|prefix_hash| (prefix_hash, &self.cache_blocks[prefix_hash])) - .sorted_by_key(|state| state.1.last_accessed) - .map(|(prefix_hash, state)| (*prefix_hash, state.block_id)) - .next()?; - - self.free_prefix(lru_prefix_hash); - self.free_blocks.push(lru_block_id); + // TODO: this is a leaky abstraction, avoid this. + if let Some(predecessor) = predecessor { + let state = &self.cache_blocks[&predecessor]; + if state.ref_count == 0 { + reclaimable_blocks.push(Reverse(( + state.last_accessed, + predecessor, + state.predecessor, + ))); + } + } } Some( self.free_blocks - .split_off(self.free_blocks.len() - n_blocks), + .split_off(self.free_blocks.len() - n_blocks_needed), ) }