First step towards cleaning up

Breaks tests, but I want to shuffle around data structures so that
we can just pass block ids to free.
This commit is contained in:
Daniël de Kok 2024-07-15 14:14:10 +00:00
parent 05611f6b40
commit 7c046c9190

View File

@ -174,6 +174,7 @@ enum BlockAllocatorCommand {
pub struct BlockAllocationWithCache { pub struct BlockAllocationWithCache {
pub blocks: Vec<u32>, pub blocks: Vec<u32>,
pub slots: Vec<u32>, pub slots: Vec<u32>,
pub prefix_hashes: Vec<u64>,
} }
#[derive(Debug)] #[derive(Debug)]
@ -225,7 +226,7 @@ impl PrefixCacheAllocator {
} }
} }
fn alloc( pub fn alloc(
&mut self, &mut self,
n_tokens: usize, n_tokens: usize,
prefill_tokens: &[u32], prefill_tokens: &[u32],
@ -287,6 +288,7 @@ impl PrefixCacheAllocator {
Some(BlockAllocationWithCache { Some(BlockAllocationWithCache {
blocks: prefix_cache_blocks, blocks: prefix_cache_blocks,
slots, slots,
prefix_hashes,
}) })
} }
@ -343,17 +345,10 @@ impl PrefixCacheAllocator {
) )
} }
fn free(&mut self, blocks: &[u32], prefill_tokens: &[u32]) { pub fn free(&mut self, blocks: &[u32], prefix_hashes: &[u64]) {
let mut hasher = DefaultHasher::new();
let mut predecessor = None; let mut predecessor = None;
for (prefill_chunk, block_id) in prefill_tokens.chunks(self.block_size).zip(blocks.iter()) {
if prefill_chunk.len() < self.block_size {
break;
}
prefill_chunk.hash(&mut hasher);
let prefix_hash = hasher.finish();
for (&prefix_hash, &block_id) in prefix_hashes.iter().zip(blocks.iter()) {
match self.cache_blocks.entry(prefix_hash) { match self.cache_blocks.entry(prefix_hash) {
Entry::Occupied(mut entry) => { Entry::Occupied(mut entry) => {
let value = entry.get_mut(); let value = entry.get_mut();
@ -362,7 +357,7 @@ impl PrefixCacheAllocator {
} }
Entry::Vacant(entry) => { Entry::Vacant(entry) => {
entry.insert(PrefixBlockState { entry.insert(PrefixBlockState {
block_id: *block_id, block_id,
last_accessed: self.time, last_accessed: self.time,
predecessor, predecessor,
ref_count: 0, ref_count: 0,
@ -379,8 +374,7 @@ impl PrefixCacheAllocator {
self.time += 1; self.time += 1;
let n_prefill_blocks = (prefill_tokens.len() + self.block_size - 1) / self.block_size; for block in &blocks[prefix_hashes.len()..] {
for block in &blocks[n_prefill_blocks..] {
self.free_blocks.push(*block); self.free_blocks.push(*block);
} }
} }
@ -400,7 +394,8 @@ mod tests {
allocation, allocation,
Some(BlockAllocationWithCache { Some(BlockAllocationWithCache {
blocks: vec![1, 2], blocks: vec![1, 2],
slots: (4..12).collect() slots: (4..12).collect(),
prefix_hashes: Vec::new(),
}) })
); );
cache.free(&allocation.unwrap().blocks, &[0, 1, 2, 3]); cache.free(&allocation.unwrap().blocks, &[0, 1, 2, 3]);
@ -410,7 +405,8 @@ mod tests {
allocation, allocation,
Some(BlockAllocationWithCache { Some(BlockAllocationWithCache {
blocks: vec![1, 2], blocks: vec![1, 2],
slots: (4..12).collect() slots: (4..12).collect(),
prefix_hashes: Vec::new(),
}) })
); );
} }
@ -423,7 +419,8 @@ mod tests {
allocation1, allocation1,
Some(BlockAllocationWithCache { Some(BlockAllocationWithCache {
blocks: vec![2, 3], blocks: vec![2, 3],
slots: (4..8).collect() slots: (4..8).collect(),
prefix_hashes: Vec::new(),
}) })
); );
@ -432,7 +429,8 @@ mod tests {
allocation2, allocation2,
Some(BlockAllocationWithCache { Some(BlockAllocationWithCache {
blocks: vec![1], blocks: vec![1],
slots: (2..4).collect() slots: (2..4).collect(),
prefix_hashes: Vec::new(),
}) })
); );
@ -445,7 +443,8 @@ mod tests {
allocation3, allocation3,
Some(BlockAllocationWithCache { Some(BlockAllocationWithCache {
blocks: vec![3, 2], blocks: vec![3, 2],
slots: vec![6, 7, 4, 5] slots: vec![6, 7, 4, 5],
prefix_hashes: Vec::new(),
}) })
); );
} }