docs/cleanups

This commit is contained in:
Daniël de Kok 2024-07-10 14:06:23 +02:00
parent 3b6bef4078
commit 6d0094e5d4

View File

@ -148,20 +148,37 @@ pub(crate) struct BlockAllocationWithCache {
} }
#[derive(Debug)] #[derive(Debug)]
struct BlockState { struct PrefixBlockState {
block_id: u32, block_id: u32,
/// Last prefix block use.
last_accessed: u64, last_accessed: u64,
/// Prefix predecessor (parent in the prefix trie).
predecessor: Option<u64>, predecessor: Option<u64>,
ref_count: usize, ref_count: usize,
} }
#[derive(Debug)] #[derive(Debug)]
pub struct PrefixCache { pub struct PrefixCache {
/// Size of a paged attention block.
block_size: usize, block_size: usize,
/// Blocks that cache a prefix with the given hash.
///
/// The blocks form a Merkle tree, because a prefix block is dependent
/// on its preceding prefix block.
cache_blocks: HashMap<u64, PrefixBlockState>,
/// Whether to cache partial blocks.
cache_partial: bool, cache_partial: bool,
/// Blocks that are immediately available for allocation.
free_blocks: Vec<u32>, free_blocks: Vec<u32>,
/// Prefix blocks with a reference count of zero.
leaves: HashSet<u64>, leaves: HashSet<u64>,
cache_blocks: HashMap<u64, BlockState>,
// Avoid a system call, use a counter for time. // Avoid a system call, use a counter for time.
time: u64, time: u64,
@ -195,24 +212,32 @@ impl PrefixCache {
prefill_chunk.hash(&mut hasher); prefill_chunk.hash(&mut hasher);
let block_id = match self.cache_blocks.get(&hasher.finish()) { let prefix_hash = hasher.finish();
let block_id = match self.cache_blocks.get(&prefix_hash) {
Some(state) => state.block_id, Some(state) => state.block_id,
None => break, None => break,
}; };
self.incref_prefix(hasher.finish()); // We have to acquire the prefixes blocks, even if the allocation fails
prefix_hashes.push(hasher.finish()); // later, otherwise the allocation below could garbage collect the
// prefix blocks.
self.incref_prefix(prefix_hash);
prefix_hashes.push(prefix_hash);
prefix_cache_blocks.push(block_id); prefix_cache_blocks.push(block_id);
tokens_from_cache += prefill_chunk.len(); tokens_from_cache += prefill_chunk.len();
} }
// Get tokens for the remaining prefill and decode.
let blocks = match self.alloc_or_reclaim(n_tokens - tokens_from_cache) { let blocks = match self.alloc_or_reclaim(n_tokens - tokens_from_cache) {
Some(blocks) => blocks, Some(blocks) => blocks,
None => { None => {
prefix_hashes.into_iter().for_each(|hash| { // If the allocation fails, we have relinquish our use of the
self.cache_blocks.get_mut(&hash).unwrap().ref_count -= 1; // prefix cache blocks. Maybe we can do this using `Drop`?
}); for prefix_hash in prefix_hashes {
self.decref_prefix(prefix_hash);
}
return None; return None;
} }
@ -222,6 +247,7 @@ impl PrefixCache {
let mut slots = Vec::with_capacity(n_tokens); let mut slots = Vec::with_capacity(n_tokens);
for block_id in prefix_cache_blocks.iter() { for block_id in prefix_cache_blocks.iter() {
// TODO: fixme: doesn't work with cache_partial yet.
for s in for s in
(*block_id * self.block_size as u32)..((*block_id + 1) * self.block_size as u32) (*block_id * self.block_size as u32)..((*block_id + 1) * self.block_size as u32)
{ {
@ -244,6 +270,7 @@ impl PrefixCache {
.remove(&prefix_hash) .remove(&prefix_hash)
.expect("Unknown hash"); .expect("Unknown hash");
// Parent has one user less.
if let Some(predecessor) = state.predecessor { if let Some(predecessor) = state.predecessor {
self.decref_prefix(predecessor); self.decref_prefix(predecessor);
} }
@ -315,25 +342,26 @@ impl PrefixCache {
} }
prefill_chunk.hash(&mut hasher); prefill_chunk.hash(&mut hasher);
let prefix_hash = hasher.finish();
match self.cache_blocks.entry(hasher.finish()) { match self.cache_blocks.entry(prefix_hash) {
Entry::Occupied(mut entry) => { Entry::Occupied(mut entry) => {
let value = entry.get_mut(); let value = entry.get_mut();
value.last_accessed = self.time; value.last_accessed = self.time;
assert!(value.ref_count > 0); self.decref_prefix(prefix_hash);
value.ref_count -= 1;
} }
Entry::Vacant(entry) => { Entry::Vacant(entry) => {
entry.insert(BlockState { entry.insert(PrefixBlockState {
block_id: *block_id, block_id: *block_id,
last_accessed: self.time, last_accessed: self.time,
predecessor, predecessor,
ref_count: 0, ref_count: 0,
}); });
self.leaves.insert(prefix_hash);
} }
}; };
predecessor = Some(hasher.finish()); predecessor = Some(prefix_hash);
} }
self.time += 1; self.time += 1;