mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Avoid continuous sorting during reclamation
This commit is contained in:
parent
c352a3e231
commit
1a461234d5
@ -1,8 +1,8 @@
|
|||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use std::{
|
use std::{
|
||||||
borrow::BorrowMut,
|
borrow::BorrowMut,
|
||||||
cmp::min,
|
cmp::{min, Reverse},
|
||||||
collections::{hash_map::Entry, HashMap, HashSet},
|
collections::{hash_map::Entry, BinaryHeap, HashMap, HashSet},
|
||||||
hash::{DefaultHasher, Hash, Hasher},
|
hash::{DefaultHasher, Hash, Hasher},
|
||||||
};
|
};
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
@ -257,7 +257,7 @@ impl PrefixCache {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn free_prefix(&mut self, prefix_hash: u64) {
|
fn free_prefix_block(&mut self, prefix_hash: u64) {
|
||||||
let state = self
|
let state = self
|
||||||
.cache_blocks
|
.cache_blocks
|
||||||
.remove(&prefix_hash)
|
.remove(&prefix_hash)
|
||||||
@ -269,6 +269,7 @@ impl PrefixCache {
|
|||||||
}
|
}
|
||||||
|
|
||||||
self.leaves.remove(&prefix_hash);
|
self.leaves.remove(&prefix_hash);
|
||||||
|
self.free_blocks.push(state.block_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn decref_prefix(&mut self, prefix_hash: u64) {
|
fn decref_prefix(&mut self, prefix_hash: u64) {
|
||||||
@ -293,36 +294,40 @@ impl PrefixCache {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn alloc_or_reclaim(&mut self, n_tokens: usize) -> Option<Vec<u32>> {
|
fn alloc_or_reclaim(&mut self, n_tokens: usize) -> Option<Vec<u32>> {
|
||||||
let n_blocks = (n_tokens + self.block_size - 1) / self.block_size;
|
let n_blocks_needed = (n_tokens + self.block_size - 1) / self.block_size;
|
||||||
let n_blocks_needed = if n_blocks > self.free_blocks.len() {
|
|
||||||
n_blocks - self.free_blocks.len()
|
let mut reclaimable_blocks = self
|
||||||
} else {
|
.leaves
|
||||||
0
|
.iter()
|
||||||
};
|
.map(|prefix_hash| {
|
||||||
|
let state = &self.cache_blocks[prefix_hash];
|
||||||
|
Reverse((state.last_accessed, *prefix_hash, state.predecessor))
|
||||||
|
})
|
||||||
|
.collect::<BinaryHeap<_>>();
|
||||||
|
|
||||||
while self.free_blocks.len() < n_blocks_needed {
|
while self.free_blocks.len() < n_blocks_needed {
|
||||||
// We have to free one block at a time, because removing the LRU
|
// We have to free one block at a time because removing the LRU
|
||||||
// prefix block may make available another prefix block that is
|
// prefix block may make available another prefix block that is
|
||||||
// LRU.
|
// LRU.
|
||||||
//
|
let (_, lru_prefix_hash, predecessor) = reclaimable_blocks.pop()?.0;
|
||||||
// TODO: switch to something like a binary heap to avoid sorting
|
self.free_prefix_block(lru_prefix_hash);
|
||||||
// the set of leaves over and over again.
|
|
||||||
|
|
||||||
let (lru_prefix_hash, lru_block_id) = self
|
// TODO: this is a leaky abstraction, avoid this.
|
||||||
.leaves
|
if let Some(predecessor) = predecessor {
|
||||||
.iter()
|
let state = &self.cache_blocks[&predecessor];
|
||||||
.map(|prefix_hash| (prefix_hash, &self.cache_blocks[prefix_hash]))
|
if state.ref_count == 0 {
|
||||||
.sorted_by_key(|state| state.1.last_accessed)
|
reclaimable_blocks.push(Reverse((
|
||||||
.map(|(prefix_hash, state)| (*prefix_hash, state.block_id))
|
state.last_accessed,
|
||||||
.next()?;
|
predecessor,
|
||||||
|
state.predecessor,
|
||||||
self.free_prefix(lru_prefix_hash);
|
)));
|
||||||
self.free_blocks.push(lru_block_id);
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Some(
|
Some(
|
||||||
self.free_blocks
|
self.free_blocks
|
||||||
.split_off(self.free_blocks.len() - n_blocks),
|
.split_off(self.free_blocks.len() - n_blocks_needed),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user