diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index 09b8a544..b0f458c6 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -1,7 +1,7 @@ -use std::{cmp::min, collections::BTreeSet, sync::Arc}; +use std::{cmp::min, collections::HashMap, sync::Arc}; use tokio::sync::{mpsc, oneshot}; -use crate::RadixTrie; +use crate::{radix::NodeId, RadixTrie}; #[derive(Debug, Clone)] pub(crate) struct BlockAllocation { @@ -13,6 +13,7 @@ pub(crate) struct BlockAllocation { pub prefix_len: u64, pub allocation_id: u64, + block_allocator: BlockAllocator, } @@ -107,9 +108,8 @@ async fn block_allocator_task( prefill_tokens, response_sender, } => { - let prefill_tokens_slice = prefill_tokens.as_ref().map(|p| p.as_slice()); response_sender - .send(allocator.allocate(tokens, prefill_tokens_slice)) + .send(allocator.allocate(tokens, prefill_tokens)) .unwrap(); } } @@ -133,7 +133,7 @@ pub trait Allocator { fn allocate( &mut self, tokens: u32, - prefill_tokens: Option<&[u32]>, + prefill_tokens: Option>>, ) -> Option<(Vec, Vec, u64, u64)>; fn free(&mut self, blocks: Vec, allocation_id: u64); @@ -160,7 +160,7 @@ impl Allocator for SimpleAllocator { fn allocate( &mut self, tokens: u32, - _prefill_tokens: Option<&[u32]>, + _prefill_tokens: Option>>, ) -> Option<(Vec, Vec, u64, u64)> { // Apply window size let (required_blocks, repeats) = { @@ -204,18 +204,11 @@ impl Allocator for SimpleAllocator { } } -#[derive(Debug)] -struct PrefixBlockState { - /// The block associated wit this prefix. - block_id: u32, - - /// Last prefix block use. - last_accessed: u64, - - ref_count: usize, -} - struct RadixAllocator { + allocation_id: u64, + + allocations: HashMap, + cache_blocks: RadixTrie, /// Blocks that are immediately available for allocation. @@ -234,6 +227,8 @@ impl RadixAllocator { } RadixAllocator { + allocation_id: 0, + allocations: HashMap::new(), cache_blocks: RadixTrie::new(), free_blocks: (1..n_blocks).collect(), } @@ -264,29 +259,30 @@ impl Allocator for RadixAllocator { fn allocate( &mut self, tokens: u32, - prefill_tokens: Option<&[u32]>, + prefill_tokens: Option>>, ) -> Option<(Vec, Vec, u64, u64)> { let mut blocks = vec![]; - let prefix_node = if let Some(prefill_tokens) = prefill_tokens { - let node_id = self.cache_blocks.find(prefill_tokens, &mut blocks); + let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() { + let node_id = self + .cache_blocks + .find(prefill_tokens.as_slice(), &mut blocks); // Even if this allocation fails below, we need to increase he // refcount to ensure that the prefix that was found is not evicted. - self.cache_blocks.incref(node_id); - Some(node_id) + node_id } else { - None + self.cache_blocks.root_id() }; + self.cache_blocks.incref(prefix_node); + let prefix_len = blocks.len(); let suffix_len = tokens - prefix_len as u32; match self.alloc_or_reclaim(suffix_len as usize) { Some(suffix_blocks) => blocks.extend(suffix_blocks), None => { - if let Some(node_id) = prefix_node { - self.cache_blocks.decref(node_id); - } + self.cache_blocks.decref(prefix_node); return None; } } @@ -294,10 +290,47 @@ impl Allocator for RadixAllocator { // 1:1 mapping of blocks and slots. let slots = blocks.clone(); - Some((blocks, slots, prefix_len as u64, 0)) + let allocation = RadixAllocation { + prefix_node, + cached_prefix_len: prefix_len, + prefill_tokens: prefill_tokens.clone(), + }; + + self.allocation_id += 1; + self.allocations.insert(self.allocation_id, allocation); + + Some((blocks, slots, prefix_len as u64, self.allocation_id)) } fn free(&mut self, blocks: Vec, allocation_id: u64) { - todo!() + let allocation = match self.allocations.remove(&allocation_id) { + Some(allocation) => allocation, + None => unreachable!("Tried to free an unknown allocation."), + }; + + self.cache_blocks.decref(allocation.prefix_node); + + if let Some(prefill_tokens) = allocation.prefill_tokens { + let prefill_tokens = prefill_tokens.as_slice(); + + // If there are prefill tokens that did not come from the cache, + // add them to the cache. + if prefill_tokens.len() > allocation.cached_prefix_len { + // TODO: check if the prefill tokens are already in the cache??? + self.cache_blocks + .insert(prefill_tokens, &blocks[..prefill_tokens.len()]); + } + + // Free non-prefill blocks. + self.free_blocks.extend(&blocks[prefill_tokens.len()..]); + } else { + self.free_blocks.extend(blocks); + } } } + +struct RadixAllocation { + prefix_node: NodeId, + cached_prefix_len: usize, + prefill_tokens: Option>>, +} diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index 311c89f3..715ed320 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -14,7 +14,7 @@ use slotmap::{DefaultKey, SlotMap}; // - We store additional information in each node, such as last access // time and a reference count. -type NodeId = DefaultKey; +pub type NodeId = DefaultKey; #[derive(Debug)] pub struct RadixTrie { @@ -252,6 +252,10 @@ impl RadixTrie { self.print_debug_(*child_id, indent + 2); } } + + pub(crate) fn root_id(&self) -> DefaultKey { + self.root + } } #[derive(Debug)]