mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 12:54:52 +00:00
Basic test passes
This commit is contained in:
parent
dbb82e274c
commit
9da64a7b16
@ -151,9 +151,11 @@ pub(crate) struct BlockAllocationWithCache {
|
||||
struct BlockState {
|
||||
block_id: u32,
|
||||
last_accessed: u64,
|
||||
predecessor: Option<u64>,
|
||||
ref_count: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PrefixCache {
|
||||
block_size: usize,
|
||||
cache_partial: bool,
|
||||
@ -180,10 +182,11 @@ impl PrefixCache {
|
||||
mut n_tokens: usize,
|
||||
prefill_tokens: &[u32],
|
||||
) -> Option<BlockAllocationWithCache> {
|
||||
// First try to lookup prefix.
|
||||
let mut hasher = DefaultHasher::new();
|
||||
let mut tokens_from_cache = 0;
|
||||
let mut cache_blocks_hashes = Vec::new();
|
||||
let mut prefix_cache_blocks = Vec::new();
|
||||
let mut prefix_hashes = Vec::new();
|
||||
for prefill_chunk in prefill_tokens.chunks(self.block_size) {
|
||||
if prefill_chunk.len() < self.block_size && !self.cache_partial {
|
||||
break;
|
||||
@ -193,22 +196,28 @@ impl PrefixCache {
|
||||
|
||||
match self.cache_blocks.get_mut(&hasher.finish()) {
|
||||
Some(state) => {
|
||||
// Ensure that we don't evict the prefix blocks.
|
||||
state.ref_count += 1;
|
||||
prefix_cache_blocks.push(state.block_id);
|
||||
prefix_hashes.push(hasher.finish());
|
||||
}
|
||||
None => todo!(),
|
||||
}
|
||||
|
||||
if !self.cache_blocks.contains_key(&hasher.finish()) {
|
||||
break;
|
||||
None => break,
|
||||
}
|
||||
|
||||
tokens_from_cache += prefill_chunk.len();
|
||||
cache_blocks_hashes.push(hasher.finish());
|
||||
}
|
||||
|
||||
let blocks = self.alloc_or_reclaim(n_tokens - tokens_from_cache)?;
|
||||
let blocks = match self.alloc_or_reclaim(n_tokens - tokens_from_cache) {
|
||||
Some(blocks) => blocks,
|
||||
None => {
|
||||
prefix_hashes.into_iter().for_each(|hash| {
|
||||
self.cache_blocks.get_mut(&hash).unwrap().ref_count -= 1;
|
||||
});
|
||||
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let mut prefix_cache_blocks = Vec::new();
|
||||
for hash in cache_blocks_hashes {
|
||||
match self.cache_blocks.get_mut(&hash) {
|
||||
Some(info) => {
|
||||
@ -219,6 +228,9 @@ impl PrefixCache {
|
||||
}
|
||||
}
|
||||
|
||||
eprintln!("Allocated blocks: {:?}", blocks);
|
||||
eprintln!("Prefix blocks: {:?}", prefix_cache_blocks);
|
||||
|
||||
prefix_cache_blocks.extend(blocks);
|
||||
|
||||
let mut slots = Vec::with_capacity(n_tokens);
|
||||
@ -279,7 +291,7 @@ impl PrefixCache {
|
||||
|
||||
fn free(&mut self, blocks: &[u32], prefill_tokens: &[u32]) {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
|
||||
let mut predecessor = None;
|
||||
for (prefill_chunk, block_id) in prefill_tokens.chunks(self.block_size).zip(blocks.iter()) {
|
||||
if prefill_chunk.len() < self.block_size && !self.cache_partial {
|
||||
break;
|
||||
@ -298,13 +310,21 @@ impl PrefixCache {
|
||||
entry.insert(BlockState {
|
||||
block_id: *block_id,
|
||||
last_accessed: self.time,
|
||||
predecessor,
|
||||
ref_count: 0,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
predecessor = Some(hasher.finish());
|
||||
}
|
||||
|
||||
self.time += 1;
|
||||
|
||||
let n_prefill_blocks = (prefill_tokens.len() + self.block_size - 1) / self.block_size;
|
||||
for block in &blocks[n_prefill_blocks..] {
|
||||
self.free_blocks.push(*block);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -327,6 +347,8 @@ mod tests {
|
||||
);
|
||||
cache.free(&allocation.unwrap().blocks, &[0, 1, 2, 3]);
|
||||
|
||||
eprintln!("{:?}", cache);
|
||||
|
||||
let allocation = cache.alloc(8, &[0, 1, 2, 3]);
|
||||
assert_eq!(
|
||||
allocation,
|
||||
|
Loading…
Reference in New Issue
Block a user