Revert the max prefix hit.

This commit is contained in:
Nicolas Patry 2024-09-07 01:19:16 +02:00
parent c67bec168e
commit 1d0847a90e
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
3 changed files with 43 additions and 58 deletions

View File

@ -168,6 +168,8 @@ pub(crate) async fn batching_task(
None None
} else { } else {
// Minimum batch size // Minimum batch size
// TODO: temporarily disable to avoid incorrect deallocation +
// reallocation when using prefix caching.
Some((batch_size as f32 * waiting_served_ratio).floor() as usize) Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
}; };

View File

@ -70,12 +70,9 @@ impl Allocator for RadixAllocator {
) -> Option<BlockAllocation> { ) -> Option<BlockAllocation> {
let mut blocks = vec![]; let mut blocks = vec![];
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() { let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
let node_id = self.cache_blocks.find( let node_id = self
// XXX This is super important we cannot match an entire prefix .cache_blocks
// otherwise input_ids is empty and the shard code cannot handle that. .find(prefill_tokens.as_slice(), &mut blocks);
&prefill_tokens.as_slice()[..prefill_tokens.len().saturating_sub(1)],
&mut blocks,
);
node_id node_id
} else { } else {
self.cache_blocks.root_id() self.cache_blocks.root_id()
@ -92,6 +89,8 @@ impl Allocator for RadixAllocator {
let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size; let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size;
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");
match self.alloc_or_reclaim(suffix_blocks as usize) { match self.alloc_or_reclaim(suffix_blocks as usize) {
Some(suffix_blocks) => blocks.extend(suffix_blocks), Some(suffix_blocks) => blocks.extend(suffix_blocks),
None => { None => {
@ -149,9 +148,7 @@ impl Allocator for RadixAllocator {
.expect("Failed to decrement refcount"); .expect("Failed to decrement refcount");
if let Some(prefill_tokens) = allocation.prefill_tokens { if let Some(prefill_tokens) = allocation.prefill_tokens {
// XXX We matched everything except the last token let prefill_tokens = prefill_tokens.as_slice();
let prefill_tokens =
&prefill_tokens.as_slice()[..prefill_tokens.len().saturating_sub(1)];
// If there are prefill tokens that did not come from the cache, // If there are prefill tokens that did not come from the cache,
// add them to the cache. // add them to the cache.
@ -613,17 +610,13 @@ mod tests {
#[test] #[test]
fn allocator_block_size() { fn allocator_block_size() {
let mut cache = RadixAllocator::new(2, 12, None); let mut cache = RadixAllocator::new(2, 12, None);
let allocation = cache let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 99])))
.unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]); assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
assert_eq!(allocation.prefix_len, 0); assert_eq!(allocation.prefix_len, 0);
cache.free(allocation.blocks.clone(), allocation.allocation_id); cache.free(allocation.blocks.clone(), allocation.allocation_id);
let allocation = cache let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 99])))
.unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]); assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
assert_eq!(allocation.prefix_len, 4); assert_eq!(allocation.prefix_len, 4);
@ -632,17 +625,13 @@ mod tests {
#[test] #[test]
fn allocator_block_size_non_aligned() { fn allocator_block_size_non_aligned() {
let mut cache = RadixAllocator::new(2, 12, None); let mut cache = RadixAllocator::new(2, 12, None);
let allocation = cache let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
.allocate(7, Some(Arc::new(vec![0, 1, 2, 99])))
.unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]); assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
assert_eq!(allocation.prefix_len, 0); assert_eq!(allocation.prefix_len, 0);
cache.free(allocation.blocks.clone(), allocation.allocation_id); cache.free(allocation.blocks.clone(), allocation.allocation_id);
let allocation = cache let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
.allocate(7, Some(Arc::new(vec![0, 1, 2, 99])))
.unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]); assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
assert_eq!(allocation.prefix_len, 2); assert_eq!(allocation.prefix_len, 2);
@ -651,17 +640,13 @@ mod tests {
#[test] #[test]
fn allocator_reuses_prefixes() { fn allocator_reuses_prefixes() {
let mut cache = RadixAllocator::new(1, 12, None); let mut cache = RadixAllocator::new(1, 12, None);
let allocation = cache let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 99])))
.unwrap();
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
assert_eq!(allocation.blocks, allocation.slots); assert_eq!(allocation.blocks, allocation.slots);
assert_eq!(allocation.prefix_len, 0); assert_eq!(allocation.prefix_len, 0);
cache.free(allocation.blocks.clone(), allocation.allocation_id); cache.free(allocation.blocks.clone(), allocation.allocation_id);
let allocation = cache let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 99])))
.unwrap();
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
assert_eq!(allocation.prefix_len, 4); assert_eq!(allocation.prefix_len, 4);
} }
@ -669,13 +654,11 @@ mod tests {
#[test] #[test]
fn allocator_collects_older_prefixes_first() { fn allocator_collects_older_prefixes_first() {
let mut cache = RadixAllocator::new(1, 7, None); let mut cache = RadixAllocator::new(1, 7, None);
let allocation1 = cache let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
.allocate(4, Some(Arc::new(vec![0, 1, 2, 3, 99])))
.unwrap();
assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]); assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]);
assert_eq!(allocation1.prefix_len, 0); assert_eq!(allocation1.prefix_len, 0);
let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5, 99]))).unwrap(); let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap();
assert_eq!(allocation2.blocks, vec![1, 2]); assert_eq!(allocation2.blocks, vec![1, 2]);
assert_eq!(allocation2.prefix_len, 0); assert_eq!(allocation2.prefix_len, 0);
@ -683,9 +666,7 @@ mod tests {
cache.free(allocation2.blocks.clone(), allocation2.allocation_id); cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
// We should get the blocks of the first allocation, since they are more recent. // We should get the blocks of the first allocation, since they are more recent.
let allocation3 = cache let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap();
.allocate(4, Some(Arc::new(vec![6, 7, 8, 9, 99])))
.unwrap();
assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]); assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]);
assert_eq!(allocation3.prefix_len, 0); assert_eq!(allocation3.prefix_len, 0);
} }
@ -693,19 +674,13 @@ mod tests {
#[test] #[test]
fn allocator_frees_fully_overlapping_prefills() { fn allocator_frees_fully_overlapping_prefills() {
let mut cache = RadixAllocator::new(1, 10, None); let mut cache = RadixAllocator::new(1, 10, None);
let allocation1 = cache let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
.allocate(4, Some(Arc::new(vec![0, 1, 2, 3, 99]))) let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
.unwrap();
let allocation2 = cache
.allocate(4, Some(Arc::new(vec![0, 1, 2, 3, 99])))
.unwrap();
cache.free(allocation2.blocks.clone(), allocation2.allocation_id); cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
cache.free(allocation1.blocks.clone(), allocation1.allocation_id); cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
let allocation3 = cache let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
.allocate(4, Some(Arc::new(vec![0, 1, 2, 3, 99])))
.unwrap();
assert_eq!(allocation3.prefix_len, 4); assert_eq!(allocation3.prefix_len, 4);
// 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks. // 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks.
@ -715,20 +690,20 @@ mod tests {
#[test] #[test]
fn allocator_frees_partially_overlapping_prefills() { fn allocator_frees_partially_overlapping_prefills() {
let mut cache = RadixAllocator::new(1, 20, None); let mut cache = RadixAllocator::new(1, 20, None);
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 99]))).unwrap(); let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap();
assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]); assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]);
assert_eq!(allocation1.prefix_len, 0); assert_eq!(allocation1.prefix_len, 0);
cache.free(allocation1.blocks.clone(), allocation1.allocation_id); cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
let allocation2 = cache let allocation2 = cache
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5, 99]))) .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
.unwrap(); .unwrap();
assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]); assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]);
assert_eq!(allocation2.prefix_len, 2); assert_eq!(allocation2.prefix_len, 2);
let allocation3 = cache let allocation3 = cache
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7, 99]))) .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
.unwrap(); .unwrap();
assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]); assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]);
assert_eq!(allocation3.prefix_len, 2); assert_eq!(allocation3.prefix_len, 2);
@ -740,14 +715,14 @@ mod tests {
assert_eq!(cache.free_blocks.len(), 11); assert_eq!(cache.free_blocks.len(), 11);
let allocation4 = cache let allocation4 = cache
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5, 99]))) .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
.unwrap(); .unwrap();
assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]); assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]);
assert_eq!(allocation4.prefix_len, 6); assert_eq!(allocation4.prefix_len, 6);
assert_eq!(cache.free_blocks.len(), 11); assert_eq!(cache.free_blocks.len(), 11);
let allocation5 = cache let allocation5 = cache
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7, 99]))) .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
.unwrap(); .unwrap();
assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]); assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]);
assert_eq!(allocation5.prefix_len, 6); assert_eq!(allocation5.prefix_len, 6);

