mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Update parent refcounts when inserting a child
This commit is contained in:
parent
aa1c96a7a4
commit
ce76f4ccc3
@ -1,4 +1,4 @@
|
|||||||
use std::collections::{hash_map::Entry, HashMap};
|
use std::collections::{BTreeSet, HashMap};
|
||||||
|
|
||||||
use slotmap::{DefaultKey, SlotMap};
|
use slotmap::{DefaultKey, SlotMap};
|
||||||
|
|
||||||
@ -18,6 +18,7 @@ type NodeId = DefaultKey;
|
|||||||
|
|
||||||
pub struct RadixTrie {
|
pub struct RadixTrie {
|
||||||
root: DefaultKey,
|
root: DefaultKey,
|
||||||
|
leaves: BTreeSet<(u64, NodeId)>,
|
||||||
nodes: SlotMap<NodeId, TrieNode>,
|
nodes: SlotMap<NodeId, TrieNode>,
|
||||||
time: u64,
|
time: u64,
|
||||||
}
|
}
|
||||||
@ -28,18 +29,19 @@ impl RadixTrie {
|
|||||||
let mut nodes = SlotMap::new();
|
let mut nodes = SlotMap::new();
|
||||||
let root = nodes.insert(root);
|
let root = nodes.insert(root);
|
||||||
RadixTrie {
|
RadixTrie {
|
||||||
|
leaves: BTreeSet::new(),
|
||||||
nodes,
|
nodes,
|
||||||
root,
|
root,
|
||||||
time: 0,
|
time: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn find(&mut self, key: &[u32], blocks: &mut Vec<u32>) {
|
pub fn find(&mut self, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
||||||
self.time += 1;
|
self.time += 1;
|
||||||
self.find_(self.root, key, blocks);
|
self.find_(self.root, key, blocks)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn find_(&mut self, node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) {
|
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
||||||
let node = &self.nodes[node_id];
|
let node = &self.nodes[node_id];
|
||||||
|
|
||||||
if let Some(&child_id) = node.children.get(&key[0]) {
|
if let Some(&child_id) = node.children.get(&key[0]) {
|
||||||
@ -50,9 +52,28 @@ impl RadixTrie {
|
|||||||
|
|
||||||
let key = &key[shared_prefix_len..];
|
let key = &key[shared_prefix_len..];
|
||||||
if !key.is_empty() {
|
if !key.is_empty() {
|
||||||
self.find_(child_id, key, blocks);
|
node_id = self.find_(child_id, key, blocks);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
node_id
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decref(&mut self, node_id: NodeId) {
|
||||||
|
let node = self.nodes.get_mut(node_id).unwrap();
|
||||||
|
assert!(node.ref_count > 0);
|
||||||
|
node.ref_count -= 1;
|
||||||
|
if node.ref_count == 0 {
|
||||||
|
self.leaves.insert((node.last_accessed, node_id));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn incref(&mut self, node_id: NodeId) {
|
||||||
|
let node = self.nodes.get_mut(node_id).unwrap();
|
||||||
|
if node.ref_count == 0 {
|
||||||
|
self.leaves.remove(&(node.last_accessed, node_id));
|
||||||
|
}
|
||||||
|
node.ref_count += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn insert(&mut self, key: &[u32], blocks: &[u32]) -> usize {
|
pub fn insert(&mut self, key: &[u32], blocks: &[u32]) -> usize {
|
||||||
@ -91,30 +112,33 @@ impl RadixTrie {
|
|||||||
let blocks = &blocks[shared_prefix_len..];
|
let blocks = &blocks[shared_prefix_len..];
|
||||||
self.insert_(child_id, key, blocks)
|
self.insert_(child_id, key, blocks)
|
||||||
} else {
|
} else {
|
||||||
let child = TrieNode::new(key.to_vec(), blocks.to_vec(), self.time, Some(node_id));
|
self.add_child(node_id, key, blocks);
|
||||||
let child_id = self.nodes.insert(child);
|
|
||||||
let node = self.nodes.get_mut(node_id).unwrap();
|
|
||||||
node.children.insert(key[0], child_id);
|
|
||||||
key.len()
|
key.len()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn split(&mut self, node_id: NodeId, prefix_len: usize) {
|
fn split(&mut self, node_id: NodeId, prefix_len: usize) {
|
||||||
let node = self.nodes.get_mut(node_id).unwrap();
|
let node = self.nodes.get_mut(node_id).unwrap();
|
||||||
|
|
||||||
let rest_key = node.key.split_off(prefix_len);
|
let rest_key = node.key.split_off(prefix_len);
|
||||||
let rest_blocks = node.blocks.split_off(prefix_len);
|
let rest_blocks = node.blocks.split_off(prefix_len);
|
||||||
let first = rest_key[0];
|
self.add_child(node_id, rest_key, rest_blocks);
|
||||||
|
}
|
||||||
|
|
||||||
let new_id = self.nodes.insert(TrieNode::new(
|
fn add_child(
|
||||||
rest_key,
|
&mut self,
|
||||||
rest_blocks,
|
parent_id: NodeId,
|
||||||
self.time,
|
key: impl Into<Vec<u32>>,
|
||||||
Some(node_id),
|
blocks: impl Into<Vec<u32>>,
|
||||||
));
|
) {
|
||||||
|
let key = key.into();
|
||||||
|
let blocks = blocks.into();
|
||||||
|
let first = key[0];
|
||||||
|
|
||||||
let node = self.nodes.get_mut(node_id).unwrap();
|
let child = TrieNode::new(key, blocks, self.time, Some(parent_id));
|
||||||
node.children.insert(first, new_id);
|
let child_id = self.nodes.insert(child);
|
||||||
|
let node = self.nodes.get_mut(parent_id).unwrap();
|
||||||
|
node.children.insert(first, child_id);
|
||||||
|
self.incref(parent_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -124,6 +148,7 @@ struct TrieNode {
|
|||||||
key: Vec<u32>,
|
key: Vec<u32>,
|
||||||
last_accessed: u64,
|
last_accessed: u64,
|
||||||
parent: Option<NodeId>,
|
parent: Option<NodeId>,
|
||||||
|
ref_count: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TrieNode {
|
impl TrieNode {
|
||||||
@ -134,6 +159,7 @@ impl TrieNode {
|
|||||||
blocks,
|
blocks,
|
||||||
last_accessed,
|
last_accessed,
|
||||||
parent,
|
parent,
|
||||||
|
ref_count: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user