From 785c6e4893508d3f21eeff647d032fa78d41e182 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 6 Sep 2024 17:31:32 +0200 Subject: [PATCH] Fixed the radix tree. Used a slice everywhere in radix.rs to keep the cheap Arc cloning instead of recomputing the input_ids. --- backends/v3/src/radix.rs | 62 ++++++++++++++++++++++++++++------------ 1 file changed, 44 insertions(+), 18 deletions(-) diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index 40b6a399..189b8082 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -71,6 +71,8 @@ impl Allocator for RadixAllocator { let mut blocks = vec![]; let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() { let node_id = self.cache_blocks.find( + // XXX This is super important we cannot match an entire prefix + // otherwise input_ids is empty and the shard code cannot handle that. &prefill_tokens.as_slice()[..prefill_tokens.len().saturating_sub(1)], &mut blocks, ); @@ -150,7 +152,9 @@ impl Allocator for RadixAllocator { .expect("Failed to decrement refcount"); if let Some(prefill_tokens) = allocation.prefill_tokens { - let prefill_tokens = prefill_tokens.as_slice(); + // XXX We matched everything except the last token + let prefill_tokens = + &prefill_tokens.as_slice()[..prefill_tokens.len().saturating_sub(1)]; // If there are prefill tokens that did not come from the cache, // add them to the cache. @@ -612,13 +616,17 @@ mod tests { #[test] fn allocator_block_size() { let mut cache = RadixAllocator::new(2, 12, None); - let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + let allocation = cache + .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 99]))) + .unwrap(); assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]); assert_eq!(allocation.prefix_len, 0); cache.free(allocation.blocks.clone(), allocation.allocation_id); - let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + let allocation = cache + .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 99]))) + .unwrap(); assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]); assert_eq!(allocation.prefix_len, 4); @@ -627,13 +635,17 @@ mod tests { #[test] fn allocator_block_size_non_aligned() { let mut cache = RadixAllocator::new(2, 12, None); - let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap(); + let allocation = cache + .allocate(7, Some(Arc::new(vec![0, 1, 2, 99]))) + .unwrap(); assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]); assert_eq!(allocation.prefix_len, 0); cache.free(allocation.blocks.clone(), allocation.allocation_id); - let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap(); + let allocation = cache + .allocate(7, Some(Arc::new(vec![0, 1, 2, 99]))) + .unwrap(); assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]); assert_eq!(allocation.prefix_len, 2); @@ -642,13 +654,17 @@ mod tests { #[test] fn allocator_reuses_prefixes() { let mut cache = RadixAllocator::new(1, 12, None); - let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + let allocation = cache + .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 99]))) + .unwrap(); assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); assert_eq!(allocation.blocks, allocation.slots); assert_eq!(allocation.prefix_len, 0); cache.free(allocation.blocks.clone(), allocation.allocation_id); - let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + let allocation = cache + .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 99]))) + .unwrap(); assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); assert_eq!(allocation.prefix_len, 4); } @@ -656,11 +672,13 @@ mod tests { #[test] fn allocator_collects_older_prefixes_first() { let mut cache = RadixAllocator::new(1, 7, None); - let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + let allocation1 = cache + .allocate(4, Some(Arc::new(vec![0, 1, 2, 3, 99]))) + .unwrap(); assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]); assert_eq!(allocation1.prefix_len, 0); - let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap(); + let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5, 99]))).unwrap(); assert_eq!(allocation2.blocks, vec![1, 2]); assert_eq!(allocation2.prefix_len, 0); @@ -668,7 +686,9 @@ mod tests { cache.free(allocation2.blocks.clone(), allocation2.allocation_id); // We should get the blocks of the first allocation, since they are more recent. - let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap(); + let allocation3 = cache + .allocate(4, Some(Arc::new(vec![6, 7, 8, 9, 99]))) + .unwrap(); assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]); assert_eq!(allocation3.prefix_len, 0); } @@ -676,13 +696,19 @@ mod tests { #[test] fn allocator_frees_fully_overlapping_prefills() { let mut cache = RadixAllocator::new(1, 10, None); - let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); - let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + let allocation1 = cache + .allocate(4, Some(Arc::new(vec![0, 1, 2, 3, 99]))) + .unwrap(); + let allocation2 = cache + .allocate(4, Some(Arc::new(vec![0, 1, 2, 3, 99]))) + .unwrap(); cache.free(allocation2.blocks.clone(), allocation2.allocation_id); cache.free(allocation1.blocks.clone(), allocation1.allocation_id); - let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + let allocation3 = cache + .allocate(4, Some(Arc::new(vec![0, 1, 2, 3, 99]))) + .unwrap(); assert_eq!(allocation3.prefix_len, 4); // 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks. @@ -692,20 +718,20 @@ mod tests { #[test] fn allocator_frees_partially_overlapping_prefills() { let mut cache = RadixAllocator::new(1, 20, None); - let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap(); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 99]))).unwrap(); assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]); assert_eq!(allocation1.prefix_len, 0); cache.free(allocation1.blocks.clone(), allocation1.allocation_id); let allocation2 = cache - .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) + .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5, 99]))) .unwrap(); assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]); assert_eq!(allocation2.prefix_len, 2); let allocation3 = cache - .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) + .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7, 99]))) .unwrap(); assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]); assert_eq!(allocation3.prefix_len, 2); @@ -717,14 +743,14 @@ mod tests { assert_eq!(cache.free_blocks.len(), 11); let allocation4 = cache - .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) + .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5, 99]))) .unwrap(); assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]); assert_eq!(allocation4.prefix_len, 6); assert_eq!(cache.free_blocks.len(), 11); let allocation5 = cache - .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) + .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7, 99]))) .unwrap(); assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]); assert_eq!(allocation5.prefix_len, 6);