mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Proper support for two allocations with overlapping prefixes
This commit is contained in:
parent
d4ce5389ce
commit
dd2d6cfe40
@ -174,7 +174,13 @@ enum BlockAllocatorCommand {
|
||||
pub struct BlockAllocationWithCache {
|
||||
pub blocks: Vec<u32>,
|
||||
pub slots: Vec<u32>,
|
||||
pub prefix_hashes: Vec<u64>,
|
||||
pub allocation: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Allocation {
|
||||
cache_prefixes: Vec<u64>,
|
||||
new_prefixes: Vec<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -192,6 +198,8 @@ struct PrefixBlockState {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PrefixCacheAllocator {
|
||||
allocations: HashMap<u64, Allocation>,
|
||||
|
||||
/// Size of a paged attention block.
|
||||
block_size: usize,
|
||||
|
||||
@ -218,6 +226,7 @@ impl PrefixCacheAllocator {
|
||||
}
|
||||
|
||||
PrefixCacheAllocator {
|
||||
allocations: HashMap::new(),
|
||||
block_size,
|
||||
cache_blocks: HashMap::new(),
|
||||
free_blocks: (1..n_blocks as u32).collect(),
|
||||
@ -234,6 +243,7 @@ impl PrefixCacheAllocator {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
let mut prefix_cache_blocks = Vec::new();
|
||||
|
||||
// Find hashes for all block_sized prefill chunks.
|
||||
let mut prefix_hashes = Vec::new();
|
||||
for prefill_chunk in prefill_tokens.chunks(self.block_size) {
|
||||
if prefill_chunk.len() < self.block_size {
|
||||
@ -259,14 +269,17 @@ impl PrefixCacheAllocator {
|
||||
n_from_cache += 1;
|
||||
}
|
||||
|
||||
let new_prefixes = prefix_hashes.split_off(n_from_cache);
|
||||
let cache_prefixes = prefix_hashes;
|
||||
|
||||
// Get tokens for the remaining prefill and decode.
|
||||
let blocks = match self.alloc_or_reclaim(n_tokens - (n_from_cache * self.block_size)) {
|
||||
Some(blocks) => blocks,
|
||||
None => {
|
||||
// If the allocation fails, we have relinquish our use of the
|
||||
// prefix cache blocks. Maybe we can do this using `Drop`?
|
||||
for prefix_hash in &prefix_hashes[..n_from_cache] {
|
||||
self.decref_prefix(*prefix_hash);
|
||||
for prefix_hash in cache_prefixes {
|
||||
self.decref_prefix(prefix_hash);
|
||||
}
|
||||
|
||||
return None;
|
||||
@ -287,10 +300,19 @@ impl PrefixCacheAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
let allocation = Allocation {
|
||||
cache_prefixes,
|
||||
new_prefixes,
|
||||
};
|
||||
|
||||
let allocation_id = self.time;
|
||||
self.time += 1;
|
||||
self.allocations.insert(allocation_id, allocation);
|
||||
|
||||
Some(BlockAllocationWithCache {
|
||||
blocks: prefix_cache_blocks,
|
||||
slots,
|
||||
prefix_hashes,
|
||||
allocation: allocation_id,
|
||||
})
|
||||
}
|
||||
|
||||
@ -347,15 +369,40 @@ impl PrefixCacheAllocator {
|
||||
)
|
||||
}
|
||||
|
||||
pub fn free(&mut self, blocks: &[u32], prefix_hashes: &[u64]) {
|
||||
pub fn free(&mut self, blocks: &[u32], allocation: u64) {
|
||||
let allocation = match self.allocations.remove(&allocation) {
|
||||
Some(allocation) => allocation,
|
||||
None => unreachable!("Tried to free an unknown allocation."),
|
||||
};
|
||||
|
||||
let mut predecessor = None;
|
||||
|
||||
for (&prefix_hash, &block_id) in prefix_hashes.iter().zip(blocks.iter()) {
|
||||
// Bookkeeping for prefill blocks that were retrieved from the cache.
|
||||
for &prefix_hash in &allocation.cache_prefixes {
|
||||
let state = match self.cache_blocks.get_mut(&prefix_hash) {
|
||||
Some(state) => state,
|
||||
None => unreachable!("Tried to free an unknown prefix block."),
|
||||
};
|
||||
|
||||
state.last_accessed = self.time;
|
||||
self.decref_prefix(prefix_hash);
|
||||
predecessor = Some(prefix_hash);
|
||||
}
|
||||
|
||||
// Bookkeeping for prefill blocks that were new.
|
||||
for (&block_id, &prefix_hash) in blocks[allocation.cache_prefixes.len()..]
|
||||
.iter()
|
||||
.zip(&allocation.new_prefixes)
|
||||
{
|
||||
// We can't simply cache the block. There may have been a concurrent
|
||||
// allocation for the same prefix.
|
||||
match self.cache_blocks.entry(prefix_hash) {
|
||||
Entry::Occupied(mut entry) => {
|
||||
let value = entry.get_mut();
|
||||
value.last_accessed = self.time;
|
||||
self.decref_prefix(prefix_hash);
|
||||
// A concurrent request added the prefix to the cache. We'll
|
||||
// only update the last accessed time.
|
||||
let state = entry.get_mut();
|
||||
// TODO: also update the time in the LRU set.
|
||||
state.last_accessed = self.time;
|
||||
}
|
||||
Entry::Vacant(entry) => {
|
||||
entry.insert(PrefixBlockState {
|
||||
@ -364,19 +411,17 @@ impl PrefixCacheAllocator {
|
||||
predecessor,
|
||||
ref_count: 0,
|
||||
});
|
||||
|
||||
if let Some(predecessor) = predecessor {
|
||||
self.incref_prefix(predecessor);
|
||||
}
|
||||
self.leaves.insert((self.time, prefix_hash));
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
predecessor = Some(prefix_hash);
|
||||
}
|
||||
|
||||
self.time += 1;
|
||||
|
||||
for block in &blocks[prefix_hashes.len()..] {
|
||||
for block in &blocks[allocation.cache_prefixes.len() + allocation.new_prefixes.len()..] {
|
||||
self.free_blocks.push(*block);
|
||||
}
|
||||
}
|
||||
@ -394,7 +439,7 @@ mod tests {
|
||||
let allocation = cache.alloc(8, &[0, 1, 2, 3]).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![1, 2]);
|
||||
assert_eq!(allocation.slots, (4..12).collect::<Vec<_>>());
|
||||
cache.free(&allocation.blocks, &allocation.prefix_hashes);
|
||||
cache.free(&allocation.blocks, allocation.allocation);
|
||||
|
||||
let allocation = cache.alloc(8, &[0, 1, 2, 3]).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![1, 2]);
|
||||
@ -412,8 +457,8 @@ mod tests {
|
||||
assert_eq!(allocation2.blocks, vec![1]);
|
||||
assert_eq!(allocation2.slots, (2..4).collect::<Vec<_>>());
|
||||
|
||||
cache.free(&allocation1.blocks, &allocation1.prefix_hashes);
|
||||
cache.free(&allocation2.blocks, &allocation2.prefix_hashes);
|
||||
cache.free(&allocation1.blocks, allocation1.allocation);
|
||||
cache.free(&allocation2.blocks, allocation2.allocation);
|
||||
|
||||
// We should get the blocks of the first allocation, since they are more recent.
|
||||
let allocation3 = cache.alloc(4, &[6, 7, 8, 9]).unwrap();
|
||||
@ -427,7 +472,7 @@ mod tests {
|
||||
let allocation1 = cache.alloc(4, &[0, 1, 2, 3]).unwrap();
|
||||
let allocation2 = cache.alloc(4, &[0, 1, 2, 3]).unwrap();
|
||||
|
||||
cache.free(&allocation2.blocks, &allocation2.prefix_hashes);
|
||||
cache.free(&allocation1.blocks, &allocation1.prefix_hashes);
|
||||
cache.free(&allocation2.blocks, allocation2.allocation);
|
||||
cache.free(&allocation1.blocks, allocation1.allocation);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user