Fixing the Trie in case of exact prefix match split.

This commit is contained in:
Nicolas Patry 2025-04-30 17:33:17 +02:00
parent 50c8ebdef0
commit 3aa8564652
No known key found for this signature in database
GPG Key ID: 87B37D879D09DEB4

View File

@ -17,6 +17,7 @@ pub struct Trie {
pub struct Node {
content: Vec<u8>,
nelements: usize,
local_elements: usize,
children: BTreeMap<u8, Node>,
}
@ -44,14 +45,17 @@ impl Node {
Self {
content: vec![],
nelements: 0,
local_elements: 0,
children: BTreeMap::new(),
}
}
fn insert(&mut self, data: &[u8], left: usize) -> (usize, usize) {
fn insert(&mut self, data: &[u8]) -> (usize, usize) {
let (start, stop) = if self.nelements == 0 {
self.content = data.to_vec();
(left, left + 1)
assert_eq!(self.local_elements, 0);
self.local_elements = 1;
(0, self.local_elements)
} else {
let mismatch = mismatch(data, &self.content);
if mismatch == self.content.len() {
@ -62,54 +66,66 @@ impl Node {
.iter()
.take_while(|(&d, _)| d < *c)
.map(|(_, n)| n.nelements)
.sum();
.sum::<usize>()
+ self.local_elements;
let next_node = self.children.entry(*c).or_insert(Node::new());
next_node.insert(&data[mismatch..], left)
let (inner_left, inner_right) = next_node.insert(&data[mismatch..]);
(inner_left + left, inner_right + left)
} else {
(0, self.nelements + 1)
assert_eq!(data.len(), self.content.len());
self.local_elements += 1;
(0, self.local_elements)
};
(left + start, left + stop)
(start, stop)
} else {
// Partial match, split node
let left = self.content[mismatch..].to_vec();
let right = data[mismatch..].to_vec();
let left_content = self.content[mismatch..].to_vec();
let right_content = data[mismatch..].to_vec();
let children = std::mem::take(&mut self.children);
let mut children_content = vec![
(left, children, self.nelements),
(right, BTreeMap::new(), 1),
(left_content, children, self.nelements, self.local_elements),
(right_content, BTreeMap::new(), 1, 1),
];
children_content.sort_by(|a, b| a.0.cmp(&b.0));
self.content.truncate(mismatch);
self.children.clear();
for (child_content, children, nelements) in children_content {
self.local_elements = 0;
for (child_content, children, nelements, local_elements) in children_content {
if !child_content.is_empty() {
let c = child_content[0];
let child = Node {
content: child_content,
nelements,
local_elements,
children,
};
self.children.insert(c, child);
} else {
self.local_elements += 1;
}
}
let c = data[mismatch];
let left: usize = self
.children
.iter()
.take_while(|(&d, _)| d < c)
.map(|(_, n)| n.nelements)
.sum();
(left, left + 1)
let (start, stop) = if let Some(c) = data.get(mismatch) {
let start = self
.children
.iter()
.take_while(|(&d, _)| d < *c)
.map(|(_, n)| n.nelements)
.sum::<usize>()
+ self.local_elements;
(start, start + 1)
} else {
(0, self.local_elements)
};
(start, stop)
}
};
self.nelements += 1;
(start, stop)
}
// TODO
#[allow(dead_code)]
#[cfg(debug_assertions)]
fn remove(&mut self, data: &[u8]) -> Result<(), Error> {
// TODO reclaim the nodes too.
let mismatch = mismatch(data, &self.content);
if mismatch != self.content.len() {
Err(Error::MissingEntry)
@ -118,6 +134,10 @@ impl Node {
if let Some(node) = self.children.get_mut(c) {
node.remove(&data[mismatch..])?;
}
} else if self.local_elements == 0 {
return Err(Error::MissingEntry);
} else {
self.local_elements -= 1;
}
self.nelements -= 1;
Ok(())
@ -132,7 +152,7 @@ impl Trie {
}
pub fn insert(&mut self, data: &[u8]) -> (usize, usize) {
self.root.insert(data, 0)
self.root.insert(data)
}
// TODO
@ -158,6 +178,7 @@ mod tests {
assert_eq!(trie.root.nelements, 2);
assert_eq!(trie.root.content, b"t");
assert_eq!(trie.root.local_elements, 0);
assert_eq!(trie.root.children.len(), 2);
assert_eq!(
trie.root.children,
@ -166,6 +187,7 @@ mod tests {
b'a',
Node {
nelements: 1,
local_elements: 1,
content: b"ata".to_vec(),
children: BTreeMap::new()
}
@ -174,6 +196,7 @@ mod tests {
b'o',
Node {
nelements: 1,
local_elements: 1,
content: b"oto".to_vec(),
children: BTreeMap::new()
}
@ -183,6 +206,7 @@ mod tests {
assert_eq!(trie.insert(b"coco"), (0, 1));
assert_eq!(trie.insert(b"zaza"), (3, 4));
assert_eq!(trie.root.nelements, 4);
assert_eq!(trie.root.local_elements, 0);
assert_eq!(trie.root.content, b"");
assert_eq!(trie.root.children.len(), 3);
assert_eq!(
@ -192,6 +216,7 @@ mod tests {
b'c',
Node {
nelements: 1,
local_elements: 1,
content: b"coco".to_vec(),
children: BTreeMap::new()
}
@ -200,12 +225,14 @@ mod tests {
b't',
Node {
nelements: 2,
local_elements: 0,
content: b"t".to_vec(),
children: BTreeMap::from_iter([
(
b'a',
Node {
nelements: 1,
local_elements: 1,
content: b"ata".to_vec(),
children: BTreeMap::new()
}
@ -214,6 +241,7 @@ mod tests {
b'o',
Node {
nelements: 1,
local_elements: 1,
content: b"oto".to_vec(),
children: BTreeMap::new()
}
@ -225,6 +253,7 @@ mod tests {
b'z',
Node {
nelements: 1,
local_elements: 1,
content: b"zaza".to_vec(),
children: BTreeMap::new()
}
@ -245,6 +274,19 @@ mod tests {
assert_eq!(trie.root.nelements, 1);
}
#[test]
fn delete_prefix() {
let mut trie = Trie::new();
trie.insert(b"toto");
trie.insert(b"to");
assert_eq!(trie.root.nelements, 2);
assert_eq!(trie.remove(b"to"), Ok(()));
assert_eq!(trie.root.nelements, 1);
assert_eq!(trie.remove(b"toto"), Ok(()));
assert_eq!(trie.root.nelements, 0);
}
#[test]
fn duplicate() {
let mut trie = Trie::new();
@ -254,4 +296,33 @@ mod tests {
assert_eq!(trie.remove(b"toto"), Ok(()));
assert_eq!(trie.root.nelements, 1);
}
#[test]
fn prefix() {
let mut trie = Trie::new();
assert_eq!(trie.insert(b"toto"), (0, 1));
assert_eq!(trie.insert(b"to"), (0, 1));
assert_eq!(trie.root.nelements, 2);
assert_eq!(trie.insert(b"toto"), (1, 3));
assert_eq!(trie.root.nelements, 3);
assert_eq!(trie.insert(b"tototo"), (3, 4));
assert_eq!(trie.root.nelements, 4);
assert_eq!(trie.remove(b"toto"), Ok(()));
assert_eq!(trie.root.nelements, 3);
}
#[test]
fn test_mismatch() {
let m = mismatch(&[0, 1, 2], &[0, 1, 3]);
assert_eq!(m, 2);
let a = vec![0; 256];
let mut b = vec![0; 256];
assert_eq!(mismatch(&a, &b), 256);
b[130] = 1;
assert_eq!(mismatch(&a, &b), 130);
b[129] = 1;
assert_eq!(mismatch(&a, &b), 129);
b[128] = 1;
assert_eq!(mismatch(&a, &b), 128);
}
}