Better leaf tracking

This commit is contained in:
Daniël de Kok 2024-07-12 16:03:21 +02:00
parent 1a461234d5
commit 3b4754cd31

View File

@ -2,7 +2,7 @@ use itertools::Itertools;
use std::{ use std::{
borrow::BorrowMut, borrow::BorrowMut,
cmp::{min, Reverse}, cmp::{min, Reverse},
collections::{hash_map::Entry, BinaryHeap, HashMap, HashSet}, collections::{hash_map::Entry, BTreeMap, BTreeSet, BinaryHeap, HashMap, HashSet},
hash::{DefaultHasher, Hash, Hasher}, hash::{DefaultHasher, Hash, Hasher},
}; };
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
@ -142,7 +142,7 @@ enum BlockAllocatorCommand {
} }
#[derive(Debug, Clone, Eq, PartialEq)] #[derive(Debug, Clone, Eq, PartialEq)]
pub(crate) struct BlockAllocationWithCache { pub struct BlockAllocationWithCache {
pub blocks: Vec<u32>, pub blocks: Vec<u32>,
pub slots: Vec<u32>, pub slots: Vec<u32>,
} }
@ -174,8 +174,8 @@ pub struct PrefixCache {
/// Blocks that are immediately available for allocation. /// Blocks that are immediately available for allocation.
free_blocks: Vec<u32>, free_blocks: Vec<u32>,
/// Prefix blocks with a reference count of zero. /// Prefix blocks with a reference count of zero, by staleness.
leaves: HashSet<u64>, leaves: BTreeSet<(u64, u64)>,
// Avoid a system call, use a counter for time. // Avoid a system call, use a counter for time.
time: u64, time: u64,
@ -187,7 +187,7 @@ impl PrefixCache {
block_size, block_size,
cache_blocks: HashMap::new(), cache_blocks: HashMap::new(),
free_blocks: (1..n_blocks as u32).collect(), free_blocks: (1..n_blocks as u32).collect(),
leaves: HashSet::new(), leaves: BTreeSet::new(),
time: 0, time: 0,
} }
} }
@ -268,7 +268,7 @@ impl PrefixCache {
self.decref_prefix(predecessor); self.decref_prefix(predecessor);
} }
self.leaves.remove(&prefix_hash); self.leaves.remove(&(state.last_accessed, prefix_hash));
self.free_blocks.push(state.block_id); self.free_blocks.push(state.block_id);
} }
@ -280,7 +280,7 @@ impl PrefixCache {
assert!(state.ref_count > 0); assert!(state.ref_count > 0);
state.ref_count -= 1; state.ref_count -= 1;
if state.ref_count == 0 { if state.ref_count == 0 {
self.leaves.insert(prefix_hash); self.leaves.insert((state.last_accessed, prefix_hash));
} }
} }
@ -290,39 +290,18 @@ impl PrefixCache {
.get_mut(&prefix_hash) .get_mut(&prefix_hash)
.expect("Unknown hash"); .expect("Unknown hash");
state.ref_count += 1; state.ref_count += 1;
self.leaves.remove(&prefix_hash); self.leaves.remove(&(state.last_accessed, prefix_hash));
} }
fn alloc_or_reclaim(&mut self, n_tokens: usize) -> Option<Vec<u32>> { fn alloc_or_reclaim(&mut self, n_tokens: usize) -> Option<Vec<u32>> {
let n_blocks_needed = (n_tokens + self.block_size - 1) / self.block_size; let n_blocks_needed = (n_tokens + self.block_size - 1) / self.block_size;
let mut reclaimable_blocks = self
.leaves
.iter()
.map(|prefix_hash| {
let state = &self.cache_blocks[prefix_hash];
Reverse((state.last_accessed, *prefix_hash, state.predecessor))
})
.collect::<BinaryHeap<_>>();
while self.free_blocks.len() < n_blocks_needed { while self.free_blocks.len() < n_blocks_needed {
// We have to free one block at a time because removing the LRU // We have to free one block at a time because removing the LRU
// prefix block may make available another prefix block that is // prefix block may make available another prefix block that is
// LRU. // LRU.
let (_, lru_prefix_hash, predecessor) = reclaimable_blocks.pop()?.0; let (_, lru_prefix_hash) = self.leaves.pop_first()?;
self.free_prefix_block(lru_prefix_hash); self.free_prefix_block(lru_prefix_hash);
// TODO: this is a leaky abstraction, avoid this.
if let Some(predecessor) = predecessor {
let state = &self.cache_blocks[&predecessor];
if state.ref_count == 0 {
reclaimable_blocks.push(Reverse((
state.last_accessed,
predecessor,
state.predecessor,
)));
}
}
} }
Some( Some(
@ -358,7 +337,7 @@ impl PrefixCache {
if let Some(predecessor) = predecessor { if let Some(predecessor) = predecessor {
self.incref_prefix(predecessor); self.incref_prefix(predecessor);
} }
self.leaves.insert(prefix_hash); self.leaves.insert((self.time, prefix_hash));
} }
}; };