mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Fix two edge cases in RadixTrie::find
- Always return a node, not its parent. - Do not recurse when a node does not represent a full prefix of the input.
This commit is contained in:
parent
683ff53fa3
commit
a18c13f75d
@ -283,7 +283,7 @@ impl RadixTrie {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Find worker.
|
/// Find worker.
|
||||||
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
fn find_(&mut self, node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
||||||
let node = &self.nodes[node_id];
|
let node = &self.nodes[node_id];
|
||||||
|
|
||||||
if key.len() >= self.block_size {
|
if key.len() >= self.block_size {
|
||||||
@ -295,9 +295,13 @@ impl RadixTrie {
|
|||||||
assert_eq!(shared_prefix_len % self.block_size, 0);
|
assert_eq!(shared_prefix_len % self.block_size, 0);
|
||||||
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
|
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
|
||||||
|
|
||||||
|
// A node represents the prefix of its children. So, only
|
||||||
|
// recurse when there is a full prefix match.
|
||||||
let key = &key[shared_prefix_len..];
|
let key = &key[shared_prefix_len..];
|
||||||
if !key.is_empty() {
|
if !key.is_empty() && shared_prefix_len == child.key.len() {
|
||||||
node_id = self.find_(child_id, key, blocks);
|
return self.find_(child_id, key, blocks);
|
||||||
|
} else {
|
||||||
|
return child_id;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -873,4 +877,27 @@ mod tests {
|
|||||||
// Clear out the whole trie.
|
// Clear out the whole trie.
|
||||||
assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]);
|
assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn full_match_returns_correct_node() {
|
||||||
|
let mut trie = RadixTrie::new(1);
|
||||||
|
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
|
||||||
|
let node_id = trie.find(&[0, 1, 2], &mut vec![]);
|
||||||
|
// At this point, there are only two nodes: the root and the node
|
||||||
|
// with tokens 0, 1, 2. Looking up the exact prefix must return
|
||||||
|
// the non-root node.
|
||||||
|
assert_ne!(node_id, trie.root);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn partial_match_does_not_recurse() {
|
||||||
|
let mut trie = RadixTrie::new(1);
|
||||||
|
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
|
||||||
|
trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2, 3, 4, 5])
|
||||||
|
.unwrap();
|
||||||
|
let mut blocks = Vec::new();
|
||||||
|
let node_id = trie.find(&[0, 1, 3, 4, 5], &mut blocks);
|
||||||
|
assert_eq!(blocks, vec![0, 1]);
|
||||||
|
assert_eq!(node_id, trie.find(&[0, 1], &mut blocks))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user