mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-03 05:02:12 +00:00
Fixing the Trie in case of exact prefix match split.
This commit is contained in:
parent
50c8ebdef0
commit
3aa8564652
@ -17,6 +17,7 @@ pub struct Trie {
|
||||
pub struct Node {
|
||||
content: Vec<u8>,
|
||||
nelements: usize,
|
||||
local_elements: usize,
|
||||
children: BTreeMap<u8, Node>,
|
||||
}
|
||||
|
||||
@ -44,14 +45,17 @@ impl Node {
|
||||
Self {
|
||||
content: vec![],
|
||||
nelements: 0,
|
||||
local_elements: 0,
|
||||
children: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn insert(&mut self, data: &[u8], left: usize) -> (usize, usize) {
|
||||
fn insert(&mut self, data: &[u8]) -> (usize, usize) {
|
||||
let (start, stop) = if self.nelements == 0 {
|
||||
self.content = data.to_vec();
|
||||
(left, left + 1)
|
||||
assert_eq!(self.local_elements, 0);
|
||||
self.local_elements = 1;
|
||||
(0, self.local_elements)
|
||||
} else {
|
||||
let mismatch = mismatch(data, &self.content);
|
||||
if mismatch == self.content.len() {
|
||||
@ -62,54 +66,66 @@ impl Node {
|
||||
.iter()
|
||||
.take_while(|(&d, _)| d < *c)
|
||||
.map(|(_, n)| n.nelements)
|
||||
.sum();
|
||||
.sum::<usize>()
|
||||
+ self.local_elements;
|
||||
let next_node = self.children.entry(*c).or_insert(Node::new());
|
||||
next_node.insert(&data[mismatch..], left)
|
||||
let (inner_left, inner_right) = next_node.insert(&data[mismatch..]);
|
||||
(inner_left + left, inner_right + left)
|
||||
} else {
|
||||
(0, self.nelements + 1)
|
||||
assert_eq!(data.len(), self.content.len());
|
||||
self.local_elements += 1;
|
||||
(0, self.local_elements)
|
||||
};
|
||||
(left + start, left + stop)
|
||||
(start, stop)
|
||||
} else {
|
||||
// Partial match, split node
|
||||
let left = self.content[mismatch..].to_vec();
|
||||
let right = data[mismatch..].to_vec();
|
||||
|
||||
let left_content = self.content[mismatch..].to_vec();
|
||||
let right_content = data[mismatch..].to_vec();
|
||||
let children = std::mem::take(&mut self.children);
|
||||
let mut children_content = vec![
|
||||
(left, children, self.nelements),
|
||||
(right, BTreeMap::new(), 1),
|
||||
(left_content, children, self.nelements, self.local_elements),
|
||||
(right_content, BTreeMap::new(), 1, 1),
|
||||
];
|
||||
children_content.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
self.content.truncate(mismatch);
|
||||
self.children.clear();
|
||||
for (child_content, children, nelements) in children_content {
|
||||
self.local_elements = 0;
|
||||
for (child_content, children, nelements, local_elements) in children_content {
|
||||
if !child_content.is_empty() {
|
||||
let c = child_content[0];
|
||||
let child = Node {
|
||||
content: child_content,
|
||||
nelements,
|
||||
local_elements,
|
||||
children,
|
||||
};
|
||||
self.children.insert(c, child);
|
||||
} else {
|
||||
self.local_elements += 1;
|
||||
}
|
||||
}
|
||||
let c = data[mismatch];
|
||||
let left: usize = self
|
||||
.children
|
||||
.iter()
|
||||
.take_while(|(&d, _)| d < c)
|
||||
.map(|(_, n)| n.nelements)
|
||||
.sum();
|
||||
(left, left + 1)
|
||||
let (start, stop) = if let Some(c) = data.get(mismatch) {
|
||||
let start = self
|
||||
.children
|
||||
.iter()
|
||||
.take_while(|(&d, _)| d < *c)
|
||||
.map(|(_, n)| n.nelements)
|
||||
.sum::<usize>()
|
||||
+ self.local_elements;
|
||||
(start, start + 1)
|
||||
} else {
|
||||
(0, self.local_elements)
|
||||
};
|
||||
(start, stop)
|
||||
}
|
||||
};
|
||||
self.nelements += 1;
|
||||
(start, stop)
|
||||
}
|
||||
|
||||
// TODO
|
||||
#[allow(dead_code)]
|
||||
#[cfg(debug_assertions)]
|
||||
fn remove(&mut self, data: &[u8]) -> Result<(), Error> {
|
||||
// TODO reclaim the nodes too.
|
||||
let mismatch = mismatch(data, &self.content);
|
||||
if mismatch != self.content.len() {
|
||||
Err(Error::MissingEntry)
|
||||
@ -118,6 +134,10 @@ impl Node {
|
||||
if let Some(node) = self.children.get_mut(c) {
|
||||
node.remove(&data[mismatch..])?;
|
||||
}
|
||||
} else if self.local_elements == 0 {
|
||||
return Err(Error::MissingEntry);
|
||||
} else {
|
||||
self.local_elements -= 1;
|
||||
}
|
||||
self.nelements -= 1;
|
||||
Ok(())
|
||||
@ -132,7 +152,7 @@ impl Trie {
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, data: &[u8]) -> (usize, usize) {
|
||||
self.root.insert(data, 0)
|
||||
self.root.insert(data)
|
||||
}
|
||||
|
||||
// TODO
|
||||
@ -158,6 +178,7 @@ mod tests {
|
||||
|
||||
assert_eq!(trie.root.nelements, 2);
|
||||
assert_eq!(trie.root.content, b"t");
|
||||
assert_eq!(trie.root.local_elements, 0);
|
||||
assert_eq!(trie.root.children.len(), 2);
|
||||
assert_eq!(
|
||||
trie.root.children,
|
||||
@ -166,6 +187,7 @@ mod tests {
|
||||
b'a',
|
||||
Node {
|
||||
nelements: 1,
|
||||
local_elements: 1,
|
||||
content: b"ata".to_vec(),
|
||||
children: BTreeMap::new()
|
||||
}
|
||||
@ -174,6 +196,7 @@ mod tests {
|
||||
b'o',
|
||||
Node {
|
||||
nelements: 1,
|
||||
local_elements: 1,
|
||||
content: b"oto".to_vec(),
|
||||
children: BTreeMap::new()
|
||||
}
|
||||
@ -183,6 +206,7 @@ mod tests {
|
||||
assert_eq!(trie.insert(b"coco"), (0, 1));
|
||||
assert_eq!(trie.insert(b"zaza"), (3, 4));
|
||||
assert_eq!(trie.root.nelements, 4);
|
||||
assert_eq!(trie.root.local_elements, 0);
|
||||
assert_eq!(trie.root.content, b"");
|
||||
assert_eq!(trie.root.children.len(), 3);
|
||||
assert_eq!(
|
||||
@ -192,6 +216,7 @@ mod tests {
|
||||
b'c',
|
||||
Node {
|
||||
nelements: 1,
|
||||
local_elements: 1,
|
||||
content: b"coco".to_vec(),
|
||||
children: BTreeMap::new()
|
||||
}
|
||||
@ -200,12 +225,14 @@ mod tests {
|
||||
b't',
|
||||
Node {
|
||||
nelements: 2,
|
||||
local_elements: 0,
|
||||
content: b"t".to_vec(),
|
||||
children: BTreeMap::from_iter([
|
||||
(
|
||||
b'a',
|
||||
Node {
|
||||
nelements: 1,
|
||||
local_elements: 1,
|
||||
content: b"ata".to_vec(),
|
||||
children: BTreeMap::new()
|
||||
}
|
||||
@ -214,6 +241,7 @@ mod tests {
|
||||
b'o',
|
||||
Node {
|
||||
nelements: 1,
|
||||
local_elements: 1,
|
||||
content: b"oto".to_vec(),
|
||||
children: BTreeMap::new()
|
||||
}
|
||||
@ -225,6 +253,7 @@ mod tests {
|
||||
b'z',
|
||||
Node {
|
||||
nelements: 1,
|
||||
local_elements: 1,
|
||||
content: b"zaza".to_vec(),
|
||||
children: BTreeMap::new()
|
||||
}
|
||||
@ -245,6 +274,19 @@ mod tests {
|
||||
assert_eq!(trie.root.nelements, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn delete_prefix() {
|
||||
let mut trie = Trie::new();
|
||||
trie.insert(b"toto");
|
||||
trie.insert(b"to");
|
||||
|
||||
assert_eq!(trie.root.nelements, 2);
|
||||
assert_eq!(trie.remove(b"to"), Ok(()));
|
||||
assert_eq!(trie.root.nelements, 1);
|
||||
assert_eq!(trie.remove(b"toto"), Ok(()));
|
||||
assert_eq!(trie.root.nelements, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn duplicate() {
|
||||
let mut trie = Trie::new();
|
||||
@ -254,4 +296,33 @@ mod tests {
|
||||
assert_eq!(trie.remove(b"toto"), Ok(()));
|
||||
assert_eq!(trie.root.nelements, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prefix() {
|
||||
let mut trie = Trie::new();
|
||||
assert_eq!(trie.insert(b"toto"), (0, 1));
|
||||
assert_eq!(trie.insert(b"to"), (0, 1));
|
||||
assert_eq!(trie.root.nelements, 2);
|
||||
assert_eq!(trie.insert(b"toto"), (1, 3));
|
||||
assert_eq!(trie.root.nelements, 3);
|
||||
assert_eq!(trie.insert(b"tototo"), (3, 4));
|
||||
assert_eq!(trie.root.nelements, 4);
|
||||
assert_eq!(trie.remove(b"toto"), Ok(()));
|
||||
assert_eq!(trie.root.nelements, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mismatch() {
|
||||
let m = mismatch(&[0, 1, 2], &[0, 1, 3]);
|
||||
assert_eq!(m, 2);
|
||||
let a = vec![0; 256];
|
||||
let mut b = vec![0; 256];
|
||||
assert_eq!(mismatch(&a, &b), 256);
|
||||
b[130] = 1;
|
||||
assert_eq!(mismatch(&a, &b), 130);
|
||||
b[129] = 1;
|
||||
assert_eq!(mismatch(&a, &b), 129);
|
||||
b[128] = 1;
|
||||
assert_eq!(mismatch(&a, &b), 128);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user