This commit is contained in:
Daniël de Kok 2024-07-16 10:10:10 +00:00
parent 7c046c9190
commit 0e6ff1293a

View File

@ -233,6 +233,7 @@ impl PrefixCacheAllocator {
) -> Option<BlockAllocationWithCache> { ) -> Option<BlockAllocationWithCache> {
let mut hasher = DefaultHasher::new(); let mut hasher = DefaultHasher::new();
let mut prefix_cache_blocks = Vec::new(); let mut prefix_cache_blocks = Vec::new();
let mut prefix_hashes = Vec::new(); let mut prefix_hashes = Vec::new();
for prefill_chunk in prefill_tokens.chunks(self.block_size) { for prefill_chunk in prefill_tokens.chunks(self.block_size) {
if prefill_chunk.len() < self.block_size { if prefill_chunk.len() < self.block_size {
@ -240,10 +241,12 @@ impl PrefixCacheAllocator {
} }
prefill_chunk.hash(&mut hasher); prefill_chunk.hash(&mut hasher);
prefix_hashes.push(hasher.finish());
}
let prefix_hash = hasher.finish(); let mut n_from_cache = 0;
for prefix_hash in prefix_hashes.iter() {
let block_id = match self.cache_blocks.get(&prefix_hash) { let block_id = match self.cache_blocks.get(prefix_hash) {
Some(state) => state.block_id, Some(state) => state.block_id,
None => break, None => break,
}; };
@ -251,20 +254,19 @@ impl PrefixCacheAllocator {
// We have to acquire the prefixes blocks, even if the allocation fails // We have to acquire the prefixes blocks, even if the allocation fails
// later, otherwise the allocation below could garbage collect the // later, otherwise the allocation below could garbage collect the
// prefix blocks. // prefix blocks.
self.incref_prefix(prefix_hash); self.incref_prefix(*prefix_hash);
prefix_hashes.push(prefix_hash);
prefix_cache_blocks.push(block_id); prefix_cache_blocks.push(block_id);
n_from_cache += 1;
} }
// Get tokens for the remaining prefill and decode. // Get tokens for the remaining prefill and decode.
let blocks = match self.alloc_or_reclaim(n_tokens - (prefix_hashes.len() * self.block_size)) let blocks = match self.alloc_or_reclaim(n_tokens - (n_from_cache * self.block_size)) {
{
Some(blocks) => blocks, Some(blocks) => blocks,
None => { None => {
// If the allocation fails, we have relinquish our use of the // If the allocation fails, we have relinquish our use of the
// prefix cache blocks. Maybe we can do this using `Drop`? // prefix cache blocks. Maybe we can do this using `Drop`?
for prefix_hash in prefix_hashes { for prefix_hash in &prefix_hashes[..n_from_cache] {
self.decref_prefix(prefix_hash); self.decref_prefix(*prefix_hash);
} }
return None; return None;
@ -389,63 +391,33 @@ mod tests {
#[test] #[test]
fn test_prefix_cache() { fn test_prefix_cache() {
let mut cache = PrefixCacheAllocator::new(4, 3, None); let mut cache = PrefixCacheAllocator::new(4, 3, None);
let allocation = cache.alloc(8, &[0, 1, 2, 3]); let allocation = cache.alloc(8, &[0, 1, 2, 3]).unwrap();
assert_eq!( assert_eq!(allocation.blocks, vec![1, 2]);
allocation, assert_eq!(allocation.slots, (4..12).collect::<Vec<_>>());
Some(BlockAllocationWithCache { cache.free(&allocation.blocks, &allocation.prefix_hashes);
blocks: vec![1, 2],
slots: (4..12).collect(),
prefix_hashes: Vec::new(),
})
);
cache.free(&allocation.unwrap().blocks, &[0, 1, 2, 3]);
let allocation = cache.alloc(8, &[0, 1, 2, 3]); let allocation = cache.alloc(8, &[0, 1, 2, 3]).unwrap();
assert_eq!( assert_eq!(allocation.blocks, vec![1, 2]);
allocation, assert_eq!(allocation.slots, (4..12).collect::<Vec<_>>());
Some(BlockAllocationWithCache {
blocks: vec![1, 2],
slots: (4..12).collect(),
prefix_hashes: Vec::new(),
})
);
} }
#[test] #[test]
fn test_older_prefixes_are_collected_first() { fn test_older_prefixes_are_collected_first() {
let mut cache = PrefixCacheAllocator::new(2, 4, None); let mut cache = PrefixCacheAllocator::new(2, 4, None);
let allocation1 = cache.alloc(4, &[0, 1, 2, 3]); let allocation1 = cache.alloc(4, &[0, 1, 2, 3]).unwrap();
assert_eq!( assert_eq!(allocation1.blocks, vec![2, 3]);
allocation1, assert_eq!(allocation1.slots, (4..8).collect::<Vec<_>>());
Some(BlockAllocationWithCache {
blocks: vec![2, 3],
slots: (4..8).collect(),
prefix_hashes: Vec::new(),
})
);
let allocation2 = cache.alloc(2, &[4, 5]); let allocation2 = cache.alloc(2, &[4, 5]).unwrap();
assert_eq!( assert_eq!(allocation2.blocks, vec![1]);
allocation2, assert_eq!(allocation2.slots, (2..4).collect::<Vec<_>>());
Some(BlockAllocationWithCache {
blocks: vec![1],
slots: (2..4).collect(),
prefix_hashes: Vec::new(),
})
);
cache.free(&allocation1.unwrap().blocks, &[0, 1, 2, 3]); cache.free(&allocation1.blocks, &allocation1.prefix_hashes);
cache.free(&allocation2.unwrap().blocks, &[4, 5]); cache.free(&allocation2.blocks, &allocation2.prefix_hashes);
// We should get the blocks of the first allocation, since they are more recent. // We should get the blocks of the first allocation, since they are more recent.
let allocation3 = cache.alloc(4, &[6, 7, 8, 9]); let allocation3 = cache.alloc(4, &[6, 7, 8, 9]).unwrap();
assert_eq!( assert_eq!(allocation3.blocks, vec![3, 2]);
allocation3, assert_eq!(allocation3.slots, vec![6, 7, 4, 5]);
Some(BlockAllocationWithCache {
blocks: vec![3, 2],
slots: vec![6, 7, 4, 5],
prefix_hashes: Vec::new(),
})
);
} }
} }