mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
Fix two edge cases in RadixTrie::find
(#3067)
- Always return a node, not its parent. - Do not recurse when a node does not represent a full prefix of the input.
This commit is contained in:
parent
a914a21899
commit
fa4e9511f8
@ -283,7 +283,7 @@ impl RadixTrie {
|
||||
}
|
||||
|
||||
/// Find worker.
|
||||
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
||||
fn find_(&mut self, node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
||||
let node = &self.nodes[node_id];
|
||||
|
||||
if key.len() >= self.block_size {
|
||||
@ -295,9 +295,13 @@ impl RadixTrie {
|
||||
assert_eq!(shared_prefix_len % self.block_size, 0);
|
||||
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
|
||||
|
||||
// A node represents the prefix of its children. So, only
|
||||
// recurse when there is a full prefix match.
|
||||
let key = &key[shared_prefix_len..];
|
||||
if !key.is_empty() {
|
||||
node_id = self.find_(child_id, key, blocks);
|
||||
if !key.is_empty() && shared_prefix_len == child.key.len() {
|
||||
return self.find_(child_id, key, blocks);
|
||||
} else {
|
||||
return child_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -873,4 +877,27 @@ mod tests {
|
||||
// Clear out the whole trie.
|
||||
assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn full_match_returns_correct_node() {
|
||||
let mut trie = RadixTrie::new(1);
|
||||
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
|
||||
let node_id = trie.find(&[0, 1, 2], &mut vec![]);
|
||||
// At this point, there are only two nodes: the root and the node
|
||||
// with tokens 0, 1, 2. Looking up the exact prefix must return
|
||||
// the non-root node.
|
||||
assert_ne!(node_id, trie.root);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn partial_match_does_not_recurse() {
|
||||
let mut trie = RadixTrie::new(1);
|
||||
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
|
||||
trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2, 3, 4, 5])
|
||||
.unwrap();
|
||||
let mut blocks = Vec::new();
|
||||
let node_id = trie.find(&[0, 1, 3, 4, 5], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1]);
|
||||
assert_eq!(node_id, trie.find(&[0, 1], &mut blocks))
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user