mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Fixed the radix tree.
Used a slice everywhere in radix.rs to keep the cheap Arc cloning instead of recomputing the input_ids.
This commit is contained in:
parent
f952024533
commit
785c6e4893
@ -71,6 +71,8 @@ impl Allocator for RadixAllocator {
|
|||||||
let mut blocks = vec![];
|
let mut blocks = vec![];
|
||||||
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
|
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
|
||||||
let node_id = self.cache_blocks.find(
|
let node_id = self.cache_blocks.find(
|
||||||
|
// XXX This is super important we cannot match an entire prefix
|
||||||
|
// otherwise input_ids is empty and the shard code cannot handle that.
|
||||||
&prefill_tokens.as_slice()[..prefill_tokens.len().saturating_sub(1)],
|
&prefill_tokens.as_slice()[..prefill_tokens.len().saturating_sub(1)],
|
||||||
&mut blocks,
|
&mut blocks,
|
||||||
);
|
);
|
||||||
@ -150,7 +152,9 @@ impl Allocator for RadixAllocator {
|
|||||||
.expect("Failed to decrement refcount");
|
.expect("Failed to decrement refcount");
|
||||||
|
|
||||||
if let Some(prefill_tokens) = allocation.prefill_tokens {
|
if let Some(prefill_tokens) = allocation.prefill_tokens {
|
||||||
let prefill_tokens = prefill_tokens.as_slice();
|
// XXX We matched everything except the last token
|
||||||
|
let prefill_tokens =
|
||||||
|
&prefill_tokens.as_slice()[..prefill_tokens.len().saturating_sub(1)];
|
||||||
|
|
||||||
// If there are prefill tokens that did not come from the cache,
|
// If there are prefill tokens that did not come from the cache,
|
||||||
// add them to the cache.
|
// add them to the cache.
|
||||||
@ -612,13 +616,17 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn allocator_block_size() {
|
fn allocator_block_size() {
|
||||||
let mut cache = RadixAllocator::new(2, 12, None);
|
let mut cache = RadixAllocator::new(2, 12, None);
|
||||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
let allocation = cache
|
||||||
|
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 99])))
|
||||||
|
.unwrap();
|
||||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
|
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
|
||||||
assert_eq!(allocation.prefix_len, 0);
|
assert_eq!(allocation.prefix_len, 0);
|
||||||
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||||
|
|
||||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
let allocation = cache
|
||||||
|
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 99])))
|
||||||
|
.unwrap();
|
||||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
|
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
|
||||||
assert_eq!(allocation.prefix_len, 4);
|
assert_eq!(allocation.prefix_len, 4);
|
||||||
@ -627,13 +635,17 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn allocator_block_size_non_aligned() {
|
fn allocator_block_size_non_aligned() {
|
||||||
let mut cache = RadixAllocator::new(2, 12, None);
|
let mut cache = RadixAllocator::new(2, 12, None);
|
||||||
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
|
let allocation = cache
|
||||||
|
.allocate(7, Some(Arc::new(vec![0, 1, 2, 99])))
|
||||||
|
.unwrap();
|
||||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
|
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
|
||||||
assert_eq!(allocation.prefix_len, 0);
|
assert_eq!(allocation.prefix_len, 0);
|
||||||
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||||
|
|
||||||
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
|
let allocation = cache
|
||||||
|
.allocate(7, Some(Arc::new(vec![0, 1, 2, 99])))
|
||||||
|
.unwrap();
|
||||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
|
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
|
||||||
assert_eq!(allocation.prefix_len, 2);
|
assert_eq!(allocation.prefix_len, 2);
|
||||||
@ -642,13 +654,17 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn allocator_reuses_prefixes() {
|
fn allocator_reuses_prefixes() {
|
||||||
let mut cache = RadixAllocator::new(1, 12, None);
|
let mut cache = RadixAllocator::new(1, 12, None);
|
||||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
let allocation = cache
|
||||||
|
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 99])))
|
||||||
|
.unwrap();
|
||||||
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
||||||
assert_eq!(allocation.blocks, allocation.slots);
|
assert_eq!(allocation.blocks, allocation.slots);
|
||||||
assert_eq!(allocation.prefix_len, 0);
|
assert_eq!(allocation.prefix_len, 0);
|
||||||
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||||
|
|
||||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
let allocation = cache
|
||||||
|
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 99])))
|
||||||
|
.unwrap();
|
||||||
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
||||||
assert_eq!(allocation.prefix_len, 4);
|
assert_eq!(allocation.prefix_len, 4);
|
||||||
}
|
}
|
||||||
@ -656,11 +672,13 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn allocator_collects_older_prefixes_first() {
|
fn allocator_collects_older_prefixes_first() {
|
||||||
let mut cache = RadixAllocator::new(1, 7, None);
|
let mut cache = RadixAllocator::new(1, 7, None);
|
||||||
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
let allocation1 = cache
|
||||||
|
.allocate(4, Some(Arc::new(vec![0, 1, 2, 3, 99])))
|
||||||
|
.unwrap();
|
||||||
assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]);
|
assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]);
|
||||||
assert_eq!(allocation1.prefix_len, 0);
|
assert_eq!(allocation1.prefix_len, 0);
|
||||||
|
|
||||||
let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap();
|
let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5, 99]))).unwrap();
|
||||||
assert_eq!(allocation2.blocks, vec![1, 2]);
|
assert_eq!(allocation2.blocks, vec![1, 2]);
|
||||||
assert_eq!(allocation2.prefix_len, 0);
|
assert_eq!(allocation2.prefix_len, 0);
|
||||||
|
|
||||||
@ -668,7 +686,9 @@ mod tests {
|
|||||||
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
||||||
|
|
||||||
// We should get the blocks of the first allocation, since they are more recent.
|
// We should get the blocks of the first allocation, since they are more recent.
|
||||||
let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap();
|
let allocation3 = cache
|
||||||
|
.allocate(4, Some(Arc::new(vec![6, 7, 8, 9, 99])))
|
||||||
|
.unwrap();
|
||||||
assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]);
|
assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]);
|
||||||
assert_eq!(allocation3.prefix_len, 0);
|
assert_eq!(allocation3.prefix_len, 0);
|
||||||
}
|
}
|
||||||
@ -676,13 +696,19 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn allocator_frees_fully_overlapping_prefills() {
|
fn allocator_frees_fully_overlapping_prefills() {
|
||||||
let mut cache = RadixAllocator::new(1, 10, None);
|
let mut cache = RadixAllocator::new(1, 10, None);
|
||||||
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
let allocation1 = cache
|
||||||
let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
.allocate(4, Some(Arc::new(vec![0, 1, 2, 3, 99])))
|
||||||
|
.unwrap();
|
||||||
|
let allocation2 = cache
|
||||||
|
.allocate(4, Some(Arc::new(vec![0, 1, 2, 3, 99])))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
||||||
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
||||||
|
|
||||||
let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
let allocation3 = cache
|
||||||
|
.allocate(4, Some(Arc::new(vec![0, 1, 2, 3, 99])))
|
||||||
|
.unwrap();
|
||||||
assert_eq!(allocation3.prefix_len, 4);
|
assert_eq!(allocation3.prefix_len, 4);
|
||||||
|
|
||||||
// 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks.
|
// 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks.
|
||||||
@ -692,20 +718,20 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn allocator_frees_partially_overlapping_prefills() {
|
fn allocator_frees_partially_overlapping_prefills() {
|
||||||
let mut cache = RadixAllocator::new(1, 20, None);
|
let mut cache = RadixAllocator::new(1, 20, None);
|
||||||
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap();
|
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 99]))).unwrap();
|
||||||
assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]);
|
assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]);
|
||||||
assert_eq!(allocation1.prefix_len, 0);
|
assert_eq!(allocation1.prefix_len, 0);
|
||||||
|
|
||||||
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
||||||
|
|
||||||
let allocation2 = cache
|
let allocation2 = cache
|
||||||
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
|
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5, 99])))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]);
|
assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]);
|
||||||
assert_eq!(allocation2.prefix_len, 2);
|
assert_eq!(allocation2.prefix_len, 2);
|
||||||
|
|
||||||
let allocation3 = cache
|
let allocation3 = cache
|
||||||
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
|
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7, 99])))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]);
|
assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]);
|
||||||
assert_eq!(allocation3.prefix_len, 2);
|
assert_eq!(allocation3.prefix_len, 2);
|
||||||
@ -717,14 +743,14 @@ mod tests {
|
|||||||
assert_eq!(cache.free_blocks.len(), 11);
|
assert_eq!(cache.free_blocks.len(), 11);
|
||||||
|
|
||||||
let allocation4 = cache
|
let allocation4 = cache
|
||||||
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
|
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5, 99])))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]);
|
assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]);
|
||||||
assert_eq!(allocation4.prefix_len, 6);
|
assert_eq!(allocation4.prefix_len, 6);
|
||||||
assert_eq!(cache.free_blocks.len(), 11);
|
assert_eq!(cache.free_blocks.len(), 11);
|
||||||
|
|
||||||
let allocation5 = cache
|
let allocation5 = cache
|
||||||
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
|
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7, 99])))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]);
|
assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]);
|
||||||
assert_eq!(allocation5.prefix_len, 6);
|
assert_eq!(allocation5.prefix_len, 6);
|
||||||
|
Loading…
Reference in New Issue
Block a user