mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Basic test passes
This commit is contained in:
parent
dbb82e274c
commit
9da64a7b16
@ -151,9 +151,11 @@ pub(crate) struct BlockAllocationWithCache {
|
|||||||
struct BlockState {
|
struct BlockState {
|
||||||
block_id: u32,
|
block_id: u32,
|
||||||
last_accessed: u64,
|
last_accessed: u64,
|
||||||
|
predecessor: Option<u64>,
|
||||||
ref_count: usize,
|
ref_count: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct PrefixCache {
|
pub struct PrefixCache {
|
||||||
block_size: usize,
|
block_size: usize,
|
||||||
cache_partial: bool,
|
cache_partial: bool,
|
||||||
@ -180,10 +182,11 @@ impl PrefixCache {
|
|||||||
mut n_tokens: usize,
|
mut n_tokens: usize,
|
||||||
prefill_tokens: &[u32],
|
prefill_tokens: &[u32],
|
||||||
) -> Option<BlockAllocationWithCache> {
|
) -> Option<BlockAllocationWithCache> {
|
||||||
// First try to lookup prefix.
|
|
||||||
let mut hasher = DefaultHasher::new();
|
let mut hasher = DefaultHasher::new();
|
||||||
let mut tokens_from_cache = 0;
|
let mut tokens_from_cache = 0;
|
||||||
let mut cache_blocks_hashes = Vec::new();
|
let mut cache_blocks_hashes = Vec::new();
|
||||||
|
let mut prefix_cache_blocks = Vec::new();
|
||||||
|
let mut prefix_hashes = Vec::new();
|
||||||
for prefill_chunk in prefill_tokens.chunks(self.block_size) {
|
for prefill_chunk in prefill_tokens.chunks(self.block_size) {
|
||||||
if prefill_chunk.len() < self.block_size && !self.cache_partial {
|
if prefill_chunk.len() < self.block_size && !self.cache_partial {
|
||||||
break;
|
break;
|
||||||
@ -193,22 +196,28 @@ impl PrefixCache {
|
|||||||
|
|
||||||
match self.cache_blocks.get_mut(&hasher.finish()) {
|
match self.cache_blocks.get_mut(&hasher.finish()) {
|
||||||
Some(state) => {
|
Some(state) => {
|
||||||
|
// Ensure that we don't evict the prefix blocks.
|
||||||
state.ref_count += 1;
|
state.ref_count += 1;
|
||||||
|
prefix_cache_blocks.push(state.block_id);
|
||||||
|
prefix_hashes.push(hasher.finish());
|
||||||
}
|
}
|
||||||
None => todo!(),
|
None => break,
|
||||||
}
|
|
||||||
|
|
||||||
if !self.cache_blocks.contains_key(&hasher.finish()) {
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tokens_from_cache += prefill_chunk.len();
|
tokens_from_cache += prefill_chunk.len();
|
||||||
cache_blocks_hashes.push(hasher.finish());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let blocks = self.alloc_or_reclaim(n_tokens - tokens_from_cache)?;
|
let blocks = match self.alloc_or_reclaim(n_tokens - tokens_from_cache) {
|
||||||
|
Some(blocks) => blocks,
|
||||||
|
None => {
|
||||||
|
prefix_hashes.into_iter().for_each(|hash| {
|
||||||
|
self.cache_blocks.get_mut(&hash).unwrap().ref_count -= 1;
|
||||||
|
});
|
||||||
|
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let mut prefix_cache_blocks = Vec::new();
|
|
||||||
for hash in cache_blocks_hashes {
|
for hash in cache_blocks_hashes {
|
||||||
match self.cache_blocks.get_mut(&hash) {
|
match self.cache_blocks.get_mut(&hash) {
|
||||||
Some(info) => {
|
Some(info) => {
|
||||||
@ -219,6 +228,9 @@ impl PrefixCache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
eprintln!("Allocated blocks: {:?}", blocks);
|
||||||
|
eprintln!("Prefix blocks: {:?}", prefix_cache_blocks);
|
||||||
|
|
||||||
prefix_cache_blocks.extend(blocks);
|
prefix_cache_blocks.extend(blocks);
|
||||||
|
|
||||||
let mut slots = Vec::with_capacity(n_tokens);
|
let mut slots = Vec::with_capacity(n_tokens);
|
||||||
@ -279,7 +291,7 @@ impl PrefixCache {
|
|||||||
|
|
||||||
fn free(&mut self, blocks: &[u32], prefill_tokens: &[u32]) {
|
fn free(&mut self, blocks: &[u32], prefill_tokens: &[u32]) {
|
||||||
let mut hasher = DefaultHasher::new();
|
let mut hasher = DefaultHasher::new();
|
||||||
|
let mut predecessor = None;
|
||||||
for (prefill_chunk, block_id) in prefill_tokens.chunks(self.block_size).zip(blocks.iter()) {
|
for (prefill_chunk, block_id) in prefill_tokens.chunks(self.block_size).zip(blocks.iter()) {
|
||||||
if prefill_chunk.len() < self.block_size && !self.cache_partial {
|
if prefill_chunk.len() < self.block_size && !self.cache_partial {
|
||||||
break;
|
break;
|
||||||
@ -298,13 +310,21 @@ impl PrefixCache {
|
|||||||
entry.insert(BlockState {
|
entry.insert(BlockState {
|
||||||
block_id: *block_id,
|
block_id: *block_id,
|
||||||
last_accessed: self.time,
|
last_accessed: self.time,
|
||||||
|
predecessor,
|
||||||
ref_count: 0,
|
ref_count: 0,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
predecessor = Some(hasher.finish());
|
||||||
}
|
}
|
||||||
|
|
||||||
self.time += 1;
|
self.time += 1;
|
||||||
|
|
||||||
|
let n_prefill_blocks = (prefill_tokens.len() + self.block_size - 1) / self.block_size;
|
||||||
|
for block in &blocks[n_prefill_blocks..] {
|
||||||
|
self.free_blocks.push(*block);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -327,6 +347,8 @@ mod tests {
|
|||||||
);
|
);
|
||||||
cache.free(&allocation.unwrap().blocks, &[0, 1, 2, 3]);
|
cache.free(&allocation.unwrap().blocks, &[0, 1, 2, 3]);
|
||||||
|
|
||||||
|
eprintln!("{:?}", cache);
|
||||||
|
|
||||||
let allocation = cache.alloc(8, &[0, 1, 2, 3]);
|
let allocation = cache.alloc(8, &[0, 1, 2, 3]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
allocation,
|
allocation,
|
||||||
|
Loading…
Reference in New Issue
Block a user