From cd4933cd5a93a12d8db8582fd1bf69bcb9eb350f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 2 Aug 2024 11:05:03 +0000 Subject: [PATCH] Trie insertion/lookup --- backends/v3/src/block_allocator.rs | 52 +-------- backends/v3/src/lib.rs | 2 + backends/v3/src/queue.rs | 3 + backends/v3/src/radix.rs | 163 +++++++++++++++++++++++++++++ 4 files changed, 172 insertions(+), 48 deletions(-) create mode 100644 backends/v3/src/radix.rs diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index 4106f624..03b26b05 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -1,4 +1,3 @@ -use radix_trie::Trie; use std::{ cmp::min, collections::{hash_map::Entry, BTreeSet, HashMap}, @@ -6,6 +5,8 @@ use std::{ }; use tokio::sync::{mpsc, oneshot}; +use crate::TrieNode; + #[derive(Debug, Clone)] pub(crate) struct BlockAllocation { pub blocks: Vec, @@ -211,7 +212,7 @@ struct PrefixBlockState { } struct RadixAllocator { - cache_blocks: Trie, ()>, + cache_blocks: TrieNode, /// Blocks that are immediately available for allocation. free_blocks: Vec, @@ -235,55 +236,10 @@ impl RadixAllocator { } RadixAllocator { - cache_blocks: Trie::new(), + cache_blocks: TrieNode::new(vec![], vec![], 0), free_blocks: (1..n_blocks).collect(), leaves: BTreeSet::new(), time: 0, } } } - -#[derive(Debug)] -struct TrieNode { - children: HashMap, - key: Vec, - blocks: Vec, - last_accessed: u64, -} - -impl TrieNode { - fn new(key: Vec, blocks: Vec, last_accessed: u64) -> Self { - TrieNode { - children: HashMap::new(), - key, - blocks, - last_accessed, - } - } - - // Insert a prefix into the trie. Returns the length of the shared prefix. - fn insert(&mut self, key: &[u32], blocks: &[u32]) -> usize { - match self.children.entry(key[0]) { - Entry::Occupied(entry) => { - let child = entry.into_mut(); - let shared_prefix_len = child - .key - .iter() - .zip(key) - .take_while(|(a, b)| a == b) - .count(); - - // We are done, the prefix is already in the trie. - if shared_prefix_len == key.len() { - return shared_prefix_len; - } - - return shared_prefix_len - + child.insert(&key[shared_prefix_len..], &blocks[shared_prefix_len..]); - } - Entry::Vacant(_) => todo!(), - } - - //node.last_accessed = last_accessed; - } -} diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs index a6f89169..190274c6 100644 --- a/backends/v3/src/lib.rs +++ b/backends/v3/src/lib.rs @@ -2,9 +2,11 @@ mod backend; mod block_allocator; mod client; mod queue; +mod radix; use crate::client::{ClientError, ShardedClient}; pub(crate) use backend::BackendV3; +pub(crate) use radix::TrieNode; use serde::Serialize; use thiserror::Error; use utoipa::ToSchema; diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 235d02b2..b4b5e8a8 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -476,6 +476,8 @@ impl From for StoppingCriteriaParameters { #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use tracing::info_span; @@ -488,6 +490,7 @@ mod tests { let entry = Entry { request: ValidGenerateRequest { inputs: vec![], + input_ids: Some(Arc::new(vec![])), input_length: 0, truncate: 0, decoder_input_details: false, diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs new file mode 100644 index 00000000..fb6756c8 --- /dev/null +++ b/backends/v3/src/radix.rs @@ -0,0 +1,163 @@ +use std::collections::{hash_map::Entry, HashMap}; + +// Radix trie that is heavily inspired by radix attention from sglang. +// +// The trie is optimized for prefix caching: +// +// - A normal radix trie stores discrete values. In this radix trie, +// inserting *abc* with value *xyz* will also enable lookup for +// *a* (*x*) and *ab* (*xy*). +// - As a result, every value is required to have the same length as +// the key. +// - We store additional information in each node, such as last access +// time and a reference count. + +#[derive(Debug)] +pub struct TrieNode { + children: HashMap, + key: Vec, + blocks: Vec, + last_accessed: u64, +} + +impl TrieNode { + pub fn new(key: Vec, blocks: Vec, last_accessed: u64) -> Self { + TrieNode { + children: HashMap::new(), + key, + blocks, + last_accessed, + } + } + + pub fn find(&self, key: &[u32], blocks: &mut Vec) { + if let Some(child) = self.children.get(&key[0]) { + let shared_prefix_len = child.key.shared_prefix_len(key); + blocks.extend(&child.blocks[..shared_prefix_len]); + + let key = &key[shared_prefix_len..]; + if !key.is_empty() { + child.find(key, blocks); + } + } + } + + // Insert a prefix into the trie. Returns the length of the shared prefix. + pub fn insert(&mut self, key: &[u32], blocks: &[u32]) -> usize { + assert_eq!(key.len(), blocks.len()); + + match self.children.entry(key[0]) { + Entry::Occupied(entry) => { + let child = entry.into_mut(); + let shared_prefix_len = child.key.shared_prefix_len(key); + + // We are done, the prefix is already in the trie. + if shared_prefix_len == key.len() { + return shared_prefix_len; + } + + // The node's prefix is a prefix of the insertion prefix. + if shared_prefix_len == child.key.len() { + return shared_prefix_len + + child.insert(&key[shared_prefix_len..], &blocks[shared_prefix_len..]); + } + + // The node's prefix and the insertion prefix only match partially, + // split the node to just contain the matching part. Then insert the + // remainder of the prefix into the node again. + child.split(shared_prefix_len); + let key = &key[shared_prefix_len..]; + let blocks = &blocks[shared_prefix_len..]; + child.insert(key, blocks) + } + Entry::Vacant(entry) => { + let child = TrieNode::new(key.to_vec(), blocks.to_vec(), self.last_accessed); + entry.insert(child); + return key.len(); + } + } + + //node.last_accessed = last_accessed; + } + + fn split(&mut self, prefix_len: usize) { + let rest_key = self.key.split_off(prefix_len); + let rest_blocks = self.blocks.split_off(prefix_len); + + self.children.insert( + rest_key[0], + TrieNode::new(rest_key, rest_blocks, self.last_accessed), + ); + } +} + +trait SharedPrefixLen { + fn shared_prefix_len(&self, other: &Self) -> usize; +} + +impl SharedPrefixLen for [T] +where + T: PartialEq, +{ + fn shared_prefix_len(&self, other: &Self) -> usize { + self.iter().zip(other).take_while(|(a, b)| a == b).count() + } +} + +#[cfg(test)] +mod tests { + #[test] + fn insertions_have_correct_prefix_len() { + let mut root = super::TrieNode::new(vec![], vec![], 0); + + assert_eq!(root.insert(&[0, 1, 2], &[0, 1, 2]), 3); + + // Already exists. + assert_eq!(root.insert(&[0, 1, 2], &[0, 1, 2]), 3); + + // Completely new at root-level + assert_eq!(root.insert(&[1, 2, 3], &[1, 2, 3]), 3); + + // Contains full prefix, but longer. + assert_eq!(root.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]), 5); + + // Shares partial prefix, we need a split. + assert_eq!( + root.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]), + 6 + ); + } + + #[test] + fn prefix_get_returns_correct_blocks() { + let mut root = super::TrieNode::new(vec![], vec![], 0); + root.insert(&[0, 1, 2], &[0, 1, 2]); + root.insert(&[1, 2, 3], &[1, 2, 3]); + root.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]); + root.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]); + + let mut blocks = Vec::new(); + root.find(&[0], &mut blocks); + assert_eq!(blocks, vec![0]); + + blocks.clear(); + root.find(&[0, 1, 2], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2]); + + blocks.clear(); + root.find(&[1, 2, 3], &mut blocks); + assert_eq!(blocks, vec![1, 2, 3]); + + blocks.clear(); + root.find(&[0, 1, 2, 3], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3]); + + blocks.clear(); + root.find(&[0, 1, 2, 3, 4], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 4]); + + blocks.clear(); + root.find(&[0, 1, 2, 3, 5], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 5]); + } +}