diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index 532ec6dd..81ce61d1 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -283,7 +283,7 @@ impl RadixTrie { } /// Find worker. - fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec) -> NodeId { + fn find_(&mut self, node_id: NodeId, key: &[u32], blocks: &mut Vec) -> NodeId { let node = &self.nodes[node_id]; if key.len() >= self.block_size { @@ -295,9 +295,13 @@ impl RadixTrie { assert_eq!(shared_prefix_len % self.block_size, 0); blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]); + // A node represents the prefix of its children. So, only + // recurse when there is a full prefix match. let key = &key[shared_prefix_len..]; - if !key.is_empty() { - node_id = self.find_(child_id, key, blocks); + if !key.is_empty() && shared_prefix_len == child.key.len() { + return self.find_(child_id, key, blocks); + } else { + return child_id; } } } @@ -873,4 +877,27 @@ mod tests { // Clear out the whole trie. assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]); } + + #[test] + fn full_match_returns_correct_node() { + let mut trie = RadixTrie::new(1); + trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); + let node_id = trie.find(&[0, 1, 2], &mut vec![]); + // At this point, there are only two nodes: the root and the node + // with tokens 0, 1, 2. Looking up the exact prefix must return + // the non-root node. + assert_ne!(node_id, trie.root); + } + + #[test] + fn partial_match_does_not_recurse() { + let mut trie = RadixTrie::new(1); + trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); + trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2, 3, 4, 5]) + .unwrap(); + let mut blocks = Vec::new(); + let node_id = trie.find(&[0, 1, 3, 4, 5], &mut blocks); + assert_eq!(blocks, vec![0, 1]); + assert_eq!(node_id, trie.find(&[0, 1], &mut blocks)) + } }