From 3aa85646520dfb2fbc19383881ee90f0d2d3dd11 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 30 Apr 2025 17:33:17 +0200 Subject: [PATCH] Fixing the Trie in case of exact prefix match split. --- kvrouter/src/trie.rs | 117 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 94 insertions(+), 23 deletions(-) diff --git a/kvrouter/src/trie.rs b/kvrouter/src/trie.rs index 002778bc..8062af60 100644 --- a/kvrouter/src/trie.rs +++ b/kvrouter/src/trie.rs @@ -17,6 +17,7 @@ pub struct Trie { pub struct Node { content: Vec, nelements: usize, + local_elements: usize, children: BTreeMap, } @@ -44,14 +45,17 @@ impl Node { Self { content: vec![], nelements: 0, + local_elements: 0, children: BTreeMap::new(), } } - fn insert(&mut self, data: &[u8], left: usize) -> (usize, usize) { + fn insert(&mut self, data: &[u8]) -> (usize, usize) { let (start, stop) = if self.nelements == 0 { self.content = data.to_vec(); - (left, left + 1) + assert_eq!(self.local_elements, 0); + self.local_elements = 1; + (0, self.local_elements) } else { let mismatch = mismatch(data, &self.content); if mismatch == self.content.len() { @@ -62,54 +66,66 @@ impl Node { .iter() .take_while(|(&d, _)| d < *c) .map(|(_, n)| n.nelements) - .sum(); + .sum::() + + self.local_elements; let next_node = self.children.entry(*c).or_insert(Node::new()); - next_node.insert(&data[mismatch..], left) + let (inner_left, inner_right) = next_node.insert(&data[mismatch..]); + (inner_left + left, inner_right + left) } else { - (0, self.nelements + 1) + assert_eq!(data.len(), self.content.len()); + self.local_elements += 1; + (0, self.local_elements) }; - (left + start, left + stop) + (start, stop) } else { // Partial match, split node - let left = self.content[mismatch..].to_vec(); - let right = data[mismatch..].to_vec(); - + let left_content = self.content[mismatch..].to_vec(); + let right_content = data[mismatch..].to_vec(); let children = std::mem::take(&mut self.children); let mut children_content = vec![ - (left, children, self.nelements), - (right, BTreeMap::new(), 1), + (left_content, children, self.nelements, self.local_elements), + (right_content, BTreeMap::new(), 1, 1), ]; children_content.sort_by(|a, b| a.0.cmp(&b.0)); self.content.truncate(mismatch); self.children.clear(); - for (child_content, children, nelements) in children_content { + self.local_elements = 0; + for (child_content, children, nelements, local_elements) in children_content { if !child_content.is_empty() { let c = child_content[0]; let child = Node { content: child_content, nelements, + local_elements, children, }; self.children.insert(c, child); + } else { + self.local_elements += 1; } } - let c = data[mismatch]; - let left: usize = self - .children - .iter() - .take_while(|(&d, _)| d < c) - .map(|(_, n)| n.nelements) - .sum(); - (left, left + 1) + let (start, stop) = if let Some(c) = data.get(mismatch) { + let start = self + .children + .iter() + .take_while(|(&d, _)| d < *c) + .map(|(_, n)| n.nelements) + .sum::() + + self.local_elements; + (start, start + 1) + } else { + (0, self.local_elements) + }; + (start, stop) } }; self.nelements += 1; (start, stop) } - // TODO - #[allow(dead_code)] + #[cfg(debug_assertions)] fn remove(&mut self, data: &[u8]) -> Result<(), Error> { + // TODO reclaim the nodes too. let mismatch = mismatch(data, &self.content); if mismatch != self.content.len() { Err(Error::MissingEntry) @@ -118,6 +134,10 @@ impl Node { if let Some(node) = self.children.get_mut(c) { node.remove(&data[mismatch..])?; } + } else if self.local_elements == 0 { + return Err(Error::MissingEntry); + } else { + self.local_elements -= 1; } self.nelements -= 1; Ok(()) @@ -132,7 +152,7 @@ impl Trie { } pub fn insert(&mut self, data: &[u8]) -> (usize, usize) { - self.root.insert(data, 0) + self.root.insert(data) } // TODO @@ -158,6 +178,7 @@ mod tests { assert_eq!(trie.root.nelements, 2); assert_eq!(trie.root.content, b"t"); + assert_eq!(trie.root.local_elements, 0); assert_eq!(trie.root.children.len(), 2); assert_eq!( trie.root.children, @@ -166,6 +187,7 @@ mod tests { b'a', Node { nelements: 1, + local_elements: 1, content: b"ata".to_vec(), children: BTreeMap::new() } @@ -174,6 +196,7 @@ mod tests { b'o', Node { nelements: 1, + local_elements: 1, content: b"oto".to_vec(), children: BTreeMap::new() } @@ -183,6 +206,7 @@ mod tests { assert_eq!(trie.insert(b"coco"), (0, 1)); assert_eq!(trie.insert(b"zaza"), (3, 4)); assert_eq!(trie.root.nelements, 4); + assert_eq!(trie.root.local_elements, 0); assert_eq!(trie.root.content, b""); assert_eq!(trie.root.children.len(), 3); assert_eq!( @@ -192,6 +216,7 @@ mod tests { b'c', Node { nelements: 1, + local_elements: 1, content: b"coco".to_vec(), children: BTreeMap::new() } @@ -200,12 +225,14 @@ mod tests { b't', Node { nelements: 2, + local_elements: 0, content: b"t".to_vec(), children: BTreeMap::from_iter([ ( b'a', Node { nelements: 1, + local_elements: 1, content: b"ata".to_vec(), children: BTreeMap::new() } @@ -214,6 +241,7 @@ mod tests { b'o', Node { nelements: 1, + local_elements: 1, content: b"oto".to_vec(), children: BTreeMap::new() } @@ -225,6 +253,7 @@ mod tests { b'z', Node { nelements: 1, + local_elements: 1, content: b"zaza".to_vec(), children: BTreeMap::new() } @@ -245,6 +274,19 @@ mod tests { assert_eq!(trie.root.nelements, 1); } + #[test] + fn delete_prefix() { + let mut trie = Trie::new(); + trie.insert(b"toto"); + trie.insert(b"to"); + + assert_eq!(trie.root.nelements, 2); + assert_eq!(trie.remove(b"to"), Ok(())); + assert_eq!(trie.root.nelements, 1); + assert_eq!(trie.remove(b"toto"), Ok(())); + assert_eq!(trie.root.nelements, 0); + } + #[test] fn duplicate() { let mut trie = Trie::new(); @@ -254,4 +296,33 @@ mod tests { assert_eq!(trie.remove(b"toto"), Ok(())); assert_eq!(trie.root.nelements, 1); } + + #[test] + fn prefix() { + let mut trie = Trie::new(); + assert_eq!(trie.insert(b"toto"), (0, 1)); + assert_eq!(trie.insert(b"to"), (0, 1)); + assert_eq!(trie.root.nelements, 2); + assert_eq!(trie.insert(b"toto"), (1, 3)); + assert_eq!(trie.root.nelements, 3); + assert_eq!(trie.insert(b"tototo"), (3, 4)); + assert_eq!(trie.root.nelements, 4); + assert_eq!(trie.remove(b"toto"), Ok(())); + assert_eq!(trie.root.nelements, 3); + } + + #[test] + fn test_mismatch() { + let m = mismatch(&[0, 1, 2], &[0, 1, 3]); + assert_eq!(m, 2); + let a = vec![0; 256]; + let mut b = vec![0; 256]; + assert_eq!(mismatch(&a, &b), 256); + b[130] = 1; + assert_eq!(mismatch(&a, &b), 130); + b[129] = 1; + assert_eq!(mismatch(&a, &b), 129); + b[128] = 1; + assert_eq!(mismatch(&a, &b), 128); + } }