Walk up to predecessors

This commit is contained in:
Daniël de Kok 2024-07-10 13:30:33 +02:00
parent 9da64a7b16
commit 3b6bef4078

View File

@ -2,7 +2,7 @@ use itertools::Itertools;
use std::{
borrow::BorrowMut,
cmp::min,
collections::{hash_map::Entry, HashMap},
collections::{hash_map::Entry, HashMap, HashSet},
hash::{DefaultHasher, Hash, Hasher},
};
use tokio::sync::{mpsc, oneshot};
@ -160,6 +160,7 @@ pub struct PrefixCache {
block_size: usize,
cache_partial: bool,
free_blocks: Vec<u32>,
leaves: HashSet<u64>,
cache_blocks: HashMap<u64, BlockState>,
// Avoid a system call, use a counter for time.
@ -173,18 +174,18 @@ impl PrefixCache {
cache_blocks: HashMap::new(),
cache_partial,
free_blocks: (1..n_blocks as u32).collect(),
leaves: HashSet::new(),
time: 0,
}
}
fn alloc(
&mut self,
mut n_tokens: usize,
n_tokens: usize,
prefill_tokens: &[u32],
) -> Option<BlockAllocationWithCache> {
let mut hasher = DefaultHasher::new();
let mut tokens_from_cache = 0;
let mut cache_blocks_hashes = Vec::new();
let mut prefix_cache_blocks = Vec::new();
let mut prefix_hashes = Vec::new();
for prefill_chunk in prefill_tokens.chunks(self.block_size) {
@ -194,15 +195,14 @@ impl PrefixCache {
prefill_chunk.hash(&mut hasher);
match self.cache_blocks.get_mut(&hasher.finish()) {
Some(state) => {
// Ensure that we don't evict the prefix blocks.
state.ref_count += 1;
prefix_cache_blocks.push(state.block_id);
prefix_hashes.push(hasher.finish());
}
let block_id = match self.cache_blocks.get(&hasher.finish()) {
Some(state) => state.block_id,
None => break,
}
};
self.incref_prefix(hasher.finish());
prefix_hashes.push(hasher.finish());
prefix_cache_blocks.push(block_id);
tokens_from_cache += prefill_chunk.len();
}
@ -218,19 +218,6 @@ impl PrefixCache {
}
};
for hash in cache_blocks_hashes {
match self.cache_blocks.get_mut(&hash) {
Some(info) => {
info.ref_count += 1;
prefix_cache_blocks.push(info.block_id);
}
None => unreachable!(),
}
}
eprintln!("Allocated blocks: {:?}", blocks);
eprintln!("Prefix blocks: {:?}", prefix_cache_blocks);
prefix_cache_blocks.extend(blocks);
let mut slots = Vec::with_capacity(n_tokens);
@ -251,6 +238,40 @@ impl PrefixCache {
})
}
fn free_prefix(&mut self, prefix_hash: u64) {
let state = self
.cache_blocks
.remove(&prefix_hash)
.expect("Unknown hash");
if let Some(predecessor) = state.predecessor {
self.decref_prefix(predecessor);
}
self.leaves.remove(&prefix_hash);
}
fn decref_prefix(&mut self, prefix_hash: u64) {
let state = self
.cache_blocks
.get_mut(&prefix_hash)
.expect("Unknown hash");
assert!(state.ref_count > 0);
state.ref_count -= 1;
if state.ref_count == 0 {
self.leaves.insert(prefix_hash);
}
}
fn incref_prefix(&mut self, prefix_hash: u64) {
let state = self
.cache_blocks
.get_mut(&prefix_hash)
.expect("Unknown hash");
state.ref_count += 1;
self.leaves.remove(&prefix_hash);
}
fn alloc_or_reclaim(&mut self, n_tokens: usize) -> Option<Vec<u32>> {
let n_blocks = (n_tokens + self.block_size - 1) / self.block_size;
let n_blocks_needed = if n_tokens > self.free_blocks.len() {
@ -259,28 +280,24 @@ impl PrefixCache {
0
};
if n_blocks_needed > 0 {
let removable_blocks = self
.cache_blocks
.iter_mut()
// Block must be unused.
.filter(|(_, state)| state.ref_count == 0)
// Remove most recent block first.
// TODO: we are not yet removing a prefix in reverse order.
.sorted_by_key(|(_, state)| state.last_accessed)
// Find enough candidates.
.take(n_blocks_needed)
.map(|(block_hash, block_state)| (*block_hash, block_state.block_id))
.collect::<Vec<_>>();
while self.free_blocks.len() < n_blocks_needed {
// We have to free one block at a time, because removing the LRU
// prefix block may make available another prefix block that is
// LRU.
//
// TODO: switch to something like a binary heap to avoid sorting
// the set of leaves over and over again.
if removable_blocks.len() < n_blocks_needed {
return None;
}
let (lru_prefix_hash, lru_block_id) = self
.leaves
.iter()
.map(|prefix_hash| (prefix_hash, &self.cache_blocks[prefix_hash]))
.sorted_by_key(|state| state.1.last_accessed)
.map(|(prefix_hash, state)| (*prefix_hash, state.block_id))
.next()?;
for (block_hash, block_id) in removable_blocks.into_iter() {
self.free_blocks.push(block_id);
self.cache_blocks.remove(&block_hash);
}
self.free_prefix(lru_prefix_hash);
self.free_blocks.push(lru_block_id);
}
Some(