mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Fix splitting
This commit is contained in:
parent
5da696046e
commit
07ede8d8e5
@ -148,8 +148,8 @@ impl RadixTrie {
|
|||||||
|
|
||||||
// The node's prefix and the insertion prefix only match partially,
|
// The node's prefix and the insertion prefix only match partially,
|
||||||
// split the node to just contain the matching part. Then insert the
|
// split the node to just contain the matching part. Then insert the
|
||||||
// remainder of the prefix into the node again.
|
// remainder of the prefix into the node again
|
||||||
self.split(child_id, shared_prefix_len);
|
let child_id = self.split_node(child_id, shared_prefix_len);
|
||||||
let key = &key[shared_prefix_len..];
|
let key = &key[shared_prefix_len..];
|
||||||
let blocks = &blocks[shared_prefix_len..];
|
let blocks = &blocks[shared_prefix_len..];
|
||||||
self.insert_(child_id, key, blocks)
|
self.insert_(child_id, key, blocks)
|
||||||
@ -159,11 +159,28 @@ impl RadixTrie {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn split(&mut self, node_id: NodeId, prefix_len: usize) {
|
fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId {
|
||||||
|
// We have to make the current node a child to ensure that its
|
||||||
|
// properties and node id stay the same.
|
||||||
|
|
||||||
let node = self.nodes.get_mut(node_id).unwrap();
|
let node = self.nodes.get_mut(node_id).unwrap();
|
||||||
let rest_key = node.key.split_off(prefix_len);
|
let mut parent_key = node.key.split_off(prefix_len);
|
||||||
let rest_blocks = node.blocks.split_off(prefix_len);
|
let mut parent_blocks = node.blocks.split_off(prefix_len);
|
||||||
self.add_node(node_id, rest_key, rest_blocks);
|
|
||||||
|
// Move first part of the prefix to the parent.
|
||||||
|
std::mem::swap(&mut node.key, &mut parent_key);
|
||||||
|
std::mem::swap(&mut node.blocks, &mut parent_blocks);
|
||||||
|
|
||||||
|
let node_key = node.key[0];
|
||||||
|
|
||||||
|
let grandparent_id = node.parent.unwrap();
|
||||||
|
let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks);
|
||||||
|
self.add_node_to_parent(parent_id, node_key, node_id);
|
||||||
|
|
||||||
|
let node = self.nodes.get_mut(node_id).unwrap();
|
||||||
|
node.parent = Some(parent_id);
|
||||||
|
|
||||||
|
parent_id
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_node(
|
fn add_node(
|
||||||
@ -171,17 +188,26 @@ impl RadixTrie {
|
|||||||
parent_id: NodeId,
|
parent_id: NodeId,
|
||||||
key: impl Into<Vec<u32>>,
|
key: impl Into<Vec<u32>>,
|
||||||
blocks: impl Into<Vec<u32>>,
|
blocks: impl Into<Vec<u32>>,
|
||||||
) {
|
) -> NodeId {
|
||||||
let key = key.into();
|
let key = key.into();
|
||||||
let blocks = blocks.into();
|
let blocks = blocks.into();
|
||||||
let first = key[0];
|
let first = key[0];
|
||||||
|
|
||||||
let child = TrieNode::new(key, blocks, self.time, Some(parent_id));
|
let child = TrieNode::new(key, blocks, self.time, Some(parent_id));
|
||||||
let child_id = self.nodes.insert(child);
|
let child_id = self.nodes.insert(child);
|
||||||
let node = self.nodes.get_mut(parent_id).unwrap();
|
|
||||||
node.children.insert(first, child_id);
|
self.add_node_to_parent(parent_id, first, child_id);
|
||||||
self.incref(parent_id);
|
|
||||||
self.leaves.insert((self.time, child_id));
|
self.leaves.insert((self.time, child_id));
|
||||||
|
|
||||||
|
child_id
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_node_to_parent(&mut self, parent_id: NodeId, first: u32, child_id: NodeId) {
|
||||||
|
let parent = self.nodes.get_mut(parent_id).unwrap();
|
||||||
|
if !parent.children.insert(first, child_id).is_some() {
|
||||||
|
// Only increase reference count if child does not replace another child.
|
||||||
|
self.incref(parent_id);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn remove_node(&mut self, node_id: NodeId) -> TrieNode {
|
fn remove_node(&mut self, node_id: NodeId) -> TrieNode {
|
||||||
@ -212,13 +238,15 @@ impl RadixTrie {
|
|||||||
fn print_debug_(&self, node_id: NodeId, indent: usize) {
|
fn print_debug_(&self, node_id: NodeId, indent: usize) {
|
||||||
let node = &self.nodes[node_id];
|
let node = &self.nodes[node_id];
|
||||||
eprintln!(
|
eprintln!(
|
||||||
"{}{:?}, key: {:?}, blocks: {:?}, ref_count: {}, last_accessed: {}",
|
"{}{:?}, key: {:?}, blocks: {:?}, ref_count: {}, last_accessed: {}, parent: {:?}, children: {:?}",
|
||||||
" ".repeat(indent),
|
" ".repeat(indent),
|
||||||
node_id,
|
node_id,
|
||||||
node.key,
|
node.key,
|
||||||
node.blocks,
|
node.blocks,
|
||||||
node.ref_count,
|
node.ref_count,
|
||||||
node.last_accessed
|
node.last_accessed,
|
||||||
|
node.parent,
|
||||||
|
node.children
|
||||||
);
|
);
|
||||||
for child_id in self.nodes[node_id].children.values() {
|
for child_id in self.nodes[node_id].children.values() {
|
||||||
self.print_debug_(*child_id, indent + 2);
|
self.print_debug_(*child_id, indent + 2);
|
||||||
|
Loading…
Reference in New Issue
Block a user