mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-08 19:04:52 +00:00
Double linked data structures are still terrible in Rust.
So use fake pointers.
This commit is contained in:
parent
cd4933cd5a
commit
590fc2c58d
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -4045,6 +4045,7 @@ dependencies = [
|
|||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"slotmap",
|
||||||
"text-generation-router",
|
"text-generation-router",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
|
@ -33,6 +33,7 @@ rand = "0.8.5"
|
|||||||
reqwest = { version = "0.11.20", features = [] }
|
reqwest = { version = "0.11.20", features = [] }
|
||||||
serde = "1.0.188"
|
serde = "1.0.188"
|
||||||
serde_json = "1.0.107"
|
serde_json = "1.0.107"
|
||||||
|
slotmap = "1.0.7"
|
||||||
thiserror = "1.0.48"
|
thiserror = "1.0.48"
|
||||||
tokenizers = { workspace = true}
|
tokenizers = { workspace = true}
|
||||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
|
@ -1,11 +1,7 @@
|
|||||||
use std::{
|
use std::{cmp::min, collections::BTreeSet, sync::Arc};
|
||||||
cmp::min,
|
|
||||||
collections::{hash_map::Entry, BTreeSet, HashMap},
|
|
||||||
sync::Arc,
|
|
||||||
};
|
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
|
||||||
use crate::TrieNode;
|
use crate::RadixTrie;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct BlockAllocation {
|
pub(crate) struct BlockAllocation {
|
||||||
@ -212,7 +208,7 @@ struct PrefixBlockState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct RadixAllocator {
|
struct RadixAllocator {
|
||||||
cache_blocks: TrieNode,
|
cache_blocks: RadixTrie,
|
||||||
|
|
||||||
/// Blocks that are immediately available for allocation.
|
/// Blocks that are immediately available for allocation.
|
||||||
free_blocks: Vec<u32>,
|
free_blocks: Vec<u32>,
|
||||||
@ -236,7 +232,7 @@ impl RadixAllocator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
RadixAllocator {
|
RadixAllocator {
|
||||||
cache_blocks: TrieNode::new(vec![], vec![], 0),
|
cache_blocks: RadixTrie::new(),
|
||||||
free_blocks: (1..n_blocks).collect(),
|
free_blocks: (1..n_blocks).collect(),
|
||||||
leaves: BTreeSet::new(),
|
leaves: BTreeSet::new(),
|
||||||
time: 0,
|
time: 0,
|
||||||
|
@ -6,7 +6,7 @@ mod radix;
|
|||||||
|
|
||||||
use crate::client::{ClientError, ShardedClient};
|
use crate::client::{ClientError, ShardedClient};
|
||||||
pub(crate) use backend::BackendV3;
|
pub(crate) use backend::BackendV3;
|
||||||
pub(crate) use radix::TrieNode;
|
pub(crate) use radix::RadixTrie;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
use std::collections::{hash_map::Entry, HashMap};
|
use std::collections::{hash_map::Entry, HashMap};
|
||||||
|
|
||||||
|
use slotmap::{DefaultKey, SlotMap};
|
||||||
|
|
||||||
// Radix trie that is heavily inspired by radix attention from sglang.
|
// Radix trie that is heavily inspired by radix attention from sglang.
|
||||||
//
|
//
|
||||||
// The trie is optimized for prefix caching:
|
// The trie is optimized for prefix caching:
|
||||||
@ -12,16 +14,115 @@ use std::collections::{hash_map::Entry, HashMap};
|
|||||||
// - We store additional information in each node, such as last access
|
// - We store additional information in each node, such as last access
|
||||||
// time and a reference count.
|
// time and a reference count.
|
||||||
|
|
||||||
#[derive(Debug)]
|
type NodeId = DefaultKey;
|
||||||
pub struct TrieNode {
|
|
||||||
children: HashMap<u32, TrieNode>,
|
pub struct RadixTrie {
|
||||||
|
root: DefaultKey,
|
||||||
|
nodes: SlotMap<NodeId, TrieNode>,
|
||||||
|
time: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RadixTrie {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let root = TrieNode::new(vec![], vec![], 0);
|
||||||
|
let mut nodes = SlotMap::new();
|
||||||
|
let root = nodes.insert(root);
|
||||||
|
RadixTrie {
|
||||||
|
nodes,
|
||||||
|
root,
|
||||||
|
time: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn find(&self, key: &[u32], blocks: &mut Vec<u32>) {
|
||||||
|
self.find_(self.root, key, blocks);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find_(&self, node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) {
|
||||||
|
let node = &self.nodes[node_id];
|
||||||
|
|
||||||
|
if let Some(&child_id) = node.children.get(&key[0]) {
|
||||||
|
let child = &self.nodes[child_id];
|
||||||
|
let shared_prefix_len = child.key.shared_prefix_len(key);
|
||||||
|
blocks.extend(&child.blocks[..shared_prefix_len]);
|
||||||
|
|
||||||
|
let key = &key[shared_prefix_len..];
|
||||||
|
if !key.is_empty() {
|
||||||
|
self.find_(child_id, key, blocks);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn insert(&mut self, key: &[u32], blocks: &[u32]) -> usize {
|
||||||
|
self.time += 1;
|
||||||
|
self.insert_(self.root, key, blocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn insert_(&mut self, node_id: NodeId, key: &[u32], blocks: &[u32]) -> usize {
|
||||||
|
assert_eq!(key.len(), blocks.len());
|
||||||
|
|
||||||
|
//let node = self.nodes.get_mut(node).unwrap();
|
||||||
|
|
||||||
|
if let Some(&child_id) = self.nodes[node_id].children.get(&key[0]) {
|
||||||
|
let child = self.nodes.get_mut(child_id).unwrap();
|
||||||
|
let shared_prefix_len = child.key.shared_prefix_len(key);
|
||||||
|
|
||||||
|
// We are done, the prefix is already in the trie.
|
||||||
|
if shared_prefix_len == key.len() {
|
||||||
|
return shared_prefix_len;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The node's prefix is a prefix of the insertion prefix.
|
||||||
|
if shared_prefix_len == child.key.len() {
|
||||||
|
return shared_prefix_len
|
||||||
|
+ self.insert_(
|
||||||
|
child_id,
|
||||||
|
&key[shared_prefix_len..],
|
||||||
|
&blocks[shared_prefix_len..],
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The node's prefix and the insertion prefix only match partially,
|
||||||
|
// split the node to just contain the matching part. Then insert the
|
||||||
|
// remainder of the prefix into the node again.
|
||||||
|
self.split(child_id, shared_prefix_len);
|
||||||
|
let key = &key[shared_prefix_len..];
|
||||||
|
let blocks = &blocks[shared_prefix_len..];
|
||||||
|
self.insert_(child_id, key, blocks)
|
||||||
|
} else {
|
||||||
|
let child = TrieNode::new(key.to_vec(), blocks.to_vec(), self.time);
|
||||||
|
let child_id = self.nodes.insert(child);
|
||||||
|
let node = self.nodes.get_mut(node_id).unwrap();
|
||||||
|
node.children.insert(key[0], child_id);
|
||||||
|
return key.len();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn split(&mut self, node_id: NodeId, prefix_len: usize) {
|
||||||
|
let node = self.nodes.get_mut(node_id).unwrap();
|
||||||
|
|
||||||
|
let rest_key = node.key.split_off(prefix_len);
|
||||||
|
let rest_blocks = node.blocks.split_off(prefix_len);
|
||||||
|
let first = rest_key[0];
|
||||||
|
|
||||||
|
let new_id = self
|
||||||
|
.nodes
|
||||||
|
.insert(TrieNode::new(rest_key, rest_blocks, self.time));
|
||||||
|
|
||||||
|
let node = self.nodes.get_mut(node_id).unwrap();
|
||||||
|
node.children.insert(first, new_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TrieNode {
|
||||||
|
children: HashMap<u32, NodeId>,
|
||||||
key: Vec<u32>,
|
key: Vec<u32>,
|
||||||
blocks: Vec<u32>,
|
blocks: Vec<u32>,
|
||||||
last_accessed: u64,
|
last_accessed: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TrieNode {
|
impl TrieNode {
|
||||||
pub fn new(key: Vec<u32>, blocks: Vec<u32>, last_accessed: u64) -> Self {
|
fn new(key: Vec<u32>, blocks: Vec<u32>, last_accessed: u64) -> Self {
|
||||||
TrieNode {
|
TrieNode {
|
||||||
children: HashMap::new(),
|
children: HashMap::new(),
|
||||||
key,
|
key,
|
||||||
@ -29,66 +130,6 @@ impl TrieNode {
|
|||||||
last_accessed,
|
last_accessed,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn find(&self, key: &[u32], blocks: &mut Vec<u32>) {
|
|
||||||
if let Some(child) = self.children.get(&key[0]) {
|
|
||||||
let shared_prefix_len = child.key.shared_prefix_len(key);
|
|
||||||
blocks.extend(&child.blocks[..shared_prefix_len]);
|
|
||||||
|
|
||||||
let key = &key[shared_prefix_len..];
|
|
||||||
if !key.is_empty() {
|
|
||||||
child.find(key, blocks);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert a prefix into the trie. Returns the length of the shared prefix.
|
|
||||||
pub fn insert(&mut self, key: &[u32], blocks: &[u32]) -> usize {
|
|
||||||
assert_eq!(key.len(), blocks.len());
|
|
||||||
|
|
||||||
match self.children.entry(key[0]) {
|
|
||||||
Entry::Occupied(entry) => {
|
|
||||||
let child = entry.into_mut();
|
|
||||||
let shared_prefix_len = child.key.shared_prefix_len(key);
|
|
||||||
|
|
||||||
// We are done, the prefix is already in the trie.
|
|
||||||
if shared_prefix_len == key.len() {
|
|
||||||
return shared_prefix_len;
|
|
||||||
}
|
|
||||||
|
|
||||||
// The node's prefix is a prefix of the insertion prefix.
|
|
||||||
if shared_prefix_len == child.key.len() {
|
|
||||||
return shared_prefix_len
|
|
||||||
+ child.insert(&key[shared_prefix_len..], &blocks[shared_prefix_len..]);
|
|
||||||
}
|
|
||||||
|
|
||||||
// The node's prefix and the insertion prefix only match partially,
|
|
||||||
// split the node to just contain the matching part. Then insert the
|
|
||||||
// remainder of the prefix into the node again.
|
|
||||||
child.split(shared_prefix_len);
|
|
||||||
let key = &key[shared_prefix_len..];
|
|
||||||
let blocks = &blocks[shared_prefix_len..];
|
|
||||||
child.insert(key, blocks)
|
|
||||||
}
|
|
||||||
Entry::Vacant(entry) => {
|
|
||||||
let child = TrieNode::new(key.to_vec(), blocks.to_vec(), self.last_accessed);
|
|
||||||
entry.insert(child);
|
|
||||||
return key.len();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//node.last_accessed = last_accessed;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn split(&mut self, prefix_len: usize) {
|
|
||||||
let rest_key = self.key.split_off(prefix_len);
|
|
||||||
let rest_blocks = self.blocks.split_off(prefix_len);
|
|
||||||
|
|
||||||
self.children.insert(
|
|
||||||
rest_key[0],
|
|
||||||
TrieNode::new(rest_key, rest_blocks, self.last_accessed),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
trait SharedPrefixLen {
|
trait SharedPrefixLen {
|
||||||
@ -108,56 +149,56 @@ where
|
|||||||
mod tests {
|
mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn insertions_have_correct_prefix_len() {
|
fn insertions_have_correct_prefix_len() {
|
||||||
let mut root = super::TrieNode::new(vec![], vec![], 0);
|
let mut trie = super::RadixTrie::new();
|
||||||
|
|
||||||
assert_eq!(root.insert(&[0, 1, 2], &[0, 1, 2]), 3);
|
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]), 3);
|
||||||
|
|
||||||
// Already exists.
|
// Already exists.
|
||||||
assert_eq!(root.insert(&[0, 1, 2], &[0, 1, 2]), 3);
|
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]), 3);
|
||||||
|
|
||||||
// Completely new at root-level
|
// Completely new at root-level
|
||||||
assert_eq!(root.insert(&[1, 2, 3], &[1, 2, 3]), 3);
|
assert_eq!(trie.insert(&[1, 2, 3], &[1, 2, 3]), 3);
|
||||||
|
|
||||||
// Contains full prefix, but longer.
|
// Contains full prefix, but longer.
|
||||||
assert_eq!(root.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]), 5);
|
assert_eq!(trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]), 5);
|
||||||
|
|
||||||
// Shares partial prefix, we need a split.
|
// Shares partial prefix, we need a split.
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
root.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]),
|
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]),
|
||||||
6
|
6
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn prefix_get_returns_correct_blocks() {
|
fn prefix_get_returns_correct_blocks() {
|
||||||
let mut root = super::TrieNode::new(vec![], vec![], 0);
|
let mut trie = super::RadixTrie::new();
|
||||||
root.insert(&[0, 1, 2], &[0, 1, 2]);
|
trie.insert(&[0, 1, 2], &[0, 1, 2]);
|
||||||
root.insert(&[1, 2, 3], &[1, 2, 3]);
|
trie.insert(&[1, 2, 3], &[1, 2, 3]);
|
||||||
root.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]);
|
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]);
|
||||||
root.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]);
|
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]);
|
||||||
|
|
||||||
let mut blocks = Vec::new();
|
let mut blocks = Vec::new();
|
||||||
root.find(&[0], &mut blocks);
|
trie.find(&[0], &mut blocks);
|
||||||
assert_eq!(blocks, vec![0]);
|
assert_eq!(blocks, vec![0]);
|
||||||
|
|
||||||
blocks.clear();
|
blocks.clear();
|
||||||
root.find(&[0, 1, 2], &mut blocks);
|
trie.find(&[0, 1, 2], &mut blocks);
|
||||||
assert_eq!(blocks, vec![0, 1, 2]);
|
assert_eq!(blocks, vec![0, 1, 2]);
|
||||||
|
|
||||||
blocks.clear();
|
blocks.clear();
|
||||||
root.find(&[1, 2, 3], &mut blocks);
|
trie.find(&[1, 2, 3], &mut blocks);
|
||||||
assert_eq!(blocks, vec![1, 2, 3]);
|
assert_eq!(blocks, vec![1, 2, 3]);
|
||||||
|
|
||||||
blocks.clear();
|
blocks.clear();
|
||||||
root.find(&[0, 1, 2, 3], &mut blocks);
|
trie.find(&[0, 1, 2, 3], &mut blocks);
|
||||||
assert_eq!(blocks, vec![0, 1, 2, 3]);
|
assert_eq!(blocks, vec![0, 1, 2, 3]);
|
||||||
|
|
||||||
blocks.clear();
|
blocks.clear();
|
||||||
root.find(&[0, 1, 2, 3, 4], &mut blocks);
|
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
|
||||||
assert_eq!(blocks, vec![0, 1, 2, 3, 4]);
|
assert_eq!(blocks, vec![0, 1, 2, 3, 4]);
|
||||||
|
|
||||||
blocks.clear();
|
blocks.clear();
|
||||||
root.find(&[0, 1, 2, 3, 5], &mut blocks);
|
trie.find(&[0, 1, 2, 3, 5], &mut blocks);
|
||||||
assert_eq!(blocks, vec![0, 1, 2, 3, 5]);
|
assert_eq!(blocks, vec![0, 1, 2, 3, 5]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user