mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-06 09:22:10 +00:00
Fixing the Trie in case of exact prefix match split.
This commit is contained in:
parent
50c8ebdef0
commit
3aa8564652
@ -17,6 +17,7 @@ pub struct Trie {
|
|||||||
pub struct Node {
|
pub struct Node {
|
||||||
content: Vec<u8>,
|
content: Vec<u8>,
|
||||||
nelements: usize,
|
nelements: usize,
|
||||||
|
local_elements: usize,
|
||||||
children: BTreeMap<u8, Node>,
|
children: BTreeMap<u8, Node>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -44,14 +45,17 @@ impl Node {
|
|||||||
Self {
|
Self {
|
||||||
content: vec![],
|
content: vec![],
|
||||||
nelements: 0,
|
nelements: 0,
|
||||||
|
local_elements: 0,
|
||||||
children: BTreeMap::new(),
|
children: BTreeMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn insert(&mut self, data: &[u8], left: usize) -> (usize, usize) {
|
fn insert(&mut self, data: &[u8]) -> (usize, usize) {
|
||||||
let (start, stop) = if self.nelements == 0 {
|
let (start, stop) = if self.nelements == 0 {
|
||||||
self.content = data.to_vec();
|
self.content = data.to_vec();
|
||||||
(left, left + 1)
|
assert_eq!(self.local_elements, 0);
|
||||||
|
self.local_elements = 1;
|
||||||
|
(0, self.local_elements)
|
||||||
} else {
|
} else {
|
||||||
let mismatch = mismatch(data, &self.content);
|
let mismatch = mismatch(data, &self.content);
|
||||||
if mismatch == self.content.len() {
|
if mismatch == self.content.len() {
|
||||||
@ -62,54 +66,66 @@ impl Node {
|
|||||||
.iter()
|
.iter()
|
||||||
.take_while(|(&d, _)| d < *c)
|
.take_while(|(&d, _)| d < *c)
|
||||||
.map(|(_, n)| n.nelements)
|
.map(|(_, n)| n.nelements)
|
||||||
.sum();
|
.sum::<usize>()
|
||||||
|
+ self.local_elements;
|
||||||
let next_node = self.children.entry(*c).or_insert(Node::new());
|
let next_node = self.children.entry(*c).or_insert(Node::new());
|
||||||
next_node.insert(&data[mismatch..], left)
|
let (inner_left, inner_right) = next_node.insert(&data[mismatch..]);
|
||||||
|
(inner_left + left, inner_right + left)
|
||||||
} else {
|
} else {
|
||||||
(0, self.nelements + 1)
|
assert_eq!(data.len(), self.content.len());
|
||||||
|
self.local_elements += 1;
|
||||||
|
(0, self.local_elements)
|
||||||
};
|
};
|
||||||
(left + start, left + stop)
|
(start, stop)
|
||||||
} else {
|
} else {
|
||||||
// Partial match, split node
|
// Partial match, split node
|
||||||
let left = self.content[mismatch..].to_vec();
|
let left_content = self.content[mismatch..].to_vec();
|
||||||
let right = data[mismatch..].to_vec();
|
let right_content = data[mismatch..].to_vec();
|
||||||
|
|
||||||
let children = std::mem::take(&mut self.children);
|
let children = std::mem::take(&mut self.children);
|
||||||
let mut children_content = vec![
|
let mut children_content = vec![
|
||||||
(left, children, self.nelements),
|
(left_content, children, self.nelements, self.local_elements),
|
||||||
(right, BTreeMap::new(), 1),
|
(right_content, BTreeMap::new(), 1, 1),
|
||||||
];
|
];
|
||||||
children_content.sort_by(|a, b| a.0.cmp(&b.0));
|
children_content.sort_by(|a, b| a.0.cmp(&b.0));
|
||||||
self.content.truncate(mismatch);
|
self.content.truncate(mismatch);
|
||||||
self.children.clear();
|
self.children.clear();
|
||||||
for (child_content, children, nelements) in children_content {
|
self.local_elements = 0;
|
||||||
|
for (child_content, children, nelements, local_elements) in children_content {
|
||||||
if !child_content.is_empty() {
|
if !child_content.is_empty() {
|
||||||
let c = child_content[0];
|
let c = child_content[0];
|
||||||
let child = Node {
|
let child = Node {
|
||||||
content: child_content,
|
content: child_content,
|
||||||
nelements,
|
nelements,
|
||||||
|
local_elements,
|
||||||
children,
|
children,
|
||||||
};
|
};
|
||||||
self.children.insert(c, child);
|
self.children.insert(c, child);
|
||||||
|
} else {
|
||||||
|
self.local_elements += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let c = data[mismatch];
|
let (start, stop) = if let Some(c) = data.get(mismatch) {
|
||||||
let left: usize = self
|
let start = self
|
||||||
.children
|
.children
|
||||||
.iter()
|
.iter()
|
||||||
.take_while(|(&d, _)| d < c)
|
.take_while(|(&d, _)| d < *c)
|
||||||
.map(|(_, n)| n.nelements)
|
.map(|(_, n)| n.nelements)
|
||||||
.sum();
|
.sum::<usize>()
|
||||||
(left, left + 1)
|
+ self.local_elements;
|
||||||
|
(start, start + 1)
|
||||||
|
} else {
|
||||||
|
(0, self.local_elements)
|
||||||
|
};
|
||||||
|
(start, stop)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
self.nelements += 1;
|
self.nelements += 1;
|
||||||
(start, stop)
|
(start, stop)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO
|
#[cfg(debug_assertions)]
|
||||||
#[allow(dead_code)]
|
|
||||||
fn remove(&mut self, data: &[u8]) -> Result<(), Error> {
|
fn remove(&mut self, data: &[u8]) -> Result<(), Error> {
|
||||||
|
// TODO reclaim the nodes too.
|
||||||
let mismatch = mismatch(data, &self.content);
|
let mismatch = mismatch(data, &self.content);
|
||||||
if mismatch != self.content.len() {
|
if mismatch != self.content.len() {
|
||||||
Err(Error::MissingEntry)
|
Err(Error::MissingEntry)
|
||||||
@ -118,6 +134,10 @@ impl Node {
|
|||||||
if let Some(node) = self.children.get_mut(c) {
|
if let Some(node) = self.children.get_mut(c) {
|
||||||
node.remove(&data[mismatch..])?;
|
node.remove(&data[mismatch..])?;
|
||||||
}
|
}
|
||||||
|
} else if self.local_elements == 0 {
|
||||||
|
return Err(Error::MissingEntry);
|
||||||
|
} else {
|
||||||
|
self.local_elements -= 1;
|
||||||
}
|
}
|
||||||
self.nelements -= 1;
|
self.nelements -= 1;
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -132,7 +152,7 @@ impl Trie {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn insert(&mut self, data: &[u8]) -> (usize, usize) {
|
pub fn insert(&mut self, data: &[u8]) -> (usize, usize) {
|
||||||
self.root.insert(data, 0)
|
self.root.insert(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO
|
// TODO
|
||||||
@ -158,6 +178,7 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(trie.root.nelements, 2);
|
assert_eq!(trie.root.nelements, 2);
|
||||||
assert_eq!(trie.root.content, b"t");
|
assert_eq!(trie.root.content, b"t");
|
||||||
|
assert_eq!(trie.root.local_elements, 0);
|
||||||
assert_eq!(trie.root.children.len(), 2);
|
assert_eq!(trie.root.children.len(), 2);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
trie.root.children,
|
trie.root.children,
|
||||||
@ -166,6 +187,7 @@ mod tests {
|
|||||||
b'a',
|
b'a',
|
||||||
Node {
|
Node {
|
||||||
nelements: 1,
|
nelements: 1,
|
||||||
|
local_elements: 1,
|
||||||
content: b"ata".to_vec(),
|
content: b"ata".to_vec(),
|
||||||
children: BTreeMap::new()
|
children: BTreeMap::new()
|
||||||
}
|
}
|
||||||
@ -174,6 +196,7 @@ mod tests {
|
|||||||
b'o',
|
b'o',
|
||||||
Node {
|
Node {
|
||||||
nelements: 1,
|
nelements: 1,
|
||||||
|
local_elements: 1,
|
||||||
content: b"oto".to_vec(),
|
content: b"oto".to_vec(),
|
||||||
children: BTreeMap::new()
|
children: BTreeMap::new()
|
||||||
}
|
}
|
||||||
@ -183,6 +206,7 @@ mod tests {
|
|||||||
assert_eq!(trie.insert(b"coco"), (0, 1));
|
assert_eq!(trie.insert(b"coco"), (0, 1));
|
||||||
assert_eq!(trie.insert(b"zaza"), (3, 4));
|
assert_eq!(trie.insert(b"zaza"), (3, 4));
|
||||||
assert_eq!(trie.root.nelements, 4);
|
assert_eq!(trie.root.nelements, 4);
|
||||||
|
assert_eq!(trie.root.local_elements, 0);
|
||||||
assert_eq!(trie.root.content, b"");
|
assert_eq!(trie.root.content, b"");
|
||||||
assert_eq!(trie.root.children.len(), 3);
|
assert_eq!(trie.root.children.len(), 3);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -192,6 +216,7 @@ mod tests {
|
|||||||
b'c',
|
b'c',
|
||||||
Node {
|
Node {
|
||||||
nelements: 1,
|
nelements: 1,
|
||||||
|
local_elements: 1,
|
||||||
content: b"coco".to_vec(),
|
content: b"coco".to_vec(),
|
||||||
children: BTreeMap::new()
|
children: BTreeMap::new()
|
||||||
}
|
}
|
||||||
@ -200,12 +225,14 @@ mod tests {
|
|||||||
b't',
|
b't',
|
||||||
Node {
|
Node {
|
||||||
nelements: 2,
|
nelements: 2,
|
||||||
|
local_elements: 0,
|
||||||
content: b"t".to_vec(),
|
content: b"t".to_vec(),
|
||||||
children: BTreeMap::from_iter([
|
children: BTreeMap::from_iter([
|
||||||
(
|
(
|
||||||
b'a',
|
b'a',
|
||||||
Node {
|
Node {
|
||||||
nelements: 1,
|
nelements: 1,
|
||||||
|
local_elements: 1,
|
||||||
content: b"ata".to_vec(),
|
content: b"ata".to_vec(),
|
||||||
children: BTreeMap::new()
|
children: BTreeMap::new()
|
||||||
}
|
}
|
||||||
@ -214,6 +241,7 @@ mod tests {
|
|||||||
b'o',
|
b'o',
|
||||||
Node {
|
Node {
|
||||||
nelements: 1,
|
nelements: 1,
|
||||||
|
local_elements: 1,
|
||||||
content: b"oto".to_vec(),
|
content: b"oto".to_vec(),
|
||||||
children: BTreeMap::new()
|
children: BTreeMap::new()
|
||||||
}
|
}
|
||||||
@ -225,6 +253,7 @@ mod tests {
|
|||||||
b'z',
|
b'z',
|
||||||
Node {
|
Node {
|
||||||
nelements: 1,
|
nelements: 1,
|
||||||
|
local_elements: 1,
|
||||||
content: b"zaza".to_vec(),
|
content: b"zaza".to_vec(),
|
||||||
children: BTreeMap::new()
|
children: BTreeMap::new()
|
||||||
}
|
}
|
||||||
@ -245,6 +274,19 @@ mod tests {
|
|||||||
assert_eq!(trie.root.nelements, 1);
|
assert_eq!(trie.root.nelements, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn delete_prefix() {
|
||||||
|
let mut trie = Trie::new();
|
||||||
|
trie.insert(b"toto");
|
||||||
|
trie.insert(b"to");
|
||||||
|
|
||||||
|
assert_eq!(trie.root.nelements, 2);
|
||||||
|
assert_eq!(trie.remove(b"to"), Ok(()));
|
||||||
|
assert_eq!(trie.root.nelements, 1);
|
||||||
|
assert_eq!(trie.remove(b"toto"), Ok(()));
|
||||||
|
assert_eq!(trie.root.nelements, 0);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn duplicate() {
|
fn duplicate() {
|
||||||
let mut trie = Trie::new();
|
let mut trie = Trie::new();
|
||||||
@ -254,4 +296,33 @@ mod tests {
|
|||||||
assert_eq!(trie.remove(b"toto"), Ok(()));
|
assert_eq!(trie.remove(b"toto"), Ok(()));
|
||||||
assert_eq!(trie.root.nelements, 1);
|
assert_eq!(trie.root.nelements, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn prefix() {
|
||||||
|
let mut trie = Trie::new();
|
||||||
|
assert_eq!(trie.insert(b"toto"), (0, 1));
|
||||||
|
assert_eq!(trie.insert(b"to"), (0, 1));
|
||||||
|
assert_eq!(trie.root.nelements, 2);
|
||||||
|
assert_eq!(trie.insert(b"toto"), (1, 3));
|
||||||
|
assert_eq!(trie.root.nelements, 3);
|
||||||
|
assert_eq!(trie.insert(b"tototo"), (3, 4));
|
||||||
|
assert_eq!(trie.root.nelements, 4);
|
||||||
|
assert_eq!(trie.remove(b"toto"), Ok(()));
|
||||||
|
assert_eq!(trie.root.nelements, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mismatch() {
|
||||||
|
let m = mismatch(&[0, 1, 2], &[0, 1, 3]);
|
||||||
|
assert_eq!(m, 2);
|
||||||
|
let a = vec![0; 256];
|
||||||
|
let mut b = vec![0; 256];
|
||||||
|
assert_eq!(mismatch(&a, &b), 256);
|
||||||
|
b[130] = 1;
|
||||||
|
assert_eq!(mismatch(&a, &b), 130);
|
||||||
|
b[129] = 1;
|
||||||
|
assert_eq!(mismatch(&a, &b), 129);
|
||||||
|
b[128] = 1;
|
||||||
|
assert_eq!(mismatch(&a, &b), 128);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user