mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Proper support for two allocations with overlapping prefixes
This commit is contained in:
parent
d4ce5389ce
commit
dd2d6cfe40
@ -174,7 +174,13 @@ enum BlockAllocatorCommand {
|
|||||||
pub struct BlockAllocationWithCache {
|
pub struct BlockAllocationWithCache {
|
||||||
pub blocks: Vec<u32>,
|
pub blocks: Vec<u32>,
|
||||||
pub slots: Vec<u32>,
|
pub slots: Vec<u32>,
|
||||||
pub prefix_hashes: Vec<u64>,
|
pub allocation: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Allocation {
|
||||||
|
cache_prefixes: Vec<u64>,
|
||||||
|
new_prefixes: Vec<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -192,6 +198,8 @@ struct PrefixBlockState {
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct PrefixCacheAllocator {
|
pub struct PrefixCacheAllocator {
|
||||||
|
allocations: HashMap<u64, Allocation>,
|
||||||
|
|
||||||
/// Size of a paged attention block.
|
/// Size of a paged attention block.
|
||||||
block_size: usize,
|
block_size: usize,
|
||||||
|
|
||||||
@ -218,6 +226,7 @@ impl PrefixCacheAllocator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
PrefixCacheAllocator {
|
PrefixCacheAllocator {
|
||||||
|
allocations: HashMap::new(),
|
||||||
block_size,
|
block_size,
|
||||||
cache_blocks: HashMap::new(),
|
cache_blocks: HashMap::new(),
|
||||||
free_blocks: (1..n_blocks as u32).collect(),
|
free_blocks: (1..n_blocks as u32).collect(),
|
||||||
@ -234,6 +243,7 @@ impl PrefixCacheAllocator {
|
|||||||
let mut hasher = DefaultHasher::new();
|
let mut hasher = DefaultHasher::new();
|
||||||
let mut prefix_cache_blocks = Vec::new();
|
let mut prefix_cache_blocks = Vec::new();
|
||||||
|
|
||||||
|
// Find hashes for all block_sized prefill chunks.
|
||||||
let mut prefix_hashes = Vec::new();
|
let mut prefix_hashes = Vec::new();
|
||||||
for prefill_chunk in prefill_tokens.chunks(self.block_size) {
|
for prefill_chunk in prefill_tokens.chunks(self.block_size) {
|
||||||
if prefill_chunk.len() < self.block_size {
|
if prefill_chunk.len() < self.block_size {
|
||||||
@ -259,14 +269,17 @@ impl PrefixCacheAllocator {
|
|||||||
n_from_cache += 1;
|
n_from_cache += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let new_prefixes = prefix_hashes.split_off(n_from_cache);
|
||||||
|
let cache_prefixes = prefix_hashes;
|
||||||
|
|
||||||
// Get tokens for the remaining prefill and decode.
|
// Get tokens for the remaining prefill and decode.
|
||||||
let blocks = match self.alloc_or_reclaim(n_tokens - (n_from_cache * self.block_size)) {
|
let blocks = match self.alloc_or_reclaim(n_tokens - (n_from_cache * self.block_size)) {
|
||||||
Some(blocks) => blocks,
|
Some(blocks) => blocks,
|
||||||
None => {
|
None => {
|
||||||
// If the allocation fails, we have relinquish our use of the
|
// If the allocation fails, we have relinquish our use of the
|
||||||
// prefix cache blocks. Maybe we can do this using `Drop`?
|
// prefix cache blocks. Maybe we can do this using `Drop`?
|
||||||
for prefix_hash in &prefix_hashes[..n_from_cache] {
|
for prefix_hash in cache_prefixes {
|
||||||
self.decref_prefix(*prefix_hash);
|
self.decref_prefix(prefix_hash);
|
||||||
}
|
}
|
||||||
|
|
||||||
return None;
|
return None;
|
||||||
@ -287,10 +300,19 @@ impl PrefixCacheAllocator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let allocation = Allocation {
|
||||||
|
cache_prefixes,
|
||||||
|
new_prefixes,
|
||||||
|
};
|
||||||
|
|
||||||
|
let allocation_id = self.time;
|
||||||
|
self.time += 1;
|
||||||
|
self.allocations.insert(allocation_id, allocation);
|
||||||
|
|
||||||
Some(BlockAllocationWithCache {
|
Some(BlockAllocationWithCache {
|
||||||
blocks: prefix_cache_blocks,
|
blocks: prefix_cache_blocks,
|
||||||
slots,
|
slots,
|
||||||
prefix_hashes,
|
allocation: allocation_id,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -347,15 +369,40 @@ impl PrefixCacheAllocator {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn free(&mut self, blocks: &[u32], prefix_hashes: &[u64]) {
|
pub fn free(&mut self, blocks: &[u32], allocation: u64) {
|
||||||
|
let allocation = match self.allocations.remove(&allocation) {
|
||||||
|
Some(allocation) => allocation,
|
||||||
|
None => unreachable!("Tried to free an unknown allocation."),
|
||||||
|
};
|
||||||
|
|
||||||
let mut predecessor = None;
|
let mut predecessor = None;
|
||||||
|
|
||||||
for (&prefix_hash, &block_id) in prefix_hashes.iter().zip(blocks.iter()) {
|
// Bookkeeping for prefill blocks that were retrieved from the cache.
|
||||||
|
for &prefix_hash in &allocation.cache_prefixes {
|
||||||
|
let state = match self.cache_blocks.get_mut(&prefix_hash) {
|
||||||
|
Some(state) => state,
|
||||||
|
None => unreachable!("Tried to free an unknown prefix block."),
|
||||||
|
};
|
||||||
|
|
||||||
|
state.last_accessed = self.time;
|
||||||
|
self.decref_prefix(prefix_hash);
|
||||||
|
predecessor = Some(prefix_hash);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bookkeeping for prefill blocks that were new.
|
||||||
|
for (&block_id, &prefix_hash) in blocks[allocation.cache_prefixes.len()..]
|
||||||
|
.iter()
|
||||||
|
.zip(&allocation.new_prefixes)
|
||||||
|
{
|
||||||
|
// We can't simply cache the block. There may have been a concurrent
|
||||||
|
// allocation for the same prefix.
|
||||||
match self.cache_blocks.entry(prefix_hash) {
|
match self.cache_blocks.entry(prefix_hash) {
|
||||||
Entry::Occupied(mut entry) => {
|
Entry::Occupied(mut entry) => {
|
||||||
let value = entry.get_mut();
|
// A concurrent request added the prefix to the cache. We'll
|
||||||
value.last_accessed = self.time;
|
// only update the last accessed time.
|
||||||
self.decref_prefix(prefix_hash);
|
let state = entry.get_mut();
|
||||||
|
// TODO: also update the time in the LRU set.
|
||||||
|
state.last_accessed = self.time;
|
||||||
}
|
}
|
||||||
Entry::Vacant(entry) => {
|
Entry::Vacant(entry) => {
|
||||||
entry.insert(PrefixBlockState {
|
entry.insert(PrefixBlockState {
|
||||||
@ -364,19 +411,17 @@ impl PrefixCacheAllocator {
|
|||||||
predecessor,
|
predecessor,
|
||||||
ref_count: 0,
|
ref_count: 0,
|
||||||
});
|
});
|
||||||
|
|
||||||
if let Some(predecessor) = predecessor {
|
if let Some(predecessor) = predecessor {
|
||||||
self.incref_prefix(predecessor);
|
self.incref_prefix(predecessor);
|
||||||
}
|
}
|
||||||
self.leaves.insert((self.time, prefix_hash));
|
self.leaves.insert((self.time, prefix_hash));
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
|
||||||
predecessor = Some(prefix_hash);
|
predecessor = Some(prefix_hash);
|
||||||
}
|
}
|
||||||
|
|
||||||
self.time += 1;
|
for block in &blocks[allocation.cache_prefixes.len() + allocation.new_prefixes.len()..] {
|
||||||
|
|
||||||
for block in &blocks[prefix_hashes.len()..] {
|
|
||||||
self.free_blocks.push(*block);
|
self.free_blocks.push(*block);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -394,7 +439,7 @@ mod tests {
|
|||||||
let allocation = cache.alloc(8, &[0, 1, 2, 3]).unwrap();
|
let allocation = cache.alloc(8, &[0, 1, 2, 3]).unwrap();
|
||||||
assert_eq!(allocation.blocks, vec![1, 2]);
|
assert_eq!(allocation.blocks, vec![1, 2]);
|
||||||
assert_eq!(allocation.slots, (4..12).collect::<Vec<_>>());
|
assert_eq!(allocation.slots, (4..12).collect::<Vec<_>>());
|
||||||
cache.free(&allocation.blocks, &allocation.prefix_hashes);
|
cache.free(&allocation.blocks, allocation.allocation);
|
||||||
|
|
||||||
let allocation = cache.alloc(8, &[0, 1, 2, 3]).unwrap();
|
let allocation = cache.alloc(8, &[0, 1, 2, 3]).unwrap();
|
||||||
assert_eq!(allocation.blocks, vec![1, 2]);
|
assert_eq!(allocation.blocks, vec![1, 2]);
|
||||||
@ -412,8 +457,8 @@ mod tests {
|
|||||||
assert_eq!(allocation2.blocks, vec![1]);
|
assert_eq!(allocation2.blocks, vec![1]);
|
||||||
assert_eq!(allocation2.slots, (2..4).collect::<Vec<_>>());
|
assert_eq!(allocation2.slots, (2..4).collect::<Vec<_>>());
|
||||||
|
|
||||||
cache.free(&allocation1.blocks, &allocation1.prefix_hashes);
|
cache.free(&allocation1.blocks, allocation1.allocation);
|
||||||
cache.free(&allocation2.blocks, &allocation2.prefix_hashes);
|
cache.free(&allocation2.blocks, allocation2.allocation);
|
||||||
|
|
||||||
// We should get the blocks of the first allocation, since they are more recent.
|
// We should get the blocks of the first allocation, since they are more recent.
|
||||||
let allocation3 = cache.alloc(4, &[6, 7, 8, 9]).unwrap();
|
let allocation3 = cache.alloc(4, &[6, 7, 8, 9]).unwrap();
|
||||||
@ -427,7 +472,7 @@ mod tests {
|
|||||||
let allocation1 = cache.alloc(4, &[0, 1, 2, 3]).unwrap();
|
let allocation1 = cache.alloc(4, &[0, 1, 2, 3]).unwrap();
|
||||||
let allocation2 = cache.alloc(4, &[0, 1, 2, 3]).unwrap();
|
let allocation2 = cache.alloc(4, &[0, 1, 2, 3]).unwrap();
|
||||||
|
|
||||||
cache.free(&allocation2.blocks, &allocation2.prefix_hashes);
|
cache.free(&allocation2.blocks, allocation2.allocation);
|
||||||
cache.free(&allocation1.blocks, &allocation1.prefix_hashes);
|
cache.free(&allocation1.blocks, allocation1.allocation);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user