From f5182c188cdc0a7c38f057eacb79cc8fbc382985 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 26 Aug 2024 17:43:27 +0200 Subject: [PATCH] Is this enough to make it work ? --- backends/v3/src/queue.rs | 6 +- backends/v3/src/radix.rs | 130 +++++++++++++++++++++++++++------------ 2 files changed, 95 insertions(+), 41 deletions(-) diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 4958b2d4..f066318a 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -291,7 +291,11 @@ impl State { None } Some(block_allocator) => { - prefill_tokens += entry.request.input_length; + if entry.request.input_length <= prefill_token_budget { + prefill_tokens += entry.request.input_length; + } else { + prefill_tokens = prefill_token_budget; + } let max_new_tokens = match self.window_size { None => entry.request.stopping_parameters.max_new_tokens, Some(window_size) => min( diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index 2606376b..be24b67b 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -19,15 +19,17 @@ pub struct RadixAllocator { // This isn't used because the prefix need to match without the windowing // mecanism. This at worst is overallocating, not necessarily being wrong. window_size: Option, + + block_size: u32, } impl RadixAllocator { pub fn new(block_size: u32, n_blocks: u32, window_size: Option) -> Self { - assert_eq!( - block_size, 1, - "Radix tree allocator only works with block_size=1, was: {}", - block_size - ); + // assert_eq!( + // block_size, 1, + // "Radix tree allocator only works with block_size=1, was: {}", + // block_size + // ); // if window_size.is_some() { // unimplemented!("Window size not supported in the prefix-caching block allocator yet"); // } @@ -35,11 +37,12 @@ impl RadixAllocator { RadixAllocator { allocation_id: 0, allocations: HashMap::new(), - cache_blocks: RadixTrie::new(), + cache_blocks: RadixTrie::new(block_size as usize), // Block 0 is reserved for health checks. free_blocks: (1..n_blocks).collect(), window_size, + block_size, } } @@ -91,10 +94,10 @@ impl Allocator for RadixAllocator { .incref(prefix_node) .expect("Failed to increment refcount"); - let prefix_len = blocks.len(); + let prefix_len = blocks.len() * self.block_size as usize; let suffix_len = tokens - prefix_len as u32; - let suffix_blocks = suffix_len; + let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size; match self.alloc_or_reclaim(suffix_blocks as usize) { Some(suffix_blocks) => blocks.extend(suffix_blocks), @@ -107,7 +110,20 @@ impl Allocator for RadixAllocator { } // 1:1 mapping of blocks and slots. - let slots = blocks.clone(); + let slots = if self.block_size == 1 { + blocks.clone() + } else { + let mut slots = Vec::with_capacity(blocks.len() * self.block_size as usize); + 'slots: for block_id in &blocks { + for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { + slots.push(s); + if slots.len() as u32 == tokens { + break 'slots; + } + } + } + slots + }; let allocation = RadixAllocation { prefix_node, @@ -142,12 +158,16 @@ impl Allocator for RadixAllocator { if let Some(prefill_tokens) = allocation.prefill_tokens { let prefill_tokens = prefill_tokens.as_slice(); + assert_eq!(prefill_tokens.len() % self.block_size as usize, 0); // If there are prefill tokens that did not come from the cache, // add them to the cache. if prefill_tokens.len() > allocation.cached_prefix_len { let prefix_len = self .cache_blocks - .insert(prefill_tokens, &blocks[..prefill_tokens.len()]) + .insert( + prefill_tokens, + &blocks[..prefill_tokens.len() / self.block_size as usize], + ) // Unwrap, failing is a programming error. .expect("Failed to store prefill tokens"); @@ -213,17 +233,14 @@ pub struct RadixTrie { /// Time as a monotonically increating counter to avoid the system /// call that a real time lookup would require. time: u64, -} -impl Default for RadixTrie { - fn default() -> Self { - Self::new() - } + /// All blocks need to be aligned with this + block_size: usize, } impl RadixTrie { /// Construct a new radix trie. - pub fn new() -> Self { + pub fn new(block_size: usize) -> Self { let root = TrieNode::new(vec![], vec![], 0, None); let mut nodes = SlotMap::new(); let root = nodes.insert(root); @@ -232,13 +249,14 @@ impl RadixTrie { nodes, root, time: 0, + block_size, } } /// Find the prefix of the given tokens. /// /// The blocks corresponding to the part of the prefix that could be found - /// are writteng to `blocks`. The number of blocks is in `0..=tokens.len()`. + /// are written to `blocks`. The number of blocks is in `0..=tokens.len()`. /// Returns the identifier of the trie node that contains the longest /// prefix. The node identifier can be used by callers to e.g. increase its /// reference count. @@ -256,8 +274,9 @@ impl RadixTrie { if let Some(&child_id) = node.children.get(&key[0]) { self.update_access_time(child_id); let child = self.nodes.get(child_id).expect("Invalid child identifier"); - let shared_prefix_len = child.key.shared_prefix_len(key); - blocks.extend(&child.blocks[..shared_prefix_len]); + let shared_prefix_len = shared_prefix(&child.key, key, self.block_size); + assert_eq!(shared_prefix_len % self.block_size, 0); + blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]); let key = &key[shared_prefix_len..]; if !key.is_empty() { @@ -358,7 +377,8 @@ impl RadixTrie { /// the first 10 elements of the tree **the blocks are not updated**. pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result { self.time += 1; - self.insert_(self.root, tokens, blocks) + let common = self.insert_(self.root, tokens, blocks)?; + Ok(common) } /// Insertion worker. @@ -372,7 +392,7 @@ impl RadixTrie { // the part of the prefix that is already in the trie to detect // mismatches. - if tokens.len() != blocks.len() { + if tokens.len() != blocks.len() * self.block_size { return Err(TrieError::BlockTokenCountMismatch); } @@ -383,10 +403,10 @@ impl RadixTrie { .get_mut(child_id) // Unwrap here, since failure is a bug. .expect("Child node does not exist"); - let shared_prefix_len = child.key.shared_prefix_len(tokens); + let shared_prefix_len = shared_prefix(&child.key, tokens, self.block_size); // We are done, the prefix is already in the trie. - if shared_prefix_len == tokens.len() { + if shared_prefix_len == tokens.len() || shared_prefix_len == 0 { return Ok(shared_prefix_len); } @@ -396,7 +416,7 @@ impl RadixTrie { + self.insert_( child_id, &tokens[shared_prefix_len..], - &blocks[shared_prefix_len..], + &blocks[shared_prefix_len / self.block_size..], )?); } @@ -405,7 +425,7 @@ impl RadixTrie { // remainder of the prefix into the node again let child_id = self.split_node(child_id, shared_prefix_len); let key = &tokens[shared_prefix_len..]; - let blocks = &blocks[shared_prefix_len..]; + let blocks = &blocks[shared_prefix_len / self.block_size..]; Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?) } else { self.add_node(node_id, tokens, blocks); @@ -559,18 +579,9 @@ impl TrieNode { } } -/// Helper trait to get the length of the shared prefix of two sequences. -trait SharedPrefixLen { - fn shared_prefix_len(&self, other: &Self) -> usize; -} - -impl SharedPrefixLen for [T] -where - T: PartialEq, -{ - fn shared_prefix_len(&self, other: &Self) -> usize { - self.iter().zip(other).take_while(|(a, b)| a == b).count() - } +fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize { + let full = left.iter().zip(right).take_while(|(a, b)| a == b).count(); + (full / block_size) * block_size } #[cfg(test)] @@ -579,6 +590,21 @@ mod tests { use super::*; + #[test] + fn allocator_block_size() { + let mut cache = RadixAllocator::new(2, 12, None); + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 6, 7]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 12, 13, 14, 15]); + assert_eq!(allocation.prefix_len, 4); + } + #[test] fn allocator_reuses_prefixes() { let mut cache = RadixAllocator::new(1, 12, None); @@ -673,7 +699,7 @@ mod tests { #[test] fn trie_insertions_have_correct_prefix_len() { - let mut trie = super::RadixTrie::new(); + let mut trie = RadixTrie::new(1); assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0); @@ -694,9 +720,33 @@ mod tests { ); } + #[test] + fn trie_insertions_block_size() { + let mut trie = RadixTrie::new(2); + + assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 0); + + // Already exists. + // But needs to be block_size aligned + assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 4); + + // Completely new at root-level + assert_eq!(trie.insert(&[1, 2, 3, 4], &[1, 2]).unwrap(), 0); + + // Contains full prefix, but longer. + assert_eq!(trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2]).unwrap(), 4); + + // Shares partial prefix, we need a split. + assert_eq!( + trie.insert(&[0, 1, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3]) + .unwrap(), + 2 + ); + } + #[test] fn trie_get_returns_correct_blocks() { - let mut trie = super::RadixTrie::new(); + let mut trie = RadixTrie::new(1); trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(); trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); @@ -730,7 +780,7 @@ mod tests { #[test] fn trie_evict_removes_correct_blocks() { - let mut trie = super::RadixTrie::new(); + let mut trie = RadixTrie::new(1); trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) .unwrap();