mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 12:54:52 +00:00
Shake out some issues, add correct removal order test
This commit is contained in:
parent
6d0094e5d4
commit
c352a3e231
@ -171,9 +171,6 @@ pub struct PrefixCache {
|
|||||||
/// on its preceding prefix block.
|
/// on its preceding prefix block.
|
||||||
cache_blocks: HashMap<u64, PrefixBlockState>,
|
cache_blocks: HashMap<u64, PrefixBlockState>,
|
||||||
|
|
||||||
/// Whether to cache partial blocks.
|
|
||||||
cache_partial: bool,
|
|
||||||
|
|
||||||
/// Blocks that are immediately available for allocation.
|
/// Blocks that are immediately available for allocation.
|
||||||
free_blocks: Vec<u32>,
|
free_blocks: Vec<u32>,
|
||||||
|
|
||||||
@ -185,11 +182,10 @@ pub struct PrefixCache {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl PrefixCache {
|
impl PrefixCache {
|
||||||
pub fn new(block_size: usize, n_blocks: usize, cache_partial: bool) -> Self {
|
pub fn new(block_size: usize, n_blocks: usize) -> Self {
|
||||||
PrefixCache {
|
PrefixCache {
|
||||||
block_size,
|
block_size,
|
||||||
cache_blocks: HashMap::new(),
|
cache_blocks: HashMap::new(),
|
||||||
cache_partial,
|
|
||||||
free_blocks: (1..n_blocks as u32).collect(),
|
free_blocks: (1..n_blocks as u32).collect(),
|
||||||
leaves: HashSet::new(),
|
leaves: HashSet::new(),
|
||||||
time: 0,
|
time: 0,
|
||||||
@ -202,11 +198,10 @@ impl PrefixCache {
|
|||||||
prefill_tokens: &[u32],
|
prefill_tokens: &[u32],
|
||||||
) -> Option<BlockAllocationWithCache> {
|
) -> Option<BlockAllocationWithCache> {
|
||||||
let mut hasher = DefaultHasher::new();
|
let mut hasher = DefaultHasher::new();
|
||||||
let mut tokens_from_cache = 0;
|
|
||||||
let mut prefix_cache_blocks = Vec::new();
|
let mut prefix_cache_blocks = Vec::new();
|
||||||
let mut prefix_hashes = Vec::new();
|
let mut prefix_hashes = Vec::new();
|
||||||
for prefill_chunk in prefill_tokens.chunks(self.block_size) {
|
for prefill_chunk in prefill_tokens.chunks(self.block_size) {
|
||||||
if prefill_chunk.len() < self.block_size && !self.cache_partial {
|
if prefill_chunk.len() < self.block_size {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -225,12 +220,11 @@ impl PrefixCache {
|
|||||||
self.incref_prefix(prefix_hash);
|
self.incref_prefix(prefix_hash);
|
||||||
prefix_hashes.push(prefix_hash);
|
prefix_hashes.push(prefix_hash);
|
||||||
prefix_cache_blocks.push(block_id);
|
prefix_cache_blocks.push(block_id);
|
||||||
|
|
||||||
tokens_from_cache += prefill_chunk.len();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get tokens for the remaining prefill and decode.
|
// Get tokens for the remaining prefill and decode.
|
||||||
let blocks = match self.alloc_or_reclaim(n_tokens - tokens_from_cache) {
|
let blocks = match self.alloc_or_reclaim(n_tokens - (prefix_hashes.len() * self.block_size))
|
||||||
|
{
|
||||||
Some(blocks) => blocks,
|
Some(blocks) => blocks,
|
||||||
None => {
|
None => {
|
||||||
// If the allocation fails, we have relinquish our use of the
|
// If the allocation fails, we have relinquish our use of the
|
||||||
@ -247,7 +241,6 @@ impl PrefixCache {
|
|||||||
|
|
||||||
let mut slots = Vec::with_capacity(n_tokens);
|
let mut slots = Vec::with_capacity(n_tokens);
|
||||||
for block_id in prefix_cache_blocks.iter() {
|
for block_id in prefix_cache_blocks.iter() {
|
||||||
// TODO: fixme: doesn't work with cache_partial yet.
|
|
||||||
for s in
|
for s in
|
||||||
(*block_id * self.block_size as u32)..((*block_id + 1) * self.block_size as u32)
|
(*block_id * self.block_size as u32)..((*block_id + 1) * self.block_size as u32)
|
||||||
{
|
{
|
||||||
@ -301,7 +294,7 @@ impl PrefixCache {
|
|||||||
|
|
||||||
fn alloc_or_reclaim(&mut self, n_tokens: usize) -> Option<Vec<u32>> {
|
fn alloc_or_reclaim(&mut self, n_tokens: usize) -> Option<Vec<u32>> {
|
||||||
let n_blocks = (n_tokens + self.block_size - 1) / self.block_size;
|
let n_blocks = (n_tokens + self.block_size - 1) / self.block_size;
|
||||||
let n_blocks_needed = if n_tokens > self.free_blocks.len() {
|
let n_blocks_needed = if n_blocks > self.free_blocks.len() {
|
||||||
n_blocks - self.free_blocks.len()
|
n_blocks - self.free_blocks.len()
|
||||||
} else {
|
} else {
|
||||||
0
|
0
|
||||||
@ -337,7 +330,7 @@ impl PrefixCache {
|
|||||||
let mut hasher = DefaultHasher::new();
|
let mut hasher = DefaultHasher::new();
|
||||||
let mut predecessor = None;
|
let mut predecessor = None;
|
||||||
for (prefill_chunk, block_id) in prefill_tokens.chunks(self.block_size).zip(blocks.iter()) {
|
for (prefill_chunk, block_id) in prefill_tokens.chunks(self.block_size).zip(blocks.iter()) {
|
||||||
if prefill_chunk.len() < self.block_size && !self.cache_partial {
|
if prefill_chunk.len() < self.block_size {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -357,6 +350,9 @@ impl PrefixCache {
|
|||||||
predecessor,
|
predecessor,
|
||||||
ref_count: 0,
|
ref_count: 0,
|
||||||
});
|
});
|
||||||
|
if let Some(predecessor) = predecessor {
|
||||||
|
self.incref_prefix(predecessor);
|
||||||
|
}
|
||||||
self.leaves.insert(prefix_hash);
|
self.leaves.insert(prefix_hash);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -381,7 +377,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_prefix_cache() {
|
fn test_prefix_cache() {
|
||||||
let mut cache = PrefixCache::new(4, 3, false);
|
let mut cache = PrefixCache::new(4, 3);
|
||||||
let allocation = cache.alloc(8, &[0, 1, 2, 3]);
|
let allocation = cache.alloc(8, &[0, 1, 2, 3]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
allocation,
|
allocation,
|
||||||
@ -392,8 +388,6 @@ mod tests {
|
|||||||
);
|
);
|
||||||
cache.free(&allocation.unwrap().blocks, &[0, 1, 2, 3]);
|
cache.free(&allocation.unwrap().blocks, &[0, 1, 2, 3]);
|
||||||
|
|
||||||
eprintln!("{:?}", cache);
|
|
||||||
|
|
||||||
let allocation = cache.alloc(8, &[0, 1, 2, 3]);
|
let allocation = cache.alloc(8, &[0, 1, 2, 3]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
allocation,
|
allocation,
|
||||||
@ -403,4 +397,39 @@ mod tests {
|
|||||||
})
|
})
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_older_prefixes_are_collected_first() {
|
||||||
|
let mut cache = PrefixCache::new(2, 4);
|
||||||
|
let allocation1 = cache.alloc(4, &[0, 1, 2, 3]);
|
||||||
|
assert_eq!(
|
||||||
|
allocation1,
|
||||||
|
Some(BlockAllocationWithCache {
|
||||||
|
blocks: vec![2, 3],
|
||||||
|
slots: (4..8).collect()
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
let allocation2 = cache.alloc(2, &[4, 5]);
|
||||||
|
assert_eq!(
|
||||||
|
allocation2,
|
||||||
|
Some(BlockAllocationWithCache {
|
||||||
|
blocks: vec![1],
|
||||||
|
slots: (2..4).collect()
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
cache.free(&allocation1.unwrap().blocks, &[0, 1, 2, 3]);
|
||||||
|
cache.free(&allocation2.unwrap().blocks, &[4, 5]);
|
||||||
|
|
||||||
|
// We should get the blocks of the first allocation, since they are more recent.
|
||||||
|
let allocation3 = cache.alloc(4, &[6, 7, 8, 9]);
|
||||||
|
assert_eq!(
|
||||||
|
allocation3,
|
||||||
|
Some(BlockAllocationWithCache {
|
||||||
|
blocks: vec![3, 2],
|
||||||
|
slots: vec![6, 7, 4, 5]
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user