diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index b0f458c6..01b213d8 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -230,6 +230,8 @@ impl RadixAllocator { allocation_id: 0, allocations: HashMap::new(), cache_blocks: RadixTrie::new(), + + // Block 0 is reserved for health checks. free_blocks: (1..n_blocks).collect(), } } @@ -248,7 +250,10 @@ impl RadixAllocator { } if self.free_blocks.len() >= n_blocks_needed { - Some(self.free_blocks.split_off(n_blocks_needed)) + Some( + self.free_blocks + .split_off(self.free_blocks.len() - n_blocks_needed), + ) } else { None } @@ -316,9 +321,21 @@ impl Allocator for RadixAllocator { // If there are prefill tokens that did not come from the cache, // add them to the cache. if prefill_tokens.len() > allocation.cached_prefix_len { - // TODO: check if the prefill tokens are already in the cache??? - self.cache_blocks + let prefix_len = self + .cache_blocks .insert(prefill_tokens, &blocks[..prefill_tokens.len()]); + + // We can have a prefill with the following structure: + // + // |---| From the prefix cache. + // A B C D E F G + //|--------| Found in the trie during insertion. + // + // This means that while processing this request there was a + // partially overlapping request that had A..=E in its + // prefill. In this case we need to free the blocks D E. + self.free_blocks + .extend(&blocks[allocation.cached_prefix_len..prefix_len]); } // Free non-prefill blocks. @@ -334,3 +351,60 @@ struct RadixAllocation { cached_prefix_len: usize, prefill_tokens: Option>>, } + +#[cfg(test)] +mod tests { + use std::{rc::Rc, sync::Arc}; + + use super::{Allocator, RadixAllocator}; + + #[test] + fn test_prefix_cache() { + let mut cache = RadixAllocator::new(1, 12, None); + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.0, vec![4, 5, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation.1, allocation.0); + assert_eq!(allocation.2, 0); + cache.free(allocation.0, allocation.3); + + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.0, vec![4, 5, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation.2, 4); + } + + #[test] + fn test_older_prefixes_are_collected_first() { + let mut cache = RadixAllocator::new(1, 7, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation1.0, vec![3, 4, 5, 6]); + assert_eq!(allocation1.2, 0); + + let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap(); + assert_eq!(allocation2.0, vec![1, 2]); + assert_eq!(allocation2.2, 0); + + cache.free(allocation1.0, allocation1.3); + cache.free(allocation2.0, allocation2.3); + + // We should get the blocks of the first allocation, since they are more recent. + let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap(); + assert_eq!(allocation3.0, vec![3, 4, 5, 6]); + assert_eq!(allocation3.2, 0); + } + + #[test] + fn correctly_free_when_fully_overlapping_prefills_in_flight() { + let mut cache = RadixAllocator::new(1, 10, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + + cache.free(allocation2.0, allocation2.3); + cache.free(allocation1.0, allocation1.3); + + let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation3.2, 4); + + // 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks. + assert_eq!(cache.free_blocks.len(), 5); + } +} diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index 715ed320..230a2fdf 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -118,6 +118,11 @@ impl RadixTrie { node.ref_count += 1; } + /// Insert a prefill along with its blocks. + /// + /// This method returns the length of the prefix that was already + /// in the trie. E.g. if the length is 10, this means that for + /// the first 10 elements of the tree **the blocks are not updated**. pub fn insert(&mut self, key: &[u32], blocks: &[u32]) -> usize { self.time += 1; self.insert_(self.root, key, blocks) @@ -152,10 +157,10 @@ impl RadixTrie { let child_id = self.split_node(child_id, shared_prefix_len); let key = &key[shared_prefix_len..]; let blocks = &blocks[shared_prefix_len..]; - self.insert_(child_id, key, blocks) + shared_prefix_len + self.insert_(child_id, key, blocks) } else { self.add_node(node_id, key, blocks); - key.len() + 0 } }