Basic test passes

This commit is contained in:
Daniël de Kok 2024-07-09 16:37:47 +02:00
parent dbb82e274c
commit 9da64a7b16

View File

@ -151,9 +151,11 @@ pub(crate) struct BlockAllocationWithCache {
struct BlockState {
block_id: u32,
last_accessed: u64,
predecessor: Option<u64>,
ref_count: usize,
}
#[derive(Debug)]
pub struct PrefixCache {
block_size: usize,
cache_partial: bool,
@ -180,10 +182,11 @@ impl PrefixCache {
mut n_tokens: usize,
prefill_tokens: &[u32],
) -> Option<BlockAllocationWithCache> {
// First try to lookup prefix.
let mut hasher = DefaultHasher::new();
let mut tokens_from_cache = 0;
let mut cache_blocks_hashes = Vec::new();
let mut prefix_cache_blocks = Vec::new();
let mut prefix_hashes = Vec::new();
for prefill_chunk in prefill_tokens.chunks(self.block_size) {
if prefill_chunk.len() < self.block_size && !self.cache_partial {
break;
@ -193,22 +196,28 @@ impl PrefixCache {
match self.cache_blocks.get_mut(&hasher.finish()) {
Some(state) => {
// Ensure that we don't evict the prefix blocks.
state.ref_count += 1;
prefix_cache_blocks.push(state.block_id);
prefix_hashes.push(hasher.finish());
}
None => todo!(),
}
if !self.cache_blocks.contains_key(&hasher.finish()) {
break;
None => break,
}
tokens_from_cache += prefill_chunk.len();
cache_blocks_hashes.push(hasher.finish());
}
let blocks = self.alloc_or_reclaim(n_tokens - tokens_from_cache)?;
let blocks = match self.alloc_or_reclaim(n_tokens - tokens_from_cache) {
Some(blocks) => blocks,
None => {
prefix_hashes.into_iter().for_each(|hash| {
self.cache_blocks.get_mut(&hash).unwrap().ref_count -= 1;
});
return None;
}
};
let mut prefix_cache_blocks = Vec::new();
for hash in cache_blocks_hashes {
match self.cache_blocks.get_mut(&hash) {
Some(info) => {
@ -219,6 +228,9 @@ impl PrefixCache {
}
}
eprintln!("Allocated blocks: {:?}", blocks);
eprintln!("Prefix blocks: {:?}", prefix_cache_blocks);
prefix_cache_blocks.extend(blocks);
let mut slots = Vec::with_capacity(n_tokens);
@ -279,7 +291,7 @@ impl PrefixCache {
fn free(&mut self, blocks: &[u32], prefill_tokens: &[u32]) {
let mut hasher = DefaultHasher::new();
let mut predecessor = None;
for (prefill_chunk, block_id) in prefill_tokens.chunks(self.block_size).zip(blocks.iter()) {
if prefill_chunk.len() < self.block_size && !self.cache_partial {
break;
@ -298,13 +310,21 @@ impl PrefixCache {
entry.insert(BlockState {
block_id: *block_id,
last_accessed: self.time,
predecessor,
ref_count: 0,
});
}
};
predecessor = Some(hasher.finish());
}
self.time += 1;
let n_prefill_blocks = (prefill_tokens.len() + self.block_size - 1) / self.block_size;
for block in &blocks[n_prefill_blocks..] {
self.free_blocks.push(*block);
}
}
}
@ -327,6 +347,8 @@ mod tests {
);
cache.free(&allocation.unwrap().blocks, &[0, 1, 2, 3]);
eprintln!("{:?}", cache);
let allocation = cache.alloc(8, &[0, 1, 2, 3]);
assert_eq!(
allocation,