This commit is contained in:
Daniël de Kok 2024-07-16 10:10:10 +00:00
parent 7c046c9190
commit 0e6ff1293a

View File

@ -233,6 +233,7 @@ impl PrefixCacheAllocator {
) -> Option<BlockAllocationWithCache> {
let mut hasher = DefaultHasher::new();
let mut prefix_cache_blocks = Vec::new();
let mut prefix_hashes = Vec::new();
for prefill_chunk in prefill_tokens.chunks(self.block_size) {
if prefill_chunk.len() < self.block_size {
@ -240,10 +241,12 @@ impl PrefixCacheAllocator {
}
prefill_chunk.hash(&mut hasher);
prefix_hashes.push(hasher.finish());
}
let prefix_hash = hasher.finish();
let block_id = match self.cache_blocks.get(&prefix_hash) {
let mut n_from_cache = 0;
for prefix_hash in prefix_hashes.iter() {
let block_id = match self.cache_blocks.get(prefix_hash) {
Some(state) => state.block_id,
None => break,
};
@ -251,20 +254,19 @@ impl PrefixCacheAllocator {
// We have to acquire the prefixes blocks, even if the allocation fails
// later, otherwise the allocation below could garbage collect the
// prefix blocks.
self.incref_prefix(prefix_hash);
prefix_hashes.push(prefix_hash);
self.incref_prefix(*prefix_hash);
prefix_cache_blocks.push(block_id);
n_from_cache += 1;
}
// Get tokens for the remaining prefill and decode.
let blocks = match self.alloc_or_reclaim(n_tokens - (prefix_hashes.len() * self.block_size))
{
let blocks = match self.alloc_or_reclaim(n_tokens - (n_from_cache * self.block_size)) {
Some(blocks) => blocks,
None => {
// If the allocation fails, we have relinquish our use of the
// prefix cache blocks. Maybe we can do this using `Drop`?
for prefix_hash in prefix_hashes {
self.decref_prefix(prefix_hash);
for prefix_hash in &prefix_hashes[..n_from_cache] {
self.decref_prefix(*prefix_hash);
}
return None;
@ -389,63 +391,33 @@ mod tests {
#[test]
fn test_prefix_cache() {
let mut cache = PrefixCacheAllocator::new(4, 3, None);
let allocation = cache.alloc(8, &[0, 1, 2, 3]);
assert_eq!(
allocation,
Some(BlockAllocationWithCache {
blocks: vec![1, 2],
slots: (4..12).collect(),
prefix_hashes: Vec::new(),
})
);
cache.free(&allocation.unwrap().blocks, &[0, 1, 2, 3]);
let allocation = cache.alloc(8, &[0, 1, 2, 3]).unwrap();
assert_eq!(allocation.blocks, vec![1, 2]);
assert_eq!(allocation.slots, (4..12).collect::<Vec<_>>());
cache.free(&allocation.blocks, &allocation.prefix_hashes);
let allocation = cache.alloc(8, &[0, 1, 2, 3]);
assert_eq!(
allocation,
Some(BlockAllocationWithCache {
blocks: vec![1, 2],
slots: (4..12).collect(),
prefix_hashes: Vec::new(),
})
);
let allocation = cache.alloc(8, &[0, 1, 2, 3]).unwrap();
assert_eq!(allocation.blocks, vec![1, 2]);
assert_eq!(allocation.slots, (4..12).collect::<Vec<_>>());
}
#[test]
fn test_older_prefixes_are_collected_first() {
let mut cache = PrefixCacheAllocator::new(2, 4, None);
let allocation1 = cache.alloc(4, &[0, 1, 2, 3]);
assert_eq!(
allocation1,
Some(BlockAllocationWithCache {
blocks: vec![2, 3],
slots: (4..8).collect(),
prefix_hashes: Vec::new(),
})
);
let allocation1 = cache.alloc(4, &[0, 1, 2, 3]).unwrap();
assert_eq!(allocation1.blocks, vec![2, 3]);
assert_eq!(allocation1.slots, (4..8).collect::<Vec<_>>());
let allocation2 = cache.alloc(2, &[4, 5]);
assert_eq!(
allocation2,
Some(BlockAllocationWithCache {
blocks: vec![1],
slots: (2..4).collect(),
prefix_hashes: Vec::new(),
})
);
let allocation2 = cache.alloc(2, &[4, 5]).unwrap();
assert_eq!(allocation2.blocks, vec![1]);
assert_eq!(allocation2.slots, (2..4).collect::<Vec<_>>());
cache.free(&allocation1.unwrap().blocks, &[0, 1, 2, 3]);
cache.free(&allocation2.unwrap().blocks, &[4, 5]);
cache.free(&allocation1.blocks, &allocation1.prefix_hashes);
cache.free(&allocation2.blocks, &allocation2.prefix_hashes);
// We should get the blocks of the first allocation, since they are more recent.
let allocation3 = cache.alloc(4, &[6, 7, 8, 9]);
assert_eq!(
allocation3,
Some(BlockAllocationWithCache {
blocks: vec![3, 2],
slots: vec![6, 7, 4, 5],
prefix_hashes: Vec::new(),
})
);
let allocation3 = cache.alloc(4, &[6, 7, 8, 9]).unwrap();
assert_eq!(allocation3.blocks, vec![3, 2]);
assert_eq!(allocation3.slots, vec![6, 7, 4, 5]);
}
}