mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-23 08:10:18 +00:00
Add radix cache free, improve allocate
This commit is contained in:
parent
9415b90892
commit
6486887b43
backends/v3/src
@ -1,7 +1,7 @@
|
||||
use std::{cmp::min, collections::BTreeSet, sync::Arc};
|
||||
use std::{cmp::min, collections::HashMap, sync::Arc};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
use crate::RadixTrie;
|
||||
use crate::{radix::NodeId, RadixTrie};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct BlockAllocation {
|
||||
@ -13,6 +13,7 @@ pub(crate) struct BlockAllocation {
|
||||
pub prefix_len: u64,
|
||||
|
||||
pub allocation_id: u64,
|
||||
|
||||
block_allocator: BlockAllocator,
|
||||
}
|
||||
|
||||
@ -107,9 +108,8 @@ async fn block_allocator_task(
|
||||
prefill_tokens,
|
||||
response_sender,
|
||||
} => {
|
||||
let prefill_tokens_slice = prefill_tokens.as_ref().map(|p| p.as_slice());
|
||||
response_sender
|
||||
.send(allocator.allocate(tokens, prefill_tokens_slice))
|
||||
.send(allocator.allocate(tokens, prefill_tokens))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
@ -133,7 +133,7 @@ pub trait Allocator {
|
||||
fn allocate(
|
||||
&mut self,
|
||||
tokens: u32,
|
||||
prefill_tokens: Option<&[u32]>,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<(Vec<u32>, Vec<u32>, u64, u64)>;
|
||||
|
||||
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
|
||||
@ -160,7 +160,7 @@ impl Allocator for SimpleAllocator {
|
||||
fn allocate(
|
||||
&mut self,
|
||||
tokens: u32,
|
||||
_prefill_tokens: Option<&[u32]>,
|
||||
_prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<(Vec<u32>, Vec<u32>, u64, u64)> {
|
||||
// Apply window size
|
||||
let (required_blocks, repeats) = {
|
||||
@ -204,18 +204,11 @@ impl Allocator for SimpleAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PrefixBlockState {
|
||||
/// The block associated wit this prefix.
|
||||
block_id: u32,
|
||||
|
||||
/// Last prefix block use.
|
||||
last_accessed: u64,
|
||||
|
||||
ref_count: usize,
|
||||
}
|
||||
|
||||
struct RadixAllocator {
|
||||
allocation_id: u64,
|
||||
|
||||
allocations: HashMap<u64, RadixAllocation>,
|
||||
|
||||
cache_blocks: RadixTrie,
|
||||
|
||||
/// Blocks that are immediately available for allocation.
|
||||
@ -234,6 +227,8 @@ impl RadixAllocator {
|
||||
}
|
||||
|
||||
RadixAllocator {
|
||||
allocation_id: 0,
|
||||
allocations: HashMap::new(),
|
||||
cache_blocks: RadixTrie::new(),
|
||||
free_blocks: (1..n_blocks).collect(),
|
||||
}
|
||||
@ -264,29 +259,30 @@ impl Allocator for RadixAllocator {
|
||||
fn allocate(
|
||||
&mut self,
|
||||
tokens: u32,
|
||||
prefill_tokens: Option<&[u32]>,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<(Vec<u32>, Vec<u32>, u64, u64)> {
|
||||
let mut blocks = vec![];
|
||||
let prefix_node = if let Some(prefill_tokens) = prefill_tokens {
|
||||
let node_id = self.cache_blocks.find(prefill_tokens, &mut blocks);
|
||||
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
|
||||
let node_id = self
|
||||
.cache_blocks
|
||||
.find(prefill_tokens.as_slice(), &mut blocks);
|
||||
// Even if this allocation fails below, we need to increase he
|
||||
// refcount to ensure that the prefix that was found is not evicted.
|
||||
self.cache_blocks.incref(node_id);
|
||||
|
||||
Some(node_id)
|
||||
node_id
|
||||
} else {
|
||||
None
|
||||
self.cache_blocks.root_id()
|
||||
};
|
||||
|
||||
self.cache_blocks.incref(prefix_node);
|
||||
|
||||
let prefix_len = blocks.len();
|
||||
let suffix_len = tokens - prefix_len as u32;
|
||||
|
||||
match self.alloc_or_reclaim(suffix_len as usize) {
|
||||
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
||||
None => {
|
||||
if let Some(node_id) = prefix_node {
|
||||
self.cache_blocks.decref(node_id);
|
||||
}
|
||||
self.cache_blocks.decref(prefix_node);
|
||||
return None;
|
||||
}
|
||||
}
|
||||
@ -294,10 +290,47 @@ impl Allocator for RadixAllocator {
|
||||
// 1:1 mapping of blocks and slots.
|
||||
let slots = blocks.clone();
|
||||
|
||||
Some((blocks, slots, prefix_len as u64, 0))
|
||||
let allocation = RadixAllocation {
|
||||
prefix_node,
|
||||
cached_prefix_len: prefix_len,
|
||||
prefill_tokens: prefill_tokens.clone(),
|
||||
};
|
||||
|
||||
self.allocation_id += 1;
|
||||
self.allocations.insert(self.allocation_id, allocation);
|
||||
|
||||
Some((blocks, slots, prefix_len as u64, self.allocation_id))
|
||||
}
|
||||
|
||||
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
|
||||
todo!()
|
||||
let allocation = match self.allocations.remove(&allocation_id) {
|
||||
Some(allocation) => allocation,
|
||||
None => unreachable!("Tried to free an unknown allocation."),
|
||||
};
|
||||
|
||||
self.cache_blocks.decref(allocation.prefix_node);
|
||||
|
||||
if let Some(prefill_tokens) = allocation.prefill_tokens {
|
||||
let prefill_tokens = prefill_tokens.as_slice();
|
||||
|
||||
// If there are prefill tokens that did not come from the cache,
|
||||
// add them to the cache.
|
||||
if prefill_tokens.len() > allocation.cached_prefix_len {
|
||||
// TODO: check if the prefill tokens are already in the cache???
|
||||
self.cache_blocks
|
||||
.insert(prefill_tokens, &blocks[..prefill_tokens.len()]);
|
||||
}
|
||||
|
||||
// Free non-prefill blocks.
|
||||
self.free_blocks.extend(&blocks[prefill_tokens.len()..]);
|
||||
} else {
|
||||
self.free_blocks.extend(blocks);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct RadixAllocation {
|
||||
prefix_node: NodeId,
|
||||
cached_prefix_len: usize,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
}
|
||||
|
@ -14,7 +14,7 @@ use slotmap::{DefaultKey, SlotMap};
|
||||
// - We store additional information in each node, such as last access
|
||||
// time and a reference count.
|
||||
|
||||
type NodeId = DefaultKey;
|
||||
pub type NodeId = DefaultKey;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RadixTrie {
|
||||
@ -252,6 +252,10 @@ impl RadixTrie {
|
||||
self.print_debug_(*child_id, indent + 2);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn root_id(&self) -> DefaultKey {
|
||||
self.root
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
Loading…
Reference in New Issue
Block a user