Add radix cache free, improve allocate

This commit is contained in:
Daniël de Kok 2024-08-06 08:57:34 +00:00
parent 9415b90892
commit 6486887b43
2 changed files with 66 additions and 29 deletions

View File

@ -1,7 +1,7 @@
use std::{cmp::min, collections::BTreeSet, sync::Arc}; use std::{cmp::min, collections::HashMap, sync::Arc};
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use crate::RadixTrie; use crate::{radix::NodeId, RadixTrie};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct BlockAllocation { pub(crate) struct BlockAllocation {
@ -13,6 +13,7 @@ pub(crate) struct BlockAllocation {
pub prefix_len: u64, pub prefix_len: u64,
pub allocation_id: u64, pub allocation_id: u64,
block_allocator: BlockAllocator, block_allocator: BlockAllocator,
} }
@ -107,9 +108,8 @@ async fn block_allocator_task(
prefill_tokens, prefill_tokens,
response_sender, response_sender,
} => { } => {
let prefill_tokens_slice = prefill_tokens.as_ref().map(|p| p.as_slice());
response_sender response_sender
.send(allocator.allocate(tokens, prefill_tokens_slice)) .send(allocator.allocate(tokens, prefill_tokens))
.unwrap(); .unwrap();
} }
} }
@ -133,7 +133,7 @@ pub trait Allocator {
fn allocate( fn allocate(
&mut self, &mut self,
tokens: u32, tokens: u32,
prefill_tokens: Option<&[u32]>, prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<(Vec<u32>, Vec<u32>, u64, u64)>; ) -> Option<(Vec<u32>, Vec<u32>, u64, u64)>;
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64); fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
@ -160,7 +160,7 @@ impl Allocator for SimpleAllocator {
fn allocate( fn allocate(
&mut self, &mut self,
tokens: u32, tokens: u32,
_prefill_tokens: Option<&[u32]>, _prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<(Vec<u32>, Vec<u32>, u64, u64)> { ) -> Option<(Vec<u32>, Vec<u32>, u64, u64)> {
// Apply window size // Apply window size
let (required_blocks, repeats) = { let (required_blocks, repeats) = {
@ -204,18 +204,11 @@ impl Allocator for SimpleAllocator {
} }
} }
#[derive(Debug)]
struct PrefixBlockState {
/// The block associated wit this prefix.
block_id: u32,
/// Last prefix block use.
last_accessed: u64,
ref_count: usize,
}
struct RadixAllocator { struct RadixAllocator {
allocation_id: u64,
allocations: HashMap<u64, RadixAllocation>,
cache_blocks: RadixTrie, cache_blocks: RadixTrie,
/// Blocks that are immediately available for allocation. /// Blocks that are immediately available for allocation.
@ -234,6 +227,8 @@ impl RadixAllocator {
} }
RadixAllocator { RadixAllocator {
allocation_id: 0,
allocations: HashMap::new(),
cache_blocks: RadixTrie::new(), cache_blocks: RadixTrie::new(),
free_blocks: (1..n_blocks).collect(), free_blocks: (1..n_blocks).collect(),
} }
@ -264,29 +259,30 @@ impl Allocator for RadixAllocator {
fn allocate( fn allocate(
&mut self, &mut self,
tokens: u32, tokens: u32,
prefill_tokens: Option<&[u32]>, prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<(Vec<u32>, Vec<u32>, u64, u64)> { ) -> Option<(Vec<u32>, Vec<u32>, u64, u64)> {
let mut blocks = vec![]; let mut blocks = vec![];
let prefix_node = if let Some(prefill_tokens) = prefill_tokens { let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
let node_id = self.cache_blocks.find(prefill_tokens, &mut blocks); let node_id = self
.cache_blocks
.find(prefill_tokens.as_slice(), &mut blocks);
// Even if this allocation fails below, we need to increase he // Even if this allocation fails below, we need to increase he
// refcount to ensure that the prefix that was found is not evicted. // refcount to ensure that the prefix that was found is not evicted.
self.cache_blocks.incref(node_id);
Some(node_id) node_id
} else { } else {
None self.cache_blocks.root_id()
}; };
self.cache_blocks.incref(prefix_node);
let prefix_len = blocks.len(); let prefix_len = blocks.len();
let suffix_len = tokens - prefix_len as u32; let suffix_len = tokens - prefix_len as u32;
match self.alloc_or_reclaim(suffix_len as usize) { match self.alloc_or_reclaim(suffix_len as usize) {
Some(suffix_blocks) => blocks.extend(suffix_blocks), Some(suffix_blocks) => blocks.extend(suffix_blocks),
None => { None => {
if let Some(node_id) = prefix_node { self.cache_blocks.decref(prefix_node);
self.cache_blocks.decref(node_id);
}
return None; return None;
} }
} }
@ -294,10 +290,47 @@ impl Allocator for RadixAllocator {
// 1:1 mapping of blocks and slots. // 1:1 mapping of blocks and slots.
let slots = blocks.clone(); let slots = blocks.clone();
Some((blocks, slots, prefix_len as u64, 0)) let allocation = RadixAllocation {
prefix_node,
cached_prefix_len: prefix_len,
prefill_tokens: prefill_tokens.clone(),
};
self.allocation_id += 1;
self.allocations.insert(self.allocation_id, allocation);
Some((blocks, slots, prefix_len as u64, self.allocation_id))
} }
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) { fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
todo!() let allocation = match self.allocations.remove(&allocation_id) {
Some(allocation) => allocation,
None => unreachable!("Tried to free an unknown allocation."),
};
self.cache_blocks.decref(allocation.prefix_node);
if let Some(prefill_tokens) = allocation.prefill_tokens {
let prefill_tokens = prefill_tokens.as_slice();
// If there are prefill tokens that did not come from the cache,
// add them to the cache.
if prefill_tokens.len() > allocation.cached_prefix_len {
// TODO: check if the prefill tokens are already in the cache???
self.cache_blocks
.insert(prefill_tokens, &blocks[..prefill_tokens.len()]);
}
// Free non-prefill blocks.
self.free_blocks.extend(&blocks[prefill_tokens.len()..]);
} else {
self.free_blocks.extend(blocks);
}
} }
} }
struct RadixAllocation {
prefix_node: NodeId,
cached_prefix_len: usize,
prefill_tokens: Option<Arc<Vec<u32>>>,
}

View File

@ -14,7 +14,7 @@ use slotmap::{DefaultKey, SlotMap};
// - We store additional information in each node, such as last access // - We store additional information in each node, such as last access
// time and a reference count. // time and a reference count.
type NodeId = DefaultKey; pub type NodeId = DefaultKey;
#[derive(Debug)] #[derive(Debug)]
pub struct RadixTrie { pub struct RadixTrie {
@ -252,6 +252,10 @@ impl RadixTrie {
self.print_debug_(*child_id, indent + 2); self.print_debug_(*child_id, indent + 2);
} }
} }
pub(crate) fn root_id(&self) -> DefaultKey {
self.root
}
} }
#[derive(Debug)] #[derive(Debug)]