Is this enough to make it work ?

This commit is contained in:
Nicolas Patry 2024-08-26 17:43:27 +02:00
parent 1568e82548
commit f5182c188c
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
2 changed files with 95 additions and 41 deletions

View File

@ -291,7 +291,11 @@ impl State {
None
}
Some(block_allocator) => {
prefill_tokens += entry.request.input_length;
if entry.request.input_length <= prefill_token_budget {
prefill_tokens += entry.request.input_length;
} else {
prefill_tokens = prefill_token_budget;
}
let max_new_tokens = match self.window_size {
None => entry.request.stopping_parameters.max_new_tokens,
Some(window_size) => min(

View File

@ -19,15 +19,17 @@ pub struct RadixAllocator {
// This isn't used because the prefix need to match without the windowing
// mecanism. This at worst is overallocating, not necessarily being wrong.
window_size: Option<u32>,
block_size: u32,
}
impl RadixAllocator {
pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self {
assert_eq!(
block_size, 1,
"Radix tree allocator only works with block_size=1, was: {}",
block_size
);
// assert_eq!(
// block_size, 1,
// "Radix tree allocator only works with block_size=1, was: {}",
// block_size
// );
// if window_size.is_some() {
// unimplemented!("Window size not supported in the prefix-caching block allocator yet");
// }
@ -35,11 +37,12 @@ impl RadixAllocator {
RadixAllocator {
allocation_id: 0,
allocations: HashMap::new(),
cache_blocks: RadixTrie::new(),
cache_blocks: RadixTrie::new(block_size as usize),
// Block 0 is reserved for health checks.
free_blocks: (1..n_blocks).collect(),
window_size,
block_size,
}
}
@ -91,10 +94,10 @@ impl Allocator for RadixAllocator {
.incref(prefix_node)
.expect("Failed to increment refcount");
let prefix_len = blocks.len();
let prefix_len = blocks.len() * self.block_size as usize;
let suffix_len = tokens - prefix_len as u32;
let suffix_blocks = suffix_len;
let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size;
match self.alloc_or_reclaim(suffix_blocks as usize) {
Some(suffix_blocks) => blocks.extend(suffix_blocks),
@ -107,7 +110,20 @@ impl Allocator for RadixAllocator {
}
// 1:1 mapping of blocks and slots.
let slots = blocks.clone();
let slots = if self.block_size == 1 {
blocks.clone()
} else {
let mut slots = Vec::with_capacity(blocks.len() * self.block_size as usize);
'slots: for block_id in &blocks {
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
slots.push(s);
if slots.len() as u32 == tokens {
break 'slots;
}
}
}
slots
};
let allocation = RadixAllocation {
prefix_node,
@ -142,12 +158,16 @@ impl Allocator for RadixAllocator {
if let Some(prefill_tokens) = allocation.prefill_tokens {
let prefill_tokens = prefill_tokens.as_slice();
assert_eq!(prefill_tokens.len() % self.block_size as usize, 0);
// If there are prefill tokens that did not come from the cache,
// add them to the cache.
if prefill_tokens.len() > allocation.cached_prefix_len {
let prefix_len = self
.cache_blocks
.insert(prefill_tokens, &blocks[..prefill_tokens.len()])
.insert(
prefill_tokens,
&blocks[..prefill_tokens.len() / self.block_size as usize],
)
// Unwrap, failing is a programming error.
.expect("Failed to store prefill tokens");
@ -213,17 +233,14 @@ pub struct RadixTrie {
/// Time as a monotonically increating counter to avoid the system
/// call that a real time lookup would require.
time: u64,
}
impl Default for RadixTrie {
fn default() -> Self {
Self::new()
}
/// All blocks need to be aligned with this
block_size: usize,
}
impl RadixTrie {
/// Construct a new radix trie.
pub fn new() -> Self {
pub fn new(block_size: usize) -> Self {
let root = TrieNode::new(vec![], vec![], 0, None);
let mut nodes = SlotMap::new();
let root = nodes.insert(root);
@ -232,13 +249,14 @@ impl RadixTrie {
nodes,
root,
time: 0,
block_size,
}
}
/// Find the prefix of the given tokens.
///
/// The blocks corresponding to the part of the prefix that could be found
/// are writteng to `blocks`. The number of blocks is in `0..=tokens.len()`.
/// are written to `blocks`. The number of blocks is in `0..=tokens.len()`.
/// Returns the identifier of the trie node that contains the longest
/// prefix. The node identifier can be used by callers to e.g. increase its
/// reference count.
@ -256,8 +274,9 @@ impl RadixTrie {
if let Some(&child_id) = node.children.get(&key[0]) {
self.update_access_time(child_id);
let child = self.nodes.get(child_id).expect("Invalid child identifier");
let shared_prefix_len = child.key.shared_prefix_len(key);
blocks.extend(&child.blocks[..shared_prefix_len]);
let shared_prefix_len = shared_prefix(&child.key, key, self.block_size);
assert_eq!(shared_prefix_len % self.block_size, 0);
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
let key = &key[shared_prefix_len..];
if !key.is_empty() {
@ -358,7 +377,8 @@ impl RadixTrie {
/// the first 10 elements of the tree **the blocks are not updated**.
pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result<usize, TrieError> {
self.time += 1;
self.insert_(self.root, tokens, blocks)
let common = self.insert_(self.root, tokens, blocks)?;
Ok(common)
}
/// Insertion worker.
@ -372,7 +392,7 @@ impl RadixTrie {
// the part of the prefix that is already in the trie to detect
// mismatches.
if tokens.len() != blocks.len() {
if tokens.len() != blocks.len() * self.block_size {
return Err(TrieError::BlockTokenCountMismatch);
}
@ -383,10 +403,10 @@ impl RadixTrie {
.get_mut(child_id)
// Unwrap here, since failure is a bug.
.expect("Child node does not exist");
let shared_prefix_len = child.key.shared_prefix_len(tokens);
let shared_prefix_len = shared_prefix(&child.key, tokens, self.block_size);
// We are done, the prefix is already in the trie.
if shared_prefix_len == tokens.len() {
if shared_prefix_len == tokens.len() || shared_prefix_len == 0 {
return Ok(shared_prefix_len);
}
@ -396,7 +416,7 @@ impl RadixTrie {
+ self.insert_(
child_id,
&tokens[shared_prefix_len..],
&blocks[shared_prefix_len..],
&blocks[shared_prefix_len / self.block_size..],
)?);
}
@ -405,7 +425,7 @@ impl RadixTrie {
// remainder of the prefix into the node again
let child_id = self.split_node(child_id, shared_prefix_len);
let key = &tokens[shared_prefix_len..];
let blocks = &blocks[shared_prefix_len..];
let blocks = &blocks[shared_prefix_len / self.block_size..];
Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?)
} else {
self.add_node(node_id, tokens, blocks);
@ -559,18 +579,9 @@ impl TrieNode {
}
}
/// Helper trait to get the length of the shared prefix of two sequences.
trait SharedPrefixLen {
fn shared_prefix_len(&self, other: &Self) -> usize;
}
impl<T> SharedPrefixLen for [T]
where
T: PartialEq,
{
fn shared_prefix_len(&self, other: &Self) -> usize {
self.iter().zip(other).take_while(|(a, b)| a == b).count()
}
fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize {
let full = left.iter().zip(right).take_while(|(a, b)| a == b).count();
(full / block_size) * block_size
}
#[cfg(test)]
@ -579,6 +590,21 @@ mod tests {
use super::*;
#[test]
fn allocator_block_size() {
let mut cache = RadixAllocator::new(2, 12, None);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
assert_eq!(allocation.prefix_len, 0);
cache.free(allocation.blocks.clone(), allocation.allocation_id);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 6, 7]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 12, 13, 14, 15]);
assert_eq!(allocation.prefix_len, 4);
}
#[test]
fn allocator_reuses_prefixes() {
let mut cache = RadixAllocator::new(1, 12, None);
@ -673,7 +699,7 @@ mod tests {
#[test]
fn trie_insertions_have_correct_prefix_len() {
let mut trie = super::RadixTrie::new();
let mut trie = RadixTrie::new(1);
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0);
@ -694,9 +720,33 @@ mod tests {
);
}
#[test]
fn trie_insertions_block_size() {
let mut trie = RadixTrie::new(2);
assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 0);
// Already exists.
// But needs to be block_size aligned
assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 4);
// Completely new at root-level
assert_eq!(trie.insert(&[1, 2, 3, 4], &[1, 2]).unwrap(), 0);
// Contains full prefix, but longer.
assert_eq!(trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2]).unwrap(), 4);
// Shares partial prefix, we need a split.
assert_eq!(
trie.insert(&[0, 1, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3])
.unwrap(),
2
);
}
#[test]
fn trie_get_returns_correct_blocks() {
let mut trie = super::RadixTrie::new();
let mut trie = RadixTrie::new(1);
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
@ -730,7 +780,7 @@ mod tests {
#[test]
fn trie_evict_removes_correct_blocks() {
let mut trie = super::RadixTrie::new();
let mut trie = RadixTrie::new(1);
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
.unwrap();