mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 12:54:52 +00:00
Walk up to predecessors
This commit is contained in:
parent
9da64a7b16
commit
3b6bef4078
@ -2,7 +2,7 @@ use itertools::Itertools;
|
|||||||
use std::{
|
use std::{
|
||||||
borrow::BorrowMut,
|
borrow::BorrowMut,
|
||||||
cmp::min,
|
cmp::min,
|
||||||
collections::{hash_map::Entry, HashMap},
|
collections::{hash_map::Entry, HashMap, HashSet},
|
||||||
hash::{DefaultHasher, Hash, Hasher},
|
hash::{DefaultHasher, Hash, Hasher},
|
||||||
};
|
};
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
@ -160,6 +160,7 @@ pub struct PrefixCache {
|
|||||||
block_size: usize,
|
block_size: usize,
|
||||||
cache_partial: bool,
|
cache_partial: bool,
|
||||||
free_blocks: Vec<u32>,
|
free_blocks: Vec<u32>,
|
||||||
|
leaves: HashSet<u64>,
|
||||||
cache_blocks: HashMap<u64, BlockState>,
|
cache_blocks: HashMap<u64, BlockState>,
|
||||||
|
|
||||||
// Avoid a system call, use a counter for time.
|
// Avoid a system call, use a counter for time.
|
||||||
@ -173,18 +174,18 @@ impl PrefixCache {
|
|||||||
cache_blocks: HashMap::new(),
|
cache_blocks: HashMap::new(),
|
||||||
cache_partial,
|
cache_partial,
|
||||||
free_blocks: (1..n_blocks as u32).collect(),
|
free_blocks: (1..n_blocks as u32).collect(),
|
||||||
|
leaves: HashSet::new(),
|
||||||
time: 0,
|
time: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn alloc(
|
fn alloc(
|
||||||
&mut self,
|
&mut self,
|
||||||
mut n_tokens: usize,
|
n_tokens: usize,
|
||||||
prefill_tokens: &[u32],
|
prefill_tokens: &[u32],
|
||||||
) -> Option<BlockAllocationWithCache> {
|
) -> Option<BlockAllocationWithCache> {
|
||||||
let mut hasher = DefaultHasher::new();
|
let mut hasher = DefaultHasher::new();
|
||||||
let mut tokens_from_cache = 0;
|
let mut tokens_from_cache = 0;
|
||||||
let mut cache_blocks_hashes = Vec::new();
|
|
||||||
let mut prefix_cache_blocks = Vec::new();
|
let mut prefix_cache_blocks = Vec::new();
|
||||||
let mut prefix_hashes = Vec::new();
|
let mut prefix_hashes = Vec::new();
|
||||||
for prefill_chunk in prefill_tokens.chunks(self.block_size) {
|
for prefill_chunk in prefill_tokens.chunks(self.block_size) {
|
||||||
@ -194,15 +195,14 @@ impl PrefixCache {
|
|||||||
|
|
||||||
prefill_chunk.hash(&mut hasher);
|
prefill_chunk.hash(&mut hasher);
|
||||||
|
|
||||||
match self.cache_blocks.get_mut(&hasher.finish()) {
|
let block_id = match self.cache_blocks.get(&hasher.finish()) {
|
||||||
Some(state) => {
|
Some(state) => state.block_id,
|
||||||
// Ensure that we don't evict the prefix blocks.
|
|
||||||
state.ref_count += 1;
|
|
||||||
prefix_cache_blocks.push(state.block_id);
|
|
||||||
prefix_hashes.push(hasher.finish());
|
|
||||||
}
|
|
||||||
None => break,
|
None => break,
|
||||||
}
|
};
|
||||||
|
|
||||||
|
self.incref_prefix(hasher.finish());
|
||||||
|
prefix_hashes.push(hasher.finish());
|
||||||
|
prefix_cache_blocks.push(block_id);
|
||||||
|
|
||||||
tokens_from_cache += prefill_chunk.len();
|
tokens_from_cache += prefill_chunk.len();
|
||||||
}
|
}
|
||||||
@ -218,19 +218,6 @@ impl PrefixCache {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
for hash in cache_blocks_hashes {
|
|
||||||
match self.cache_blocks.get_mut(&hash) {
|
|
||||||
Some(info) => {
|
|
||||||
info.ref_count += 1;
|
|
||||||
prefix_cache_blocks.push(info.block_id);
|
|
||||||
}
|
|
||||||
None => unreachable!(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
eprintln!("Allocated blocks: {:?}", blocks);
|
|
||||||
eprintln!("Prefix blocks: {:?}", prefix_cache_blocks);
|
|
||||||
|
|
||||||
prefix_cache_blocks.extend(blocks);
|
prefix_cache_blocks.extend(blocks);
|
||||||
|
|
||||||
let mut slots = Vec::with_capacity(n_tokens);
|
let mut slots = Vec::with_capacity(n_tokens);
|
||||||
@ -251,6 +238,40 @@ impl PrefixCache {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn free_prefix(&mut self, prefix_hash: u64) {
|
||||||
|
let state = self
|
||||||
|
.cache_blocks
|
||||||
|
.remove(&prefix_hash)
|
||||||
|
.expect("Unknown hash");
|
||||||
|
|
||||||
|
if let Some(predecessor) = state.predecessor {
|
||||||
|
self.decref_prefix(predecessor);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.leaves.remove(&prefix_hash);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decref_prefix(&mut self, prefix_hash: u64) {
|
||||||
|
let state = self
|
||||||
|
.cache_blocks
|
||||||
|
.get_mut(&prefix_hash)
|
||||||
|
.expect("Unknown hash");
|
||||||
|
assert!(state.ref_count > 0);
|
||||||
|
state.ref_count -= 1;
|
||||||
|
if state.ref_count == 0 {
|
||||||
|
self.leaves.insert(prefix_hash);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn incref_prefix(&mut self, prefix_hash: u64) {
|
||||||
|
let state = self
|
||||||
|
.cache_blocks
|
||||||
|
.get_mut(&prefix_hash)
|
||||||
|
.expect("Unknown hash");
|
||||||
|
state.ref_count += 1;
|
||||||
|
self.leaves.remove(&prefix_hash);
|
||||||
|
}
|
||||||
|
|
||||||
fn alloc_or_reclaim(&mut self, n_tokens: usize) -> Option<Vec<u32>> {
|
fn alloc_or_reclaim(&mut self, n_tokens: usize) -> Option<Vec<u32>> {
|
||||||
let n_blocks = (n_tokens + self.block_size - 1) / self.block_size;
|
let n_blocks = (n_tokens + self.block_size - 1) / self.block_size;
|
||||||
let n_blocks_needed = if n_tokens > self.free_blocks.len() {
|
let n_blocks_needed = if n_tokens > self.free_blocks.len() {
|
||||||
@ -259,28 +280,24 @@ impl PrefixCache {
|
|||||||
0
|
0
|
||||||
};
|
};
|
||||||
|
|
||||||
if n_blocks_needed > 0 {
|
while self.free_blocks.len() < n_blocks_needed {
|
||||||
let removable_blocks = self
|
// We have to free one block at a time, because removing the LRU
|
||||||
.cache_blocks
|
// prefix block may make available another prefix block that is
|
||||||
.iter_mut()
|
// LRU.
|
||||||
// Block must be unused.
|
//
|
||||||
.filter(|(_, state)| state.ref_count == 0)
|
// TODO: switch to something like a binary heap to avoid sorting
|
||||||
// Remove most recent block first.
|
// the set of leaves over and over again.
|
||||||
// TODO: we are not yet removing a prefix in reverse order.
|
|
||||||
.sorted_by_key(|(_, state)| state.last_accessed)
|
|
||||||
// Find enough candidates.
|
|
||||||
.take(n_blocks_needed)
|
|
||||||
.map(|(block_hash, block_state)| (*block_hash, block_state.block_id))
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
if removable_blocks.len() < n_blocks_needed {
|
let (lru_prefix_hash, lru_block_id) = self
|
||||||
return None;
|
.leaves
|
||||||
}
|
.iter()
|
||||||
|
.map(|prefix_hash| (prefix_hash, &self.cache_blocks[prefix_hash]))
|
||||||
|
.sorted_by_key(|state| state.1.last_accessed)
|
||||||
|
.map(|(prefix_hash, state)| (*prefix_hash, state.block_id))
|
||||||
|
.next()?;
|
||||||
|
|
||||||
for (block_hash, block_id) in removable_blocks.into_iter() {
|
self.free_prefix(lru_prefix_hash);
|
||||||
self.free_blocks.push(block_id);
|
self.free_blocks.push(lru_block_id);
|
||||||
self.cache_blocks.remove(&block_hash);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Some(
|
Some(
|
||||||
|
Loading…
Reference in New Issue
Block a user