Proper support for two allocations with overlapping prefixes

This commit is contained in:
Daniël de Kok 2024-07-16 11:40:35 +00:00
parent d4ce5389ce
commit dd2d6cfe40

View File

@ -174,7 +174,13 @@ enum BlockAllocatorCommand {
pub struct BlockAllocationWithCache {
pub blocks: Vec<u32>,
pub slots: Vec<u32>,
pub prefix_hashes: Vec<u64>,
pub allocation: u64,
}
#[derive(Debug)]
struct Allocation {
cache_prefixes: Vec<u64>,
new_prefixes: Vec<u64>,
}
#[derive(Debug)]
@ -192,6 +198,8 @@ struct PrefixBlockState {
#[derive(Debug)]
pub struct PrefixCacheAllocator {
allocations: HashMap<u64, Allocation>,
/// Size of a paged attention block.
block_size: usize,
@ -218,6 +226,7 @@ impl PrefixCacheAllocator {
}
PrefixCacheAllocator {
allocations: HashMap::new(),
block_size,
cache_blocks: HashMap::new(),
free_blocks: (1..n_blocks as u32).collect(),
@ -234,6 +243,7 @@ impl PrefixCacheAllocator {
let mut hasher = DefaultHasher::new();
let mut prefix_cache_blocks = Vec::new();
// Find hashes for all block_sized prefill chunks.
let mut prefix_hashes = Vec::new();
for prefill_chunk in prefill_tokens.chunks(self.block_size) {
if prefill_chunk.len() < self.block_size {
@ -259,14 +269,17 @@ impl PrefixCacheAllocator {
n_from_cache += 1;
}
let new_prefixes = prefix_hashes.split_off(n_from_cache);
let cache_prefixes = prefix_hashes;
// Get tokens for the remaining prefill and decode.
let blocks = match self.alloc_or_reclaim(n_tokens - (n_from_cache * self.block_size)) {
Some(blocks) => blocks,
None => {
// If the allocation fails, we have relinquish our use of the
// prefix cache blocks. Maybe we can do this using `Drop`?
for prefix_hash in &prefix_hashes[..n_from_cache] {
self.decref_prefix(*prefix_hash);
for prefix_hash in cache_prefixes {
self.decref_prefix(prefix_hash);
}
return None;
@ -287,10 +300,19 @@ impl PrefixCacheAllocator {
}
}
let allocation = Allocation {
cache_prefixes,
new_prefixes,
};
let allocation_id = self.time;
self.time += 1;
self.allocations.insert(allocation_id, allocation);
Some(BlockAllocationWithCache {
blocks: prefix_cache_blocks,
slots,
prefix_hashes,
allocation: allocation_id,
})
}
@ -347,15 +369,40 @@ impl PrefixCacheAllocator {
)
}
pub fn free(&mut self, blocks: &[u32], prefix_hashes: &[u64]) {
pub fn free(&mut self, blocks: &[u32], allocation: u64) {
let allocation = match self.allocations.remove(&allocation) {
Some(allocation) => allocation,
None => unreachable!("Tried to free an unknown allocation."),
};
let mut predecessor = None;
for (&prefix_hash, &block_id) in prefix_hashes.iter().zip(blocks.iter()) {
// Bookkeeping for prefill blocks that were retrieved from the cache.
for &prefix_hash in &allocation.cache_prefixes {
let state = match self.cache_blocks.get_mut(&prefix_hash) {
Some(state) => state,
None => unreachable!("Tried to free an unknown prefix block."),
};
state.last_accessed = self.time;
self.decref_prefix(prefix_hash);
predecessor = Some(prefix_hash);
}
// Bookkeeping for prefill blocks that were new.
for (&block_id, &prefix_hash) in blocks[allocation.cache_prefixes.len()..]
.iter()
.zip(&allocation.new_prefixes)
{
// We can't simply cache the block. There may have been a concurrent
// allocation for the same prefix.
match self.cache_blocks.entry(prefix_hash) {
Entry::Occupied(mut entry) => {
let value = entry.get_mut();
value.last_accessed = self.time;
self.decref_prefix(prefix_hash);
// A concurrent request added the prefix to the cache. We'll
// only update the last accessed time.
let state = entry.get_mut();
// TODO: also update the time in the LRU set.
state.last_accessed = self.time;
}
Entry::Vacant(entry) => {
entry.insert(PrefixBlockState {
@ -364,19 +411,17 @@ impl PrefixCacheAllocator {
predecessor,
ref_count: 0,
});
if let Some(predecessor) = predecessor {
self.incref_prefix(predecessor);
}
self.leaves.insert((self.time, prefix_hash));
}
};
}
predecessor = Some(prefix_hash);
}
self.time += 1;
for block in &blocks[prefix_hashes.len()..] {
for block in &blocks[allocation.cache_prefixes.len() + allocation.new_prefixes.len()..] {
self.free_blocks.push(*block);
}
}
@ -394,7 +439,7 @@ mod tests {
let allocation = cache.alloc(8, &[0, 1, 2, 3]).unwrap();
assert_eq!(allocation.blocks, vec![1, 2]);
assert_eq!(allocation.slots, (4..12).collect::<Vec<_>>());
cache.free(&allocation.blocks, &allocation.prefix_hashes);
cache.free(&allocation.blocks, allocation.allocation);
let allocation = cache.alloc(8, &[0, 1, 2, 3]).unwrap();
assert_eq!(allocation.blocks, vec![1, 2]);
@ -412,8 +457,8 @@ mod tests {
assert_eq!(allocation2.blocks, vec![1]);
assert_eq!(allocation2.slots, (2..4).collect::<Vec<_>>());
cache.free(&allocation1.blocks, &allocation1.prefix_hashes);
cache.free(&allocation2.blocks, &allocation2.prefix_hashes);
cache.free(&allocation1.blocks, allocation1.allocation);
cache.free(&allocation2.blocks, allocation2.allocation);
// We should get the blocks of the first allocation, since they are more recent.
let allocation3 = cache.alloc(4, &[6, 7, 8, 9]).unwrap();
@ -427,7 +472,7 @@ mod tests {
let allocation1 = cache.alloc(4, &[0, 1, 2, 3]).unwrap();
let allocation2 = cache.alloc(4, &[0, 1, 2, 3]).unwrap();
cache.free(&allocation2.blocks, &allocation2.prefix_hashes);
cache.free(&allocation1.blocks, &allocation1.prefix_hashes);
cache.free(&allocation2.blocks, allocation2.allocation);
cache.free(&allocation1.blocks, allocation1.allocation);
}
}