mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-19 22:30:19 +00:00
Is this enough to make it work ?
This commit is contained in:
parent
1568e82548
commit
f5182c188c
@ -291,7 +291,11 @@ impl State {
|
||||
None
|
||||
}
|
||||
Some(block_allocator) => {
|
||||
prefill_tokens += entry.request.input_length;
|
||||
if entry.request.input_length <= prefill_token_budget {
|
||||
prefill_tokens += entry.request.input_length;
|
||||
} else {
|
||||
prefill_tokens = prefill_token_budget;
|
||||
}
|
||||
let max_new_tokens = match self.window_size {
|
||||
None => entry.request.stopping_parameters.max_new_tokens,
|
||||
Some(window_size) => min(
|
||||
|
@ -19,15 +19,17 @@ pub struct RadixAllocator {
|
||||
// This isn't used because the prefix need to match without the windowing
|
||||
// mecanism. This at worst is overallocating, not necessarily being wrong.
|
||||
window_size: Option<u32>,
|
||||
|
||||
block_size: u32,
|
||||
}
|
||||
|
||||
impl RadixAllocator {
|
||||
pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self {
|
||||
assert_eq!(
|
||||
block_size, 1,
|
||||
"Radix tree allocator only works with block_size=1, was: {}",
|
||||
block_size
|
||||
);
|
||||
// assert_eq!(
|
||||
// block_size, 1,
|
||||
// "Radix tree allocator only works with block_size=1, was: {}",
|
||||
// block_size
|
||||
// );
|
||||
// if window_size.is_some() {
|
||||
// unimplemented!("Window size not supported in the prefix-caching block allocator yet");
|
||||
// }
|
||||
@ -35,11 +37,12 @@ impl RadixAllocator {
|
||||
RadixAllocator {
|
||||
allocation_id: 0,
|
||||
allocations: HashMap::new(),
|
||||
cache_blocks: RadixTrie::new(),
|
||||
cache_blocks: RadixTrie::new(block_size as usize),
|
||||
|
||||
// Block 0 is reserved for health checks.
|
||||
free_blocks: (1..n_blocks).collect(),
|
||||
window_size,
|
||||
block_size,
|
||||
}
|
||||
}
|
||||
|
||||
@ -91,10 +94,10 @@ impl Allocator for RadixAllocator {
|
||||
.incref(prefix_node)
|
||||
.expect("Failed to increment refcount");
|
||||
|
||||
let prefix_len = blocks.len();
|
||||
let prefix_len = blocks.len() * self.block_size as usize;
|
||||
let suffix_len = tokens - prefix_len as u32;
|
||||
|
||||
let suffix_blocks = suffix_len;
|
||||
let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size;
|
||||
|
||||
match self.alloc_or_reclaim(suffix_blocks as usize) {
|
||||
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
||||
@ -107,7 +110,20 @@ impl Allocator for RadixAllocator {
|
||||
}
|
||||
|
||||
// 1:1 mapping of blocks and slots.
|
||||
let slots = blocks.clone();
|
||||
let slots = if self.block_size == 1 {
|
||||
blocks.clone()
|
||||
} else {
|
||||
let mut slots = Vec::with_capacity(blocks.len() * self.block_size as usize);
|
||||
'slots: for block_id in &blocks {
|
||||
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
||||
slots.push(s);
|
||||
if slots.len() as u32 == tokens {
|
||||
break 'slots;
|
||||
}
|
||||
}
|
||||
}
|
||||
slots
|
||||
};
|
||||
|
||||
let allocation = RadixAllocation {
|
||||
prefix_node,
|
||||
@ -142,12 +158,16 @@ impl Allocator for RadixAllocator {
|
||||
if let Some(prefill_tokens) = allocation.prefill_tokens {
|
||||
let prefill_tokens = prefill_tokens.as_slice();
|
||||
|
||||
assert_eq!(prefill_tokens.len() % self.block_size as usize, 0);
|
||||
// If there are prefill tokens that did not come from the cache,
|
||||
// add them to the cache.
|
||||
if prefill_tokens.len() > allocation.cached_prefix_len {
|
||||
let prefix_len = self
|
||||
.cache_blocks
|
||||
.insert(prefill_tokens, &blocks[..prefill_tokens.len()])
|
||||
.insert(
|
||||
prefill_tokens,
|
||||
&blocks[..prefill_tokens.len() / self.block_size as usize],
|
||||
)
|
||||
// Unwrap, failing is a programming error.
|
||||
.expect("Failed to store prefill tokens");
|
||||
|
||||
@ -213,17 +233,14 @@ pub struct RadixTrie {
|
||||
/// Time as a monotonically increating counter to avoid the system
|
||||
/// call that a real time lookup would require.
|
||||
time: u64,
|
||||
}
|
||||
|
||||
impl Default for RadixTrie {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
/// All blocks need to be aligned with this
|
||||
block_size: usize,
|
||||
}
|
||||
|
||||
impl RadixTrie {
|
||||
/// Construct a new radix trie.
|
||||
pub fn new() -> Self {
|
||||
pub fn new(block_size: usize) -> Self {
|
||||
let root = TrieNode::new(vec![], vec![], 0, None);
|
||||
let mut nodes = SlotMap::new();
|
||||
let root = nodes.insert(root);
|
||||
@ -232,13 +249,14 @@ impl RadixTrie {
|
||||
nodes,
|
||||
root,
|
||||
time: 0,
|
||||
block_size,
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the prefix of the given tokens.
|
||||
///
|
||||
/// The blocks corresponding to the part of the prefix that could be found
|
||||
/// are writteng to `blocks`. The number of blocks is in `0..=tokens.len()`.
|
||||
/// are written to `blocks`. The number of blocks is in `0..=tokens.len()`.
|
||||
/// Returns the identifier of the trie node that contains the longest
|
||||
/// prefix. The node identifier can be used by callers to e.g. increase its
|
||||
/// reference count.
|
||||
@ -256,8 +274,9 @@ impl RadixTrie {
|
||||
if let Some(&child_id) = node.children.get(&key[0]) {
|
||||
self.update_access_time(child_id);
|
||||
let child = self.nodes.get(child_id).expect("Invalid child identifier");
|
||||
let shared_prefix_len = child.key.shared_prefix_len(key);
|
||||
blocks.extend(&child.blocks[..shared_prefix_len]);
|
||||
let shared_prefix_len = shared_prefix(&child.key, key, self.block_size);
|
||||
assert_eq!(shared_prefix_len % self.block_size, 0);
|
||||
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
|
||||
|
||||
let key = &key[shared_prefix_len..];
|
||||
if !key.is_empty() {
|
||||
@ -358,7 +377,8 @@ impl RadixTrie {
|
||||
/// the first 10 elements of the tree **the blocks are not updated**.
|
||||
pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result<usize, TrieError> {
|
||||
self.time += 1;
|
||||
self.insert_(self.root, tokens, blocks)
|
||||
let common = self.insert_(self.root, tokens, blocks)?;
|
||||
Ok(common)
|
||||
}
|
||||
|
||||
/// Insertion worker.
|
||||
@ -372,7 +392,7 @@ impl RadixTrie {
|
||||
// the part of the prefix that is already in the trie to detect
|
||||
// mismatches.
|
||||
|
||||
if tokens.len() != blocks.len() {
|
||||
if tokens.len() != blocks.len() * self.block_size {
|
||||
return Err(TrieError::BlockTokenCountMismatch);
|
||||
}
|
||||
|
||||
@ -383,10 +403,10 @@ impl RadixTrie {
|
||||
.get_mut(child_id)
|
||||
// Unwrap here, since failure is a bug.
|
||||
.expect("Child node does not exist");
|
||||
let shared_prefix_len = child.key.shared_prefix_len(tokens);
|
||||
let shared_prefix_len = shared_prefix(&child.key, tokens, self.block_size);
|
||||
|
||||
// We are done, the prefix is already in the trie.
|
||||
if shared_prefix_len == tokens.len() {
|
||||
if shared_prefix_len == tokens.len() || shared_prefix_len == 0 {
|
||||
return Ok(shared_prefix_len);
|
||||
}
|
||||
|
||||
@ -396,7 +416,7 @@ impl RadixTrie {
|
||||
+ self.insert_(
|
||||
child_id,
|
||||
&tokens[shared_prefix_len..],
|
||||
&blocks[shared_prefix_len..],
|
||||
&blocks[shared_prefix_len / self.block_size..],
|
||||
)?);
|
||||
}
|
||||
|
||||
@ -405,7 +425,7 @@ impl RadixTrie {
|
||||
// remainder of the prefix into the node again
|
||||
let child_id = self.split_node(child_id, shared_prefix_len);
|
||||
let key = &tokens[shared_prefix_len..];
|
||||
let blocks = &blocks[shared_prefix_len..];
|
||||
let blocks = &blocks[shared_prefix_len / self.block_size..];
|
||||
Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?)
|
||||
} else {
|
||||
self.add_node(node_id, tokens, blocks);
|
||||
@ -559,18 +579,9 @@ impl TrieNode {
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper trait to get the length of the shared prefix of two sequences.
|
||||
trait SharedPrefixLen {
|
||||
fn shared_prefix_len(&self, other: &Self) -> usize;
|
||||
}
|
||||
|
||||
impl<T> SharedPrefixLen for [T]
|
||||
where
|
||||
T: PartialEq,
|
||||
{
|
||||
fn shared_prefix_len(&self, other: &Self) -> usize {
|
||||
self.iter().zip(other).take_while(|(a, b)| a == b).count()
|
||||
}
|
||||
fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize {
|
||||
let full = left.iter().zip(right).take_while(|(a, b)| a == b).count();
|
||||
(full / block_size) * block_size
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@ -579,6 +590,21 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn allocator_block_size() {
|
||||
let mut cache = RadixAllocator::new(2, 12, None);
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
|
||||
assert_eq!(allocation.prefix_len, 0);
|
||||
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![8, 9, 6, 7]);
|
||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 12, 13, 14, 15]);
|
||||
assert_eq!(allocation.prefix_len, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocator_reuses_prefixes() {
|
||||
let mut cache = RadixAllocator::new(1, 12, None);
|
||||
@ -673,7 +699,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn trie_insertions_have_correct_prefix_len() {
|
||||
let mut trie = super::RadixTrie::new();
|
||||
let mut trie = RadixTrie::new(1);
|
||||
|
||||
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0);
|
||||
|
||||
@ -694,9 +720,33 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trie_insertions_block_size() {
|
||||
let mut trie = RadixTrie::new(2);
|
||||
|
||||
assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 0);
|
||||
|
||||
// Already exists.
|
||||
// But needs to be block_size aligned
|
||||
assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 4);
|
||||
|
||||
// Completely new at root-level
|
||||
assert_eq!(trie.insert(&[1, 2, 3, 4], &[1, 2]).unwrap(), 0);
|
||||
|
||||
// Contains full prefix, but longer.
|
||||
assert_eq!(trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2]).unwrap(), 4);
|
||||
|
||||
// Shares partial prefix, we need a split.
|
||||
assert_eq!(
|
||||
trie.insert(&[0, 1, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3])
|
||||
.unwrap(),
|
||||
2
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trie_get_returns_correct_blocks() {
|
||||
let mut trie = super::RadixTrie::new();
|
||||
let mut trie = RadixTrie::new(1);
|
||||
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
|
||||
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
|
||||
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
|
||||
@ -730,7 +780,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn trie_evict_removes_correct_blocks() {
|
||||
let mut trie = super::RadixTrie::new();
|
||||
let mut trie = RadixTrie::new(1);
|
||||
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
|
||||
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
|
||||
.unwrap();
|
||||
|
Loading…
Reference in New Issue
Block a user