View File

@ -268,6 +268,9 @@ class FlashCausalLMBatch(Batch):
assert ( assert (
prefix_len <= orig_input_length prefix_len <= orig_input_length
), f"Prefix {prefix_len} vs input {orig_input_length}" ), f"Prefix {prefix_len} vs input {orig_input_length}"
if prefix_len == orig_input_length:
assert prefix_len > 0
prefix_len -= 1
prefix_ids.append(tokenized_input[:prefix_len]) prefix_ids.append(tokenized_input[:prefix_len])
tokenized_input = tokenized_input[prefix_len:] tokenized_input = tokenized_input[prefix_len:]
@ -1157,13 +1160,6 @@ class FlashCausalLM(Model):
"input_lengths": input_lengths_tensor, "input_lengths": input_lengths_tensor,
"prefix_lengths": prefix_lengths_tensor, "prefix_lengths": prefix_lengths_tensor,
} }
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph self.cuda_graphs[bs]["graph"] = graph
@ -1199,6 +1195,13 @@ class FlashCausalLM(Model):
prefix_lens=prefix_lengths, prefix_lens=prefix_lengths,
prefix_lens_tensor=prefix_lengths_tensor, prefix_lens_tensor=prefix_lengths_tensor,
): ):
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
self.model.forward( self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -1215,6 +1218,13 @@ class FlashCausalLM(Model):
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL): with torch.cuda.graph(graph, pool=MEM_POOL):
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -1517,9 +1527,7 @@ class FlashCausalLM(Model):
cuda_graph["slots"].fill_(-1) cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
input_lengths + prefix_lens_tensor
)
cuda_graph["prefix_lengths"].zero_() cuda_graph["prefix_lengths"].zero_()
cuda_graph["prefix_lengths"][: prefix_lens_tensor.shape[0]] = prefix_lens_tensor cuda_graph["prefix_lengths"][: prefix_lens_tensor.shape[0]] = prefix_lens_tensor