Fix two edge cases in RadixTrie::find

- Always return a node, not its parent.
- Do not recurse when a node does not represent a full prefix of the
  input.
This commit is contained in:
Daniël de Kok 2025-03-03 12:59:26 +00:00
parent 683ff53fa3
commit a18c13f75d

View File

@ -283,7 +283,7 @@ impl RadixTrie {
} }
/// Find worker. /// Find worker.
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId { fn find_(&mut self, node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
let node = &self.nodes[node_id]; let node = &self.nodes[node_id];
if key.len() >= self.block_size { if key.len() >= self.block_size {
@ -295,9 +295,13 @@ impl RadixTrie {
assert_eq!(shared_prefix_len % self.block_size, 0); assert_eq!(shared_prefix_len % self.block_size, 0);
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]); blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
// A node represents the prefix of its children. So, only
// recurse when there is a full prefix match.
let key = &key[shared_prefix_len..]; let key = &key[shared_prefix_len..];
if !key.is_empty() { if !key.is_empty() && shared_prefix_len == child.key.len() {
node_id = self.find_(child_id, key, blocks); return self.find_(child_id, key, blocks);
} else {
return child_id;
} }
} }
} }
@ -873,4 +877,27 @@ mod tests {
// Clear out the whole trie. // Clear out the whole trie.
assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]); assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]);
} }
#[test]
fn full_match_returns_correct_node() {
let mut trie = RadixTrie::new(1);
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
let node_id = trie.find(&[0, 1, 2], &mut vec![]);
// At this point, there are only two nodes: the root and the node
// with tokens 0, 1, 2. Looking up the exact prefix must return
// the non-root node.
assert_ne!(node_id, trie.root);
}
#[test]
fn partial_match_does_not_recurse() {
let mut trie = RadixTrie::new(1);
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2, 3, 4, 5])
.unwrap();
let mut blocks = Vec::new();
let node_id = trie.find(&[0, 1, 3, 4, 5], &mut blocks);
assert_eq!(blocks, vec![0, 1]);
assert_eq!(node_id, trie.find(&[0, 1], &mut blocks))
}
} }